Browse Source

Support Decimal128/Decimal256 in Arrow parser

No description

---
b24f71e64f22e615ebb32f33fb2cfc5c88198c1a

Pull Request resolved: https://github.com/ytsaurus/ytsaurus/pull/769

Co-authored-by: nadya02 <nadya02@yandex-team.com>
Maxim Akhmedov 6 months ago
parent
commit
c73194dcb0

+ 15 - 0
yt/yt/library/decimal/decimal.cpp

@@ -535,6 +535,21 @@ TStringBuf TDecimal::WriteBinary128(int precision, TValue128 value, char* buffer
     return TStringBuf{buffer, sizeof(TValue128)};
 }
 
+TStringBuf TDecimal::WriteBinaryVariadic(int precision, TValue128 value, char* buffer, size_t bufferLength)
+{
+    const size_t resultLength = GetValueBinarySize(precision);
+    switch (resultLength) {
+        case 4:
+            return WriteBinary32(precision, static_cast<i32>(value.Low), buffer, bufferLength);
+        case 8:
+            return WriteBinary64(precision, static_cast<i64>(value.Low), buffer, bufferLength);
+        case 16:
+            return WriteBinary128(precision, value, buffer, bufferLength);
+        default:
+            THROW_ERROR_EXCEPTION("Invalid precision %v", precision);
+    }
+}
+
 template <typename T>
 Y_FORCE_INLINE void CheckBufferLength(int precision, size_t bufferLength)
 {

+ 3 - 0
yt/yt/library/decimal/decimal.h

@@ -50,6 +50,9 @@ public:
     static TStringBuf WriteBinary64(int precision, i64 value, char* buffer, size_t bufferLength);
     static TStringBuf WriteBinary128(int precision, TValue128 value, char* buffer, size_t bufferLength);
 
+    // Writes either 32-bit, 64-bit or 128-bit binary value depending on precision, provided a TValue128.
+    static TStringBuf WriteBinaryVariadic(int precision, TValue128 value, char* buffer, size_t bufferLength);
+
     static i32 ParseBinary32(int precision, TStringBuf buffer);
     static i64 ParseBinary64(int precision, TStringBuf buffer);
     static TValue128 ParseBinary128(int precision, TStringBuf buffer);

+ 109 - 9
yt/yt/library/formats/arrow_parser.cpp

@@ -7,6 +7,8 @@
 
 #include <yt/yt/client/formats/parser.h>
 
+#include <yt/yt/library/decimal/decimal.h>
+
 #include <library/cpp/yt/memory/chunked_output_stream.h>
 
 #include <util/stream/buffer.h>
@@ -19,10 +21,13 @@
 
 #include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/api.h>
 
+#include <contrib/libs/apache/arrow/cpp/src/arrow/util/decimal.h>
+
 namespace NYT::NFormats {
 
 using namespace NTableClient;
 using TUnversionedRowValues = std::vector<NTableClient::TUnversionedValue>;
+using namespace NDecimal;
 
 namespace {
 
@@ -31,7 +36,7 @@ namespace {
 void ThrowOnError(const arrow::Status& status)
 {
     if (!status.ok()) {
-        THROW_ERROR_EXCEPTION("Arrow error occurred: %Qv", status.message());
+        THROW_ERROR_EXCEPTION("Arrow error [%v]: %Qv", status.CodeAsString(), status.message());
     }
 }
 
@@ -158,6 +163,31 @@ public:
         return ParseNull();
     }
 
+    // Decimal types. For now, YT natively supports only Decimal128 with scale up to 35.
+    // Thus, we represent short enough decimals as native YT decimals, and wider decimals as
+    // their decimal string representation; but the latter is subject to change whenever we
+    // get the native support for Decimal128 with scale up to 38 or Decimal256 with scale up to 76.
+    arrow::Status Visit(const arrow::Decimal128Type& type) override
+    {
+        constexpr int MaximumYTDecimalPrecision = 35;
+        if (type.precision() <= MaximumYTDecimalPrecision) {
+            return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) {
+                return MakeDecimalBinaryValue(value, columnId, type.precision());
+            });
+        } else {
+            return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) {
+                return MakeDecimalTextValue<arrow::Decimal128>(value, columnId, type.scale());
+            });
+        }
+    }
+
+    arrow::Status Visit(const arrow::Decimal256Type& type) override
+    {
+        return ParseStringLikeArray<arrow::Decimal256Array>([&] (const TStringBuf& value, i64 columnId) {
+            return MakeDecimalTextValue<arrow::Decimal256>(value, columnId, type.scale());
+        });
+    }
+
 private:
     const i64 ColumnId_;
 
