Blob Blame Raw
From 11c32a684b20b92f800d3ffa670dc8773e22bb92 Mon Sep 17 00:00:00 2001
From: Chris Lalancette <clalancette@gmail.com>
Date: Thu, 2 Mar 2017 18:52:06 -0500
Subject: [PATCH] Support file:// URLs again.

The switch to requests broke support for file:// URLs.
Re-implement it here by implementing our own Adapter for
requests which allows file:// to work again.  While in
here, I also fixed another problem having to do with
how things are printed out during the download; this
should look better now.

Signed-off-by: Chris Lalancette <clalancette@gmail.com>
---
 oz/ozutil.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 66 insertions(+), 10 deletions(-)

diff --git a/oz/ozutil.py b/oz/ozutil.py
index 2523ea7..09cd124 100644
--- a/oz/ozutil.py
+++ b/oz/ozutil.py
@@ -30,6 +30,7 @@
 import time
 import select
 import contextlib
+import urllib
 try:
     import configparser
 except ImportError:
@@ -744,6 +745,57 @@ def default_screenshot_dir():
     """
     return os.path.join(default_data_dir(), "screenshots")

+class LocalFileAdapter(requests.adapters.BaseAdapter):
+    @staticmethod
+    def _chkpath(method, path):
+        """Return an HTTP status for the given filesystem path."""
+        if method.lower() in ('put', 'delete'):
+            return 501, "Not Implemented"  # TODO
+        elif method.lower() not in ('get', 'head', 'post'):
+            return 405, "Method Not Allowed"
+        elif os.path.isdir(path):
+            return 400, "Path Not A File"
+        elif not os.path.isfile(path):
+            return 404, "File Not Found"
+        elif not os.access(path, os.R_OK):
+            return 403, "Access Denied"
+        else:
+            return 200, "OK"
+
+    def send(self, req, **kwargs):  # pylint: disable=unused-argument
+        """Return the file specified by the given request
+
+        @type req: C{PreparedRequest}
+        @todo: Should I bother filling `response.headers` and processing
+               If-Modified-Since and friends using `os.stat`?
+        """
+        path = os.path.normcase(os.path.normpath(urllib.url2pathname(req.path_url)))
+        response = requests.Response()
+
+        response.status_code, response.reason = self._chkpath(req.method, path)
+        if response.status_code == 200 and req.method.lower() != 'head':
+            try:
+                response.raw = open(path, 'rb')
+            except (OSError, IOError), err:
+                response.status_code = 500
+                response.reason = str(err)
+
+        if isinstance(req.url, bytes):
+            response.url = req.url.decode('utf-8')
+        else:
+            response.url = req.url
+
+        response.headers['Content-Length'] = os.path.getsize(path)
+        response.headers['Accept-Ranges'] = 'bytes'
+        response.headers['Redirect-URL'] = req.url
+        response.request = req
+        response.connection = self
+
+        return response
+
+    def close(self):
+        pass
+
 def http_get_header(url, redirect=True):
     """
     Function to get the HTTP headers from a URL.  The available headers will be
@@ -755,11 +807,13 @@ def http_get_header(url, redirect=True):
     'Redirect-URL' will always be None in the redirect=True case, and may be
     None in the redirect=True case if no redirects were required.
     """
-    with contextlib.closing(requests.post(url, allow_redirects=redirect, stream=True, timeout=10)) as r:
-        info = r.headers
-        info['HTTP-Code'] = r.status_code
+    with requests.Session() as requests_session:
+        requests_session.mount('file://', LocalFileAdapter())
+        response = requests_session.post(url, allow_redirects=redirect, stream=True, timeout=10)
+        info = response.headers
+        info['HTTP-Code'] = response.status_code
         if not redirect:
-            info['Redirect-URL'] = r.headers.get('Location')
+            info['Redirect-URL'] = response.headers.get('Location')
         else:
             info['Redirect-URL'] = None

@@ -769,15 +823,17 @@ def http_download_file(url, fd, show_progress, logger):
     """
     Function to download a file from url to file descriptor fd.
     """
-    with contextlib.closing(requests.get(url, stream=True, allow_redirects=True)) as r:
-        file_size = int(r.headers.get('Content-Length'))
+    with requests.Session() as requests_session:
+        requests_session.mount('file://', LocalFileAdapter())
+        response = requests_session.get(url, stream=True, allow_redirects=True)
+        file_size = int(response.headers.get('Content-Length'))
         chunk_size = 10*1024*1024
-        i = 0
-        for chunk in r.iter_content(chunk_size):
-            i = i + 1
+        done = 0
+        for chunk in response.iter_content(chunk_size):
             write_bytes_to_fd(fd, chunk)
+            done += len(chunk)
             if show_progress:
-                logger.debug("%dkB of %dkB" % ((i * chunk_size) / 1024, file_size / 1024))
+                logger.debug("%dkB of %dkB" % (done / 1024, file_size / 1024))

 def ftp_download_directory(server, username, password, basepath, destination):
     """