common.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import array
  2. import struct
  3. import sys
  4. from typing import Sequence, MutableSequence, Dict, Optional, Union, Generator
  5. from clickhouse_connect.driver.exceptions import ProgrammingError, StreamClosedError, DataError
  6. from clickhouse_connect.driver.types import Closable
  7. # pylint: disable=invalid-name
  8. must_swap = sys.byteorder == 'big'
  9. int_size = array.array('i').itemsize
  10. low_card_version = 1
  11. array_map = {1: 'b', 2: 'h', 4: 'i', 8: 'q'}
  12. decimal_prec = {32: 9, 64: 18, 128: 38, 256: 79}
  13. if int_size == 2:
  14. array_map[4] = 'l'
  15. array_sizes = {v: k for k, v in array_map.items()}
  16. array_sizes['f'] = 4
  17. array_sizes['d'] = 8
  18. np_date_types = {0: '[s]', 3: '[ms]', 6: '[us]', 9: '[ns]'}
  19. def array_type(size: int, signed: bool):
  20. """
  21. Determines the Python array.array code for the requested byte size
  22. :param size: byte size
  23. :param signed: whether int types should be signed or unsigned
  24. :return: Python array.array code
  25. """
  26. try:
  27. code = array_map[size]
  28. except KeyError:
  29. return None
  30. return code if signed else code.upper()
  31. def write_array(code: str, column: Sequence, dest: MutableSequence, col_name: Optional[str]=None):
  32. """
  33. Write a column of native Python data matching the array.array code
  34. :param code: Python array.array code matching the column data type
  35. :param column: Column of native Python values
  36. :param dest: Destination byte buffer
  37. :param col_name: Optional column name for error tracking
  38. """
  39. try:
  40. buff = struct.Struct(f'<{len(column)}{code}')
  41. dest += buff.pack(*column)
  42. except (TypeError, OverflowError, struct.error) as ex:
  43. col_msg = ''
  44. if col_name:
  45. col_msg = f' for source column `{col_name}`'
  46. raise DataError(f'Unable to create Python array{col_msg}. This is usually caused by trying to insert None ' +
  47. 'values into a ClickHouse column that is not Nullable') from ex
  48. def write_uint64(value: int, dest: MutableSequence):
  49. """
  50. Write a single UInt64 value to a binary write buffer
  51. :param value: UInt64 value to write
  52. :param dest: Destination byte buffer
  53. """
  54. dest.extend(value.to_bytes(8, 'little'))
  55. def write_leb128(value: int, dest: MutableSequence):
  56. """
  57. Write a LEB128 encoded integer to a target binary buffer
  58. :param value: Integer value (positive only)
  59. :param dest: Target buffer
  60. """
  61. while True:
  62. b = value & 0x7f
  63. value >>= 7
  64. if value == 0:
  65. dest.append(b)
  66. return
  67. dest.append(0x80 | b)
  68. def decimal_size(prec: int):
  69. """
  70. Determine the bit size of a ClickHouse or Python Decimal needed to store a value of the requested precision
  71. :param prec: Precision of the Decimal in total number of base 10 digits
  72. :return: Required bit size
  73. """
  74. if prec < 1 or prec > 79:
  75. raise ArithmeticError(f'Invalid precision {prec} for ClickHouse Decimal type')
  76. if prec < 10:
  77. return 32
  78. if prec < 19:
  79. return 64
  80. if prec < 39:
  81. return 128
  82. return 256
  83. def unescape_identifier(x: str) -> str:
  84. if x.startswith('`') and x.endswith('`'):
  85. return x[1:-1]
  86. return x
  87. def dict_copy(source: Dict = None, update: Optional[Dict] = None) -> Dict:
  88. copy = source.copy() if source else {}
  89. if update:
  90. copy.update(update)
  91. return copy
  92. def dict_add(source: Dict, key: str, value: any) -> Dict:
  93. if value is not None:
  94. source[key] = value
  95. return source
  96. def empty_gen():
  97. yield from ()
  98. def coerce_int(val: Optional[Union[str, int]]) -> int:
  99. if not val:
  100. return 0
  101. return int(val)
  102. def coerce_bool(val: Optional[Union[str, bool]]):
  103. if not val:
  104. return False
  105. return val is True or (isinstance(val, str) and val.lower() in ('true', '1', 'y', 'yes'))
  106. def first_value(column: Sequence, nullable:bool = True):
  107. if nullable:
  108. return next((x for x in column if x is not None), None)
  109. if len(column):
  110. return column[0]
  111. return None
  112. class SliceView(Sequence):
  113. """
  114. Provides a view into a sequence rather than copying. Borrows liberally from
  115. https://gist.github.com/mathieucaroff/0cf094325fb5294fb54c6a577f05a2c1
  116. Also see the discussion on SO: https://stackoverflow.com/questions/3485475/can-i-create-a-view-on-a-python-list
  117. """
  118. slots = ('_source', '_range')
  119. def __init__(self, source: Sequence, source_slice: Optional[slice] = None):
  120. if isinstance(source, SliceView):
  121. self._source = source._source
  122. self._range = source._range[source_slice]
  123. else:
  124. self._source = source
  125. if source_slice is None:
  126. self._range = range(len(source))
  127. else:
  128. self._range = range(len(source))[source_slice]
  129. def __len__(self):
  130. return len(self._range)
  131. def __getitem__(self, i):
  132. if isinstance(i, slice):
  133. return SliceView(self._source, i)
  134. return self._source[self._range[i]]
  135. def __str__(self):
  136. r = self._range
  137. return str(self._source[slice(r.start, r.stop, r.step)])
  138. def __repr__(self):
  139. r = self._range
  140. return f'SliceView({self._source[slice(r.start, r.stop, r.step)]})'
  141. def __eq__(self, other):
  142. if self is other:
  143. return True
  144. if len(self) != len(other):
  145. return False
  146. for v, w in zip(self, other):
  147. if v != w:
  148. return False
  149. return True
  150. class StreamContext:
  151. """
  152. Wraps a generator and its "source" in a Context. This ensures that the source will be "closed" even if the
  153. generator is not fully consumed or there is an exception during consumption
  154. """
  155. __slots__ = 'source', 'gen', '_in_context'
  156. def __init__(self, source: Closable, gen: Generator):
  157. self.source = source
  158. self.gen = gen
  159. self._in_context = False
  160. def __iter__(self):
  161. return self
  162. def __next__(self):
  163. if not self._in_context:
  164. raise ProgrammingError('Stream should be used within a context')
  165. return next(self.gen)
  166. def __enter__(self):
  167. if not self.gen:
  168. raise StreamClosedError
  169. self._in_context = True
  170. return self
  171. def __exit__(self, exc_type, exc_val, exc_tb):
  172. self._in_context = False
  173. self.source.close()
  174. self.gen = None