@@ -209,7 +239,7 @@ private:
     }
 
     template <typename ArrayType>
-    arrow::Status ParseStringLikeArray()
+    arrow::Status ParseStringLikeArray(auto makeUnversionedValueFunc)
     {
         auto array = std::static_pointer_cast<ArrayType>(Array_);
         for (int rowIndex = 0; rowIndex < array->length(); ++rowIndex) {
@@ -225,12 +255,23 @@ private:
                 BufferForStringLikeValues_->Advance(element.size());
                 auto value = TStringBuf(buffer, element.size());
 
-                (*RowValues_)[rowIndex] = MakeUnversionedStringValue(value, ColumnId_);
+                (*RowValues_)[rowIndex] = makeUnversionedValueFunc(value, ColumnId_);
             }
         }
         return arrow::Status::OK();
     }
 
+    template <typename ArrayType>
+    arrow::Status ParseStringLikeArray()
+    {
+        // Note that MakeUnversionedStringValue actually has third argument in its signature,
+        // which leads to a "too few arguments" in the point of its invocation if we try to pass
+        // it directly to ParseStringLikeArray.
+        return ParseStringLikeArray<ArrayType>([] (const TStringBuf& value, i64 columnId) {
+            return MakeUnversionedStringValue(value, columnId);
+        });
+    }
+
     arrow::Status ParseBoolean()
     {
         auto array = std::static_pointer_cast<arrow::BooleanArray>(Array_);
@@ -252,6 +293,34 @@ private:
         }
         return arrow::Status::OK();
     }
+
+    TUnversionedValue MakeDecimalBinaryValue(const TStringBuf& value, i64 columnId, int precision)
+    {
+        // NB: arrow wire representation of Decimal128 is little-endian and (obviously) 128 bit,
+        // while YT in-memory representation of Decimal is big-endian, variadic-length of either 32 bit, 64 bit or 128 bit,
+        // and MSB-flipped to ensure lexical sorting order.
+        TDecimal::TValue128 value128;
+        YT_VERIFY(value.size() == sizeof(value128));
+        std::memcpy(&value128, value.data(), value.size());
+
+        const auto maxByteCount = sizeof(value128);
+        char* buffer = BufferForStringLikeValues_->Preallocate(maxByteCount);
+        auto decimalBinary = TDecimal::WriteBinaryVariadic(precision, value128, buffer, maxByteCount);
+        BufferForStringLikeValues_->Advance(decimalBinary.size());
+
+        return MakeUnversionedStringValue(decimalBinary, columnId);
+    }
+
+    template <class TArrowDecimalType>
+    TUnversionedValue MakeDecimalTextValue(const TStringBuf& value, i64 columnId, int scale)
+    {
+        TArrowDecimalType decimal(reinterpret_cast<const uint8_t*>(value.data()));
+        auto string = decimal.ToString(scale);
+        char* buffer = BufferForStringLikeValues_->Preallocate(string.size());
+        std::memcpy(buffer, string.data(), string.size());
+        BufferForStringLikeValues_->Advance(string.size());
+        return MakeUnversionedStringValue(TStringBuf(buffer, string.size()), columnId);
+    }
 };
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -552,12 +621,14 @@ private:
 ////////////////////////////////////////////////////////////////////////////////
 
 void CheckArrowType(
+    auto ytTypeOrMetatype,
     const std::shared_ptr<arrow::DataType>& arrowType,
     std::initializer_list<arrow::Type::type> allowedTypes)
 {
     if (std::find(allowedTypes.begin(), allowedTypes.end(), arrowType->id()) == allowedTypes.end()) {
-        THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv",
-            arrowType->name());
+        THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv for YT type or metatype %Qlv",
+            arrowType->name(),
+            ytTypeOrMetatype);
     }
 }
 
