#include "mkql_block_decimal.h" #include #include #include #include #include #include #include #include #include #include namespace NKikimr { namespace NMiniKQL { namespace { template struct TDecimalBlockExec { NYql::NDecimal::TInt128 Apply(NYql::NDecimal::TInt128 left, TRight right) const { return static_cast(this)->Do(left, right); } template const U* GetScalarValue(const arrow::Scalar& scalar) const { return reinterpret_cast(GetPrimitiveScalarValuePtr(scalar)); } template<> const NYql::NDecimal::TInt128* GetScalarValue(const arrow::Scalar& scalar) const { return reinterpret_cast(GetStringScalarValue(scalar).data()); } void ArrayScalarCore( const NYql::NDecimal::TInt128* val1Ptr, const ui8* valid1, const TRight* val2Ptr, const ui8* valid2, NYql::NDecimal::TInt128* resPtr, ui8* resValid, int64_t length, int64_t offset1, int64_t offset2) const { val1Ptr += offset1; Y_UNUSED(valid2); Y_UNUSED(offset2); for (int64_t i = 0; i < length; ++i, ++val1Ptr, ++resPtr) { if (!valid1 || arrow::BitUtil::GetBit(valid1, i + offset1)) { *resPtr = Apply(*val1Ptr, *val2Ptr); arrow::BitUtil::SetBit(resValid, i); } else { arrow::BitUtil::ClearBit(resValid, i); } } } void ScalarArrayCore( const NYql::NDecimal::TInt128* val1Ptr, const ui8* valid1, const TRight* val2Ptr, const ui8* valid2, NYql::NDecimal::TInt128* resPtr, ui8* resValid, int64_t length, int64_t offset1, int64_t offset2) const { val2Ptr += offset2; Y_UNUSED(valid1); Y_UNUSED(offset1); for (int64_t i = 0; i < length; ++i, ++val2Ptr, ++resPtr) { if (!valid2 || arrow::BitUtil::GetBit(valid2, i + offset2)) { *resPtr = Apply(*val1Ptr, *val2Ptr); arrow::BitUtil::SetBit(resValid, i); } else { arrow::BitUtil::ClearBit(resValid, i); } } } void ArrayArrayCore( const NYql::NDecimal::TInt128* val1Ptr, const ui8* valid1, const TRight* val2Ptr, const ui8* valid2, NYql::NDecimal::TInt128* resPtr, ui8* resValid, int64_t length, int64_t offset1, int64_t offset2) const { val1Ptr += offset1; val2Ptr += offset2; for (int64_t i = 0; i < length; ++i, ++val1Ptr, ++val2Ptr, ++resPtr) { if ((!valid1 || arrow::BitUtil::GetBit(valid1, i + offset1)) && (!valid2 || arrow::BitUtil::GetBit(valid2, i + offset2))) { *resPtr = Apply(*val1Ptr, *val2Ptr); arrow::BitUtil::SetBit(resValid, i); } else { arrow::BitUtil::ClearBit(resValid, i); } } } arrow::Status ExecScalarScalar(arrow::compute::KernelContext* kernelCtx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; if (!arg1.scalar()->is_valid || !arg2.scalar()->is_valid) { *res = arrow::MakeNullScalar(GetPrimitiveDataType()); } else { const auto val1Ptr = GetScalarValue(*arg1.scalar()); const auto val2Ptr = GetScalarValue(*arg2.scalar()); std::shared_ptr buffer(ARROW_RESULT(arrow::AllocateBuffer(16, kernelCtx->memory_pool()))); auto* mem = reinterpret_cast(buffer->mutable_data()); auto resDatum = arrow::Datum(std::make_shared::TScalarResult>(buffer)); *mem = Apply(*val1Ptr, *val2Ptr); *res = resDatum; } return arrow::Status::OK(); } arrow::Status ExecScalarArray(const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; auto& resArr = *res->array(); if (arg1.scalar()->is_valid) { const auto val1Ptr = GetScalarValue(*arg1.scalar()); const auto& arr2 = *arg2.array(); auto length = arr2.length; const auto val2Ptr = reinterpret_cast(arr2.buffers[1]->data()); const auto nullCount2 = arr2.GetNullCount(); const auto valid2 = (nullCount2 == 0) ? nullptr : arr2.GetValues(0); auto resPtr = reinterpret_cast(resArr.buffers[1]->mutable_data()); auto resValid = res->array()->GetMutableValues(0); ScalarArrayCore(val1Ptr, nullptr, val2Ptr, valid2, resPtr, resValid, length, 0, arr2.offset); } else { GetBitmap(resArr, 0).SetBitsTo(false); } return arrow::Status::OK(); } arrow::Status ExecArrayScalar(const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; auto& resArr = *res->array(); if (arg2.scalar()->is_valid) { const auto& arr1 = *arg1.array(); const auto val1Ptr = reinterpret_cast(arr1.buffers[1]->data()); auto length = arr1.length; const auto nullCount1 = arr1.GetNullCount(); const auto valid1 = (nullCount1 == 0) ? nullptr : arr1.GetValues(0); const auto val2Ptr = GetScalarValue(*arg2.scalar()); auto resPtr = reinterpret_cast(resArr.buffers[1]->mutable_data()); auto resValid = res->array()->GetMutableValues(0); ArrayScalarCore(val1Ptr, valid1, val2Ptr, nullptr, resPtr, resValid, length, arr1.offset, 0); } else { GetBitmap(resArr, 0).SetBitsTo(false); } return arrow::Status::OK(); } arrow::Status ExecArrayArray(const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; const auto& arr1 = *arg1.array(); const auto& arr2 = *arg2.array(); MKQL_ENSURE(arr1.length == arr2.length, "Expected same length"); auto length = arr1.length; const auto val1Ptr = reinterpret_cast(arr1.buffers[1]->data()); const auto nullCount1 = arr1.GetNullCount(); const auto valid1 = (nullCount1 == 0) ? nullptr : arr1.GetValues(0); const auto val2Ptr = reinterpret_cast(arr2.buffers[1]->data()); const auto nullCount2 = arr2.GetNullCount(); const auto valid2 = (nullCount2 == 0) ? nullptr : arr2.GetValues(0); auto& resArr = *res->array(); auto resPtr = reinterpret_cast(resArr.buffers[1]->mutable_data()); auto resValid = res->array()->GetMutableValues(0); ArrayArrayCore(val1Ptr, valid1, val2Ptr, valid2, resPtr, resValid, length, arr1.offset, arr2.offset); return arrow::Status::OK(); } arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { Y_UNUSED(ctx); MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; if (arg1.is_scalar()) { if (arg2.is_scalar()) { return ExecScalarScalar(ctx, batch, res); } else { return ExecScalarArray(batch, res); } } else { if (arg2.is_scalar()) { return ExecArrayScalar(batch, res); } else { return ExecArrayArray(batch, res); } } return arrow::Status::OK(); } }; template struct TDecimalMulBlockExec: NYql::NDecimal::TDecimalMultiplicator, TDecimalBlockExec, TRight> { TDecimalMulBlockExec( ui8 precision, ui8 scale) : NYql::NDecimal::TDecimalMultiplicator(precision, scale) { } }; template struct TDecimalDivBlockExec: NYql::NDecimal::TDecimalDivisor, TDecimalBlockExec, TRight> { TDecimalDivBlockExec( ui8 precision, ui8 scale) : NYql::NDecimal::TDecimalDivisor(precision, scale) { } }; template struct TDecimalModBlockExec: NYql::NDecimal::TDecimalRemainder, TDecimalBlockExec, TRight> { TDecimalModBlockExec( ui8 precision, ui8 scale) : NYql::NDecimal::TDecimalRemainder(precision, scale) { } }; template