transform.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import logging
  2. from typing import Union
  3. from clickhouse_connect.datatypes import registry
  4. from clickhouse_connect.driver.common import write_leb128
  5. from clickhouse_connect.driver.exceptions import StreamCompleteException, StreamFailureError
  6. from clickhouse_connect.driver.insert import InsertContext
  7. from clickhouse_connect.driver.npquery import NumpyResult
  8. from clickhouse_connect.driver.query import QueryResult, QueryContext
  9. from clickhouse_connect.driver.types import ByteSource
  10. from clickhouse_connect.driver.compression import get_compressor
  11. _EMPTY_CTX = QueryContext()
  12. logger = logging.getLogger(__name__)
  13. class NativeTransform:
  14. # pylint: disable=too-many-locals
  15. @staticmethod
  16. def parse_response(source: ByteSource, context: QueryContext = _EMPTY_CTX) -> Union[NumpyResult, QueryResult]:
  17. names = []
  18. col_types = []
  19. block_num = 0
  20. def get_block():
  21. nonlocal block_num
  22. result_block = []
  23. try:
  24. try:
  25. if context.block_info:
  26. source.read_bytes(8)
  27. num_cols = source.read_leb128()
  28. except StreamCompleteException:
  29. return None
  30. num_rows = source.read_leb128()
  31. for col_num in range(num_cols):
  32. name = source.read_leb128_str()
  33. type_name = source.read_leb128_str()
  34. if block_num == 0:
  35. names.append(name)
  36. col_type = registry.get_from_name(type_name)
  37. col_types.append(col_type)
  38. else:
  39. col_type = col_types[col_num]
  40. if num_rows == 0:
  41. result_block.append(tuple())
  42. else:
  43. context.start_column(name)
  44. column = col_type.read_column(source, num_rows, context)
  45. result_block.append(column)
  46. except Exception as ex:
  47. source.close()
  48. if isinstance(ex, StreamCompleteException):
  49. # We ran out of data before it was expected, this could be ClickHouse reporting an error
  50. # in the response
  51. message = source.last_message
  52. if len(message) > 1024:
  53. message = message[-1024:]
  54. error_start = message.find('Code: ')
  55. if error_start != -1:
  56. message = message[error_start:]
  57. raise StreamFailureError(message) from None
  58. raise
  59. block_num += 1
  60. return result_block
  61. first_block = get_block()
  62. if first_block is None:
  63. return NumpyResult() if context.use_numpy else QueryResult([])
  64. def gen():
  65. yield first_block
  66. while True:
  67. next_block = get_block()
  68. if next_block is None:
  69. return
  70. yield next_block
  71. if context.use_numpy:
  72. res_types = [col.dtype if hasattr(col, 'dtype') else 'O' for col in first_block]
  73. return NumpyResult(gen(), tuple(names), tuple(col_types), res_types, source)
  74. return QueryResult(None, gen(), tuple(names), tuple(col_types), context.column_oriented, source)
  75. @staticmethod
  76. def build_insert(context: InsertContext):
  77. compressor = get_compressor(context.compression)
  78. def chunk_gen():
  79. for block in context.next_block():
  80. output = bytearray()
  81. output += block.prefix
  82. write_leb128(block.column_count, output)
  83. write_leb128(block.row_count, output)
  84. for col_name, col_type, data in zip(block.column_names, block.column_types, block.column_data):
  85. col_enc = col_name.encode()
  86. write_leb128(len(col_enc), output)
  87. output += col_enc
  88. col_enc = col_type.name.encode()
  89. write_leb128(len(col_enc), output)
  90. output += col_enc
  91. context.start_column(col_name)
  92. try:
  93. col_type.write_column(data, output, context)
  94. except Exception as ex: # pylint: disable=broad-except
  95. # This is hideous, but some low level serializations can fail while streaming
  96. # the insert if the user has included bad data in the column. We need to ensure that the
  97. # insert fails (using garbage data) to avoid a partial insert, and use the context to
  98. # propagate the correct exception to the user
  99. logger.error('Error serializing column `%s` into data type `%s`',
  100. col_name, col_type.name, exc_info=True)
  101. context.insert_exception = ex
  102. yield 'INTERNAL EXCEPTION WHILE SERIALIZING'.encode()
  103. return
  104. yield compressor.compress_block(output)
  105. footer = compressor.flush()
  106. if footer:
  107. yield footer
  108. return chunk_gen()