_helper.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. from __future__ import annotations
  2. import contextlib
  3. import functools
  4. import os
  5. import socket
  6. import ssl
  7. import sys
  8. import typing
  9. import urllib.parse
  10. import urllib.request
  11. from .exceptions import RequestError
  12. from ..dependencies import certifi
  13. from ..socks import ProxyType, sockssocket
  14. from ..utils import format_field, traverse_obj
  15. if typing.TYPE_CHECKING:
  16. from collections.abc import Iterable
  17. from ..utils.networking import HTTPHeaderDict
  18. def ssl_load_certs(context: ssl.SSLContext, use_certifi=True):
  19. if certifi and use_certifi:
  20. context.load_verify_locations(cafile=certifi.where())
  21. else:
  22. try:
  23. context.load_default_certs()
  24. # Work around the issue in load_default_certs when there are bad certificates. See:
  25. # https://github.com/yt-dlp/yt-dlp/issues/1060,
  26. # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312
  27. except ssl.SSLError:
  28. # enum_certificates is not present in mingw python. See https://github.com/yt-dlp/yt-dlp/issues/1151
  29. if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'):
  30. for storename in ('CA', 'ROOT'):
  31. ssl_load_windows_store_certs(context, storename)
  32. context.set_default_verify_paths()
  33. def ssl_load_windows_store_certs(ssl_context, storename):
  34. # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py
  35. try:
  36. certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename)
  37. if encoding == 'x509_asn' and (
  38. trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)]
  39. except PermissionError:
  40. return
  41. for cert in certs:
  42. with contextlib.suppress(ssl.SSLError):
  43. ssl_context.load_verify_locations(cadata=cert)
  44. def make_socks_proxy_opts(socks_proxy):
  45. url_components = urllib.parse.urlparse(socks_proxy)
  46. if url_components.scheme.lower() == 'socks5':
  47. socks_type = ProxyType.SOCKS5
  48. rdns = False
  49. elif url_components.scheme.lower() == 'socks5h':
  50. socks_type = ProxyType.SOCKS5
  51. rdns = True
  52. elif url_components.scheme.lower() == 'socks4':
  53. socks_type = ProxyType.SOCKS4
  54. rdns = False
  55. elif url_components.scheme.lower() == 'socks4a':
  56. socks_type = ProxyType.SOCKS4A
  57. rdns = True
  58. else:
  59. raise ValueError(f'Unknown SOCKS proxy version: {url_components.scheme.lower()}')
  60. def unquote_if_non_empty(s):
  61. if not s:
  62. return s
  63. return urllib.parse.unquote_plus(s)
  64. return {
  65. 'proxytype': socks_type,
  66. 'addr': url_components.hostname,
  67. 'port': url_components.port or 1080,
  68. 'rdns': rdns,
  69. 'username': unquote_if_non_empty(url_components.username),
  70. 'password': unquote_if_non_empty(url_components.password),
  71. }
  72. def select_proxy(url, proxies):
  73. """Unified proxy selector for all backends"""
  74. url_components = urllib.parse.urlparse(url)
  75. if 'no' in proxies:
  76. hostport = url_components.hostname + format_field(url_components.port, None, ':%s')
  77. if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}):
  78. return
  79. elif urllib.request.proxy_bypass(hostport): # check system settings
  80. return
  81. return traverse_obj(proxies, url_components.scheme or 'http', 'all')
  82. def get_redirect_method(method, status):
  83. """Unified redirect method handling"""
  84. # A 303 must either use GET or HEAD for subsequent request
  85. # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4
  86. if status == 303 and method != 'HEAD':
  87. method = 'GET'
  88. # 301 and 302 redirects are commonly turned into a GET from a POST
  89. # for subsequent requests by browsers, so we'll do the same.
  90. # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2
  91. # https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3
  92. if status in (301, 302) and method == 'POST':
  93. method = 'GET'
  94. return method
  95. def make_ssl_context(
  96. verify=True,
  97. client_certificate=None,
  98. client_certificate_key=None,
  99. client_certificate_password=None,
  100. legacy_support=False,
  101. use_certifi=True,
  102. ):
  103. context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  104. context.check_hostname = verify
  105. context.verify_mode = ssl.CERT_REQUIRED if verify else ssl.CERT_NONE
  106. # OpenSSL 1.1.1+ Python 3.8+ keylog file
  107. if hasattr(context, 'keylog_filename'):
  108. context.keylog_filename = os.environ.get('SSLKEYLOGFILE') or None
  109. # Some servers may reject requests if ALPN extension is not sent. See:
  110. # https://github.com/python/cpython/issues/85140
  111. # https://github.com/yt-dlp/yt-dlp/issues/3878
  112. with contextlib.suppress(NotImplementedError):
  113. context.set_alpn_protocols(['http/1.1'])
  114. if verify:
  115. ssl_load_certs(context, use_certifi)
  116. if legacy_support:
  117. context.options |= 4 # SSL_OP_LEGACY_SERVER_CONNECT
  118. context.set_ciphers('DEFAULT') # compat
  119. elif ssl.OPENSSL_VERSION_INFO >= (1, 1, 1) and not ssl.OPENSSL_VERSION.startswith('LibreSSL'):
  120. # Use the default SSL ciphers and minimum TLS version settings from Python 3.10 [1].
  121. # This is to ensure consistent behavior across Python versions and libraries, and help avoid fingerprinting
  122. # in some situations [2][3].
  123. # Python 3.10 only supports OpenSSL 1.1.1+ [4]. Because this change is likely
  124. # untested on older versions, we only apply this to OpenSSL 1.1.1+ to be safe.
  125. # LibreSSL is excluded until further investigation due to cipher support issues [5][6].
  126. # 1. https://github.com/python/cpython/commit/e983252b516edb15d4338b0a47631b59ef1e2536
  127. # 2. https://github.com/yt-dlp/yt-dlp/issues/4627
  128. # 3. https://github.com/yt-dlp/yt-dlp/pull/5294
  129. # 4. https://peps.python.org/pep-0644/
  130. # 5. https://peps.python.org/pep-0644/#libressl-support
  131. # 6. https://github.com/yt-dlp/yt-dlp/commit/5b9f253fa0aee996cf1ed30185d4b502e00609c4#commitcomment-89054368
  132. context.set_ciphers(
  133. '@SECLEVEL=2:ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES:DHE+AES:!aNULL:!eNULL:!aDSS:!SHA1:!AESCCM')
  134. context.minimum_version = ssl.TLSVersion.TLSv1_2
  135. if client_certificate:
  136. try:
  137. context.load_cert_chain(
  138. client_certificate, keyfile=client_certificate_key,
  139. password=client_certificate_password)
  140. except ssl.SSLError:
  141. raise RequestError('Unable to load client certificate')
  142. if getattr(context, 'post_handshake_auth', None) is not None:
  143. context.post_handshake_auth = True
  144. return context
  145. class InstanceStoreMixin:
  146. def __init__(self, **kwargs):
  147. self.__instances = []
  148. super().__init__(**kwargs) # So that both MRO works
  149. @staticmethod
  150. def _create_instance(**kwargs):
  151. raise NotImplementedError
  152. def _get_instance(self, **kwargs):
  153. for key, instance in self.__instances:
  154. if key == kwargs:
  155. return instance
  156. instance = self._create_instance(**kwargs)
  157. self.__instances.append((kwargs, instance))
  158. return instance
  159. def _close_instance(self, instance):
  160. if callable(getattr(instance, 'close', None)):
  161. instance.close()
  162. def _clear_instances(self):
  163. for _, instance in self.__instances:
  164. self._close_instance(instance)
  165. self.__instances.clear()
  166. def add_accept_encoding_header(headers: HTTPHeaderDict, supported_encodings: Iterable[str]):
  167. if 'Accept-Encoding' not in headers:
  168. headers['Accept-Encoding'] = ', '.join(supported_encodings) or 'identity'
  169. def wrap_request_errors(func):
  170. @functools.wraps(func)
  171. def wrapper(self, *args, **kwargs):
  172. try:
  173. return func(self, *args, **kwargs)
  174. except RequestError as e:
  175. if e.handler is None:
  176. e.handler = self
  177. raise
  178. return wrapper
  179. def _socket_connect(ip_addr, timeout, source_address):
  180. af, socktype, proto, canonname, sa = ip_addr
  181. sock = socket.socket(af, socktype, proto)
  182. try:
  183. if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
  184. sock.settimeout(timeout)
  185. if source_address:
  186. sock.bind(source_address)
  187. sock.connect(sa)
  188. return sock
  189. except OSError:
  190. sock.close()
  191. raise
  192. def create_socks_proxy_socket(dest_addr, proxy_args, proxy_ip_addr, timeout, source_address):
  193. af, socktype, proto, canonname, sa = proxy_ip_addr
  194. sock = sockssocket(af, socktype, proto)
  195. try:
  196. connect_proxy_args = proxy_args.copy()
  197. connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
  198. sock.setproxy(**connect_proxy_args)
  199. if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
  200. sock.settimeout(timeout)
  201. if source_address:
  202. sock.bind(source_address)
  203. sock.connect(dest_addr)
  204. return sock
  205. except OSError:
  206. sock.close()
  207. raise
  208. def create_connection(
  209. address,
  210. timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
  211. source_address=None,
  212. *,
  213. _create_socket_func=_socket_connect,
  214. ):
  215. # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
  216. # This filters the addresses based on the given source_address.
  217. # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
  218. host, port = address
  219. ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
  220. if not ip_addrs:
  221. raise OSError('getaddrinfo returns an empty list')
  222. if source_address is not None:
  223. af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
  224. ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
  225. if not ip_addrs:
  226. raise OSError(
  227. f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
  228. f'Can\'t use "{source_address[0]}" as source address')
  229. err = None
  230. for ip_addr in ip_addrs:
  231. try:
  232. sock = _create_socket_func(ip_addr, timeout, source_address)
  233. # Explicitly break __traceback__ reference cycle
  234. # https://bugs.python.org/issue36820
  235. err = None
  236. return sock
  237. except OSError as e:
  238. err = e
  239. try:
  240. raise err
  241. finally:
  242. # Explicitly break __traceback__ reference cycle
  243. # https://bugs.python.org/issue36820
  244. err = None