transform.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import logging
  2. from typing import Union
  3. from clickhouse_connect.datatypes import registry
  4. from clickhouse_connect.driver.common import write_leb128
  5. from clickhouse_connect.driver.exceptions import StreamCompleteException, StreamFailureError
  6. from clickhouse_connect.driver.insert import InsertContext
  7. from clickhouse_connect.driver.npquery import NumpyResult
  8. from clickhouse_connect.driver.query import QueryResult, QueryContext
  9. from clickhouse_connect.driver.types import ByteSource
  10. from clickhouse_connect.driver.compression import get_compressor
  11. _EMPTY_CTX = QueryContext()
  12. logger = logging.getLogger(__name__)
  13. class NativeTransform:
  14. # pylint: disable=too-many-locals
  15. @staticmethod
  16. def parse_response(source: ByteSource, context: QueryContext = _EMPTY_CTX) -> Union[NumpyResult, QueryResult]:
  17. names = []
  18. col_types = []
  19. block_num = 0
  20. def get_block():
  21. nonlocal block_num
  22. result_block = []
  23. try:
  24. try:
  25. if context.block_info:
  26. source.read_bytes(8)
  27. num_cols = source.read_leb128()
  28. except StreamCompleteException:
  29. return None
  30. num_rows = source.read_leb128()
  31. for col_num in range(num_cols):
  32. name = source.read_leb128_str()
  33. type_name = source.read_leb128_str()
  34. if block_num == 0:
  35. names.append(name)
  36. col_type = registry.get_from_name(type_name)
  37. col_types.append(col_type)
  38. else:
  39. col_type = col_types[col_num]
  40. if num_rows == 0:
  41. result_block.append(tuple())
  42. else:
  43. context.start_column(name)
  44. column = col_type.read_column(source, num_rows, context)
  45. result_block.append(column)
  46. except Exception as ex:
  47. source.close()
  48. if isinstance(ex, StreamCompleteException):
  49. # We ran out of data before it was expected, this could be ClickHouse reporting an error
  50. # in the response
  51. if source.last_message:
  52. raise StreamFailureError(extract_error_message(source.last_message)) from None
  53. raise
  54. block_num += 1
  55. return result_block
  56. first_block = get_block()
  57. if first_block is None:
  58. return NumpyResult() if context.use_numpy else QueryResult([])
  59. def gen():
  60. yield first_block
  61. while True:
  62. next_block = get_block()
  63. if next_block is None:
  64. return
  65. yield next_block
  66. if context.use_numpy:
  67. res_types = [col.dtype if hasattr(col, 'dtype') else 'O' for col in first_block]
  68. return NumpyResult(gen(), tuple(names), tuple(col_types), res_types, source)
  69. return QueryResult(None, gen(), tuple(names), tuple(col_types), context.column_oriented, source)
  70. @staticmethod
  71. def build_insert(context: InsertContext):
  72. compressor = get_compressor(context.compression)
  73. def chunk_gen():
  74. for block in context.next_block():
  75. output = bytearray()
  76. output += block.prefix
  77. write_leb128(block.column_count, output)
  78. write_leb128(block.row_count, output)
  79. for col_name, col_type, data in zip(block.column_names, block.column_types, block.column_data):
  80. col_enc = col_name.encode()
  81. write_leb128(len(col_enc), output)
  82. output += col_enc
  83. col_enc = col_type.insert_name.encode()
  84. write_leb128(len(col_enc), output)
  85. output += col_enc
  86. context.start_column(col_name)
  87. try:
  88. col_type.write_column(data, output, context)
  89. except Exception as ex: # pylint: disable=broad-except
  90. # This is hideous, but some low level serializations can fail while streaming
  91. # the insert if the user has included bad data in the column. We need to ensure that the
  92. # insert fails (using garbage data) to avoid a partial insert, and use the context to
  93. # propagate the correct exception to the user
  94. logger.error('Error serializing column `%s` into data type `%s`',
  95. col_name, col_type.name, exc_info=True)
  96. context.insert_exception = ex
  97. yield 'INTERNAL EXCEPTION WHILE SERIALIZING'.encode()
  98. return
  99. yield compressor.compress_block(output)
  100. footer = compressor.flush()
  101. if footer:
  102. yield footer
  103. return chunk_gen()
  104. def extract_error_message(message: bytes) -> str:
  105. if len(message) > 1024:
  106. message = message[-1024:]
  107. error_start = message.find('Code: '.encode())
  108. if error_start != -1:
  109. message = message[error_start:]
  110. try:
  111. message_str = message.decode()
  112. except UnicodeError:
  113. message_str = f'unrecognized data found in stream: `{message.hex()[128:]}`'
  114. return message_str