binding.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import ipaddress
  2. import re
  3. import uuid
  4. from datetime import tzinfo, datetime, date
  5. from enum import Enum
  6. from typing import Optional, Union, Sequence, Dict, Any, Tuple
  7. import pytz
  8. from clickhouse_connect import common
  9. from clickhouse_connect.driver.common import dict_copy
  10. from clickhouse_connect.json_impl import any_to_json
  11. BS = '\\'
  12. must_escape = (BS, '\'', '`', '\t', '\n')
  13. external_bind_re = re.compile(r'{.+:.+}')
  14. class DT64Param:
  15. def __init__(self, value: datetime):
  16. self.value = value
  17. def format(self, tz: tzinfo, top_level:bool) -> str:
  18. value = self.value
  19. if tz:
  20. value = value.astimezone(tz)
  21. s = value.strftime('%Y-%m-%d %H:%M:%S.%f')
  22. if top_level:
  23. return s
  24. return f"'{s}'"
  25. def quote_identifier(identifier: str):
  26. first_char = identifier[0]
  27. if first_char in ('`', '"') and identifier[-1] == first_char:
  28. # Identifier is already quoted, assume that it's valid
  29. return identifier
  30. return f'`{escape_str(identifier)}`'
  31. def finalize_query(query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]],
  32. server_tz: Optional[tzinfo] = None) -> str:
  33. while query.endswith(';'):
  34. query = query[:-1]
  35. if not parameters:
  36. return query
  37. if hasattr(parameters, 'items'):
  38. return query % {k: format_query_value(v, server_tz) for k, v in parameters.items()}
  39. return query % tuple(format_query_value(v, server_tz) for v in parameters)
  40. # pylint: disable=too-many-locals,too-many-branches
  41. def bind_query(query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]],
  42. server_tz: Optional[tzinfo] = None) -> Tuple[str, Dict[str, str]]:
  43. while query.endswith(';'):
  44. query = query[:-1]
  45. if not parameters:
  46. return query, {}
  47. binary_binds = None
  48. if isinstance(parameters, dict):
  49. params_copy = dict_copy(parameters)
  50. binary_binds = {k: v for k, v in params_copy.items() if k.startswith('$') and k.endswith('$') and len(k) > 1}
  51. for key in binary_binds.keys():
  52. del params_copy[key]
  53. final_params = {}
  54. for k, v in params_copy.items():
  55. if k.endswith('_64'):
  56. if isinstance(v, datetime):
  57. k = k[:-3]
  58. v = DT64Param(v)
  59. elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], datetime):
  60. k = k[:-3]
  61. v = [DT64Param(x) for x in v]
  62. final_params[k] = v
  63. if external_bind_re.search(query) is None:
  64. query, bound_params = finalize_query(query, final_params, server_tz), {}
  65. else:
  66. bound_params = {f'param_{k}': format_bind_value(v, server_tz) for k, v in final_params.items()}
  67. else:
  68. query, bound_params = finalize_query(query, parameters, server_tz), {}
  69. if binary_binds:
  70. binary_query = query.encode()
  71. binary_indexes = {}
  72. for k, v in binary_binds.items():
  73. key = k.encode()
  74. item_index = 0
  75. while True:
  76. item_index = binary_query.find(key, item_index)
  77. if item_index == -1:
  78. break
  79. binary_indexes[item_index + len(key)] = key, v
  80. item_index += len(key)
  81. query = b''
  82. start = 0
  83. for loc in sorted(binary_indexes.keys()):
  84. key, value = binary_indexes[loc]
  85. query += binary_query[start:loc] + value + key
  86. start = loc
  87. query += binary_query[start:]
  88. return query, bound_params
  89. def format_str(value: str):
  90. return f"'{escape_str(value)}'"
  91. def escape_str(value: str):
  92. return ''.join(f'{BS}{c}' if c in must_escape else c for c in value)
  93. # pylint: disable=too-many-return-statements
  94. def format_query_value(value: Any, server_tz: tzinfo = pytz.UTC):
  95. """
  96. Format Python values in a ClickHouse query
  97. :param value: Python object
  98. :param server_tz: Server timezone for adjusting datetime values
  99. :return: Literal string for python value
  100. """
  101. if value is None:
  102. return 'NULL'
  103. if isinstance(value, str):
  104. return format_str(value)
  105. if isinstance(value, DT64Param):
  106. return value.format(server_tz, False)
  107. if isinstance(value, datetime):
  108. if value.tzinfo is not None or server_tz != pytz.UTC:
  109. value = value.astimezone(server_tz)
  110. return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
  111. if isinstance(value, date):
  112. return f"'{value.isoformat()}'"
  113. if isinstance(value, list):
  114. return f"[{', '.join(str_query_value(x, server_tz) for x in value)}]"
  115. if isinstance(value, tuple):
  116. return f"({', '.join(str_query_value(x, server_tz) for x in value)})"
  117. if isinstance(value, dict):
  118. if common.get_setting('dict_parameter_format') == 'json':
  119. return format_str(any_to_json(value).decode())
  120. pairs = [str_query_value(k, server_tz) + ':' + str_query_value(v, server_tz)
  121. for k, v in value.items()]
  122. return f"{{{', '.join(pairs)}}}"
  123. if isinstance(value, Enum):
  124. return format_query_value(value.value, server_tz)
  125. if isinstance(value, (uuid.UUID, ipaddress.IPv4Address, ipaddress.IPv6Address)):
  126. return f"'{value}'"
  127. return value
  128. def str_query_value(value: Any, server_tz: tzinfo = pytz.UTC):
  129. return str(format_query_value(value, server_tz))
  130. # pylint: disable=too-many-branches
  131. def format_bind_value(value: Any, server_tz: tzinfo = pytz.UTC, top_level: bool = True):
  132. """
  133. Format Python values in a ClickHouse query
  134. :param value: Python object
  135. :param server_tz: Server timezone for adjusting datetime values
  136. :param top_level: Flag for top level for nested structures
  137. :return: Literal string for python value
  138. """
  139. def recurse(x):
  140. return format_bind_value(x, server_tz, False)
  141. if value is None:
  142. return '\\N'
  143. if isinstance(value, str):
  144. if top_level:
  145. # At the top levels, strings must not be surrounded by quotes
  146. return escape_str(value)
  147. return format_str(value)
  148. if isinstance(value, DT64Param):
  149. return value.format(server_tz, top_level)
  150. if isinstance(value, datetime):
  151. value = value.astimezone(server_tz)
  152. val = value.strftime('%Y-%m-%d %H:%M:%S')
  153. if top_level:
  154. return val
  155. return f"'{val}'"
  156. if isinstance(value, date):
  157. if top_level:
  158. return value.isoformat()
  159. return f"'{value.isoformat()}'"
  160. if isinstance(value, list):
  161. return f"[{', '.join(recurse(x) for x in value)}]"
  162. if isinstance(value, tuple):
  163. return f"({', '.join(recurse(x) for x in value)})"
  164. if isinstance(value, dict):
  165. if common.get_setting('dict_parameter_format') == 'json':
  166. return any_to_json(value).decode()
  167. pairs = [recurse(k) + ':' + recurse(v)
  168. for k, v in value.items()]
  169. return f"{{{', '.join(pairs)}}}"
  170. if isinstance(value, Enum):
  171. return recurse(value.value)
  172. return str(value)