123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- import sys
- import array
- from typing import Any, Iterable
- from clickhouse_connect.driver.exceptions import StreamCompleteException
- from clickhouse_connect.driver.types import ByteSource
- must_swap = sys.byteorder == 'big'
- class ResponseBuffer(ByteSource):
- slots = 'slice_sz', 'buf_loc', 'end', 'gen', 'buffer', 'slice'
- def __init__(self, source):
- self.slice_sz = 4096
- self.buf_loc = 0
- self.buf_sz = 0
- self.source = source
- self.gen = source.gen
- self.buffer = bytes()
- def read_bytes(self, sz: int):
- if self.buf_loc + sz <= self.buf_sz:
- self.buf_loc += sz
- return self.buffer[self.buf_loc - sz: self.buf_loc]
- # Create a temporary buffer that bridges two or more source chunks
- bridge = bytearray(self.buffer[self.buf_loc: self.buf_sz])
- self.buf_loc = 0
- self.buf_sz = 0
- while len(bridge) < sz:
- chunk = next(self.gen, None)
- if not chunk:
- raise StreamCompleteException
- x = len(chunk)
- if len(bridge) + x <= sz:
- bridge.extend(chunk)
- else:
- tail = sz - len(bridge)
- bridge.extend(chunk[:tail])
- self.buffer = chunk
- self.buf_sz = x
- self.buf_loc = tail
- return bridge
- def read_byte(self) -> int:
- if self.buf_loc < self.buf_sz:
- self.buf_loc += 1
- return self.buffer[self.buf_loc - 1]
- self.buf_sz = 0
- self.buf_loc = 0
- chunk = next(self.gen, None)
- if not chunk:
- raise StreamCompleteException
- x = len(chunk)
- if x > 1:
- self.buffer = chunk
- self.buf_loc = 1
- self.buf_sz = x
- return chunk[0]
- def read_leb128(self) -> int:
- sz = 0
- shift = 0
- while True:
- b = self.read_byte()
- sz += ((b & 0x7f) << shift)
- if (b & 0x80) == 0:
- return sz
- shift += 7
- def read_leb128_str(self) -> str:
- sz = self.read_leb128()
- return self.read_bytes(sz).decode()
- def read_uint64(self) -> int:
- return int.from_bytes(self.read_bytes(8), 'little', signed=False)
- def read_str_col(self,
- num_rows: int,
- encoding: str,
- nullable: bool = False,
- null_obj: Any = None) -> Iterable[str]:
- column = []
- app = column.append
- null_map = self.read_bytes(num_rows) if nullable else None
- for ix in range(num_rows):
- sz = 0
- shift = 0
- while True:
- b = self.read_byte()
- sz += ((b & 0x7f) << shift)
- if (b & 0x80) == 0:
- break
- shift += 7
- x = self.read_bytes(sz)
- if null_map and null_map[ix]:
- app(null_obj)
- elif encoding:
- try:
- app(x.decode(encoding))
- except UnicodeDecodeError:
- app(x.hex())
- else:
- app(x)
- return column
- def read_bytes_col(self, sz: int, num_rows: int) -> Iterable[bytes]:
- source = self.read_bytes(sz * num_rows)
- return [bytes(source[x:x+sz]) for x in range(0, sz * num_rows, sz)]
- def read_fixed_str_col(self, sz: int, num_rows: int, encoding: str) -> Iterable[str]:
- source = self.read_bytes(sz * num_rows)
- column = []
- app = column.append
- for ix in range(0, sz * num_rows, sz):
- try:
- app(str(source[ix: ix + sz], encoding).rstrip('\x00'))
- except UnicodeDecodeError:
- app(source[ix: ix + sz].hex())
- return column
- def read_array(self, array_type: str, num_rows: int) -> Iterable[Any]:
- column = array.array(array_type)
- sz = column.itemsize * num_rows
- b = self.read_bytes(sz)
- column.frombytes(b)
- if must_swap:
- column.byteswap()
- return column
- @property
- def last_message(self):
- if len(self.buffer) == 0:
- return None
- return self.buffer.decode()
- def close(self):
- if self.source:
- self.source.close()
- self.source = None
|