httpclient.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. import io
  2. import json
  3. import logging
  4. import re
  5. import uuid
  6. from base64 import b64encode
  7. from typing import Optional, Dict, Any, Sequence, Union, List, Callable, Generator, BinaryIO
  8. from urllib.parse import urlencode
  9. from urllib3 import Timeout
  10. from urllib3.exceptions import HTTPError
  11. from urllib3.poolmanager import PoolManager
  12. from urllib3.response import HTTPResponse
  13. from clickhouse_connect import common
  14. from clickhouse_connect.datatypes import registry
  15. from clickhouse_connect.datatypes.base import ClickHouseType
  16. from clickhouse_connect.driver.client import Client
  17. from clickhouse_connect.driver.common import dict_copy, coerce_bool, coerce_int, dict_add
  18. from clickhouse_connect.driver.compression import available_compression
  19. from clickhouse_connect.driver.ctypes import RespBuffCls
  20. from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError, ProgrammingError
  21. from clickhouse_connect.driver.external import ExternalData
  22. from clickhouse_connect.driver.httputil import ResponseSource, get_pool_manager, get_response_data, \
  23. default_pool_manager, get_proxy_manager, all_managers, check_env_proxy, check_conn_expiration
  24. from clickhouse_connect.driver.insert import InsertContext
  25. from clickhouse_connect.driver.query import QueryResult, QueryContext
  26. from clickhouse_connect.driver.binding import quote_identifier, bind_query
  27. from clickhouse_connect.driver.summary import QuerySummary
  28. from clickhouse_connect.driver.transform import NativeTransform
  29. logger = logging.getLogger(__name__)
  30. columns_only_re = re.compile(r'LIMIT 0\s*$', re.IGNORECASE)
  31. ex_header = 'X-ClickHouse-Exception-Code'
  32. # pylint: disable=too-many-instance-attributes
  33. class HttpClient(Client):
  34. params = {}
  35. valid_transport_settings = {'database', 'buffer_size', 'session_id',
  36. 'compress', 'decompress', 'session_timeout',
  37. 'session_check', 'query_id', 'quota_key',
  38. 'wait_end_of_query', 'client_protocol_version'}
  39. optional_transport_settings = {'send_progress_in_http_headers',
  40. 'http_headers_progress_interval_ms',
  41. 'enable_http_compression'}
  42. _owns_pool_manager = False
  43. # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals,too-many-branches,too-many-statements,unused-argument
  44. def __init__(self,
  45. interface: str,
  46. host: str,
  47. port: int,
  48. username: str,
  49. password: str,
  50. database: str,
  51. compress: Union[bool, str] = True,
  52. query_limit: int = 0,
  53. query_retries: int = 2,
  54. connect_timeout: int = 10,
  55. send_receive_timeout: int = 300,
  56. client_name: Optional[str] = None,
  57. verify: Union[bool, str] = True,
  58. ca_cert: Optional[str] = None,
  59. client_cert: Optional[str] = None,
  60. client_cert_key: Optional[str] = None,
  61. session_id: Optional[str] = None,
  62. settings: Optional[Dict[str, Any]] = None,
  63. pool_mgr: Optional[PoolManager] = None,
  64. http_proxy: Optional[str] = None,
  65. https_proxy: Optional[str] = None,
  66. server_host_name: Optional[str] = None,
  67. apply_server_timezone: Optional[Union[str, bool]] = None,
  68. show_clickhouse_errors: Optional[bool] = None,
  69. autogenerate_session_id: Optional[bool] = None,
  70. tls_mode: Optional[str] = None):
  71. """
  72. Create an HTTP ClickHouse Connect client
  73. See clickhouse_connect.get_client for parameters
  74. """
  75. self.url = f'{interface}://{host}:{port}'
  76. self.headers = {}
  77. self.params = dict_copy(HttpClient.params)
  78. ch_settings = dict_copy(settings, self.params)
  79. self.http = pool_mgr
  80. if interface == 'https':
  81. if isinstance(verify, str) and verify.lower() == 'proxy':
  82. verify = True
  83. tls_mode = tls_mode or 'proxy'
  84. if not https_proxy:
  85. https_proxy = check_env_proxy('https', host, port)
  86. verify = coerce_bool(verify)
  87. if client_cert and (tls_mode is None or tls_mode == 'mutual'):
  88. if not username:
  89. raise ProgrammingError('username parameter is required for Mutual TLS authentication')
  90. self.headers['X-ClickHouse-User'] = username
  91. self.headers['X-ClickHouse-SSL-Certificate-Auth'] = 'on'
  92. # pylint: disable=too-many-boolean-expressions
  93. if not self.http and (server_host_name or ca_cert or client_cert or not verify or https_proxy):
  94. options = {'verify': verify}
  95. dict_add(options, 'ca_cert', ca_cert)
  96. dict_add(options, 'client_cert', client_cert)
  97. dict_add(options, 'client_cert_key', client_cert_key)
  98. if server_host_name:
  99. if options['verify']:
  100. options['assert_hostname'] = server_host_name
  101. options['server_hostname'] = server_host_name
  102. self.http = get_pool_manager(https_proxy=https_proxy, **options)
  103. self._owns_pool_manager = True
  104. if not self.http:
  105. if not http_proxy:
  106. http_proxy = check_env_proxy('http', host, port)
  107. if http_proxy:
  108. self.http = get_proxy_manager(host, http_proxy)
  109. else:
  110. self.http = default_pool_manager()
  111. if (not client_cert or tls_mode in ('strict', 'proxy')) and username:
  112. self.headers['Authorization'] = 'Basic ' + b64encode(f'{username}:{password}'.encode()).decode()
  113. self.headers['User-Agent'] = common.build_client_name(client_name)
  114. self._read_format = self._write_format = 'Native'
  115. self._transform = NativeTransform()
  116. # There are use cases when the client needs to disable timeouts.
  117. if connect_timeout is not None:
  118. connect_timeout = coerce_int(connect_timeout)
  119. if send_receive_timeout is not None:
  120. send_receive_timeout = coerce_int(send_receive_timeout)
  121. self.timeout = Timeout(connect=connect_timeout, read=send_receive_timeout)
  122. self.http_retries = 1
  123. self._send_progress = None
  124. self._send_comp_setting = False
  125. self._progress_interval = None
  126. self._active_session = None
  127. # allow to override the global autogenerate_session_id setting via the constructor params
  128. _autogenerate_session_id = common.get_setting('autogenerate_session_id') \
  129. if autogenerate_session_id is None \
  130. else autogenerate_session_id
  131. if session_id:
  132. ch_settings['session_id'] = session_id
  133. elif 'session_id' not in ch_settings and _autogenerate_session_id:
  134. ch_settings['session_id'] = str(uuid.uuid4())
  135. if coerce_bool(compress):
  136. compression = ','.join(available_compression)
  137. self.write_compression = available_compression[0]
  138. elif compress and compress not in ('False', 'false', '0'):
  139. if compress not in available_compression:
  140. raise ProgrammingError(f'Unsupported compression method {compress}')
  141. compression = compress
  142. self.write_compression = compress
  143. else:
  144. compression = None
  145. super().__init__(database=database,
  146. uri=self.url,
  147. query_limit=query_limit,
  148. query_retries=query_retries,
  149. server_host_name=server_host_name,
  150. apply_server_timezone=apply_server_timezone,
  151. show_clickhouse_errors=show_clickhouse_errors)
  152. self.params = dict_copy(self.params, self._validate_settings(ch_settings))
  153. comp_setting = self._setting_status('enable_http_compression')
  154. self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable
  155. if comp_setting.is_set or comp_setting.is_writable:
  156. self.compression = compression
  157. send_setting = self._setting_status('send_progress_in_http_headers')
  158. self._send_progress = not send_setting.is_set and send_setting.is_writable
  159. if (send_setting.is_set or send_setting.is_writable) and \
  160. self._setting_status('http_headers_progress_interval_ms').is_writable:
  161. self._progress_interval = str(min(120000, max(10000, (send_receive_timeout - 5) * 1000)))
  162. def set_client_setting(self, key, value):
  163. str_value = self._validate_setting(key, value, common.get_setting('invalid_setting_action'))
  164. if str_value is not None:
  165. self.params[key] = str_value
  166. def get_client_setting(self, key) -> Optional[str]:
  167. return self.params.get(key)
  168. def _prep_query(self, context: QueryContext):
  169. final_query = super()._prep_query(context)
  170. if context.is_insert:
  171. return final_query
  172. fmt = f'\n FORMAT {self._read_format}'
  173. if isinstance(final_query, bytes):
  174. return final_query + fmt.encode()
  175. return final_query + fmt
  176. def _query_with_context(self, context: QueryContext) -> QueryResult:
  177. headers = {}
  178. params = {}
  179. if self.database:
  180. params['database'] = self.database
  181. if self.protocol_version:
  182. params['client_protocol_version'] = self.protocol_version
  183. context.block_info = True
  184. params.update(context.bind_params)
  185. params.update(self._validate_settings(context.settings))
  186. if not context.is_insert and columns_only_re.search(context.uncommented_query):
  187. response = self._raw_request(f'{context.final_query}\n FORMAT JSON',
  188. params, headers, retries=self.query_retries)
  189. json_result = json.loads(response.data)
  190. # ClickHouse will respond with a JSON object of meta, data, and some other objects
  191. # We just grab the column names and column types from the metadata sub object
  192. names: List[str] = []
  193. types: List[ClickHouseType] = []
  194. for col in json_result['meta']:
  195. names.append(col['name'])
  196. types.append(registry.get_from_name(col['type']))
  197. return QueryResult([], None, tuple(names), tuple(types))
  198. if self.compression:
  199. headers['Accept-Encoding'] = self.compression
  200. if self._send_comp_setting:
  201. params['enable_http_compression'] = '1'
  202. final_query = self._prep_query(context)
  203. if context.external_data:
  204. body = bytes()
  205. params['query'] = final_query
  206. params.update(context.external_data.query_params)
  207. fields = context.external_data.form_data
  208. else:
  209. body = final_query
  210. fields = None
  211. headers['Content-Type'] = 'text/plain; charset=utf-8'
  212. response = self._raw_request(body,
  213. params,
  214. headers,
  215. stream=True,
  216. retries=self.query_retries,
  217. fields=fields,
  218. server_wait=not context.streaming)
  219. byte_source = RespBuffCls(ResponseSource(response)) # pylint: disable=not-callable
  220. context.set_response_tz(self._check_tz_change(response.headers.get('X-ClickHouse-Timezone')))
  221. query_result = self._transform.parse_response(byte_source, context)
  222. query_result.summary = self._summary(response)
  223. return query_result
  224. def data_insert(self, context: InsertContext) -> QuerySummary:
  225. """
  226. See BaseClient doc_string for this method
  227. """
  228. if context.empty:
  229. logger.debug('No data included in insert, skipping')
  230. return QuerySummary()
  231. def error_handler(resp: HTTPResponse):
  232. # If we actually had a local exception when building the insert, throw that instead
  233. if context.insert_exception:
  234. ex = context.insert_exception
  235. context.insert_exception = None
  236. raise ex
  237. self._error_handler(resp)
  238. headers = {'Content-Type': 'application/octet-stream'}
  239. if context.compression is None:
  240. context.compression = self.write_compression
  241. if context.compression:
  242. headers['Content-Encoding'] = context.compression
  243. block_gen = self._transform.build_insert(context)
  244. params = {}
  245. if self.database:
  246. params['database'] = self.database
  247. params.update(self._validate_settings(context.settings))
  248. response = self._raw_request(block_gen, params, headers, error_handler=error_handler, server_wait=False)
  249. logger.debug('Context insert response code: %d, content: %s', response.status, response.data)
  250. context.data = None
  251. return QuerySummary(self._summary(response))
  252. def raw_insert(self, table: str = None,
  253. column_names: Optional[Sequence[str]] = None,
  254. insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None,
  255. settings: Optional[Dict] = None,
  256. fmt: Optional[str] = None,
  257. compression: Optional[str] = None) -> QuerySummary:
  258. """
  259. See BaseClient doc_string for this method
  260. """
  261. params = {}
  262. headers = {'Content-Type': 'application/octet-stream'}
  263. if compression:
  264. headers['Content-Encoding'] = compression
  265. if table:
  266. cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else ''
  267. query = f'INSERT INTO {table}{cols} FORMAT {fmt if fmt else self._write_format}'
  268. if not compression and isinstance(insert_block, str):
  269. insert_block = query + '\n' + insert_block
  270. elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)):
  271. insert_block = (query + '\n').encode() + insert_block
  272. else:
  273. params['query'] = query
  274. if self.database:
  275. params['database'] = self.database
  276. params.update(self._validate_settings(settings or {}))
  277. response = self._raw_request(insert_block, params, headers, server_wait=False)
  278. logger.debug('Raw insert response code: %d, content: %s', response.status, response.data)
  279. return QuerySummary(self._summary(response))
  280. @staticmethod
  281. def _summary(response: HTTPResponse):
  282. summary = {}
  283. if 'X-ClickHouse-Summary' in response.headers:
  284. try:
  285. summary = json.loads(response.headers['X-ClickHouse-Summary'])
  286. except json.JSONDecodeError:
  287. pass
  288. summary['query_id'] = response.headers.get('X-ClickHouse-Query-Id', '')
  289. return summary
  290. def command(self,
  291. cmd,
  292. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  293. data: Union[str, bytes] = None,
  294. settings: Optional[Dict] = None,
  295. use_database: int = True,
  296. external_data: Optional[ExternalData] = None) -> Union[str, int, Sequence[str], QuerySummary]:
  297. """
  298. See BaseClient doc_string for this method
  299. """
  300. cmd, params = bind_query(cmd, parameters, self.server_tz)
  301. headers = {}
  302. payload = None
  303. fields = None
  304. if external_data:
  305. if data:
  306. raise ProgrammingError('Cannot combine command data with external data') from None
  307. fields = external_data.form_data
  308. params.update(external_data.query_params)
  309. elif isinstance(data, str):
  310. headers['Content-Type'] = 'text/plain; charset=utf-8'
  311. payload = data.encode()
  312. elif isinstance(data, bytes):
  313. headers['Content-Type'] = 'application/octet-stream'
  314. payload = data
  315. if payload is None and not cmd:
  316. raise ProgrammingError('Command sent without query or recognized data') from None
  317. if payload or fields:
  318. params['query'] = cmd
  319. else:
  320. payload = cmd
  321. if use_database and self.database:
  322. params['database'] = self.database
  323. params.update(self._validate_settings(settings or {}))
  324. method = 'POST' if payload or fields else 'GET'
  325. response = self._raw_request(payload, params, headers, method, fields=fields, server_wait=False)
  326. if response.data:
  327. try:
  328. result = response.data.decode()[:-1].split('\t')
  329. if len(result) == 1:
  330. try:
  331. return int(result[0])
  332. except ValueError:
  333. return result[0]
  334. return result
  335. except UnicodeDecodeError:
  336. return str(response.data)
  337. return QuerySummary(self._summary(response))
  338. def _error_handler(self, response: HTTPResponse, retried: bool = False) -> None:
  339. if self.show_clickhouse_errors:
  340. try:
  341. err_content = get_response_data(response)
  342. except Exception: # pylint: disable=broad-except
  343. err_content = None
  344. finally:
  345. response.close()
  346. err_str = f'HTTPDriver for {self.url} returned response code {response.status}'
  347. err_code = response.headers.get(ex_header)
  348. if err_code:
  349. err_str = f'HTTPDriver for {self.url} received ClickHouse error code {err_code}'
  350. if err_content:
  351. err_msg = common.format_error(err_content.decode(errors='backslashreplace'))
  352. if err_msg.startswith('Code'):
  353. err_str = f'{err_str}\n {err_msg}'
  354. else:
  355. err_str = 'The ClickHouse server returned an error.'
  356. raise OperationalError(err_str) if retried else DatabaseError(err_str) from None
  357. def _raw_request(self,
  358. data,
  359. params: Dict[str, str],
  360. headers: Optional[Dict[str, Any]] = None,
  361. method: str = 'POST',
  362. retries: int = 0,
  363. stream: bool = False,
  364. server_wait: bool = True,
  365. fields: Optional[Dict[str, tuple]] = None,
  366. error_handler: Callable = None) -> HTTPResponse:
  367. if isinstance(data, str):
  368. data = data.encode()
  369. headers = dict_copy(self.headers, headers)
  370. attempts = 0
  371. if server_wait:
  372. params['wait_end_of_query'] = '1'
  373. # We can't actually read the progress headers, but we enable them so ClickHouse sends something
  374. # to keep the connection alive when waiting for long-running queries and (2) to get summary information
  375. # if not streaming
  376. if self._send_progress:
  377. params['send_progress_in_http_headers'] = '1'
  378. if self._progress_interval:
  379. params['http_headers_progress_interval_ms'] = self._progress_interval
  380. final_params = dict_copy(self.params, params)
  381. url = f'{self.url}?{urlencode(final_params)}'
  382. kwargs = {
  383. 'headers': headers,
  384. 'timeout': self.timeout,
  385. 'retries': self.http_retries,
  386. 'preload_content': not stream
  387. }
  388. if self.server_host_name:
  389. kwargs['assert_same_host'] = False
  390. kwargs['headers'].update({'Host': self.server_host_name})
  391. if fields:
  392. kwargs['fields'] = fields
  393. else:
  394. kwargs['body'] = data
  395. check_conn_expiration(self.http)
  396. query_session = final_params.get('session_id')
  397. while True:
  398. attempts += 1
  399. if query_session:
  400. if query_session == self._active_session:
  401. raise ProgrammingError('Attempt to execute concurrent queries within the same session.' +
  402. 'Please use a separate client instance per thread/process.')
  403. # There is a race condition here when using multiprocessing -- in that case the server will
  404. # throw an error instead, but in most cases this more helpful error will be thrown first
  405. self._active_session = query_session
  406. try:
  407. response = self.http.request(method, url, **kwargs)
  408. except HTTPError as ex:
  409. if isinstance(ex.__context__, ConnectionResetError):
  410. # The server closed the connection, probably because the Keep Alive has expired
  411. # We should be safe to retry, as ClickHouse should not have processed anything on a connection
  412. # that it killed. We also only retry this once, as multiple disconnects are unlikely to be
  413. # related to the Keep Alive settings
  414. if attempts == 1:
  415. logger.debug('Retrying remotely closed connection')
  416. continue
  417. logger.warning('Unexpected Http Driver Exception')
  418. err_url = f' ({self.url})' if self.show_clickhouse_errors else ''
  419. raise OperationalError(f'Error {ex} executing HTTP request attempt {attempts}{err_url}') from ex
  420. finally:
  421. if query_session:
  422. self._active_session = None # Make sure we always clear this
  423. if 200 <= response.status < 300 and not response.headers.get(ex_header):
  424. return response
  425. if response.status in (429, 503, 504):
  426. if attempts > retries:
  427. self._error_handler(response, True)
  428. logger.debug('Retrying requests with status code %d', response.status)
  429. elif error_handler:
  430. error_handler(response)
  431. else:
  432. self._error_handler(response)
  433. def raw_query(self, query: str,
  434. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  435. settings: Optional[Dict[str, Any]] = None,
  436. fmt: str = None,
  437. use_database: bool = True,
  438. external_data: Optional[ExternalData] = None) -> bytes:
  439. """
  440. See BaseClient doc_string for this method
  441. """
  442. body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
  443. return self._raw_request(body, params, fields=fields).data
  444. def raw_stream(self, query: str,
  445. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  446. settings: Optional[Dict[str, Any]] = None,
  447. fmt: str = None,
  448. use_database: bool = True,
  449. external_data: Optional[ExternalData] = None) -> io.IOBase:
  450. """
  451. See BaseClient doc_string for this method
  452. """
  453. body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
  454. return self._raw_request(body, params, fields=fields, stream=True, server_wait=False)
  455. def _prep_raw_query(self, query: str,
  456. parameters: Optional[Union[Sequence, Dict[str, Any]]],
  457. settings: Optional[Dict[str, Any]],
  458. fmt: str,
  459. use_database: bool,
  460. external_data: Optional[ExternalData]):
  461. if fmt:
  462. query += f'\n FORMAT {fmt}'
  463. final_query, bind_params = bind_query(query, parameters, self.server_tz)
  464. params = self._validate_settings(settings or {})
  465. if use_database and self.database:
  466. params['database'] = self.database
  467. params.update(bind_params)
  468. if external_data:
  469. if isinstance(final_query, bytes):
  470. raise ProgrammingError('Cannot combine binary query data with `External Data`')
  471. body = bytes()
  472. params['query'] = final_query
  473. params.update(external_data.query_params)
  474. fields = external_data.form_data
  475. else:
  476. body = final_query
  477. fields = None
  478. return body, params, fields
  479. def ping(self):
  480. """
  481. See BaseClient doc_string for this method
  482. """
  483. try:
  484. response = self.http.request('GET', f'{self.url}/ping', timeout=3)
  485. return 200 <= response.status < 300
  486. except HTTPError:
  487. logger.debug('ping failed', exc_info=True)
  488. return False
  489. def close_connections(self):
  490. self.http.clear()
  491. def close(self):
  492. if self._owns_pool_manager:
  493. self.http.clear()
  494. all_managers.pop(self.http, None)