123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- import logging
- import re
- from typing import Optional, Sequence, List, Dict
- from clickhouse_connect.datatypes.registry import get_from_name
- from clickhouse_connect.driver.common import unescape_identifier
- from clickhouse_connect.driver.exceptions import ProgrammingError
- from clickhouse_connect.driver import Client
- from clickhouse_connect.driver.parser import parse_callable
- from clickhouse_connect.driver.query import remove_sql_comments
- logger = logging.getLogger(__name__)
- insert_re = re.compile(r'^\s*INSERT\s+INTO\s+(.*$)', re.IGNORECASE)
- str_type = get_from_name('String')
- int_type = get_from_name('Int32')
- # pylint: disable=too-many-instance-attributes
- class Cursor:
- """
- See :ref:`https://peps.python.org/pep-0249/`
- """
- def __init__(self, client: Client):
- self.client = client
- self.arraysize = 1
- self.data: Optional[Sequence] = None
- self.names = []
- self.types = []
- self._rowcount = 0
- self._summary: List[Dict[str, str]] = []
- self._ix = 0
- def check_valid(self):
- if self.data is None:
- raise ProgrammingError('Cursor is not valid')
- @property
- def description(self):
- return [(n, t, None, None, None, None, True) for n, t in zip(self.names, self.types)]
- @property
- def rowcount(self):
- return self._rowcount
- @property
- def summary(self) -> List[Dict[str, str]]:
- return self._summary
- def close(self):
- self.data = None
- def execute(self, operation: str, parameters=None):
- query_result = self.client.query(operation, parameters)
- self.data = query_result.result_set
- self._rowcount = len(self.data)
- self._summary.append(query_result.summary)
- if query_result.column_names:
- self.names = query_result.column_names
- self.types = [x.name for x in query_result.column_types]
- elif self.data:
- self.names = [f'col_{x}' for x in range(len(self.data[0]))]
- self.types = [x.__class__ for x in self.data[0]]
- def _try_bulk_insert(self, operation: str, data):
- match = insert_re.match(remove_sql_comments(operation))
- if not match:
- return False
- temp = match.group(1)
- table_end = min(temp.find(' '), temp.find('('))
- table = temp[:table_end].strip()
- temp = temp[table_end:].strip()
- if temp[0] == '(':
- _, op_columns, temp = parse_callable(temp)
- else:
- op_columns = None
- if 'VALUES' not in temp.upper():
- return False
- col_names = list(data[0].keys())
- if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names):
- return False # Data sent in doesn't match the columns in the insert statement
- data_values = [list(row.values()) for row in data]
- self.client.insert(table, data_values, col_names)
- self.data = []
- return True
- def executemany(self, operation, parameters):
- if not parameters or self._try_bulk_insert(operation, parameters):
- return
- self.data = []
- try:
- for param_row in parameters:
- query_result = self.client.query(operation, param_row)
- self.data.extend(query_result.result_set)
- if self.names or self.types:
- if query_result.column_names != self.names:
- logger.warning('Inconsistent column names %s : %s for operation %s in cursor executemany',
- self.names, query_result.column_names, operation)
- else:
- self.names = query_result.column_names
- self.types = query_result.column_types
- self._summary.append(query_result.summary)
- except TypeError as ex:
- raise ProgrammingError(f'Invalid parameters {parameters} passed to cursor executemany') from ex
- self._rowcount = len(self.data)
- def fetchall(self):
- self.check_valid()
- ret = self.data
- self._ix = self._rowcount
- return ret
- def fetchone(self):
- self.check_valid()
- if self._ix >= self._rowcount:
- return None
- val = self.data[self._ix]
- self._ix += 1
- return val
- def fetchmany(self, size: int = -1):
- self.check_valid()
- end = self._ix + max(size, self._rowcount - self._ix)
- ret = self.data[self._ix: end]
- self._ix = end
- return ret
- def nextset(self):
- raise NotImplementedError
- def callproc(self, *args, **kwargs):
- raise NotImplementedError
|