test_socks.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. #!/usr/bin/env python3
  2. # Allow direct execution
  3. import os
  4. import sys
  5. import threading
  6. import unittest
  7. import pytest
  8. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  9. import abc
  10. import contextlib
  11. import enum
  12. import functools
  13. import http.server
  14. import json
  15. import random
  16. import socket
  17. import struct
  18. import time
  19. from socketserver import (
  20. BaseRequestHandler,
  21. StreamRequestHandler,
  22. ThreadingTCPServer,
  23. )
  24. from test.helper import http_server_port, verify_address_availability
  25. from yt_dlp.networking import Request
  26. from yt_dlp.networking.exceptions import ProxyError, TransportError
  27. from yt_dlp.socks import (
  28. SOCKS4_REPLY_VERSION,
  29. SOCKS4_VERSION,
  30. SOCKS5_USER_AUTH_SUCCESS,
  31. SOCKS5_USER_AUTH_VERSION,
  32. SOCKS5_VERSION,
  33. Socks5AddressType,
  34. Socks5Auth,
  35. )
  36. SOCKS5_USER_AUTH_FAILURE = 0x1
  37. class Socks4CD(enum.IntEnum):
  38. REQUEST_GRANTED = 90
  39. REQUEST_REJECTED_OR_FAILED = 91
  40. REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD = 92
  41. REQUEST_REJECTED_DIFFERENT_USERID = 93
  42. class Socks5Reply(enum.IntEnum):
  43. SUCCEEDED = 0x0
  44. GENERAL_FAILURE = 0x1
  45. CONNECTION_NOT_ALLOWED = 0x2
  46. NETWORK_UNREACHABLE = 0x3
  47. HOST_UNREACHABLE = 0x4
  48. CONNECTION_REFUSED = 0x5
  49. TTL_EXPIRED = 0x6
  50. COMMAND_NOT_SUPPORTED = 0x7
  51. ADDRESS_TYPE_NOT_SUPPORTED = 0x8
  52. class SocksTestRequestHandler(BaseRequestHandler):
  53. def __init__(self, *args, socks_info=None, **kwargs):
  54. self.socks_info = socks_info
  55. super().__init__(*args, **kwargs)
  56. class SocksProxyHandler(BaseRequestHandler):
  57. def __init__(self, request_handler_class, socks_server_kwargs, *args, **kwargs):
  58. self.socks_kwargs = socks_server_kwargs or {}
  59. self.request_handler_class = request_handler_class
  60. super().__init__(*args, **kwargs)
  61. class Socks5ProxyHandler(StreamRequestHandler, SocksProxyHandler):
  62. # SOCKS5 protocol https://tools.ietf.org/html/rfc1928
  63. # SOCKS5 username/password authentication https://tools.ietf.org/html/rfc1929
  64. def handle(self):
  65. sleep = self.socks_kwargs.get('sleep')
  66. if sleep:
  67. time.sleep(sleep)
  68. version, nmethods = self.connection.recv(2)
  69. assert version == SOCKS5_VERSION
  70. methods = list(self.connection.recv(nmethods))
  71. auth = self.socks_kwargs.get('auth')
  72. if auth is not None and Socks5Auth.AUTH_USER_PASS not in methods:
  73. self.connection.sendall(struct.pack('!BB', SOCKS5_VERSION, Socks5Auth.AUTH_NO_ACCEPTABLE))
  74. self.server.close_request(self.request)
  75. return
  76. elif Socks5Auth.AUTH_USER_PASS in methods:
  77. self.connection.sendall(struct.pack('!BB', SOCKS5_VERSION, Socks5Auth.AUTH_USER_PASS))
  78. _, user_len = struct.unpack('!BB', self.connection.recv(2))
  79. username = self.connection.recv(user_len).decode()
  80. pass_len = ord(self.connection.recv(1))
  81. password = self.connection.recv(pass_len).decode()
  82. if username == auth[0] and password == auth[1]:
  83. self.connection.sendall(struct.pack('!BB', SOCKS5_USER_AUTH_VERSION, SOCKS5_USER_AUTH_SUCCESS))
  84. else:
  85. self.connection.sendall(struct.pack('!BB', SOCKS5_USER_AUTH_VERSION, SOCKS5_USER_AUTH_FAILURE))
  86. self.server.close_request(self.request)
  87. return
  88. elif Socks5Auth.AUTH_NONE in methods:
  89. self.connection.sendall(struct.pack('!BB', SOCKS5_VERSION, Socks5Auth.AUTH_NONE))
  90. else:
  91. self.connection.sendall(struct.pack('!BB', SOCKS5_VERSION, Socks5Auth.AUTH_NO_ACCEPTABLE))
  92. self.server.close_request(self.request)
  93. return
  94. version, command, _, address_type = struct.unpack('!BBBB', self.connection.recv(4))
  95. socks_info = {
  96. 'version': version,
  97. 'auth_methods': methods,
  98. 'command': command,
  99. 'client_address': self.client_address,
  100. 'ipv4_address': None,
  101. 'domain_address': None,
  102. 'ipv6_address': None,
  103. }
  104. if address_type == Socks5AddressType.ATYP_IPV4:
  105. socks_info['ipv4_address'] = socket.inet_ntoa(self.connection.recv(4))
  106. elif address_type == Socks5AddressType.ATYP_DOMAINNAME:
  107. socks_info['domain_address'] = self.connection.recv(ord(self.connection.recv(1))).decode()
  108. elif address_type == Socks5AddressType.ATYP_IPV6:
  109. socks_info['ipv6_address'] = socket.inet_ntop(socket.AF_INET6, self.connection.recv(16))
  110. else:
  111. self.server.close_request(self.request)
  112. socks_info['port'] = struct.unpack('!H', self.connection.recv(2))[0]
  113. # dummy response, the returned IP is just a placeholder
  114. self.connection.sendall(struct.pack(
  115. '!BBBBIH', SOCKS5_VERSION, self.socks_kwargs.get('reply', Socks5Reply.SUCCEEDED), 0x0, 0x1, 0x7f000001, 40000))
  116. self.request_handler_class(self.request, self.client_address, self.server, socks_info=socks_info)
  117. class Socks4ProxyHandler(StreamRequestHandler, SocksProxyHandler):
  118. # SOCKS4 protocol http://www.openssh.com/txt/socks4.protocol
  119. # SOCKS4A protocol http://www.openssh.com/txt/socks4a.protocol
  120. def _read_until_null(self):
  121. return b''.join(iter(functools.partial(self.connection.recv, 1), b'\x00'))
  122. def handle(self):
  123. sleep = self.socks_kwargs.get('sleep')
  124. if sleep:
  125. time.sleep(sleep)
  126. socks_info = {
  127. 'version': SOCKS4_VERSION,
  128. 'command': None,
  129. 'client_address': self.client_address,
  130. 'ipv4_address': None,
  131. 'port': None,
  132. 'domain_address': None,
  133. }
  134. version, command, dest_port, dest_ip = struct.unpack('!BBHI', self.connection.recv(8))
  135. socks_info['port'] = dest_port
  136. socks_info['command'] = command
  137. if version != SOCKS4_VERSION:
  138. self.server.close_request(self.request)
  139. return
  140. use_remote_dns = False
  141. if 0x0 < dest_ip <= 0xFF:
  142. use_remote_dns = True
  143. else:
  144. socks_info['ipv4_address'] = socket.inet_ntoa(struct.pack('!I', dest_ip))
  145. user_id = self._read_until_null().decode()
  146. if user_id != (self.socks_kwargs.get('user_id') or ''):
  147. self.connection.sendall(struct.pack(
  148. '!BBHI', SOCKS4_REPLY_VERSION, Socks4CD.REQUEST_REJECTED_DIFFERENT_USERID, 0x00, 0x00000000))
  149. self.server.close_request(self.request)
  150. return
  151. if use_remote_dns:
  152. socks_info['domain_address'] = self._read_until_null().decode()
  153. # dummy response, the returned IP is just a placeholder
  154. self.connection.sendall(
  155. struct.pack(
  156. '!BBHI', SOCKS4_REPLY_VERSION,
  157. self.socks_kwargs.get('cd_reply', Socks4CD.REQUEST_GRANTED), 40000, 0x7f000001))
  158. self.request_handler_class(self.request, self.client_address, self.server, socks_info=socks_info)
  159. class IPv6ThreadingTCPServer(ThreadingTCPServer):
  160. address_family = socket.AF_INET6
  161. class SocksHTTPTestRequestHandler(http.server.BaseHTTPRequestHandler, SocksTestRequestHandler):
  162. def do_GET(self):
  163. if self.path == '/socks_info':
  164. payload = json.dumps(self.socks_info.copy())
  165. self.send_response(200)
  166. self.send_header('Content-Type', 'application/json; charset=utf-8')
  167. self.send_header('Content-Length', str(len(payload)))
  168. self.end_headers()
  169. self.wfile.write(payload.encode())
  170. class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
  171. def handle(self):
  172. import websockets.sync.server
  173. protocol = websockets.ServerProtocol()
  174. connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
  175. connection.handshake()
  176. for message in connection:
  177. if message == 'socks_info':
  178. connection.send(json.dumps(self.socks_info))
  179. connection.close()
  180. @contextlib.contextmanager
  181. def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
  182. server = server_thread = None
  183. try:
  184. bind_address = bind_ip or '127.0.0.1'
  185. server_type = ThreadingTCPServer if '.' in bind_address else IPv6ThreadingTCPServer
  186. server = server_type(
  187. (bind_address, 0), functools.partial(socks_server_class, request_handler, socks_server_kwargs))
  188. server_port = http_server_port(server)
  189. server_thread = threading.Thread(target=server.serve_forever)
  190. server_thread.daemon = True
  191. server_thread.start()
  192. if '.' not in bind_address:
  193. yield f'[{bind_address}]:{server_port}'
  194. else:
  195. yield f'{bind_address}:{server_port}'
  196. finally:
  197. server.shutdown()
  198. server.server_close()
  199. server_thread.join(2.0)
  200. class SocksProxyTestContext(abc.ABC):
  201. REQUEST_HANDLER_CLASS = None
  202. def socks_server(self, server_class, *args, **kwargs):
  203. return socks_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs)
  204. @abc.abstractmethod
  205. def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs) -> dict:
  206. """return a dict of socks_info"""
  207. class HTTPSocksTestProxyContext(SocksProxyTestContext):
  208. REQUEST_HANDLER_CLASS = SocksHTTPTestRequestHandler
  209. def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
  210. request = Request(f'http://{target_domain or "127.0.0.1"}:{target_port or "40000"}/socks_info', **req_kwargs)
  211. handler.validate(request)
  212. return json.loads(handler.send(request).read().decode())
  213. class WebSocketSocksTestProxyContext(SocksProxyTestContext):
  214. REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
  215. def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
  216. request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
  217. handler.validate(request)
  218. ws = handler.send(request)
  219. ws.send('socks_info')
  220. socks_info = ws.recv()
  221. ws.close()
  222. return json.loads(socks_info)
  223. CTX_MAP = {
  224. 'http': HTTPSocksTestProxyContext,
  225. 'ws': WebSocketSocksTestProxyContext,
  226. }
  227. @pytest.fixture(scope='module')
  228. def ctx(request):
  229. return CTX_MAP[request.param]()
  230. @pytest.mark.parametrize(
  231. 'handler,ctx', [
  232. ('Urllib', 'http'),
  233. ('Requests', 'http'),
  234. ('Websockets', 'ws'),
  235. ('CurlCFFI', 'http'),
  236. ], indirect=True)
  237. class TestSocks4Proxy:
  238. def test_socks4_no_auth(self, handler, ctx):
  239. with handler() as rh:
  240. with ctx.socks_server(Socks4ProxyHandler) as server_address:
  241. response = ctx.socks_info_request(
  242. rh, proxies={'all': f'socks4://{server_address}'})
  243. assert response['version'] == 4
  244. def test_socks4_auth(self, handler, ctx):
  245. with handler() as rh:
  246. with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
  247. with pytest.raises(ProxyError):
  248. ctx.socks_info_request(rh, proxies={'all': f'socks4://{server_address}'})
  249. response = ctx.socks_info_request(
  250. rh, proxies={'all': f'socks4://user:@{server_address}'})
  251. assert response['version'] == 4
  252. def test_socks4a_ipv4_target(self, handler, ctx):
  253. with ctx.socks_server(Socks4ProxyHandler) as server_address:
  254. with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
  255. response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
  256. assert response['version'] == 4
  257. assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
  258. def test_socks4a_domain_target(self, handler, ctx):
  259. with ctx.socks_server(Socks4ProxyHandler) as server_address:
  260. with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
  261. response = ctx.socks_info_request(rh, target_domain='localhost')
  262. assert response['version'] == 4
  263. assert response['ipv4_address'] is None
  264. assert response['domain_address'] == 'localhost'
  265. def test_ipv4_client_source_address(self, handler, ctx):
  266. with ctx.socks_server(Socks4ProxyHandler) as server_address:
  267. source_address = f'127.0.0.{random.randint(5, 255)}'
  268. verify_address_availability(source_address)
  269. with handler(proxies={'all': f'socks4://{server_address}'},
  270. source_address=source_address) as rh:
  271. response = ctx.socks_info_request(rh)
  272. assert response['client_address'][0] == source_address
  273. assert response['version'] == 4
  274. @pytest.mark.parametrize('reply_code', [
  275. Socks4CD.REQUEST_REJECTED_OR_FAILED,
  276. Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
  277. Socks4CD.REQUEST_REJECTED_DIFFERENT_USERID,
  278. ])
  279. def test_socks4_errors(self, handler, ctx, reply_code):
  280. with ctx.socks_server(Socks4ProxyHandler, cd_reply=reply_code) as server_address:
  281. with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
  282. with pytest.raises(ProxyError):
  283. ctx.socks_info_request(rh)
  284. def test_ipv6_socks4_proxy(self, handler, ctx):
  285. with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
  286. with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
  287. response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
  288. assert response['client_address'][0] == '::1'
  289. assert response['ipv4_address'] == '127.0.0.1'
  290. assert response['version'] == 4
  291. def test_timeout(self, handler, ctx):
  292. with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
  293. with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
  294. with pytest.raises(TransportError):
  295. ctx.socks_info_request(rh)
  296. @pytest.mark.parametrize(
  297. 'handler,ctx', [
  298. ('Urllib', 'http'),
  299. ('Requests', 'http'),
  300. ('Websockets', 'ws'),
  301. ('CurlCFFI', 'http'),
  302. ], indirect=True)
  303. class TestSocks5Proxy:
  304. def test_socks5_no_auth(self, handler, ctx):
  305. with ctx.socks_server(Socks5ProxyHandler) as server_address:
  306. with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
  307. response = ctx.socks_info_request(rh)
  308. assert response['auth_methods'] == [0x0]
  309. assert response['version'] == 5
  310. def test_socks5_user_pass(self, handler, ctx):
  311. with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
  312. with handler() as rh:
  313. with pytest.raises(ProxyError):
  314. ctx.socks_info_request(rh, proxies={'all': f'socks5://{server_address}'})
  315. response = ctx.socks_info_request(
  316. rh, proxies={'all': f'socks5://test:testpass@{server_address}'})
  317. assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
  318. assert response['version'] == 5
  319. def test_socks5_ipv4_target(self, handler, ctx):
  320. with ctx.socks_server(Socks5ProxyHandler) as server_address:
  321. with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
  322. response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
  323. assert response['ipv4_address'] == '127.0.0.1'
  324. assert response['version'] == 5
  325. def test_socks5_domain_target(self, handler, ctx):
  326. with ctx.socks_server(Socks5ProxyHandler) as server_address:
  327. with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
  328. response = ctx.socks_info_request(rh, target_domain='localhost')
  329. assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
  330. assert response['version'] == 5
  331. def test_socks5h_domain_target(self, handler, ctx):
  332. with ctx.socks_server(Socks5ProxyHandler) as server_address:
  333. with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
  334. response = ctx.socks_info_request(rh, target_domain='localhost')
  335. assert response['ipv4_address'] is None
  336. assert response['domain_address'] == 'localhost'
  337. assert response['version'] == 5
  338. def test_socks5h_ip_target(self, handler, ctx):
  339. with ctx.socks_server(Socks5ProxyHandler) as server_address:
  340. with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
  341. response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
  342. assert response['ipv4_address'] == '127.0.0.1'
  343. assert response['domain_address'] is None
  344. assert response['version'] == 5
  345. def test_socks5_ipv6_destination(self, handler, ctx):
  346. with ctx.socks_server(Socks5ProxyHandler) as server_address:
  347. with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
  348. response = ctx.socks_info_request(rh, target_domain='[::1]')
  349. assert response['ipv6_address'] == '::1'
  350. assert response['version'] == 5
  351. def test_ipv6_socks5_proxy(self, handler, ctx):
  352. with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
  353. with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
  354. response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
  355. assert response['client_address'][0] == '::1'
  356. assert response['ipv4_address'] == '127.0.0.1'
  357. assert response['version'] == 5
  358. # XXX: is there any feasible way of testing IPv6 source addresses?
  359. # Same would go for non-proxy source_address test...
  360. def test_ipv4_client_source_address(self, handler, ctx):
  361. with ctx.socks_server(Socks5ProxyHandler) as server_address:
  362. source_address = f'127.0.0.{random.randint(5, 255)}'
  363. verify_address_availability(source_address)
  364. with handler(proxies={'all': f'socks5://{server_address}'}, source_address=source_address) as rh:
  365. response = ctx.socks_info_request(rh)
  366. assert response['client_address'][0] == source_address
  367. assert response['version'] == 5
  368. @pytest.mark.parametrize('reply_code', [
  369. Socks5Reply.GENERAL_FAILURE,
  370. Socks5Reply.CONNECTION_NOT_ALLOWED,
  371. Socks5Reply.NETWORK_UNREACHABLE,
  372. Socks5Reply.HOST_UNREACHABLE,
  373. Socks5Reply.CONNECTION_REFUSED,
  374. Socks5Reply.TTL_EXPIRED,
  375. Socks5Reply.COMMAND_NOT_SUPPORTED,
  376. Socks5Reply.ADDRESS_TYPE_NOT_SUPPORTED,
  377. ])
  378. def test_socks5_errors(self, handler, ctx, reply_code):
  379. with ctx.socks_server(Socks5ProxyHandler, reply=reply_code) as server_address:
  380. with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
  381. with pytest.raises(ProxyError):
  382. ctx.socks_info_request(rh)
  383. def test_timeout(self, handler, ctx):
  384. with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
  385. with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:
  386. with pytest.raises(TransportError):
  387. ctx.socks_info_request(rh)
  388. if __name__ == '__main__':
  389. unittest.main()