#include "mkql_builtins_string_kernels.h" #include "mkql_builtins_impl.h" // Y_IGNORE namespace NKikimr { namespace NMiniKQL { namespace { template constexpr auto GetArgumentsCount(Return(*)(Args...)) noexcept { return sizeof...(Args); } using TUntypedStringBinaryScalarFuncPtr = void(*)(std::string_view, std::string_view, void*); using TUntypedStringBinaryArrayFuncPtr = void(*)(const void* stringOffsets1, const void* data1, const void* stringOffsets2, const void* data2, void* resPtr, int64_t length, int64_t offset1, int64_t offset2); Y_NO_INLINE arrow::Status ExecStringScalarScalarImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter, TUntypedStringBinaryScalarFuncPtr func) { const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; if (!arg1.scalar()->is_valid || !arg2.scalar()->is_valid) { *res = arrow::MakeNullScalar(typeGetter()); } else { auto resDatum = scalarGetter(); const auto resPtr = GetPrimitiveScalarValueMutablePtr(*resDatum.scalar()); const auto val1 = GetStringScalarValue(*arg1.scalar()); const auto val2 = GetStringScalarValue(*arg2.scalar()); func(val1, val2, resPtr); *res = resDatum.scalar(); } return arrow::Status::OK(); } Y_NO_INLINE arrow::Status ExecStringScalarArrayImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, TUntypedStringBinaryArrayFuncPtr func) { const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; auto& resArr = *res->array(); if (arg1.scalar()->is_valid) { const auto val1 = GetStringScalarValue(*arg1.scalar()); const auto& arr2 = *arg2.array(); auto length = arr2.length; auto resPtr = resArr.buffers[1]->mutable_data(); const size_t val1Size = val1.size(); const auto offsets2 = arr2.buffers[1]->data(); const auto data2 = arr2.buffers[2]->data(); func(&val1Size, val1.data(), offsets2, data2, resPtr, length, 0, arr2.offset); } return arrow::Status::OK(); } Y_NO_INLINE arrow::Status ExecStringArrayScalarImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, TUntypedStringBinaryArrayFuncPtr func) { const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; auto& resArr = *res->array(); if (arg2.scalar()->is_valid) { const auto val2 = GetStringScalarValue(*arg2.scalar()); const auto& arr1 = *arg1.array(); auto length = arr1.length; auto resPtr = resArr.buffers[1]->mutable_data(); const size_t val2Size = val2.size(); const auto offsets1 = arr1.buffers[1]->data(); const auto data1 = arr1.buffers[2]->data(); func(offsets1, data1, &val2Size, val2.data(), resPtr, length, arr1.offset, 0); } return arrow::Status::OK(); } Y_NO_INLINE arrow::Status ExecStringArrayArrayImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, TUntypedStringBinaryArrayFuncPtr func) { const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; const auto& arr1 = *arg1.array(); const auto& arr2 = *arg2.array(); auto& resArr = *res->array(); MKQL_ENSURE(arr1.length == arr2.length, "Expected same length"); auto length = arr1.length; auto resPtr = resArr.buffers[1]->mutable_data(); const auto offsets1 = arr1.buffers[1]->data(); const auto offsets2 = arr2.buffers[1]->data(); const auto data1 = arr1.buffers[2]->data(); const auto data2 = arr2.buffers[2]->data(); func(offsets1, data1, offsets2, data2, resPtr, length, arr1.offset, arr2.offset); return arrow::Status::OK(); } Y_NO_INLINE arrow::Status ExecStringBinaryImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter, TUntypedStringBinaryScalarFuncPtr scalarScalarFunc, TUntypedStringBinaryArrayFuncPtr scalarArrayFunc, TUntypedStringBinaryArrayFuncPtr arrayScalarFunc, TUntypedStringBinaryArrayFuncPtr arrayArrayFunc) { MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; if (arg1.is_scalar()) { if (arg2.is_scalar()) { return ExecStringScalarScalarImpl(batch, res, typeGetter, scalarGetter, scalarScalarFunc); } else { return ExecStringScalarArrayImpl(batch, res, scalarArrayFunc); } } else { if (arg2.is_scalar()) { return ExecStringArrayScalarImpl(batch, res, arrayScalarFunc); } else { return ExecStringArrayArrayImpl(batch, res, arrayArrayFunc); } } } template struct TBinaryStringExecs { using TOffset1 = typename TPrimitiveDataType::TResult::offset_type; using TOffset2 = typename TPrimitiveDataType::TResult::offset_type; using TTypedStringBinaryScalarFuncPtr = void(*)(std::string_view, std::string_view, TOutput*); using TTypedStringBinaryArrayFuncPtr = void(*)(const TOffset1* stringOffsets1, const char* data1, const TOffset2* stringOffsets2, const char* data2, TOutput* resPtr, int64_t length, int64_t offset1, int64_t offset2); static void ScalarScalarCore(std::string_view arg1, std::string_view arg2, TOutput* resPtr) { *resPtr = TOp::Do(arg1, arg2); } static void ScalarArrayCore(const TOffset1* stringOffsets1, const char* data1, const TOffset2* stringOffsets2, const char* data2, TOutput* resPtr, int64_t length, int64_t offset1, int64_t offset2) { Y_UNUSED(offset1); const auto val1 = std::string_view(data1, *(const size_t*)stringOffsets1); stringOffsets2 += offset2; if (val1.empty()) { if constexpr (GetArgumentsCount(TOp::DoWithEmptyLeft) == 0) { std::fill(resPtr, resPtr + length, TOp::DoWithEmptyLeft()); } else { for (int64_t i = 0; i < length; ++i, ++resPtr, ++stringOffsets2) { *resPtr = TOp::DoWithEmptyLeft(stringOffsets2[1] - stringOffsets2[0]); } } } else { for (int64_t i = 0; i < length; ++i, ++resPtr, ++stringOffsets2) { std::string_view val2(data2 + stringOffsets2[0], stringOffsets2[1] - stringOffsets2[0]); *resPtr = TOp::Do(val1, val2); } } } static void ArrayScalarCore(const TOffset1* stringOffsets1, const char* data1, const TOffset2* stringOffsets2, const char* data2, TOutput* resPtr, int64_t length, int64_t offset1, int64_t offset2) { Y_UNUSED(offset2); const auto val2 = std::string_view(data2, *(const size_t*)stringOffsets2); stringOffsets1 += offset1; if (val2.empty()) { if constexpr (GetArgumentsCount(TOp::DoWithEmptyRight) == 0) { std::fill(resPtr, resPtr + length, TOp::DoWithEmptyRight()); } else { for (int64_t i = 0; i < length; ++i, ++resPtr, ++stringOffsets1) { *resPtr = TOp::DoWithEmptyRight(stringOffsets1[1] - stringOffsets1[0]); } } } else { for (int64_t i = 0; i < length; ++i, ++resPtr, ++stringOffsets1) { std::string_view val1(data1 + stringOffsets1[0], stringOffsets1[1] - stringOffsets1[0]); *resPtr = TOp::Do(val1, val2); } } } static void ArrayArrayCore(const TOffset1* stringOffsets1, const char* data1, const TOffset2* stringOffsets2, const char* data2, TOutput* resPtr, int64_t length, int64_t offset1, int64_t offset2) { stringOffsets1 += offset1; stringOffsets2 += offset2; for (int64_t i = 0; i < length; ++i, ++stringOffsets1, ++stringOffsets2, ++resPtr) { std::string_view val1(data1 + stringOffsets1[0], stringOffsets1[1] - stringOffsets1[0]); std::string_view val2(data2 + stringOffsets2[0], stringOffsets2[1] - stringOffsets2[0]); *resPtr = TOp::Do(val1, val2); } } static arrow::Status Exec(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { static_assert(!std::is_same::value); TTypedStringBinaryScalarFuncPtr scalarScalarFunc = &ScalarScalarCore; TTypedStringBinaryArrayFuncPtr scalarArrayFunc = &ScalarArrayCore; TTypedStringBinaryArrayFuncPtr arrayScalarFunc = &ArrayScalarCore; TTypedStringBinaryArrayFuncPtr arrayArrayFunc = &ArrayArrayCore; return ExecStringBinaryImpl(batch, res, &GetPrimitiveDataType, &MakeDefaultScalarDatum, (TUntypedStringBinaryScalarFuncPtr)scalarScalarFunc, (TUntypedStringBinaryArrayFuncPtr)scalarArrayFunc, (TUntypedStringBinaryArrayFuncPtr)arrayScalarFunc, (TUntypedStringBinaryArrayFuncPtr)arrayArrayFunc); } }; using TUntypedStringUnaryScalarFuncPtr = void(*)(std::string_view, void*); using TUntypedStringUnaryArrayFuncPtr = void(*)(const void* stringOffsets, const void* data, void* resPtr, int64_t length, int64_t offset); Y_NO_INLINE arrow::Status ExecStringScalarImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter, TUntypedStringUnaryScalarFuncPtr func) { const auto& arg = batch.values[0]; if (!arg.scalar()->is_valid) { *res = arrow::MakeNullScalar(typeGetter()); } else { auto resDatum = scalarGetter(); const auto resPtr = GetPrimitiveScalarValueMutablePtr(*resDatum.scalar()); const auto val = GetStringScalarValue(*arg.scalar()); func(val, resPtr); *res = resDatum.scalar(); } return arrow::Status::OK(); } Y_NO_INLINE arrow::Status ExecStringArrayImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, TUntypedStringUnaryArrayFuncPtr func) { const auto& arg = batch.values[0]; auto& resArr = *res->array(); const auto& arr = *arg.array(); const auto length = arr.length; const auto resValues = resArr.buffers[1]->mutable_data(); const auto offsets = arr.buffers[1]->data(); const auto data = arr.buffers[2]->data(); func(offsets, data, resValues, length, arr.offset); return arrow::Status::OK(); } Y_NO_INLINE arrow::Status ExecStringUnaryImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter, TUntypedStringUnaryScalarFuncPtr scalarFunc, TUntypedStringUnaryArrayFuncPtr arrayFunc) { MKQL_ENSURE(batch.values.size() == 1, "Expected single argument"); const auto& arg = batch.values[0]; if (arg.is_scalar()) { return ExecStringScalarImpl(batch, res, typeGetter, scalarGetter, scalarFunc); } else { return ExecStringArrayImpl(batch, res, arrayFunc); } } template struct TUnaryStringExecs { using TOffset = typename TPrimitiveDataType::TResult::offset_type; using TTypedStringUnaryScalarFuncPtr = void(*)(std::string_view, TOutput* resPtr); using TTypedStringUnaryArrayFuncPtr = void(*)(const TOffset* offsets, const char* data, TOutput* resPtr, int64_t length, int64_t offset); static void ScalarCore(std::string_view arg, TOutput* resPtr) { *resPtr = TOp::Do(arg); } static void ArrayCore(const TOffset* stringOffsets, const char* data, TOutput* resPtr, int64_t length, int64_t offset) { stringOffsets += offset; for (int64_t i = 0; i < length; ++i, ++stringOffsets, ++resPtr) { std::string_view val(data + stringOffsets[0], stringOffsets[1] - stringOffsets[0]); *resPtr = TOp::Do(val); } } static arrow::Status Exec(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { TTypedStringUnaryScalarFuncPtr scalarFunc = &ScalarCore; TTypedStringUnaryArrayFuncPtr arrayFunc = &ArrayCore; return ExecStringUnaryImpl(batch, res, &GetPrimitiveDataType, &MakeDefaultScalarDatum, (TUntypedStringUnaryScalarFuncPtr)scalarFunc, (TUntypedStringUnaryArrayFuncPtr)arrayFunc); } }; // ------------------------------------------------------------------------------------- // String comparison // ------------------------------------------------------------------------------------- struct TStrEqualsOp { static inline bool Do(std::string_view left, std::string_view right) { return left == right; } static inline bool DoWithEmptyLeft(size_t rightLen) { return rightLen == 0; } static inline bool DoWithEmptyRight(size_t leftLen) { return leftLen == 0; } }; struct TStrNotEqualsOp { static inline bool Do(std::string_view left, std::string_view right) { return left != right; } static inline bool DoWithEmptyLeft(size_t rightLen) { return rightLen != 0; } static inline bool DoWithEmptyRight(size_t leftLen) { return leftLen != 0; } }; struct TStrLessOp { static inline bool Do(std::string_view left, std::string_view right) { return left < right; } static inline bool DoWithEmptyLeft(size_t rightLen) { return rightLen != 0; } static constexpr bool DoWithEmptyRight() { return false; } }; struct TStrLessOrEqualOp { static inline bool Do(std::string_view left, std::string_view right) { return left <= right; } static constexpr bool DoWithEmptyLeft() { return true; } static inline bool DoWithEmptyRight(size_t leftLen) { return leftLen == 0; } }; struct TStrGreaterOp { static inline bool Do(std::string_view left, std::string_view right) { return left > right; } static constexpr bool DoWithEmptyLeft() { return false; } static inline bool DoWithEmptyRight(size_t leftLen) { return leftLen != 0; } }; struct TStrGreaterOrEqualOp { static inline bool Do(std::string_view left, std::string_view right) { return left >= right; } static inline bool DoWithEmptyLeft(size_t rightLen) { return rightLen == 0; } static constexpr bool DoWithEmptyRight() { return true; } }; struct TStrStartsWithOp { static inline bool Do(std::string_view left, std::string_view right) { return left.starts_with(right); } static inline bool DoWithEmptyLeft(size_t rightLen) { return rightLen == 0; } static constexpr bool DoWithEmptyRight() { return true; } }; struct TStrEndsWithOp { static inline bool Do(std::string_view left, std::string_view right) { return left.ends_with(right); } static inline bool DoWithEmptyLeft(size_t rightLen) { return rightLen == 0; } static constexpr bool DoWithEmptyRight() { return true; } }; struct TStrContainsOp { static inline bool Do(std::string_view left, std::string_view right) { return left.contains(right); } static inline bool DoWithEmptyLeft(size_t rightLen) { return rightLen == 0; } static constexpr bool DoWithEmptyRight() { return true; } }; Y_NO_INLINE void AddCompareStringKernelImpl(TKernelFamilyBase& kernelFamily, NUdf::TDataTypeId type1, NUdf::TDataTypeId type2, const arrow::compute::ArrayKernelExec& exec, arrow::compute::InputType&& inputType1, arrow::compute::InputType&& inputType2, arrow::compute::OutputType&& outputType) { std::vector argTypes({ type1, type2 }); NUdf::TDataTypeId returnType = NUdf::TDataType::Id; auto k = std::make_unique(std::vector{ inputType1, inputType2 }, outputType, exec); k->null_handling = arrow::compute::NullHandling::INTERSECTION; kernelFamily.Adopt(argTypes, returnType, std::make_unique(kernelFamily, argTypes, returnType, std::move(k), TKernel::ENullMode::Default)); } template void AddCompareStringKernel(TKernelFamilyBase& kernelFamily) { // ui8 type is used as bool replacement using TOutput = ui8; using TExecs = TBinaryStringExecs; AddCompareStringKernelImpl(kernelFamily, NUdf::TDataType::Id, NUdf::TDataType::Id, &TExecs::Exec, GetPrimitiveInputArrowType(), GetPrimitiveInputArrowType(), GetPrimitiveOutputArrowType() ); } template void AddCompareStringKernels(TKernelFamilyBase& kernelFamily) { AddCompareStringKernel(kernelFamily); AddCompareStringKernel(kernelFamily); AddCompareStringKernel(kernelFamily); AddCompareStringKernel(kernelFamily); } // ------------------------------------------------------------------------------------- // String size // ------------------------------------------------------------------------------------- template struct TStrSizeOp { static inline TOutput Do(std::string_view input) { return static_cast(input.size()); } }; Y_NO_INLINE void AddSizeStringKernelImpl(TKernelFamilyBase& kernelFamily, NUdf::TDataTypeId type1, NUdf::TDataTypeId returnType, const arrow::compute::ArrayKernelExec& exec, arrow::compute::InputType&& inputType1, arrow::compute::OutputType&& outputType) { std::vector argTypes({ type1 }); auto k = std::make_unique(std::vector{ inputType1 }, outputType, exec); k->null_handling = arrow::compute::NullHandling::INTERSECTION; kernelFamily.Adopt(argTypes, returnType, std::make_unique(kernelFamily, argTypes, returnType, std::move(k), TKernel::ENullMode::Default)); } template void AddSizeStringKernel(TKernelFamilyBase& kernelFamily) { using TOutput = ui32; using TOp = TStrSizeOp; using TExecs = TUnaryStringExecs; AddSizeStringKernelImpl(kernelFamily, NUdf::TDataType::Id, NUdf::TDataType::Id, &TExecs::Exec, GetPrimitiveInputArrowType(), GetPrimitiveOutputArrowType()); } } // namespace void RegisterStringKernelEquals(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterStringKernelNotEquals(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterStringKernelLess(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterStringKernelLessOrEqual(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterStringKernelGreater(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterStringKernelGreaterOrEqual(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterStringKernelSize(TKernelFamilyBase& kernelFamily) { AddSizeStringKernel(kernelFamily); AddSizeStringKernel(kernelFamily); } void RegisterStringKernelStartsWith(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterStringKernelEndsWith(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterStringKernelContains(TKernelFamilyBase& kernelFamily) { AddCompareStringKernels(kernelFamily); } void RegisterSizeBuiltin(TKernelFamilyMap& kernelFamilyMap) { auto family = std::make_unique(); RegisterStringKernelSize(*family); kernelFamilyMap["Size"] = std::move(family); } void RegisterWith(TKernelFamilyMap& kernelFamilyMap) { auto family = std::make_unique(); RegisterStringKernelStartsWith(*family); kernelFamilyMap["StartsWith"] = std::move(family); family = std::make_unique(); RegisterStringKernelEndsWith(*family); kernelFamilyMap["EndsWith"] = std::move(family); family = std::make_unique(); RegisterStringKernelContains(*family); kernelFamilyMap["StringContains"] = std::move(family); } } }