@@ -573,6 +644,7 @@ void CheckMatchingArrowTypes(
 
         case ESimpleLogicalValueType::Interval:
             CheckArrowType(
+                columnType,
                 column->type(),
                 {
                     arrow::Type::INT8,
@@ -597,6 +669,7 @@ void CheckMatchingArrowTypes(
         case ESimpleLogicalValueType::Datetime:
         case ESimpleLogicalValueType::Timestamp:
             CheckArrowType(
+                columnType,
                 column->type(),
                 {
                     arrow::Type::UINT8,
@@ -611,6 +684,7 @@ void CheckMatchingArrowTypes(
         case ESimpleLogicalValueType::Json:
         case ESimpleLogicalValueType::Utf8:
             CheckArrowType(
+                columnType,
                 column->type(),
                 {
                     arrow::Type::STRING,
@@ -618,13 +692,16 @@ void CheckMatchingArrowTypes(
                     arrow::Type::LARGE_STRING,
                     arrow::Type::LARGE_BINARY,
                     arrow::Type::FIXED_SIZE_BINARY,
-                    arrow::Type::DICTIONARY
+                    arrow::Type::DICTIONARY,
+                    arrow::Type::DECIMAL128,
+                    arrow::Type::DECIMAL256,
                 });
             break;
 
         case ESimpleLogicalValueType::Float:
         case ESimpleLogicalValueType::Double:
             CheckArrowType(
+                columnType,
                 column->type(),
                 {
                     arrow::Type::HALF_FLOAT,
@@ -636,12 +713,14 @@ void CheckMatchingArrowTypes(
 
         case ESimpleLogicalValueType::Boolean:
             CheckArrowType(
+                columnType,
                 column->type(),
                 {arrow::Type::BOOL, arrow::Type::DICTIONARY});
             break;
 
         case ESimpleLogicalValueType::Any:
             CheckArrowType(
+                columnType,
                 column->type(),
                 {
                     arrow::Type::INT8,
@@ -679,6 +758,7 @@ void CheckMatchingArrowTypes(
         case ESimpleLogicalValueType::Null:
         case ESimpleLogicalValueType::Void:
             CheckArrowType(
+                columnType,
                 column->type(),
                 {
                     arrow::Type::NA,
@@ -688,6 +768,7 @@ void CheckMatchingArrowTypes(
 
         case ESimpleLogicalValueType::Uuid:
             CheckArrowType(
+                columnType,
                 column->type(),
                 {
                     arrow::Type::STRING,
@@ -749,9 +830,10 @@ void PrepareArrayForComplexType(
     int columnIndex,
     int columnId)
 {
-    switch (denullifiedLogicalType->GetMetatype()) {
+    switch (auto metatype = denullifiedLogicalType->GetMetatype()) {
         case ELogicalMetatype::List:
             CheckArrowType(
+                metatype,
                 column->type(),
                 {
                     arrow::Type::LIST,
@@ -761,6 +843,7 @@ void PrepareArrayForComplexType(
 
         case ELogicalMetatype::Dict:
             CheckArrowType(
+                metatype,
                 column->type(),
                 {
                     arrow::Type::MAP,
@@ -770,32 +853,49 @@ void PrepareArrayForComplexType(
 
         case ELogicalMetatype::Struct:
             CheckArrowType(
+                metatype,
                 column->type(),
                 {
                     arrow::Type::STRUCT,
                     arrow::Type::BINARY
                 });
             break;
+
         case ELogicalMetatype::Decimal:
+            CheckArrowType(
+                metatype,
+                column->type(),
+                {
+                    arrow::Type::DECIMAL128,
+                    arrow::Type::DECIMAL256
+                });
+            break;
+
         case ELogicalMetatype::Optional:
         case ELogicalMetatype::Tuple:
         case ELogicalMetatype::VariantTuple:
         case ELogicalMetatype::VariantStruct:
-            CheckArrowType(column->type(), {arrow::Type::BINARY});
+            CheckArrowType(metatype, column->type(), {arrow::Type::BINARY});
             break;
 
         default:
             THROW_ERROR_EXCEPTION("Unexpected arrow type in complex type %Qv", column->type()->name());
     }
 
-    if (column->type()->id() == arrow::Type::BINARY) {
+    if (column->type()->id() == arrow::Type::BINARY ||
+        column->type()->id() == arrow::Type::DECIMAL128 ||
+        column->type()->id() == arrow::Type::DECIMAL256)
+    {
         TUnversionedRowValues stringValues(rowsValues[columnIndex].size());
         TArraySimpleVisitor visitor(columnId, column, bufferForStringLikeValues, &stringValues);
         ThrowOnError(column->type()->Accept(&visitor));
         for (int offset = 0; offset < std::ssize(rowsValues[columnIndex]); offset++) {
             if (column->IsNull(offset)) {
                 rowsValues[columnIndex][offset] = MakeUnversionedNullValue(columnId);
+            } else if (column->type()->id() == arrow::Type::DECIMAL128 || column->type()->id() == arrow::Type::DECIMAL256) {
+                rowsValues[columnIndex][offset] = MakeUnversionedStringValue(stringValues[offset].AsStringBuf(), columnId);
             } else {
+                // TODO(max): is it even correct? Binary is not necessarily a correct YSON...
                 rowsValues[columnIndex][offset] = MakeUnversionedCompositeValue(stringValues[offset].AsStringBuf(), columnId);
             }
         }