httpclient.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. import io
  2. import json
  3. import logging
  4. import re
  5. import uuid
  6. from base64 import b64encode
  7. from typing import Optional, Dict, Any, Sequence, Union, List, Callable, Generator, BinaryIO
  8. from urllib.parse import urlencode
  9. from urllib3 import Timeout
  10. from urllib3.exceptions import HTTPError
  11. from urllib3.poolmanager import PoolManager
  12. from urllib3.response import HTTPResponse
  13. from clickhouse_connect import common
  14. from clickhouse_connect.datatypes import registry
  15. from clickhouse_connect.datatypes.base import ClickHouseType
  16. from clickhouse_connect.driver.ctypes import RespBuffCls
  17. from clickhouse_connect.driver.client import Client
  18. from clickhouse_connect.driver.common import dict_copy, coerce_bool, coerce_int
  19. from clickhouse_connect.driver.compression import available_compression
  20. from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError, ProgrammingError
  21. from clickhouse_connect.driver.external import ExternalData
  22. from clickhouse_connect.driver.httputil import ResponseSource, get_pool_manager, get_response_data, \
  23. default_pool_manager, get_proxy_manager, all_managers, check_env_proxy, check_conn_reset
  24. from clickhouse_connect.driver.insert import InsertContext
  25. from clickhouse_connect.driver.summary import QuerySummary
  26. from clickhouse_connect.driver.query import QueryResult, QueryContext, quote_identifier, bind_query
  27. from clickhouse_connect.driver.transform import NativeTransform
  28. logger = logging.getLogger(__name__)
  29. columns_only_re = re.compile(r'LIMIT 0\s*$', re.IGNORECASE)
  30. # pylint: disable=too-many-instance-attributes
  31. class HttpClient(Client):
  32. params = {}
  33. valid_transport_settings = {'database', 'buffer_size', 'session_id',
  34. 'compress', 'decompress', 'session_timeout',
  35. 'session_check', 'query_id', 'quota_key',
  36. 'wait_end_of_query', 'client_protocol_version'}
  37. optional_transport_settings = {'send_progress_in_http_headers',
  38. 'http_headers_progress_interval_ms',
  39. 'enable_http_compression'}
  40. _owns_pool_manager = False
  41. # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements,unused-argument
  42. def __init__(self,
  43. interface: str,
  44. host: str,
  45. port: int,
  46. username: str,
  47. password: str,
  48. database: str,
  49. compress: Union[bool, str] = True,
  50. query_limit: int = 0,
  51. query_retries: int = 2,
  52. connect_timeout: int = 10,
  53. send_receive_timeout: int = 300,
  54. client_name: Optional[str] = None,
  55. verify: bool = True,
  56. ca_cert: Optional[str] = None,
  57. client_cert: Optional[str] = None,
  58. client_cert_key: Optional[str] = None,
  59. session_id: Optional[str] = None,
  60. settings: Optional[Dict[str, Any]] = None,
  61. pool_mgr: Optional[PoolManager] = None,
  62. http_proxy: Optional[str] = None,
  63. https_proxy: Optional[str] = None,
  64. server_host_name: Optional[str] = None,
  65. apply_server_timezone: Optional[Union[str, bool]] = None):
  66. """
  67. Create an HTTP ClickHouse Connect client
  68. See clickhouse_connect.get_client for parameters
  69. """
  70. self.url = f'{interface}://{host}:{port}'
  71. self.headers = {}
  72. ch_settings = dict_copy(settings, self.params)
  73. self.http = pool_mgr
  74. if interface == 'https':
  75. if not https_proxy:
  76. https_proxy = check_env_proxy('https', host, port)
  77. if client_cert:
  78. if not username:
  79. raise ProgrammingError('username parameter is required for Mutual TLS authentication')
  80. self.headers['X-ClickHouse-User'] = username
  81. self.headers['X-ClickHouse-SSL-Certificate-Auth'] = 'on'
  82. verify = coerce_bool(verify)
  83. # pylint: disable=too-many-boolean-expressions
  84. if not self.http and (server_host_name or ca_cert or client_cert or not verify or https_proxy):
  85. options = {
  86. 'ca_cert': ca_cert,
  87. 'client_cert': client_cert,
  88. 'verify': verify,
  89. 'client_cert_key': client_cert_key
  90. }
  91. if server_host_name:
  92. if verify:
  93. options['assert_hostname'] = server_host_name
  94. options['server_hostname'] = server_host_name
  95. self.http = get_pool_manager(https_proxy=https_proxy, **options)
  96. self._owns_pool_manager = True
  97. if not self.http:
  98. if not http_proxy:
  99. http_proxy = check_env_proxy('http', host, port)
  100. if http_proxy:
  101. self.http = get_proxy_manager(host, http_proxy)
  102. else:
  103. self.http = default_pool_manager()
  104. if not client_cert and username:
  105. self.headers['Authorization'] = 'Basic ' + b64encode(f'{username}:{password}'.encode()).decode()
  106. self.headers['User-Agent'] = common.build_client_name(client_name)
  107. self._read_format = self._write_format = 'Native'
  108. self._transform = NativeTransform()
  109. # There is use cases when client need to disable timeouts.
  110. if connect_timeout is not None:
  111. connect_timeout = coerce_int(connect_timeout)
  112. if send_receive_timeout is not None:
  113. send_receive_timeout = coerce_int(send_receive_timeout)
  114. self.timeout = Timeout(connect=connect_timeout, read=send_receive_timeout)
  115. self.http_retries = 1
  116. self._send_progress = None
  117. self._send_comp_setting = False
  118. self._progress_interval = None
  119. self._active_session = None
  120. if session_id:
  121. ch_settings['session_id'] = session_id
  122. elif 'session_id' not in ch_settings and common.get_setting('autogenerate_session_id'):
  123. ch_settings['session_id'] = str(uuid.uuid4())
  124. if coerce_bool(compress):
  125. compression = ','.join(available_compression)
  126. self.write_compression = available_compression[0]
  127. elif compress and compress not in ('False', 'false', '0'):
  128. if compress not in available_compression:
  129. raise ProgrammingError(f'Unsupported compression method {compress}')
  130. compression = compress
  131. self.write_compression = compress
  132. else:
  133. compression = None
  134. super().__init__(database=database,
  135. uri=self.url,
  136. query_limit=query_limit,
  137. query_retries=query_retries,
  138. server_host_name=server_host_name,
  139. apply_server_timezone=apply_server_timezone)
  140. self.params = self._validate_settings(ch_settings)
  141. comp_setting = self._setting_status('enable_http_compression')
  142. self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable
  143. if comp_setting.is_set or comp_setting.is_writable:
  144. self.compression = compression
  145. send_setting = self._setting_status('send_progress_in_http_headers')
  146. self._send_progress = not send_setting.is_set and send_setting.is_writable
  147. if (send_setting.is_set or send_setting.is_writable) and \
  148. self._setting_status('http_headers_progress_interval_ms').is_writable:
  149. self._progress_interval = str(min(120000, max(10000, (send_receive_timeout - 5) * 1000)))
  150. def set_client_setting(self, key, value):
  151. str_value = self._validate_setting(key, value, common.get_setting('invalid_setting_action'))
  152. if str_value is not None:
  153. self.params[key] = str_value
  154. def get_client_setting(self, key) -> Optional[str]:
  155. values = self.params.get(key)
  156. return values[0] if values else None
  157. def _prep_query(self, context: QueryContext):
  158. final_query = super()._prep_query(context)
  159. if context.is_insert:
  160. return final_query
  161. return f'{final_query}\n FORMAT {self._read_format}'
  162. def _query_with_context(self, context: QueryContext) -> QueryResult:
  163. headers = {}
  164. params = {}
  165. if self.database:
  166. params['database'] = self.database
  167. if self.protocol_version:
  168. params['client_protocol_version'] = self.protocol_version
  169. context.block_info = True
  170. params.update(context.bind_params)
  171. params.update(self._validate_settings(context.settings))
  172. if columns_only_re.search(context.uncommented_query):
  173. response = self._raw_request(f'{context.final_query}\n FORMAT JSON',
  174. params, headers, retries=self.query_retries)
  175. json_result = json.loads(response.data)
  176. # ClickHouse will respond with a JSON object of meta, data, and some other objects
  177. # We just grab the column names and column types from the metadata sub object
  178. names: List[str] = []
  179. types: List[ClickHouseType] = []
  180. for col in json_result['meta']:
  181. names.append(col['name'])
  182. types.append(registry.get_from_name(col['type']))
  183. return QueryResult([], None, tuple(names), tuple(types))
  184. if self.compression:
  185. headers['Accept-Encoding'] = self.compression
  186. if self._send_comp_setting:
  187. params['enable_http_compression'] = '1'
  188. final_query = self._prep_query(context)
  189. if context.external_data:
  190. body = bytes()
  191. params['query'] = final_query
  192. params.update(context.external_data.query_params)
  193. fields = context.external_data.form_data
  194. else:
  195. body = final_query
  196. fields = None
  197. headers['Content-Type'] = 'text/plain; charset=utf-8'
  198. response = self._raw_request(body,
  199. params,
  200. headers,
  201. stream=True,
  202. retries=self.query_retries,
  203. fields=fields,
  204. server_wait=not context.streaming)
  205. byte_source = RespBuffCls(ResponseSource(response)) # pylint: disable=not-callable
  206. context.set_response_tz(self._check_tz_change(response.headers.get('X-ClickHouse-Timezone')))
  207. query_result = self._transform.parse_response(byte_source, context)
  208. query_result.summary = self._summary(response)
  209. return query_result
  210. def data_insert(self, context: InsertContext) -> QuerySummary:
  211. """
  212. See BaseClient doc_string for this method
  213. """
  214. if context.empty:
  215. logger.debug('No data included in insert, skipping')
  216. return QuerySummary()
  217. def error_handler(resp: HTTPResponse):
  218. # If we actually had a local exception when building the insert, throw that instead
  219. if context.insert_exception:
  220. ex = context.insert_exception
  221. context.insert_exception = None
  222. raise ex
  223. self._error_handler(resp)
  224. headers = {'Content-Type': 'application/octet-stream'}
  225. if context.compression is None:
  226. context.compression = self.write_compression
  227. if context.compression:
  228. headers['Content-Encoding'] = context.compression
  229. block_gen = self._transform.build_insert(context)
  230. params = {}
  231. if self.database:
  232. params['database'] = self.database
  233. params.update(self._validate_settings(context.settings))
  234. response = self._raw_request(block_gen, params, headers, error_handler=error_handler, server_wait=False)
  235. logger.debug('Context insert response code: %d, content: %s', response.status, response.data)
  236. context.data = None
  237. return QuerySummary(self._summary(response))
  238. def raw_insert(self, table: str = None,
  239. column_names: Optional[Sequence[str]] = None,
  240. insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None,
  241. settings: Optional[Dict] = None,
  242. fmt: Optional[str] = None,
  243. compression: Optional[str] = None) -> QuerySummary:
  244. """
  245. See BaseClient doc_string for this method
  246. """
  247. params = {}
  248. headers = {'Content-Type': 'application/octet-stream'}
  249. if compression:
  250. headers['Content-Encoding'] = compression
  251. if table:
  252. cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else ''
  253. query = f'INSERT INTO {table}{cols} FORMAT {fmt if fmt else self._write_format}'
  254. if not compression and isinstance(insert_block, str):
  255. insert_block = query + '\n' + insert_block
  256. elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)):
  257. insert_block = (query + '\n').encode() + insert_block
  258. else:
  259. params['query'] = query
  260. if self.database:
  261. params['database'] = self.database
  262. params.update(self._validate_settings(settings or {}))
  263. response = self._raw_request(insert_block, params, headers, server_wait=False)
  264. logger.debug('Raw insert response code: %d, content: %s', response.status, response.data)
  265. return QuerySummary(self._summary(response))
  266. @staticmethod
  267. def _summary(response: HTTPResponse):
  268. summary = {}
  269. if 'X-ClickHouse-Summary' in response.headers:
  270. try:
  271. summary = json.loads(response.headers['X-ClickHouse-Summary'])
  272. except json.JSONDecodeError:
  273. pass
  274. summary['query_id'] = response.headers.get('X-ClickHouse-Query-Id', '')
  275. return summary
  276. def command(self,
  277. cmd,
  278. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  279. data: Union[str, bytes] = None,
  280. settings: Optional[Dict] = None,
  281. use_database: int = True,
  282. external_data: Optional[ExternalData] = None) -> Union[str, int, Sequence[str], QuerySummary]:
  283. """
  284. See BaseClient doc_string for this method
  285. """
  286. cmd, params = bind_query(cmd, parameters, self.server_tz)
  287. headers = {}
  288. payload = None
  289. fields = None
  290. if external_data:
  291. if data:
  292. raise ProgrammingError('Cannot combine command data with external data') from None
  293. fields = external_data.form_data
  294. params.update(external_data.query_params)
  295. elif isinstance(data, str):
  296. headers['Content-Type'] = 'text/plain; charset=utf-8'
  297. payload = data.encode()
  298. elif isinstance(data, bytes):
  299. headers['Content-Type'] = 'application/octet-stream'
  300. payload = data
  301. if payload is None and not cmd:
  302. raise ProgrammingError('Command sent without query or recognized data') from None
  303. if payload or fields:
  304. params['query'] = cmd
  305. else:
  306. payload = cmd
  307. if use_database and self.database:
  308. params['database'] = self.database
  309. params.update(self._validate_settings(settings or {}))
  310. method = 'POST' if payload or fields else 'GET'
  311. response = self._raw_request(payload, params, headers, method, fields=fields)
  312. if response.data:
  313. try:
  314. result = response.data.decode()[:-1].split('\t')
  315. if len(result) == 1:
  316. try:
  317. return int(result[0])
  318. except ValueError:
  319. return result[0]
  320. return result
  321. except UnicodeDecodeError:
  322. return str(response.data)
  323. return QuerySummary(self._summary(response))
  324. def _error_handler(self, response: HTTPResponse, retried: bool = False) -> None:
  325. err_str = f'HTTPDriver for {self.url} returned response code {response.status})'
  326. try:
  327. err_content = get_response_data(response)
  328. except Exception: # pylint: disable=broad-except
  329. err_content = None
  330. finally:
  331. response.close()
  332. if err_content:
  333. err_msg = common.format_error(err_content.decode(errors='backslashreplace'))
  334. err_str = f':{err_str}\n {err_msg}'
  335. raise OperationalError(err_str) if retried else DatabaseError(err_str) from None
  336. def _raw_request(self,
  337. data,
  338. params: Dict[str, str],
  339. headers: Optional[Dict[str, Any]] = None,
  340. method: str = 'POST',
  341. retries: int = 0,
  342. stream: bool = False,
  343. server_wait: bool = True,
  344. fields: Optional[Dict[str, tuple]] = None,
  345. error_handler: Callable = None) -> HTTPResponse:
  346. if isinstance(data, str):
  347. data = data.encode()
  348. headers = dict_copy(self.headers, headers)
  349. attempts = 0
  350. if server_wait:
  351. params['wait_end_of_query'] = '1'
  352. # We can't actually read the progress headers, but we enable them so ClickHouse sends something
  353. # to keep the connection alive when waiting for long-running queries and (2) to get summary information
  354. # if not streaming
  355. if self._send_progress:
  356. params['send_progress_in_http_headers'] = '1'
  357. if self._progress_interval:
  358. params['http_headers_progress_interval_ms'] = self._progress_interval
  359. final_params = dict_copy(self.params, params)
  360. url = f'{self.url}?{urlencode(final_params)}'
  361. kwargs = {
  362. 'headers': headers,
  363. 'timeout': self.timeout,
  364. 'retries': self.http_retries,
  365. 'preload_content': not stream
  366. }
  367. if self.server_host_name:
  368. kwargs['assert_same_host'] = False
  369. kwargs['headers'].update({'Host': self.server_host_name})
  370. if fields:
  371. kwargs['fields'] = fields
  372. else:
  373. kwargs['body'] = data
  374. check_conn_reset(self.http)
  375. query_session = final_params.get('session_id')
  376. while True:
  377. attempts += 1
  378. if query_session:
  379. if query_session == self._active_session:
  380. raise ProgrammingError('Attempt to execute concurrent queries within the same session.' +
  381. 'Please use a separate client instance per thread/process.')
  382. # There is a race condition here when using multiprocessing -- in that case the server will
  383. # throw an error instead, but in most cases this more helpful error will be thrown first
  384. self._active_session = query_session
  385. try:
  386. response = self.http.request(method, url, **kwargs)
  387. except HTTPError as ex:
  388. if isinstance(ex.__context__, ConnectionResetError):
  389. # The server closed the connection, probably because the Keep Alive has expired
  390. # We should be safe to retry, as ClickHouse should not have processed anything on a connection
  391. # that it killed. We also only retry this once, as multiple disconnects are unlikely to be
  392. # related to the Keep Alive settings
  393. if attempts == 1:
  394. logger.debug('Retrying remotely closed connection')
  395. continue
  396. logger.warning('Unexpected Http Driver Exception')
  397. raise OperationalError(f'Error {ex} executing HTTP request attempt {attempts} {self.url}') from ex
  398. finally:
  399. if query_session:
  400. self._active_session = None # Make sure we always clear this
  401. if 200 <= response.status < 300:
  402. return response
  403. if response.status in (429, 503, 504):
  404. if attempts > retries:
  405. self._error_handler(response, True)
  406. logger.debug('Retrying requests with status code %d', response.status)
  407. elif error_handler:
  408. error_handler(response)
  409. else:
  410. self._error_handler(response)
  411. def raw_query(self, query: str,
  412. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  413. settings: Optional[Dict[str, Any]] = None,
  414. fmt: str = None,
  415. use_database: bool = True,
  416. external_data: Optional[ExternalData] = None) -> bytes:
  417. """
  418. See BaseClient doc_string for this method
  419. """
  420. body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
  421. return self._raw_request(body, params, fields=fields).data
  422. def raw_stream(self, query: str,
  423. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  424. settings: Optional[Dict[str, Any]] = None,
  425. fmt: str = None,
  426. use_database: bool = True,
  427. external_data: Optional[ExternalData] = None) -> io.IOBase:
  428. """
  429. See BaseClient doc_string for this method
  430. """
  431. body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
  432. return self._raw_request(body, params, fields=fields, stream=True)
  433. def _prep_raw_query(self, query: str,
  434. parameters: Optional[Union[Sequence, Dict[str, Any]]],
  435. settings: Optional[Dict[str, Any]],
  436. fmt: str,
  437. use_database: bool,
  438. external_data: Optional[ExternalData]):
  439. final_query, bind_params = bind_query(query, parameters, self.server_tz)
  440. if fmt:
  441. final_query += f'\n FORMAT {fmt}'
  442. params = self._validate_settings(settings or {})
  443. if use_database and self.database:
  444. params['database'] = self.database
  445. params.update(bind_params)
  446. if external_data:
  447. body = bytes()
  448. params['query'] = final_query
  449. params.update(external_data.query_params)
  450. fields = external_data.form_data
  451. else:
  452. body = final_query
  453. fields = None
  454. return body, params, fields
  455. def ping(self):
  456. """
  457. See BaseClient doc_string for this method
  458. """
  459. try:
  460. response = self.http.request('GET', f'{self.url}/ping', timeout=3)
  461. return 200 <= response.status < 300
  462. except HTTPError:
  463. logger.debug('ping failed', exc_info=True)
  464. return False
  465. def close(self):
  466. if self._owns_pool_manager:
  467. self.http.clear()
  468. all_managers.pop(self.http, None)