cursor.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import logging
  2. import re
  3. from typing import Optional, Sequence
  4. from clickhouse_connect.datatypes.registry import get_from_name
  5. from clickhouse_connect.driver.common import unescape_identifier
  6. from clickhouse_connect.driver.exceptions import ProgrammingError
  7. from clickhouse_connect.driver import Client
  8. from clickhouse_connect.driver.parser import parse_callable
  9. from clickhouse_connect.driver.query import remove_sql_comments
  10. logger = logging.getLogger(__name__)
  11. insert_re = re.compile(r'^\s*INSERT\s+INTO\s+(.*$)', re.IGNORECASE)
  12. str_type = get_from_name('String')
  13. int_type = get_from_name('Int32')
  14. class Cursor:
  15. """
  16. See :ref:`https://peps.python.org/pep-0249/`
  17. """
  18. def __init__(self, client: Client):
  19. self.client = client
  20. self.arraysize = 1
  21. self.data: Optional[Sequence] = None
  22. self.names = []
  23. self.types = []
  24. self._rowcount = 0
  25. self._ix = 0
  26. def check_valid(self):
  27. if self.data is None:
  28. raise ProgrammingError('Cursor is not valid')
  29. @property
  30. def description(self):
  31. return [(n, t, None, None, None, None, True) for n, t in zip(self.names, self.types)]
  32. @property
  33. def rowcount(self):
  34. return self._rowcount
  35. def close(self):
  36. self.data = None
  37. def execute(self, operation: str, parameters=None):
  38. query_result = self.client.query(operation, parameters)
  39. self.data = query_result.result_set
  40. self._rowcount = len(self.data)
  41. if query_result.column_names:
  42. self.names = query_result.column_names
  43. self.types = [x.name for x in query_result.column_types]
  44. elif self.data:
  45. self.names = [f'col_{x}' for x in range(len(self.data[0]))]
  46. self.types = [x.__class__ for x in self.data[0]]
  47. def _try_bulk_insert(self, operation: str, data):
  48. match = insert_re.match(remove_sql_comments(operation))
  49. if not match:
  50. return False
  51. temp = match.group(1)
  52. table_end = min(temp.find(' '), temp.find('('))
  53. table = temp[:table_end].strip()
  54. temp = temp[table_end:].strip()
  55. if temp[0] == '(':
  56. _, op_columns, temp = parse_callable(temp)
  57. else:
  58. op_columns = None
  59. if 'VALUES' not in temp.upper():
  60. return False
  61. col_names = list(data[0].keys())
  62. if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names):
  63. return False # Data sent in doesn't match the columns in the insert statement
  64. data_values = [list(row.values()) for row in data]
  65. self.client.insert(table, data_values, col_names)
  66. self.data = []
  67. return True
  68. def executemany(self, operation, parameters):
  69. if not parameters or self._try_bulk_insert(operation, parameters):
  70. return
  71. self.data = []
  72. try:
  73. for param_row in parameters:
  74. query_result = self.client.query(operation, param_row)
  75. self.data.extend(query_result.result_set)
  76. if self.names or self.types:
  77. if query_result.column_names != self.names:
  78. logger.warning('Inconsistent column names %s : %s for operation %s in cursor executemany',
  79. self.names, query_result.column_names, operation)
  80. else:
  81. self.names = query_result.column_names
  82. self.types = query_result.column_types
  83. except TypeError as ex:
  84. raise ProgrammingError(f'Invalid parameters {parameters} passed to cursor executemany') from ex
  85. self._rowcount = len(self.data)
  86. def fetchall(self):
  87. self.check_valid()
  88. ret = self.data
  89. self._ix = self._rowcount
  90. return ret
  91. def fetchone(self):
  92. self.check_valid()
  93. if self._ix >= self._rowcount:
  94. return None
  95. val = self.data[self._ix]
  96. self._ix += 1
  97. return val
  98. def fetchmany(self, size: int = -1):
  99. self.check_valid()
  100. end = self._ix + max(size, self._rowcount - self._ix)
  101. ret = self.data[self._ix: end]
  102. self._ix = end
  103. return ret
  104. def nextset(self):
  105. raise NotImplementedError
  106. def callproc(self, *args, **kwargs):
  107. raise NotImplementedError