Browse Source

[core] Improve HTTP redirect handling (#7094)

Aligns HTTP redirect handling with what browsers commonly do and RFC standards. 

Fixes issues https://github.com/yt-dlp/yt-dlp/commit/afac4caa7db30804bebac33e53c3cb0237958224 missed.

Authored by: coletdjnz
coletdjnz 1 year ago
parent
commit
08916a49c7
3 changed files with 283 additions and 74 deletions
  1. 0 6
      test/test_YoutubeDL.py
  2. 262 30
      test/test_http.py
  3. 21 38
      yt_dlp/utils/_utils.py

+ 0 - 6
test/test_YoutubeDL.py

@@ -10,7 +10,6 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 import copy
 import json
-import urllib.error
 
 from test.helper import FakeYDL, assertRegexpMatches
 from yt_dlp import YoutubeDL
@@ -1097,11 +1096,6 @@ class TestYoutubeDL(unittest.TestCase):
         test_selection({'playlist_items': '-15::2'}, INDICES[1::2], True)
         test_selection({'playlist_items': '-15::15'}, [], True)
 
-    def test_urlopen_no_file_protocol(self):
-        # see https://github.com/ytdl-org/youtube-dl/issues/8227
-        ydl = YDL()
-        self.assertRaises(urllib.error.URLError, ydl.urlopen, 'file:///etc/passwd')
-
     def test_do_not_override_ie_key_in_url_transparent(self):
         ydl = YDL()
 

+ 262 - 30
test/test_http.py

@@ -7,40 +7,163 @@ import unittest
 
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
-
+import gzip
+import http.cookiejar
 import http.server
+import io
+import pathlib
 import ssl
+import tempfile
 import threading
+import urllib.error
 import urllib.request
 
 from test.helper import http_server_port
 from yt_dlp import YoutubeDL
+from yt_dlp.utils import sanitized_Request, urlencode_postdata
+
+from .helper import FakeYDL
 
 TEST_DIR = os.path.dirname(os.path.abspath(__file__))
 
 
 class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
+    protocol_version = 'HTTP/1.1'
+
     def log_message(self, format, *args):
         pass
 
+    def _headers(self):
+        payload = str(self.headers).encode('utf-8')
+        self.send_response(200)
+        self.send_header('Content-Type', 'application/json')
+        self.send_header('Content-Length', str(len(payload)))
+        self.end_headers()
+        self.wfile.write(payload)
+
+    def _redirect(self):
+        self.send_response(int(self.path[len('/redirect_'):]))
+        self.send_header('Location', '/method')
+        self.send_header('Content-Length', '0')
+        self.end_headers()
+
+    def _method(self, method, payload=None):
+        self.send_response(200)
+        self.send_header('Content-Length', str(len(payload or '')))
+        self.send_header('Method', method)
+        self.end_headers()
+        if payload:
+            self.wfile.write(payload)
+
+    def _status(self, status):
+        payload = f'<html>{status} NOT FOUND</html>'.encode()
+        self.send_response(int(status))
+        self.send_header('Content-Type', 'text/html; charset=utf-8')
+        self.send_header('Content-Length', str(len(payload)))
+        self.end_headers()
+        self.wfile.write(payload)
+
+    def _read_data(self):
+        if 'Content-Length' in self.headers:
+            return self.rfile.read(int(self.headers['Content-Length']))
+
+    def do_POST(self):
+        data = self._read_data()
+        if self.path.startswith('/redirect_'):
+            self._redirect()
+        elif self.path.startswith('/method'):
+            self._method('POST', data)
+        elif self.path.startswith('/headers'):
+            self._headers()
+        else:
+            self._status(404)
+
+    def do_HEAD(self):
+        if self.path.startswith('/redirect_'):
+            self._redirect()
+        elif self.path.startswith('/method'):
+            self._method('HEAD')
+        else:
+            self._status(404)
+
+    def do_PUT(self):
+        data = self._read_data()
+        if self.path.startswith('/redirect_'):
+            self._redirect()
+        elif self.path.startswith('/method'):
+            self._method('PUT', data)
+        else:
+            self._status(404)
+
     def do_GET(self):
         if self.path == '/video.html':
+            payload = b'<html><video src="/vid.mp4" /></html>'
             self.send_response(200)
             self.send_header('Content-Type', 'text/html; charset=utf-8')
+            self.send_header('Content-Length', str(len(payload)))  # required for persistent connections
             self.end_headers()
-            self.wfile.write(b'<html><video src="/vid.mp4" /></html>')
+            self.wfile.write(payload)
         elif self.path == '/vid.mp4':
+            payload = b'\x00\x00\x00\x00\x20\x66\x74[video]'
             self.send_response(200)
             self.send_header('Content-Type', 'video/mp4')
+            self.send_header('Content-Length', str(len(payload)))
             self.end_headers()
-            self.wfile.write(b'\x00\x00\x00\x00\x20\x66\x74[video]')
+            self.wfile.write(payload)
         elif self.path == '/%E4%B8%AD%E6%96%87.html':
+            payload = b'<html><video src="/vid.mp4" /></html>'
             self.send_response(200)
             self.send_header('Content-Type', 'text/html; charset=utf-8')
+            self.send_header('Content-Length', str(len(payload)))
+            self.end_headers()
+            self.wfile.write(payload)
+        elif self.path == '/%c7%9f':
+            payload = b'<html><video src="/vid.mp4" /></html>'
+            self.send_response(200)
+            self.send_header('Content-Type', 'text/html; charset=utf-8')
+            self.send_header('Content-Length', str(len(payload)))
+            self.end_headers()
+            self.wfile.write(payload)
+        elif self.path.startswith('/redirect_'):
+            self._redirect()
+        elif self.path.startswith('/method'):
+            self._method('GET')
+        elif self.path.startswith('/headers'):
+            self._headers()
+        elif self.path == '/trailing_garbage':
+            payload = b'<html><video src="/vid.mp4" /></html>'
+            self.send_response(200)
+            self.send_header('Content-Type', 'text/html; charset=utf-8')
+            self.send_header('Content-Encoding', 'gzip')
+            buf = io.BytesIO()
+            with gzip.GzipFile(fileobj=buf, mode='wb') as f:
+                f.write(payload)
+            compressed = buf.getvalue() + b'trailing garbage'
+            self.send_header('Content-Length', str(len(compressed)))
+            self.end_headers()
+            self.wfile.write(compressed)
+        elif self.path == '/302-non-ascii-redirect':
+            new_url = f'http://127.0.0.1:{http_server_port(self.server)}/中文.html'
+            self.send_response(301)
+            self.send_header('Location', new_url)
+            self.send_header('Content-Length', '0')
             self.end_headers()
-            self.wfile.write(b'<html><video src="/vid.mp4" /></html>')
         else:
-            assert False
+            self._status(404)
+
+    def send_header(self, keyword, value):
+        """
+        Forcibly allow HTTP server to send non percent-encoded non-ASCII characters in headers.
+        This is against what is defined in RFC 3986, however we need to test we support this
+        since some sites incorrectly do this.
+        """
+        if keyword.lower() == 'connection':
+            return super().send_header(keyword, value)
+
+        if not hasattr(self, '_headers_buffer'):
+            self._headers_buffer = []
+
+        self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
 
 
 class FakeLogger:
@@ -56,36 +179,128 @@ class FakeLogger:
 
 class TestHTTP(unittest.TestCase):
     def setUp(self):
-        self.httpd = http.server.HTTPServer(
+        # HTTP server
+        self.http_httpd = http.server.ThreadingHTTPServer(
             ('127.0.0.1', 0), HTTPTestRequestHandler)
-        self.port = http_server_port(self.httpd)
-        self.server_thread = threading.Thread(target=self.httpd.serve_forever)
-        self.server_thread.daemon = True
-        self.server_thread.start()
-
-
-class TestHTTPS(unittest.TestCase):
-    def setUp(self):
+        self.http_port = http_server_port(self.http_httpd)
+        self.http_server_thread = threading.Thread(target=self.http_httpd.serve_forever)
+        # FIXME: we should probably stop the http server thread after each test
+        # See: https://github.com/yt-dlp/yt-dlp/pull/7094#discussion_r1199746041
+        self.http_server_thread.daemon = True
+        self.http_server_thread.start()
+
+        # HTTPS server
         certfn = os.path.join(TEST_DIR, 'testcert.pem')
-        self.httpd = http.server.HTTPServer(
+        self.https_httpd = http.server.ThreadingHTTPServer(
             ('127.0.0.1', 0), HTTPTestRequestHandler)
         sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
         sslctx.load_cert_chain(certfn, None)
-        self.httpd.socket = sslctx.wrap_socket(self.httpd.socket, server_side=True)
-        self.port = http_server_port(self.httpd)
-        self.server_thread = threading.Thread(target=self.httpd.serve_forever)
-        self.server_thread.daemon = True
-        self.server_thread.start()
+        self.https_httpd.socket = sslctx.wrap_socket(self.https_httpd.socket, server_side=True)
+        self.https_port = http_server_port(self.https_httpd)
+        self.https_server_thread = threading.Thread(target=self.https_httpd.serve_forever)
+        self.https_server_thread.daemon = True
+        self.https_server_thread.start()
 
     def test_nocheckcertificate(self):
-        ydl = YoutubeDL({'logger': FakeLogger()})
-        self.assertRaises(
-            Exception,
-            ydl.extract_info, 'https://127.0.0.1:%d/video.html' % self.port)
-
-        ydl = YoutubeDL({'logger': FakeLogger(), 'nocheckcertificate': True})
-        r = ydl.extract_info('https://127.0.0.1:%d/video.html' % self.port)
-        self.assertEqual(r['url'], 'https://127.0.0.1:%d/vid.mp4' % self.port)
+        with FakeYDL({'logger': FakeLogger()}) as ydl:
+            with self.assertRaises(urllib.error.URLError):
+                ydl.urlopen(sanitized_Request(f'https://127.0.0.1:{self.https_port}/headers'))
+
+        with FakeYDL({'logger': FakeLogger(), 'nocheckcertificate': True}) as ydl:
+            r = ydl.urlopen(sanitized_Request(f'https://127.0.0.1:{self.https_port}/headers'))
+            self.assertEqual(r.status, 200)
+            r.close()
+
+    def test_percent_encode(self):
+        with FakeYDL() as ydl:
+            # Unicode characters should be encoded with uppercase percent-encoding
+            res = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/中文.html'))
+            self.assertEqual(res.status, 200)
+            res.close()
+            # don't normalize existing percent encodings
+            res = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/%c7%9f'))
+            self.assertEqual(res.status, 200)
+            res.close()
+
+    def test_unicode_path_redirection(self):
+        with FakeYDL() as ydl:
+            r = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/302-non-ascii-redirect'))
+            self.assertEqual(r.url, f'http://127.0.0.1:{self.http_port}/%E4%B8%AD%E6%96%87.html')
+            r.close()
+
+    def test_redirect(self):
+        with FakeYDL() as ydl:
+            def do_req(redirect_status, method):
+                data = b'testdata' if method in ('POST', 'PUT') else None
+                res = ydl.urlopen(sanitized_Request(
+                    f'http://127.0.0.1:{self.http_port}/redirect_{redirect_status}', method=method, data=data))
+                return res.read().decode('utf-8'), res.headers.get('method', '')
+
+            # A 303 must either use GET or HEAD for subsequent request
+            self.assertEqual(do_req(303, 'POST'), ('', 'GET'))
+            self.assertEqual(do_req(303, 'HEAD'), ('', 'HEAD'))
+
+            self.assertEqual(do_req(303, 'PUT'), ('', 'GET'))
+
+            # 301 and 302 turn POST only into a GET
+            self.assertEqual(do_req(301, 'POST'), ('', 'GET'))
+            self.assertEqual(do_req(301, 'HEAD'), ('', 'HEAD'))
+            self.assertEqual(do_req(302, 'POST'), ('', 'GET'))
+            self.assertEqual(do_req(302, 'HEAD'), ('', 'HEAD'))
+
+            self.assertEqual(do_req(301, 'PUT'), ('testdata', 'PUT'))
+            self.assertEqual(do_req(302, 'PUT'), ('testdata', 'PUT'))
+
+            # 307 and 308 should not change method
+            for m in ('POST', 'PUT'):
+                self.assertEqual(do_req(307, m), ('testdata', m))
+                self.assertEqual(do_req(308, m), ('testdata', m))
+
+            self.assertEqual(do_req(307, 'HEAD'), ('', 'HEAD'))
+            self.assertEqual(do_req(308, 'HEAD'), ('', 'HEAD'))
+
+            # These should not redirect and instead raise an HTTPError
+            for code in (300, 304, 305, 306):
+                with self.assertRaises(urllib.error.HTTPError):
+                    do_req(code, 'GET')
+
+    def test_content_type(self):
+        # https://github.com/yt-dlp/yt-dlp/commit/379a4f161d4ad3e40932dcf5aca6e6fb9715ab28
+        with FakeYDL({'nocheckcertificate': True}) as ydl:
+            # method should be auto-detected as POST
+            r = sanitized_Request(f'https://localhost:{self.https_port}/headers', data=urlencode_postdata({'test': 'test'}))
+
+            headers = ydl.urlopen(r).read().decode('utf-8')
+            self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
+
+            # test http
+            r = sanitized_Request(f'http://localhost:{self.http_port}/headers', data=urlencode_postdata({'test': 'test'}))
+            headers = ydl.urlopen(r).read().decode('utf-8')
+            self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
+
+    def test_cookiejar(self):
+        with FakeYDL() as ydl:
+            ydl.cookiejar.set_cookie(http.cookiejar.Cookie(
+                0, 'test', 'ytdlp', None, False, '127.0.0.1', True,
+                False, '/headers', True, False, None, False, None, None, {}))
+            data = ydl.urlopen(sanitized_Request(f'http://127.0.0.1:{self.http_port}/headers')).read()
+            self.assertIn(b'Cookie: test=ytdlp', data)
+
+    def test_no_compression_compat_header(self):
+        with FakeYDL() as ydl:
+            data = ydl.urlopen(
+                sanitized_Request(
+                    f'http://127.0.0.1:{self.http_port}/headers',
+                    headers={'Youtubedl-no-compression': True})).read()
+            self.assertIn(b'Accept-Encoding: identity', data)
+            self.assertNotIn(b'youtubedl-no-compression', data.lower())
+
+    def test_gzip_trailing_garbage(self):
+        # https://github.com/ytdl-org/youtube-dl/commit/aa3e950764337ef9800c936f4de89b31c00dfcf5
+        # https://github.com/ytdl-org/youtube-dl/commit/6f2ec15cee79d35dba065677cad9da7491ec6e6f
+        with FakeYDL() as ydl:
+            data = ydl.urlopen(sanitized_Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode('utf-8')
+            self.assertEqual(data, '<html><video src="/vid.mp4" /></html>')
 
 
 class TestClientCert(unittest.TestCase):
@@ -112,8 +327,8 @@ class TestClientCert(unittest.TestCase):
             'nocheckcertificate': True,
             **params,
         })
-        r = ydl.extract_info('https://127.0.0.1:%d/video.html' % self.port)
-        self.assertEqual(r['url'], 'https://127.0.0.1:%d/vid.mp4' % self.port)
+        r = ydl.extract_info(f'https://127.0.0.1:{self.port}/video.html')
+        self.assertEqual(r['url'], f'https://127.0.0.1:{self.port}/vid.mp4')
 
     def test_certificate_combined_nopass(self):
         self._run_test(client_certificate=os.path.join(self.certdir, 'clientwithkey.crt'))
@@ -188,5 +403,22 @@ class TestProxy(unittest.TestCase):
         self.assertEqual(response, 'normal: http://xn--fiq228c.tw/')
 
 
+class TestFileURL(unittest.TestCase):
+    # See https://github.com/ytdl-org/youtube-dl/issues/8227
+    def test_file_urls(self):
+        tf = tempfile.NamedTemporaryFile(delete=False)
+        tf.write(b'foobar')
+        tf.close()
+        url = pathlib.Path(tf.name).as_uri()
+        with FakeYDL() as ydl:
+            self.assertRaisesRegex(
+                urllib.error.URLError, 'file:// URLs are explicitly disabled in yt-dlp for security reasons', ydl.urlopen, url)
+        with FakeYDL({'enable_file_urls': True}) as ydl:
+            res = ydl.urlopen(url)
+            self.assertEqual(res.read(), b'foobar')
+            res.close()
+        os.unlink(tf.name)
+
+
 if __name__ == '__main__':
     unittest.main()

+ 21 - 38
yt_dlp/utils/_utils.py

@@ -1664,61 +1664,44 @@ class YoutubeDLRedirectHandler(urllib.request.HTTPRedirectHandler):
 
     The code is based on HTTPRedirectHandler implementation from CPython [1].
 
-    This redirect handler solves two issues:
-     - ensures redirect URL is always unicode under python 2
-     - introduces support for experimental HTTP response status code
-       308 Permanent Redirect [2] used by some sites [3]
+    This redirect handler fixes and improves the logic to better align with RFC7261
+     and what browsers tend to do [2][3]
 
     1. https://github.com/python/cpython/blob/master/Lib/urllib/request.py
-    2. https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/308
-    3. https://github.com/ytdl-org/youtube-dl/issues/28768
+    2. https://datatracker.ietf.org/doc/html/rfc7231
+    3. https://github.com/python/cpython/issues/91306
     """
 
     http_error_301 = http_error_303 = http_error_307 = http_error_308 = urllib.request.HTTPRedirectHandler.http_error_302
 
     def redirect_request(self, req, fp, code, msg, headers, newurl):
-        """Return a Request or None in response to a redirect.
-
-        This is called by the http_error_30x methods when a
-        redirection response is received.  If a redirection should
-        take place, return a new Request to allow http_error_30x to
-        perform the redirect.  Otherwise, raise HTTPError if no-one
-        else should try to handle this url.  Return None if you can't
-        but another Handler might.
-        """
-        m = req.get_method()
-        if (not (code in (301, 302, 303, 307, 308) and m in ("GET", "HEAD")
-                 or code in (301, 302, 303) and m == "POST")):
+        if code not in (301, 302, 303, 307, 308):
             raise urllib.error.HTTPError(req.full_url, code, msg, headers, fp)
-        # Strictly (according to RFC 2616), 301 or 302 in response to
-        # a POST MUST NOT cause a redirection without confirmation
-        # from the user (of urllib.request, in this case).  In practice,
-        # essentially all clients do redirect in this case, so we do
-        # the same.
-
-        # Be conciliant with URIs containing a space.  This is mainly
-        # redundant with the more complete encoding done in http_error_302(),
-        # but it is kept for compatibility with other callers.
-        newurl = newurl.replace(' ', '%20')
-
-        CONTENT_HEADERS = ("content-length", "content-type")
-        # NB: don't use dict comprehension for python 2.6 compatibility
-        newheaders = {k: v for k, v in req.headers.items() if k.lower() not in CONTENT_HEADERS}
 
+        new_method = req.get_method()
+        new_data = req.data
+        remove_headers = []
         # A 303 must either use GET or HEAD for subsequent request
         # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4
-        if code == 303 and m != 'HEAD':
-            m = 'GET'
+        if code == 303 and req.get_method() != 'HEAD':
+            new_method = 'GET'
         # 301 and 302 redirects are commonly turned into a GET from a POST
         # for subsequent requests by browsers, so we'll do the same.
         # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2
         # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3
-        if code in (301, 302) and m == 'POST':
-            m = 'GET'
+        elif code in (301, 302) and req.get_method() == 'POST':
+            new_method = 'GET'
+
+        # only remove payload if method changed (e.g. POST to GET)
+        if new_method != req.get_method():
+            new_data = None
+            remove_headers.extend(['Content-Length', 'Content-Type'])
+
+        new_headers = {k: v for k, v in req.headers.items() if k.lower() not in remove_headers}
 
         return urllib.request.Request(
-            newurl, headers=newheaders, origin_req_host=req.origin_req_host,
-            unverifiable=True, method=m)
+            newurl, headers=new_headers, origin_req_host=req.origin_req_host,
+            unverifiable=True, method=new_method, data=new_data)
 
 
 def extract_timezone(date_str):