httpclient.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. import io
  2. import json
  3. import logging
  4. import re
  5. import uuid
  6. from base64 import b64encode
  7. from typing import Optional, Dict, Any, Sequence, Union, List, Callable, Generator, BinaryIO
  8. from urllib.parse import urlencode
  9. from urllib3 import Timeout
  10. from urllib3.exceptions import HTTPError
  11. from urllib3.poolmanager import PoolManager
  12. from urllib3.response import HTTPResponse
  13. from clickhouse_connect import common
  14. from clickhouse_connect.datatypes import registry
  15. from clickhouse_connect.datatypes.base import ClickHouseType
  16. from clickhouse_connect.driver.client import Client
  17. from clickhouse_connect.driver.common import dict_copy, coerce_bool, coerce_int, dict_add
  18. from clickhouse_connect.driver.compression import available_compression
  19. from clickhouse_connect.driver.ctypes import RespBuffCls
  20. from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError, ProgrammingError
  21. from clickhouse_connect.driver.external import ExternalData
  22. from clickhouse_connect.driver.httputil import ResponseSource, get_pool_manager, get_response_data, \
  23. default_pool_manager, get_proxy_manager, all_managers, check_env_proxy, check_conn_expiration
  24. from clickhouse_connect.driver.insert import InsertContext
  25. from clickhouse_connect.driver.query import QueryResult, QueryContext, quote_identifier, bind_query
  26. from clickhouse_connect.driver.summary import QuerySummary
  27. from clickhouse_connect.driver.transform import NativeTransform
  28. logger = logging.getLogger(__name__)
  29. columns_only_re = re.compile(r'LIMIT 0\s*$', re.IGNORECASE)
  30. # pylint: disable=too-many-instance-attributes
  31. class HttpClient(Client):
  32. params = {}
  33. valid_transport_settings = {'database', 'buffer_size', 'session_id',
  34. 'compress', 'decompress', 'session_timeout',
  35. 'session_check', 'query_id', 'quota_key',
  36. 'wait_end_of_query', 'client_protocol_version'}
  37. optional_transport_settings = {'send_progress_in_http_headers',
  38. 'http_headers_progress_interval_ms',
  39. 'enable_http_compression'}
  40. _owns_pool_manager = False
  41. # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements,unused-argument
  42. def __init__(self,
  43. interface: str,
  44. host: str,
  45. port: int,
  46. username: str,
  47. password: str,
  48. database: str,
  49. compress: Union[bool, str] = True,
  50. query_limit: int = 0,
  51. query_retries: int = 2,
  52. connect_timeout: int = 10,
  53. send_receive_timeout: int = 300,
  54. client_name: Optional[str] = None,
  55. verify: Union[bool, str] = True,
  56. ca_cert: Optional[str] = None,
  57. client_cert: Optional[str] = None,
  58. client_cert_key: Optional[str] = None,
  59. session_id: Optional[str] = None,
  60. settings: Optional[Dict[str, Any]] = None,
  61. pool_mgr: Optional[PoolManager] = None,
  62. http_proxy: Optional[str] = None,
  63. https_proxy: Optional[str] = None,
  64. server_host_name: Optional[str] = None,
  65. apply_server_timezone: Optional[Union[str, bool]] = None,
  66. show_clickhouse_errors: Optional[bool] = None,
  67. autogenerate_session_id: Optional[bool] = None):
  68. """
  69. Create an HTTP ClickHouse Connect client
  70. See clickhouse_connect.get_client for parameters
  71. """
  72. self.url = f'{interface}://{host}:{port}'
  73. self.headers = {}
  74. ch_settings = dict_copy(settings, self.params)
  75. self.http = pool_mgr
  76. if interface == 'https':
  77. if not https_proxy:
  78. https_proxy = check_env_proxy('https', host, port)
  79. if https_proxy and isinstance(verify, str) and verify.lower() == 'proxy':
  80. verify = 'proxy'
  81. else:
  82. verify = coerce_bool(verify)
  83. if client_cert and verify != 'proxy':
  84. if not username:
  85. raise ProgrammingError('username parameter is required for Mutual TLS authentication')
  86. self.headers['X-ClickHouse-User'] = username
  87. self.headers['X-ClickHouse-SSL-Certificate-Auth'] = 'on'
  88. # pylint: disable=too-many-boolean-expressions
  89. if not self.http and (server_host_name or ca_cert or client_cert or not verify or https_proxy):
  90. options = {'verify': verify is not False}
  91. dict_add(options, 'ca_cert', ca_cert)
  92. dict_add(options, 'client_cert', client_cert)
  93. dict_add(options, 'client_cert_key', client_cert_key)
  94. if server_host_name:
  95. if options['verify']:
  96. options['assert_hostname'] = server_host_name
  97. options['server_hostname'] = server_host_name
  98. self.http = get_pool_manager(https_proxy=https_proxy, **options)
  99. self._owns_pool_manager = True
  100. if not self.http:
  101. if not http_proxy:
  102. http_proxy = check_env_proxy('http', host, port)
  103. if http_proxy:
  104. self.http = get_proxy_manager(host, http_proxy)
  105. else:
  106. self.http = default_pool_manager()
  107. if (not client_cert or verify == 'proxy') and username:
  108. self.headers['Authorization'] = 'Basic ' + b64encode(f'{username}:{password}'.encode()).decode()
  109. self.headers['User-Agent'] = common.build_client_name(client_name)
  110. self._read_format = self._write_format = 'Native'
  111. self._transform = NativeTransform()
  112. # There are use cases when the client needs to disable timeouts.
  113. if connect_timeout is not None:
  114. connect_timeout = coerce_int(connect_timeout)
  115. if send_receive_timeout is not None:
  116. send_receive_timeout = coerce_int(send_receive_timeout)
  117. self.timeout = Timeout(connect=connect_timeout, read=send_receive_timeout)
  118. self.http_retries = 1
  119. self._send_progress = None
  120. self._send_comp_setting = False
  121. self._progress_interval = None
  122. self._active_session = None
  123. # allow to override the global autogenerate_session_id setting via the constructor params
  124. _autogenerate_session_id = common.get_setting('autogenerate_session_id') \
  125. if autogenerate_session_id is None \
  126. else autogenerate_session_id
  127. if session_id:
  128. ch_settings['session_id'] = session_id
  129. elif 'session_id' not in ch_settings and _autogenerate_session_id:
  130. ch_settings['session_id'] = str(uuid.uuid4())
  131. if coerce_bool(compress):
  132. compression = ','.join(available_compression)
  133. self.write_compression = available_compression[0]
  134. elif compress and compress not in ('False', 'false', '0'):
  135. if compress not in available_compression:
  136. raise ProgrammingError(f'Unsupported compression method {compress}')
  137. compression = compress
  138. self.write_compression = compress
  139. else:
  140. compression = None
  141. super().__init__(database=database,
  142. uri=self.url,
  143. query_limit=query_limit,
  144. query_retries=query_retries,
  145. server_host_name=server_host_name,
  146. apply_server_timezone=apply_server_timezone,
  147. show_clickhouse_errors=show_clickhouse_errors)
  148. self.params = self._validate_settings(ch_settings)
  149. comp_setting = self._setting_status('enable_http_compression')
  150. self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable
  151. if comp_setting.is_set or comp_setting.is_writable:
  152. self.compression = compression
  153. send_setting = self._setting_status('send_progress_in_http_headers')
  154. self._send_progress = not send_setting.is_set and send_setting.is_writable
  155. if (send_setting.is_set or send_setting.is_writable) and \
  156. self._setting_status('http_headers_progress_interval_ms').is_writable:
  157. self._progress_interval = str(min(120000, max(10000, (send_receive_timeout - 5) * 1000)))
  158. def set_client_setting(self, key, value):
  159. str_value = self._validate_setting(key, value, common.get_setting('invalid_setting_action'))
  160. if str_value is not None:
  161. self.params[key] = str_value
  162. def get_client_setting(self, key) -> Optional[str]:
  163. return self.params.get(key)
  164. def _prep_query(self, context: QueryContext):
  165. final_query = super()._prep_query(context)
  166. if context.is_insert:
  167. return final_query
  168. fmt = f'\n FORMAT {self._read_format}'
  169. if isinstance(final_query, bytes):
  170. return final_query + fmt.encode()
  171. return final_query + fmt
  172. def _query_with_context(self, context: QueryContext) -> QueryResult:
  173. headers = {}
  174. params = {}
  175. if self.database:
  176. params['database'] = self.database
  177. if self.protocol_version:
  178. params['client_protocol_version'] = self.protocol_version
  179. context.block_info = True
  180. params.update(context.bind_params)
  181. params.update(self._validate_settings(context.settings))
  182. if columns_only_re.search(context.uncommented_query):
  183. response = self._raw_request(f'{context.final_query}\n FORMAT JSON',
  184. params, headers, retries=self.query_retries)
  185. json_result = json.loads(response.data)
  186. # ClickHouse will respond with a JSON object of meta, data, and some other objects
  187. # We just grab the column names and column types from the metadata sub object
  188. names: List[str] = []
  189. types: List[ClickHouseType] = []
  190. for col in json_result['meta']:
  191. names.append(col['name'])
  192. types.append(registry.get_from_name(col['type']))
  193. return QueryResult([], None, tuple(names), tuple(types))
  194. if self.compression:
  195. headers['Accept-Encoding'] = self.compression
  196. if self._send_comp_setting:
  197. params['enable_http_compression'] = '1'
  198. final_query = self._prep_query(context)
  199. if context.external_data:
  200. body = bytes()
  201. params['query'] = final_query
  202. params.update(context.external_data.query_params)
  203. fields = context.external_data.form_data
  204. else:
  205. body = final_query
  206. fields = None
  207. headers['Content-Type'] = 'text/plain; charset=utf-8'
  208. response = self._raw_request(body,
  209. params,
  210. headers,
  211. stream=True,
  212. retries=self.query_retries,
  213. fields=fields,
  214. server_wait=not context.streaming)
  215. byte_source = RespBuffCls(ResponseSource(response)) # pylint: disable=not-callable
  216. context.set_response_tz(self._check_tz_change(response.headers.get('X-ClickHouse-Timezone')))
  217. query_result = self._transform.parse_response(byte_source, context)
  218. query_result.summary = self._summary(response)
  219. return query_result
  220. def data_insert(self, context: InsertContext) -> QuerySummary:
  221. """
  222. See BaseClient doc_string for this method
  223. """
  224. if context.empty:
  225. logger.debug('No data included in insert, skipping')
  226. return QuerySummary()
  227. def error_handler(resp: HTTPResponse):
  228. # If we actually had a local exception when building the insert, throw that instead
  229. if context.insert_exception:
  230. ex = context.insert_exception
  231. context.insert_exception = None
  232. raise ex
  233. self._error_handler(resp)
  234. headers = {'Content-Type': 'application/octet-stream'}
  235. if context.compression is None:
  236. context.compression = self.write_compression
  237. if context.compression:
  238. headers['Content-Encoding'] = context.compression
  239. block_gen = self._transform.build_insert(context)
  240. params = {}
  241. if self.database:
  242. params['database'] = self.database
  243. params.update(self._validate_settings(context.settings))
  244. response = self._raw_request(block_gen, params, headers, error_handler=error_handler, server_wait=False)
  245. logger.debug('Context insert response code: %d, content: %s', response.status, response.data)
  246. context.data = None
  247. return QuerySummary(self._summary(response))
  248. def raw_insert(self, table: str = None,
  249. column_names: Optional[Sequence[str]] = None,
  250. insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None,
  251. settings: Optional[Dict] = None,
  252. fmt: Optional[str] = None,
  253. compression: Optional[str] = None) -> QuerySummary:
  254. """
  255. See BaseClient doc_string for this method
  256. """
  257. params = {}
  258. headers = {'Content-Type': 'application/octet-stream'}
  259. if compression:
  260. headers['Content-Encoding'] = compression
  261. if table:
  262. cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else ''
  263. query = f'INSERT INTO {table}{cols} FORMAT {fmt if fmt else self._write_format}'
  264. if not compression and isinstance(insert_block, str):
  265. insert_block = query + '\n' + insert_block
  266. elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)):
  267. insert_block = (query + '\n').encode() + insert_block
  268. else:
  269. params['query'] = query
  270. if self.database:
  271. params['database'] = self.database
  272. params.update(self._validate_settings(settings or {}))
  273. response = self._raw_request(insert_block, params, headers, server_wait=False)
  274. logger.debug('Raw insert response code: %d, content: %s', response.status, response.data)
  275. return QuerySummary(self._summary(response))
  276. @staticmethod
  277. def _summary(response: HTTPResponse):
  278. summary = {}
  279. if 'X-ClickHouse-Summary' in response.headers:
  280. try:
  281. summary = json.loads(response.headers['X-ClickHouse-Summary'])
  282. except json.JSONDecodeError:
  283. pass
  284. summary['query_id'] = response.headers.get('X-ClickHouse-Query-Id', '')
  285. return summary
  286. def command(self,
  287. cmd,
  288. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  289. data: Union[str, bytes] = None,
  290. settings: Optional[Dict] = None,
  291. use_database: int = True,
  292. external_data: Optional[ExternalData] = None) -> Union[str, int, Sequence[str], QuerySummary]:
  293. """
  294. See BaseClient doc_string for this method
  295. """
  296. cmd, params = bind_query(cmd, parameters, self.server_tz)
  297. headers = {}
  298. payload = None
  299. fields = None
  300. if external_data:
  301. if data:
  302. raise ProgrammingError('Cannot combine command data with external data') from None
  303. fields = external_data.form_data
  304. params.update(external_data.query_params)
  305. elif isinstance(data, str):
  306. headers['Content-Type'] = 'text/plain; charset=utf-8'
  307. payload = data.encode()
  308. elif isinstance(data, bytes):
  309. headers['Content-Type'] = 'application/octet-stream'
  310. payload = data
  311. if payload is None and not cmd:
  312. raise ProgrammingError('Command sent without query or recognized data') from None
  313. if payload or fields:
  314. params['query'] = cmd
  315. else:
  316. payload = cmd
  317. if use_database and self.database:
  318. params['database'] = self.database
  319. params.update(self._validate_settings(settings or {}))
  320. method = 'POST' if payload or fields else 'GET'
  321. response = self._raw_request(payload, params, headers, method, fields=fields)
  322. if response.data:
  323. try:
  324. result = response.data.decode()[:-1].split('\t')
  325. if len(result) == 1:
  326. try:
  327. return int(result[0])
  328. except ValueError:
  329. return result[0]
  330. return result
  331. except UnicodeDecodeError:
  332. return str(response.data)
  333. return QuerySummary(self._summary(response))
  334. def _error_handler(self, response: HTTPResponse, retried: bool = False) -> None:
  335. if self.show_clickhouse_errors:
  336. try:
  337. err_content = get_response_data(response)
  338. except Exception: # pylint: disable=broad-except
  339. err_content = None
  340. finally:
  341. response.close()
  342. err_str = f'HTTPDriver for {self.url} returned response code {response.status})'
  343. if err_content:
  344. err_msg = common.format_error(err_content.decode(errors='backslashreplace'))
  345. err_str = f'{err_str}\n {err_msg}'
  346. else:
  347. err_str = 'The ClickHouse server returned an error.'
  348. raise OperationalError(err_str) if retried else DatabaseError(err_str) from None
  349. def _raw_request(self,
  350. data,
  351. params: Dict[str, str],
  352. headers: Optional[Dict[str, Any]] = None,
  353. method: str = 'POST',
  354. retries: int = 0,
  355. stream: bool = False,
  356. server_wait: bool = True,
  357. fields: Optional[Dict[str, tuple]] = None,
  358. error_handler: Callable = None) -> HTTPResponse:
  359. if isinstance(data, str):
  360. data = data.encode()
  361. headers = dict_copy(self.headers, headers)
  362. attempts = 0
  363. if server_wait:
  364. params['wait_end_of_query'] = '1'
  365. # We can't actually read the progress headers, but we enable them so ClickHouse sends something
  366. # to keep the connection alive when waiting for long-running queries and (2) to get summary information
  367. # if not streaming
  368. if self._send_progress:
  369. params['send_progress_in_http_headers'] = '1'
  370. if self._progress_interval:
  371. params['http_headers_progress_interval_ms'] = self._progress_interval
  372. final_params = dict_copy(self.params, params)
  373. url = f'{self.url}?{urlencode(final_params)}'
  374. kwargs = {
  375. 'headers': headers,
  376. 'timeout': self.timeout,
  377. 'retries': self.http_retries,
  378. 'preload_content': not stream
  379. }
  380. if self.server_host_name:
  381. kwargs['assert_same_host'] = False
  382. kwargs['headers'].update({'Host': self.server_host_name})
  383. if fields:
  384. kwargs['fields'] = fields
  385. else:
  386. kwargs['body'] = data
  387. check_conn_expiration(self.http)
  388. query_session = final_params.get('session_id')
  389. while True:
  390. attempts += 1
  391. if query_session:
  392. if query_session == self._active_session:
  393. raise ProgrammingError('Attempt to execute concurrent queries within the same session.' +
  394. 'Please use a separate client instance per thread/process.')
  395. # There is a race condition here when using multiprocessing -- in that case the server will
  396. # throw an error instead, but in most cases this more helpful error will be thrown first
  397. self._active_session = query_session
  398. try:
  399. response = self.http.request(method, url, **kwargs)
  400. except HTTPError as ex:
  401. if isinstance(ex.__context__, ConnectionResetError):
  402. # The server closed the connection, probably because the Keep Alive has expired
  403. # We should be safe to retry, as ClickHouse should not have processed anything on a connection
  404. # that it killed. We also only retry this once, as multiple disconnects are unlikely to be
  405. # related to the Keep Alive settings
  406. if attempts == 1:
  407. logger.debug('Retrying remotely closed connection')
  408. continue
  409. logger.warning('Unexpected Http Driver Exception')
  410. err_url = f' ({self.url})' if self.show_clickhouse_errors else ''
  411. raise OperationalError(f'Error {ex} executing HTTP request attempt {attempts}{err_url}') from ex
  412. finally:
  413. if query_session:
  414. self._active_session = None # Make sure we always clear this
  415. if 200 <= response.status < 300:
  416. return response
  417. if response.status in (429, 503, 504):
  418. if attempts > retries:
  419. self._error_handler(response, True)
  420. logger.debug('Retrying requests with status code %d', response.status)
  421. elif error_handler:
  422. error_handler(response)
  423. else:
  424. self._error_handler(response)
  425. def raw_query(self, query: str,
  426. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  427. settings: Optional[Dict[str, Any]] = None,
  428. fmt: str = None,
  429. use_database: bool = True,
  430. external_data: Optional[ExternalData] = None) -> bytes:
  431. """
  432. See BaseClient doc_string for this method
  433. """
  434. body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
  435. return self._raw_request(body, params, fields=fields).data
  436. def raw_stream(self, query: str,
  437. parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
  438. settings: Optional[Dict[str, Any]] = None,
  439. fmt: str = None,
  440. use_database: bool = True,
  441. external_data: Optional[ExternalData] = None) -> io.IOBase:
  442. """
  443. See BaseClient doc_string for this method
  444. """
  445. body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
  446. return self._raw_request(body, params, fields=fields, stream=True)
  447. def _prep_raw_query(self, query: str,
  448. parameters: Optional[Union[Sequence, Dict[str, Any]]],
  449. settings: Optional[Dict[str, Any]],
  450. fmt: str,
  451. use_database: bool,
  452. external_data: Optional[ExternalData]):
  453. if fmt:
  454. query += f'\n FORMAT {fmt}'
  455. final_query, bind_params = bind_query(query, parameters, self.server_tz)
  456. params = self._validate_settings(settings or {})
  457. if use_database and self.database:
  458. params['database'] = self.database
  459. params.update(bind_params)
  460. if external_data:
  461. if isinstance(final_query, bytes):
  462. raise ProgrammingError('Cannot combine binary query data with `External Data`')
  463. body = bytes()
  464. params['query'] = final_query
  465. params.update(external_data.query_params)
  466. fields = external_data.form_data
  467. else:
  468. body = final_query
  469. fields = None
  470. return body, params, fields
  471. def ping(self):
  472. """
  473. See BaseClient doc_string for this method
  474. """
  475. try:
  476. response = self.http.request('GET', f'{self.url}/ping', timeout=3)
  477. return 200 <= response.status < 300
  478. except HTTPError:
  479. logger.debug('ping failed', exc_info=True)
  480. return False
  481. def close(self):
  482. if self._owns_pool_manager:
  483. self.http.clear()
  484. all_managers.pop(self.http, None)