httputil.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import atexit
  2. import http
  3. import logging
  4. import multiprocessing
  5. import os
  6. import sys
  7. import socket
  8. import time
  9. from collections import deque
  10. from typing import Dict, Any, Optional, Tuple, Callable
  11. import certifi
  12. import lz4.frame
  13. import urllib3
  14. import zstandard
  15. from urllib3.poolmanager import PoolManager, ProxyManager
  16. from urllib3.response import HTTPResponse
  17. from clickhouse_connect.driver.exceptions import ProgrammingError
  18. from clickhouse_connect import common
  19. logger = logging.getLogger(__name__)
  20. # We disable this warning. Verify must be explicitly set to false, so we assume the user knows what they're doing
  21. urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
  22. # Increase this number just to be safe when ClickHouse is returning progress headers
  23. http.client._MAXHEADERS = 10000 # pylint: disable=protected-access
  24. DEFAULT_KEEP_INTERVAL = 30
  25. DEFAULT_KEEP_COUNT = 3
  26. DEFAULT_KEEP_IDLE = 30
  27. SOCKET_TCP = socket.IPPROTO_TCP
  28. core_socket_options = [
  29. (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
  30. (SOCKET_TCP, socket.TCP_NODELAY, 1),
  31. (socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 256),
  32. (socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 256)
  33. ]
  34. logging.getLogger('urllib3').setLevel(logging.WARNING)
  35. _proxy_managers = {}
  36. all_managers = {}
  37. @atexit.register
  38. def close_managers():
  39. for manager in all_managers:
  40. manager.clear()
  41. # pylint: disable=no-member,too-many-arguments,too-many-branches
  42. def get_pool_manager_options(keep_interval: int = DEFAULT_KEEP_INTERVAL,
  43. keep_count: int = DEFAULT_KEEP_COUNT,
  44. keep_idle: int = DEFAULT_KEEP_IDLE,
  45. ca_cert: str = None,
  46. verify: bool = True,
  47. client_cert: str = None,
  48. client_cert_key: str = None,
  49. **options) -> Dict[str, Any]:
  50. socket_options = core_socket_options.copy()
  51. if getattr(socket, 'TCP_KEEPINTVL', None) is not None:
  52. socket_options.append((SOCKET_TCP, socket.TCP_KEEPINTVL, keep_interval))
  53. if getattr(socket, 'TCP_KEEPCNT', None) is not None:
  54. socket_options.append((SOCKET_TCP, socket.TCP_KEEPCNT, keep_count))
  55. if getattr(socket, 'TCP_KEEPIDLE', None) is not None:
  56. socket_options.append((SOCKET_TCP, socket.TCP_KEEPIDLE, keep_idle))
  57. if sys.platform == 'darwin':
  58. socket_options.append((SOCKET_TCP, getattr(socket, 'TCP_KEEPALIVE', 0x10), keep_interval))
  59. options['maxsize'] = options.get('maxsize', 8)
  60. options['retries'] = options.get('retries', 1)
  61. if ca_cert == 'certifi':
  62. ca_cert = certifi.where()
  63. options['cert_reqs'] = 'CERT_REQUIRED' if verify else 'CERT_NONE'
  64. if ca_cert:
  65. options['ca_certs'] = ca_cert
  66. if client_cert:
  67. options['cert_file'] = client_cert
  68. if client_cert_key:
  69. options['key_file'] = client_cert_key
  70. options['socket_options'] = socket_options
  71. options['block'] = options.get('block', False)
  72. return options
  73. def get_pool_manager(keep_interval: int = DEFAULT_KEEP_INTERVAL,
  74. keep_count: int = DEFAULT_KEEP_COUNT,
  75. keep_idle: int = DEFAULT_KEEP_IDLE,
  76. ca_cert: str = None,
  77. verify: bool = True,
  78. client_cert: str = None,
  79. client_cert_key: str = None,
  80. http_proxy: str = None,
  81. https_proxy: str = None,
  82. **options):
  83. options = get_pool_manager_options(keep_interval,
  84. keep_count,
  85. keep_idle,
  86. ca_cert,
  87. verify,
  88. client_cert,
  89. client_cert_key,
  90. **options)
  91. if http_proxy:
  92. if https_proxy:
  93. raise ProgrammingError('Only one of http_proxy or https_proxy should be specified')
  94. if not http_proxy.startswith('http'):
  95. http_proxy = f'http://{http_proxy}'
  96. manager = ProxyManager(http_proxy, **options)
  97. elif https_proxy:
  98. if not https_proxy.startswith('http'):
  99. https_proxy = f'https://{https_proxy}'
  100. manager = ProxyManager(https_proxy, **options)
  101. else:
  102. manager = PoolManager(**options)
  103. all_managers[manager] = int(time.time())
  104. return manager
  105. def check_conn_expiration(manager: PoolManager):
  106. reset_seconds = common.get_setting('max_connection_age')
  107. if reset_seconds:
  108. last_reset = all_managers.get(manager, 0)
  109. now = int(time.time())
  110. if last_reset < now - reset_seconds:
  111. logger.debug('connection expiration')
  112. manager.clear()
  113. all_managers[manager] = now
  114. def get_proxy_manager(host: str, http_proxy):
  115. key = f'{host}__{http_proxy}'
  116. if key in _proxy_managers:
  117. return _proxy_managers[key]
  118. proxy_manager = get_pool_manager(http_proxy=http_proxy)
  119. _proxy_managers[key] = proxy_manager
  120. return proxy_manager
  121. def get_response_data(response: HTTPResponse) -> bytes:
  122. encoding = response.headers.get('content-encoding', None)
  123. if encoding == 'zstd':
  124. try:
  125. zstd_decom = zstandard.ZstdDecompressor()
  126. return zstd_decom.stream_reader(response.data).read()
  127. except zstandard.ZstdError:
  128. pass
  129. if encoding == 'lz4':
  130. lz4_decom = lz4.frame.LZ4FrameDecompressor()
  131. return lz4_decom.decompress(response.data, len(response.data))
  132. return response.data
  133. def check_env_proxy(scheme: str, host: str, port: int) -> Optional[str]:
  134. env_var = f'{scheme}_proxy'.lower()
  135. proxy = os.environ.get(env_var)
  136. if not proxy:
  137. proxy = os.environ.get(env_var.upper())
  138. if not proxy:
  139. return None
  140. no_proxy = os.environ.get('no_proxy')
  141. if not no_proxy:
  142. no_proxy = os.environ.get('NO_PROXY')
  143. if not no_proxy:
  144. return proxy
  145. if no_proxy == '*':
  146. return None # Wildcard no proxy means don't actually proxy anything
  147. host = host.lower()
  148. for name in no_proxy.split(','):
  149. name = name.strip()
  150. if name:
  151. name = name.lstrip('.').lower()
  152. if name in (host, f'{host}:{port}'):
  153. return None # Host or host/port matches
  154. if host.endswith('.' + name):
  155. return None # Domain matches
  156. return proxy
  157. _default_pool_manager = get_pool_manager()
  158. def default_pool_manager():
  159. if multiprocessing.current_process().name == 'MainProcess':
  160. return _default_pool_manager
  161. # PoolManagers don't seem to be safe for some multiprocessing environments, always return a new one
  162. return get_pool_manager()
  163. class ResponseSource:
  164. def __init__(self, response: HTTPResponse, chunk_size: int = 1024 * 1024):
  165. self.response = response
  166. compression = response.headers.get('content-encoding')
  167. decompress:Optional[Callable] = None
  168. if compression == 'zstd':
  169. zstd_decom = zstandard.ZstdDecompressor().decompressobj()
  170. def zstd_decompress(c: deque) -> Tuple[bytes, int]:
  171. chunk = c.popleft()
  172. return zstd_decom.decompress(chunk), len(chunk)
  173. decompress = zstd_decompress
  174. elif compression == 'lz4':
  175. lz4_decom = lz4.frame.LZ4FrameDecompressor()
  176. def lz_decompress(c: deque) -> Tuple[Optional[bytes], int]:
  177. read_amt = 0
  178. data = c.popleft()
  179. read_amt += len(data)
  180. if lz4_decom.unused_data:
  181. read_amt += len(lz4_decom.unused_data)
  182. data = lz4_decom.unused_data + data
  183. block = lz4_decom.decompress(data)
  184. if lz4_decom.unused_data:
  185. read_amt -= len(lz4_decom.unused_data)
  186. return block, read_amt
  187. decompress = lz_decompress
  188. buffer_size = common.get_setting('http_buffer_size')
  189. def buffered():
  190. chunks = deque()
  191. done = False
  192. current_size = 0
  193. read_gen = response.stream(chunk_size, decompress is None)
  194. while True:
  195. while not done:
  196. try:
  197. chunk = next(read_gen, None) # Always try to read at least one chunk if there are any left
  198. except Exception: # pylint: disable=broad-except
  199. # By swallowing an unexpected exception reading the stream, we will let consumers decide how to
  200. # handle the unexpected end of stream
  201. pass
  202. if not chunk:
  203. done = True
  204. break
  205. chunks.append(chunk)
  206. current_size += len(chunk)
  207. if current_size > buffer_size:
  208. break
  209. if len(chunks) == 0:
  210. return
  211. if decompress:
  212. chunk, used = decompress(chunks)
  213. current_size -= used
  214. else:
  215. chunk = chunks.popleft()
  216. current_size -= len(chunk)
  217. if chunk:
  218. yield chunk
  219. self.gen = buffered()
  220. def close(self):
  221. self.response.drain_conn()
  222. self.response.close()