test_websockets.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. #!/usr/bin/env python3
  2. # Allow direct execution
  3. import os
  4. import sys
  5. import time
  6. import pytest
  7. from test.helper import verify_address_availability
  8. from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT
  9. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. import http.client
  11. import http.cookiejar
  12. import http.server
  13. import json
  14. import random
  15. import ssl
  16. import threading
  17. from yt_dlp import socks, traverse_obj
  18. from yt_dlp.cookies import YoutubeDLCookieJar
  19. from yt_dlp.dependencies import websockets
  20. from yt_dlp.networking import Request
  21. from yt_dlp.networking.exceptions import (
  22. CertificateVerifyError,
  23. HTTPError,
  24. ProxyError,
  25. RequestError,
  26. SSLError,
  27. TransportError,
  28. )
  29. from yt_dlp.utils.networking import HTTPHeaderDict
  30. TEST_DIR = os.path.dirname(os.path.abspath(__file__))
  31. def websocket_handler(websocket):
  32. for message in websocket:
  33. if isinstance(message, bytes):
  34. if message == b'bytes':
  35. return websocket.send('2')
  36. elif isinstance(message, str):
  37. if message == 'headers':
  38. return websocket.send(json.dumps(dict(websocket.request.headers)))
  39. elif message == 'path':
  40. return websocket.send(websocket.request.path)
  41. elif message == 'source_address':
  42. return websocket.send(websocket.remote_address[0])
  43. elif message == 'str':
  44. return websocket.send('1')
  45. return websocket.send(message)
  46. def process_request(self, request):
  47. if request.path.startswith('/gen_'):
  48. status = http.HTTPStatus(int(request.path[5:]))
  49. if 300 <= status.value <= 300:
  50. return websockets.http11.Response(
  51. status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
  52. return self.protocol.reject(status.value, status.phrase)
  53. elif request.path.startswith('/get_cookie'):
  54. response = self.protocol.accept(request)
  55. response.headers['Set-Cookie'] = 'test=ytdlp'
  56. return response
  57. return self.protocol.accept(request)
  58. def create_websocket_server(**ws_kwargs):
  59. import websockets.sync.server
  60. wsd = websockets.sync.server.serve(
  61. websocket_handler, '127.0.0.1', 0,
  62. process_request=process_request, open_timeout=2, **ws_kwargs)
  63. ws_port = wsd.socket.getsockname()[1]
  64. ws_server_thread = threading.Thread(target=wsd.serve_forever)
  65. ws_server_thread.daemon = True
  66. ws_server_thread.start()
  67. return ws_server_thread, ws_port
  68. def create_ws_websocket_server():
  69. return create_websocket_server()
  70. def create_wss_websocket_server():
  71. certfn = os.path.join(TEST_DIR, 'testcert.pem')
  72. sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  73. sslctx.load_cert_chain(certfn, None)
  74. return create_websocket_server(ssl=sslctx)
  75. MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
  76. def create_mtls_wss_websocket_server():
  77. certfn = os.path.join(TEST_DIR, 'testcert.pem')
  78. cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
  79. sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  80. sslctx.verify_mode = ssl.CERT_REQUIRED
  81. sslctx.load_verify_locations(cafile=cacertfn)
  82. sslctx.load_cert_chain(certfn, None)
  83. return create_websocket_server(ssl=sslctx)
  84. def create_legacy_wss_websocket_server():
  85. certfn = os.path.join(TEST_DIR, 'testcert.pem')
  86. sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  87. sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
  88. sslctx.set_ciphers('SHA1:AESCCM:aDSS:eNULL:aNULL')
  89. sslctx.load_cert_chain(certfn, None)
  90. return create_websocket_server(ssl=sslctx)
  91. def ws_validate_and_send(rh, req):
  92. rh.validate(req)
  93. max_tries = 3
  94. for i in range(max_tries):
  95. try:
  96. return rh.send(req)
  97. except TransportError as e:
  98. if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
  99. # websockets server sometimes hangs on new connections
  100. continue
  101. raise
  102. @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
  103. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  104. class TestWebsSocketRequestHandlerConformance:
  105. @classmethod
  106. def setup_class(cls):
  107. cls.ws_thread, cls.ws_port = create_ws_websocket_server()
  108. cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
  109. cls.wss_thread, cls.wss_port = create_wss_websocket_server()
  110. cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
  111. cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
  112. cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
  113. cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
  114. cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
  115. cls.legacy_wss_thread, cls.legacy_wss_port = create_legacy_wss_websocket_server()
  116. cls.legacy_wss_host = f'wss://127.0.0.1:{cls.legacy_wss_port}'
  117. def test_basic_websockets(self, handler):
  118. with handler() as rh:
  119. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  120. assert 'upgrade' in ws.headers
  121. assert ws.status == 101
  122. ws.send('foo')
  123. assert ws.recv() == 'foo'
  124. ws.close()
  125. # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
  126. @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
  127. def test_send_types(self, handler, msg, opcode):
  128. with handler() as rh:
  129. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  130. ws.send(msg)
  131. assert int(ws.recv()) == opcode
  132. ws.close()
  133. def test_verify_cert(self, handler):
  134. with handler() as rh:
  135. with pytest.raises(CertificateVerifyError):
  136. ws_validate_and_send(rh, Request(self.wss_base_url))
  137. with handler(verify=False) as rh:
  138. ws = ws_validate_and_send(rh, Request(self.wss_base_url))
  139. assert ws.status == 101
  140. ws.close()
  141. def test_ssl_error(self, handler):
  142. with handler(verify=False) as rh:
  143. with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
  144. ws_validate_and_send(rh, Request(self.bad_wss_host))
  145. assert not issubclass(exc_info.type, CertificateVerifyError)
  146. def test_legacy_ssl_extension(self, handler):
  147. with handler(verify=False) as rh:
  148. ws = ws_validate_and_send(rh, Request(self.legacy_wss_host, extensions={'legacy_ssl': True}))
  149. assert ws.status == 101
  150. ws.close()
  151. # Ensure only applies to request extension
  152. with pytest.raises(SSLError):
  153. ws_validate_and_send(rh, Request(self.legacy_wss_host))
  154. def test_legacy_ssl_support(self, handler):
  155. with handler(verify=False, legacy_ssl_support=True) as rh:
  156. ws = ws_validate_and_send(rh, Request(self.legacy_wss_host))
  157. assert ws.status == 101
  158. ws.close()
  159. @pytest.mark.parametrize('path,expected', [
  160. # Unicode characters should be encoded with uppercase percent-encoding
  161. ('/中文', '/%E4%B8%AD%E6%96%87'),
  162. # don't normalize existing percent encodings
  163. ('/%c7%9f', '/%c7%9f'),
  164. ])
  165. def test_percent_encode(self, handler, path, expected):
  166. with handler() as rh:
  167. ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
  168. ws.send('path')
  169. assert ws.recv() == expected
  170. assert ws.status == 101
  171. ws.close()
  172. def test_remove_dot_segments(self, handler):
  173. with handler() as rh:
  174. # This isn't a comprehensive test,
  175. # but it should be enough to check whether the handler is removing dot segments
  176. ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
  177. assert ws.status == 101
  178. ws.send('path')
  179. assert ws.recv() == '/test'
  180. ws.close()
  181. # We are restricted to known HTTP status codes in http.HTTPStatus
  182. # Redirects are not supported for websockets
  183. @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
  184. def test_raise_http_error(self, handler, status):
  185. with handler() as rh:
  186. with pytest.raises(HTTPError) as exc_info:
  187. ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
  188. assert exc_info.value.status == status
  189. @pytest.mark.parametrize('params,extensions', [
  190. ({'timeout': sys.float_info.min}, {}),
  191. ({}, {'timeout': sys.float_info.min}),
  192. ])
  193. def test_read_timeout(self, handler, params, extensions):
  194. with handler(**params) as rh:
  195. with pytest.raises(TransportError):
  196. ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
  197. def test_connect_timeout(self, handler):
  198. # nothing should be listening on this port
  199. connect_timeout_url = 'ws://10.255.255.255'
  200. with handler(timeout=0.01) as rh, pytest.raises(TransportError):
  201. now = time.time()
  202. ws_validate_and_send(rh, Request(connect_timeout_url))
  203. assert time.time() - now < DEFAULT_TIMEOUT
  204. # Per request timeout, should override handler timeout
  205. request = Request(connect_timeout_url, extensions={'timeout': 0.01})
  206. with handler() as rh, pytest.raises(TransportError):
  207. now = time.time()
  208. ws_validate_and_send(rh, request)
  209. assert time.time() - now < DEFAULT_TIMEOUT
  210. def test_cookies(self, handler):
  211. cookiejar = YoutubeDLCookieJar()
  212. cookiejar.set_cookie(http.cookiejar.Cookie(
  213. version=0, name='test', value='ytdlp', port=None, port_specified=False,
  214. domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
  215. path_specified=True, secure=False, expires=None, discard=False, comment=None,
  216. comment_url=None, rest={}))
  217. with handler(cookiejar=cookiejar) as rh:
  218. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  219. ws.send('headers')
  220. assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
  221. ws.close()
  222. with handler() as rh:
  223. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  224. ws.send('headers')
  225. assert 'cookie' not in json.loads(ws.recv())
  226. ws.close()
  227. ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
  228. ws.send('headers')
  229. assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
  230. ws.close()
  231. @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
  232. def test_cookie_sync_only_cookiejar(self, handler):
  233. # Ensure that cookies are ONLY being handled by the cookiejar
  234. with handler() as rh:
  235. ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
  236. ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()}))
  237. ws.send('headers')
  238. assert 'cookie' not in json.loads(ws.recv())
  239. ws.close()
  240. @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
  241. def test_cookie_sync_delete_cookie(self, handler):
  242. # Ensure that cookies are ONLY being handled by the cookiejar
  243. cookiejar = YoutubeDLCookieJar()
  244. with handler(verbose=True, cookiejar=cookiejar) as rh:
  245. ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie'))
  246. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  247. ws.send('headers')
  248. assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
  249. ws.close()
  250. cookiejar.clear_session_cookies()
  251. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  252. ws.send('headers')
  253. assert 'cookie' not in json.loads(ws.recv())
  254. ws.close()
  255. def test_source_address(self, handler):
  256. source_address = f'127.0.0.{random.randint(5, 255)}'
  257. verify_address_availability(source_address)
  258. with handler(source_address=source_address) as rh:
  259. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  260. ws.send('source_address')
  261. assert source_address == ws.recv()
  262. ws.close()
  263. def test_response_url(self, handler):
  264. with handler() as rh:
  265. url = f'{self.ws_base_url}/something'
  266. ws = ws_validate_and_send(rh, Request(url))
  267. assert ws.url == url
  268. ws.close()
  269. def test_request_headers(self, handler):
  270. with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
  271. # Global Headers
  272. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  273. ws.send('headers')
  274. headers = HTTPHeaderDict(json.loads(ws.recv()))
  275. assert headers['test1'] == 'test'
  276. ws.close()
  277. # Per request headers, merged with global
  278. ws = ws_validate_and_send(rh, Request(
  279. self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
  280. ws.send('headers')
  281. headers = HTTPHeaderDict(json.loads(ws.recv()))
  282. assert headers['test1'] == 'test'
  283. assert headers['test2'] == 'changed'
  284. assert headers['test3'] == 'test3'
  285. ws.close()
  286. @pytest.mark.parametrize('client_cert', (
  287. {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
  288. {
  289. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
  290. 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
  291. },
  292. {
  293. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
  294. 'client_certificate_password': 'foobar',
  295. },
  296. {
  297. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
  298. 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
  299. 'client_certificate_password': 'foobar',
  300. },
  301. ))
  302. def test_mtls(self, handler, client_cert):
  303. with handler(
  304. # Disable client-side validation of unacceptable self-signed testcert.pem
  305. # The test is of a check on the server side, so unaffected
  306. verify=False,
  307. client_cert=client_cert,
  308. ) as rh:
  309. ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
  310. def test_request_disable_proxy(self, handler):
  311. for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
  312. # Given handler is configured with a proxy
  313. with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
  314. # When a proxy is explicitly set to None for the request
  315. ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
  316. # Then no proxy should be used
  317. assert ws.status == 101
  318. ws.close()
  319. @pytest.mark.skip_handlers_if(
  320. lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
  321. def test_noproxy(self, handler):
  322. for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
  323. # Given the handler is configured with a proxy
  324. with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
  325. for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
  326. # When request no proxy includes the request url host
  327. ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
  328. # Then the proxy should not be used
  329. assert ws.status == 101
  330. ws.close()
  331. @pytest.mark.skip_handlers_if(
  332. lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
  333. def test_allproxy(self, handler):
  334. supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
  335. # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
  336. # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
  337. with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
  338. with pytest.raises(TransportError):
  339. ws_validate_and_send(rh, Request(self.ws_base_url)).close()
  340. with handler(timeout=0.1) as rh:
  341. with pytest.raises(TransportError):
  342. ws_validate_and_send(
  343. rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
  344. def create_fake_ws_connection(raised):
  345. import websockets.sync.client
  346. class FakeWsConnection(websockets.sync.client.ClientConnection):
  347. def __init__(self, *args, **kwargs):
  348. class FakeResponse:
  349. body = b''
  350. headers = {}
  351. status_code = 101
  352. reason_phrase = 'test'
  353. self.response = FakeResponse()
  354. def send(self, *args, **kwargs):
  355. raise raised()
  356. def recv(self, *args, **kwargs):
  357. raise raised()
  358. def close(self, *args, **kwargs):
  359. return
  360. return FakeWsConnection()
  361. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  362. class TestWebsocketsRequestHandler:
  363. @pytest.mark.parametrize('raised,expected', [
  364. # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
  365. (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
  366. # Requires a response object. Should be covered by HTTP error tests.
  367. # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
  368. (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
  369. # These are subclasses of InvalidHandshake
  370. (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
  371. (lambda: websockets.exceptions.NegotiationError(), TransportError),
  372. # Catch-all
  373. (lambda: websockets.exceptions.WebSocketException(), TransportError),
  374. (lambda: TimeoutError(), TransportError),
  375. # These may be raised by our create_connection implementation, which should also be caught
  376. (lambda: OSError(), TransportError),
  377. (lambda: ssl.SSLError(), SSLError),
  378. (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
  379. (lambda: socks.ProxyError(), ProxyError),
  380. ])
  381. def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
  382. import websockets.sync.client
  383. import yt_dlp.networking._websockets
  384. with handler() as rh:
  385. def fake_connect(*args, **kwargs):
  386. raise raised()
  387. monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
  388. monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
  389. with pytest.raises(expected) as exc_info:
  390. rh.send(Request('ws://fake-url'))
  391. assert exc_info.type is expected
  392. @pytest.mark.parametrize('raised,expected,match', [
  393. # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
  394. (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
  395. (lambda: RuntimeError(), TransportError, None),
  396. (lambda: TimeoutError(), TransportError, None),
  397. (lambda: TypeError(), RequestError, None),
  398. (lambda: socks.ProxyError(), ProxyError, None),
  399. # Catch-all
  400. (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
  401. ])
  402. def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
  403. from yt_dlp.networking._websockets import WebsocketsResponseAdapter
  404. ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
  405. with pytest.raises(expected, match=match) as exc_info:
  406. ws.send('test')
  407. assert exc_info.type is expected
  408. @pytest.mark.parametrize('raised,expected,match', [
  409. # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
  410. (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
  411. (lambda: RuntimeError(), TransportError, None),
  412. (lambda: TimeoutError(), TransportError, None),
  413. (lambda: socks.ProxyError(), ProxyError, None),
  414. # Catch-all
  415. (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
  416. ])
  417. def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
  418. from yt_dlp.networking._websockets import WebsocketsResponseAdapter
  419. ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
  420. with pytest.raises(expected, match=match) as exc_info:
  421. ws.recv()
  422. assert exc_info.type is expected