cursor.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import logging
  2. import re
  3. from typing import Optional, Sequence, List, Dict
  4. from clickhouse_connect.datatypes.registry import get_from_name
  5. from clickhouse_connect.driver.common import unescape_identifier
  6. from clickhouse_connect.driver.exceptions import ProgrammingError
  7. from clickhouse_connect.driver import Client
  8. from clickhouse_connect.driver.parser import parse_callable
  9. from clickhouse_connect.driver.query import remove_sql_comments
  10. logger = logging.getLogger(__name__)
  11. insert_re = re.compile(r'^\s*INSERT\s+INTO\s+(.*$)', re.IGNORECASE)
  12. str_type = get_from_name('String')
  13. int_type = get_from_name('Int32')
  14. # pylint: disable=too-many-instance-attributes
  15. class Cursor:
  16. """
  17. See :ref:`https://peps.python.org/pep-0249/`
  18. """
  19. def __init__(self, client: Client):
  20. self.client = client
  21. self.arraysize = 1
  22. self.data: Optional[Sequence] = None
  23. self.names = []
  24. self.types = []
  25. self._rowcount = 0
  26. self._summary: List[Dict[str, str]] = []
  27. self._ix = 0
  28. def check_valid(self):
  29. if self.data is None:
  30. raise ProgrammingError('Cursor is not valid')
  31. @property
  32. def description(self):
  33. return [(n, t, None, None, None, None, True) for n, t in zip(self.names, self.types)]
  34. @property
  35. def rowcount(self):
  36. return self._rowcount
  37. @property
  38. def summary(self) -> List[Dict[str, str]]:
  39. return self._summary
  40. def close(self):
  41. self.data = None
  42. def execute(self, operation: str, parameters=None):
  43. query_result = self.client.query(operation, parameters)
  44. self.data = query_result.result_set
  45. self._rowcount = len(self.data)
  46. self._summary.append(query_result.summary)
  47. if query_result.column_names:
  48. self.names = query_result.column_names
  49. self.types = [x.name for x in query_result.column_types]
  50. elif self.data:
  51. self.names = [f'col_{x}' for x in range(len(self.data[0]))]
  52. self.types = [x.__class__ for x in self.data[0]]
  53. def _try_bulk_insert(self, operation: str, data):
  54. match = insert_re.match(remove_sql_comments(operation))
  55. if not match:
  56. return False
  57. temp = match.group(1)
  58. table_end = min(temp.find(' '), temp.find('('))
  59. table = temp[:table_end].strip()
  60. temp = temp[table_end:].strip()
  61. if temp[0] == '(':
  62. _, op_columns, temp = parse_callable(temp)
  63. else:
  64. op_columns = None
  65. if 'VALUES' not in temp.upper():
  66. return False
  67. col_names = list(data[0].keys())
  68. if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names):
  69. return False # Data sent in doesn't match the columns in the insert statement
  70. data_values = [list(row.values()) for row in data]
  71. self.client.insert(table, data_values, col_names)
  72. self.data = []
  73. return True
  74. def executemany(self, operation, parameters):
  75. if not parameters or self._try_bulk_insert(operation, parameters):
  76. return
  77. self.data = []
  78. try:
  79. for param_row in parameters:
  80. query_result = self.client.query(operation, param_row)
  81. self.data.extend(query_result.result_set)
  82. if self.names or self.types:
  83. if query_result.column_names != self.names:
  84. logger.warning('Inconsistent column names %s : %s for operation %s in cursor executemany',
  85. self.names, query_result.column_names, operation)
  86. else:
  87. self.names = query_result.column_names
  88. self.types = query_result.column_types
  89. self._summary.append(query_result.summary)
  90. except TypeError as ex:
  91. raise ProgrammingError(f'Invalid parameters {parameters} passed to cursor executemany') from ex
  92. self._rowcount = len(self.data)
  93. def fetchall(self):
  94. self.check_valid()
  95. ret = self.data
  96. self._ix = self._rowcount
  97. return ret
  98. def fetchone(self):
  99. self.check_valid()
  100. if self._ix >= self._rowcount:
  101. return None
  102. val = self.data[self._ix]
  103. self._ix += 1
  104. return val
  105. def fetchmany(self, size: int = -1):
  106. self.check_valid()
  107. end = self._ix + max(size, self._rowcount - self._ix)
  108. ret = self.data[self._ix: end]
  109. self._ix = end
  110. return ret
  111. def nextset(self):
  112. raise NotImplementedError
  113. def callproc(self, *args, **kwargs):
  114. raise NotImplementedError