test_websockets.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. #!/usr/bin/env python3
  2. # Allow direct execution
  3. import os
  4. import sys
  5. import time
  6. import pytest
  7. from test.helper import verify_address_availability
  8. from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT
  9. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. import http.client
  11. import http.cookiejar
  12. import http.server
  13. import json
  14. import random
  15. import ssl
  16. import threading
  17. from yt_dlp import socks, traverse_obj
  18. from yt_dlp.cookies import YoutubeDLCookieJar
  19. from yt_dlp.dependencies import websockets
  20. from yt_dlp.networking import Request
  21. from yt_dlp.networking.exceptions import (
  22. CertificateVerifyError,
  23. HTTPError,
  24. ProxyError,
  25. RequestError,
  26. SSLError,
  27. TransportError,
  28. )
  29. from yt_dlp.utils.networking import HTTPHeaderDict
  30. TEST_DIR = os.path.dirname(os.path.abspath(__file__))
  31. def websocket_handler(websocket):
  32. for message in websocket:
  33. if isinstance(message, bytes):
  34. if message == b'bytes':
  35. return websocket.send('2')
  36. elif isinstance(message, str):
  37. if message == 'headers':
  38. return websocket.send(json.dumps(dict(websocket.request.headers)))
  39. elif message == 'path':
  40. return websocket.send(websocket.request.path)
  41. elif message == 'source_address':
  42. return websocket.send(websocket.remote_address[0])
  43. elif message == 'str':
  44. return websocket.send('1')
  45. return websocket.send(message)
  46. def process_request(self, request):
  47. if request.path.startswith('/gen_'):
  48. status = http.HTTPStatus(int(request.path[5:]))
  49. if 300 <= status.value <= 300:
  50. return websockets.http11.Response(
  51. status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
  52. return self.protocol.reject(status.value, status.phrase)
  53. return self.protocol.accept(request)
  54. def create_websocket_server(**ws_kwargs):
  55. import websockets.sync.server
  56. wsd = websockets.sync.server.serve(
  57. websocket_handler, '127.0.0.1', 0,
  58. process_request=process_request, open_timeout=2, **ws_kwargs)
  59. ws_port = wsd.socket.getsockname()[1]
  60. ws_server_thread = threading.Thread(target=wsd.serve_forever)
  61. ws_server_thread.daemon = True
  62. ws_server_thread.start()
  63. return ws_server_thread, ws_port
  64. def create_ws_websocket_server():
  65. return create_websocket_server()
  66. def create_wss_websocket_server():
  67. certfn = os.path.join(TEST_DIR, 'testcert.pem')
  68. sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  69. sslctx.load_cert_chain(certfn, None)
  70. return create_websocket_server(ssl_context=sslctx)
  71. MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
  72. def create_mtls_wss_websocket_server():
  73. certfn = os.path.join(TEST_DIR, 'testcert.pem')
  74. cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
  75. sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  76. sslctx.verify_mode = ssl.CERT_REQUIRED
  77. sslctx.load_verify_locations(cafile=cacertfn)
  78. sslctx.load_cert_chain(certfn, None)
  79. return create_websocket_server(ssl_context=sslctx)
  80. def ws_validate_and_send(rh, req):
  81. rh.validate(req)
  82. max_tries = 3
  83. for i in range(max_tries):
  84. try:
  85. return rh.send(req)
  86. except TransportError as e:
  87. if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
  88. # websockets server sometimes hangs on new connections
  89. continue
  90. raise
  91. @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
  92. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  93. class TestWebsSocketRequestHandlerConformance:
  94. @classmethod
  95. def setup_class(cls):
  96. cls.ws_thread, cls.ws_port = create_ws_websocket_server()
  97. cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
  98. cls.wss_thread, cls.wss_port = create_wss_websocket_server()
  99. cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
  100. cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
  101. cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
  102. cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
  103. cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
  104. def test_basic_websockets(self, handler):
  105. with handler() as rh:
  106. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  107. assert 'upgrade' in ws.headers
  108. assert ws.status == 101
  109. ws.send('foo')
  110. assert ws.recv() == 'foo'
  111. ws.close()
  112. # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
  113. @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
  114. def test_send_types(self, handler, msg, opcode):
  115. with handler() as rh:
  116. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  117. ws.send(msg)
  118. assert int(ws.recv()) == opcode
  119. ws.close()
  120. def test_verify_cert(self, handler):
  121. with handler() as rh:
  122. with pytest.raises(CertificateVerifyError):
  123. ws_validate_and_send(rh, Request(self.wss_base_url))
  124. with handler(verify=False) as rh:
  125. ws = ws_validate_and_send(rh, Request(self.wss_base_url))
  126. assert ws.status == 101
  127. ws.close()
  128. def test_ssl_error(self, handler):
  129. with handler(verify=False) as rh:
  130. with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
  131. ws_validate_and_send(rh, Request(self.bad_wss_host))
  132. assert not issubclass(exc_info.type, CertificateVerifyError)
  133. @pytest.mark.parametrize('path,expected', [
  134. # Unicode characters should be encoded with uppercase percent-encoding
  135. ('/中文', '/%E4%B8%AD%E6%96%87'),
  136. # don't normalize existing percent encodings
  137. ('/%c7%9f', '/%c7%9f'),
  138. ])
  139. def test_percent_encode(self, handler, path, expected):
  140. with handler() as rh:
  141. ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
  142. ws.send('path')
  143. assert ws.recv() == expected
  144. assert ws.status == 101
  145. ws.close()
  146. def test_remove_dot_segments(self, handler):
  147. with handler() as rh:
  148. # This isn't a comprehensive test,
  149. # but it should be enough to check whether the handler is removing dot segments
  150. ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
  151. assert ws.status == 101
  152. ws.send('path')
  153. assert ws.recv() == '/test'
  154. ws.close()
  155. # We are restricted to known HTTP status codes in http.HTTPStatus
  156. # Redirects are not supported for websockets
  157. @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
  158. def test_raise_http_error(self, handler, status):
  159. with handler() as rh:
  160. with pytest.raises(HTTPError) as exc_info:
  161. ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
  162. assert exc_info.value.status == status
  163. @pytest.mark.parametrize('params,extensions', [
  164. ({'timeout': sys.float_info.min}, {}),
  165. ({}, {'timeout': sys.float_info.min}),
  166. ])
  167. def test_read_timeout(self, handler, params, extensions):
  168. with handler(**params) as rh:
  169. with pytest.raises(TransportError):
  170. ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
  171. def test_connect_timeout(self, handler):
  172. # nothing should be listening on this port
  173. connect_timeout_url = 'ws://10.255.255.255'
  174. with handler(timeout=0.01) as rh, pytest.raises(TransportError):
  175. now = time.time()
  176. ws_validate_and_send(rh, Request(connect_timeout_url))
  177. assert time.time() - now < DEFAULT_TIMEOUT
  178. # Per request timeout, should override handler timeout
  179. request = Request(connect_timeout_url, extensions={'timeout': 0.01})
  180. with handler() as rh, pytest.raises(TransportError):
  181. now = time.time()
  182. ws_validate_and_send(rh, request)
  183. assert time.time() - now < DEFAULT_TIMEOUT
  184. def test_cookies(self, handler):
  185. cookiejar = YoutubeDLCookieJar()
  186. cookiejar.set_cookie(http.cookiejar.Cookie(
  187. version=0, name='test', value='ytdlp', port=None, port_specified=False,
  188. domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
  189. path_specified=True, secure=False, expires=None, discard=False, comment=None,
  190. comment_url=None, rest={}))
  191. with handler(cookiejar=cookiejar) as rh:
  192. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  193. ws.send('headers')
  194. assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
  195. ws.close()
  196. with handler() as rh:
  197. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  198. ws.send('headers')
  199. assert 'cookie' not in json.loads(ws.recv())
  200. ws.close()
  201. ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
  202. ws.send('headers')
  203. assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
  204. ws.close()
  205. def test_source_address(self, handler):
  206. source_address = f'127.0.0.{random.randint(5, 255)}'
  207. verify_address_availability(source_address)
  208. with handler(source_address=source_address) as rh:
  209. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  210. ws.send('source_address')
  211. assert source_address == ws.recv()
  212. ws.close()
  213. def test_response_url(self, handler):
  214. with handler() as rh:
  215. url = f'{self.ws_base_url}/something'
  216. ws = ws_validate_and_send(rh, Request(url))
  217. assert ws.url == url
  218. ws.close()
  219. def test_request_headers(self, handler):
  220. with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
  221. # Global Headers
  222. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  223. ws.send('headers')
  224. headers = HTTPHeaderDict(json.loads(ws.recv()))
  225. assert headers['test1'] == 'test'
  226. ws.close()
  227. # Per request headers, merged with global
  228. ws = ws_validate_and_send(rh, Request(
  229. self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
  230. ws.send('headers')
  231. headers = HTTPHeaderDict(json.loads(ws.recv()))
  232. assert headers['test1'] == 'test'
  233. assert headers['test2'] == 'changed'
  234. assert headers['test3'] == 'test3'
  235. ws.close()
  236. @pytest.mark.parametrize('client_cert', (
  237. {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
  238. {
  239. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
  240. 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
  241. },
  242. {
  243. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
  244. 'client_certificate_password': 'foobar',
  245. },
  246. {
  247. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
  248. 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
  249. 'client_certificate_password': 'foobar',
  250. },
  251. ))
  252. def test_mtls(self, handler, client_cert):
  253. with handler(
  254. # Disable client-side validation of unacceptable self-signed testcert.pem
  255. # The test is of a check on the server side, so unaffected
  256. verify=False,
  257. client_cert=client_cert,
  258. ) as rh:
  259. ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
  260. def test_request_disable_proxy(self, handler):
  261. for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
  262. # Given handler is configured with a proxy
  263. with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
  264. # When a proxy is explicitly set to None for the request
  265. ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
  266. # Then no proxy should be used
  267. assert ws.status == 101
  268. ws.close()
  269. @pytest.mark.skip_handlers_if(
  270. lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
  271. def test_noproxy(self, handler):
  272. for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
  273. # Given the handler is configured with a proxy
  274. with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
  275. for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
  276. # When request no proxy includes the request url host
  277. ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
  278. # Then the proxy should not be used
  279. assert ws.status == 101
  280. ws.close()
  281. @pytest.mark.skip_handlers_if(
  282. lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
  283. def test_allproxy(self, handler):
  284. supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
  285. # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
  286. # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
  287. with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
  288. with pytest.raises(TransportError):
  289. ws_validate_and_send(rh, Request(self.ws_base_url)).close()
  290. with handler(timeout=0.1) as rh:
  291. with pytest.raises(TransportError):
  292. ws_validate_and_send(
  293. rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
  294. def create_fake_ws_connection(raised):
  295. import websockets.sync.client
  296. class FakeWsConnection(websockets.sync.client.ClientConnection):
  297. def __init__(self, *args, **kwargs):
  298. class FakeResponse:
  299. body = b''
  300. headers = {}
  301. status_code = 101
  302. reason_phrase = 'test'
  303. self.response = FakeResponse()
  304. def send(self, *args, **kwargs):
  305. raise raised()
  306. def recv(self, *args, **kwargs):
  307. raise raised()
  308. def close(self, *args, **kwargs):
  309. return
  310. return FakeWsConnection()
  311. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  312. class TestWebsocketsRequestHandler:
  313. @pytest.mark.parametrize('raised,expected', [
  314. # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
  315. (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
  316. # Requires a response object. Should be covered by HTTP error tests.
  317. # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
  318. (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
  319. # These are subclasses of InvalidHandshake
  320. (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
  321. (lambda: websockets.exceptions.NegotiationError(), TransportError),
  322. # Catch-all
  323. (lambda: websockets.exceptions.WebSocketException(), TransportError),
  324. (lambda: TimeoutError(), TransportError),
  325. # These may be raised by our create_connection implementation, which should also be caught
  326. (lambda: OSError(), TransportError),
  327. (lambda: ssl.SSLError(), SSLError),
  328. (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
  329. (lambda: socks.ProxyError(), ProxyError),
  330. ])
  331. def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
  332. import websockets.sync.client
  333. import yt_dlp.networking._websockets
  334. with handler() as rh:
  335. def fake_connect(*args, **kwargs):
  336. raise raised()
  337. monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
  338. monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
  339. with pytest.raises(expected) as exc_info:
  340. rh.send(Request('ws://fake-url'))
  341. assert exc_info.type is expected
  342. @pytest.mark.parametrize('raised,expected,match', [
  343. # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
  344. (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
  345. (lambda: RuntimeError(), TransportError, None),
  346. (lambda: TimeoutError(), TransportError, None),
  347. (lambda: TypeError(), RequestError, None),
  348. (lambda: socks.ProxyError(), ProxyError, None),
  349. # Catch-all
  350. (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
  351. ])
  352. def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
  353. from yt_dlp.networking._websockets import WebsocketsResponseAdapter
  354. ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
  355. with pytest.raises(expected, match=match) as exc_info:
  356. ws.send('test')
  357. assert exc_info.type is expected
  358. @pytest.mark.parametrize('raised,expected,match', [
  359. # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
  360. (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
  361. (lambda: RuntimeError(), TransportError, None),
  362. (lambda: TimeoutError(), TransportError, None),
  363. (lambda: socks.ProxyError(), ProxyError, None),
  364. # Catch-all
  365. (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
  366. ])
  367. def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
  368. from yt_dlp.networking._websockets import WebsocketsResponseAdapter
  369. ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
  370. with pytest.raises(expected, match=match) as exc_info:
  371. ws.recv()
  372. assert exc_info.type is expected