mkql_block_decimal.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #include "mkql_block_decimal.h"
  2. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  3. #include <yql/essentials/minikql/mkql_type_builder.h>
  4. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  5. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  6. #include <yql/essentials/minikql/mkql_node_builder.h>
  7. #include <yql/essentials/minikql/mkql_node_cast.h>
  8. #include <yql/essentials/public/udf/arrow/block_builder.h>
  9. #include <yql/essentials/public/udf/arrow/block_reader.h>
  10. #include <yql/essentials/public/udf/arrow/util.h>
  11. #include <yql/essentials/public/decimal/yql_decimal.h>
  12. namespace NKikimr {
  13. namespace NMiniKQL {
  14. namespace {
  15. template<typename T, typename TRight>
  16. struct TDecimalBlockExec {
  17. NYql::NDecimal::TInt128 Apply(NYql::NDecimal::TInt128 left, TRight right) const {
  18. return static_cast<const T*>(this)->Do(left, right);
  19. }
  20. template<typename U>
  21. const U* GetScalarValue(const arrow::Scalar& scalar) const {
  22. return reinterpret_cast<const U*>(GetPrimitiveScalarValuePtr(scalar));
  23. }
  24. template<>
  25. const NYql::NDecimal::TInt128* GetScalarValue<NYql::NDecimal::TInt128>(const arrow::Scalar& scalar) const {
  26. return reinterpret_cast<const NYql::NDecimal::TInt128*>(GetStringScalarValue(scalar).data());
  27. }
  28. void ArrayScalarCore(
  29. const NYql::NDecimal::TInt128* val1Ptr,
  30. const ui8* valid1,
  31. const TRight* val2Ptr,
  32. const ui8* valid2,
  33. NYql::NDecimal::TInt128* resPtr,
  34. ui8* resValid,
  35. int64_t length,
  36. int64_t offset1,
  37. int64_t offset2) const {
  38. val1Ptr += offset1;
  39. Y_UNUSED(valid2);
  40. Y_UNUSED(offset2);
  41. for (int64_t i = 0; i < length; ++i, ++val1Ptr, ++resPtr) {
  42. if (!valid1 || arrow::BitUtil::GetBit(valid1, i + offset1)) {
  43. *resPtr = Apply(*val1Ptr, *val2Ptr);
  44. arrow::BitUtil::SetBit(resValid, i);
  45. } else {
  46. arrow::BitUtil::ClearBit(resValid, i);
  47. }
  48. }
  49. }
  50. void ScalarArrayCore(
  51. const NYql::NDecimal::TInt128* val1Ptr,
  52. const ui8* valid1,
  53. const TRight* val2Ptr,
  54. const ui8* valid2,
  55. NYql::NDecimal::TInt128* resPtr,
  56. ui8* resValid,
  57. int64_t length,
  58. int64_t offset1,
  59. int64_t offset2) const {
  60. val2Ptr += offset2;
  61. Y_UNUSED(valid1);
  62. Y_UNUSED(offset1);
  63. for (int64_t i = 0; i < length; ++i, ++val2Ptr, ++resPtr) {
  64. if (!valid2 || arrow::BitUtil::GetBit(valid2, i + offset2)) {
  65. *resPtr = Apply(*val1Ptr, *val2Ptr);
  66. arrow::BitUtil::SetBit(resValid, i);
  67. } else {
  68. arrow::BitUtil::ClearBit(resValid, i);
  69. }
  70. }
  71. }
  72. void ArrayArrayCore(
  73. const NYql::NDecimal::TInt128* val1Ptr,
  74. const ui8* valid1,
  75. const TRight* val2Ptr,
  76. const ui8* valid2,
  77. NYql::NDecimal::TInt128* resPtr,
  78. ui8* resValid,
  79. int64_t length,
  80. int64_t offset1,
  81. int64_t offset2) const
  82. {
  83. val1Ptr += offset1;
  84. val2Ptr += offset2;
  85. for (int64_t i = 0; i < length; ++i, ++val1Ptr, ++val2Ptr, ++resPtr) {
  86. if ((!valid1 || arrow::BitUtil::GetBit(valid1, i + offset1)) &&
  87. (!valid2 || arrow::BitUtil::GetBit(valid2, i + offset2))) {
  88. *resPtr = Apply(*val1Ptr, *val2Ptr);
  89. arrow::BitUtil::SetBit(resValid, i);
  90. } else {
  91. arrow::BitUtil::ClearBit(resValid, i);
  92. }
  93. }
  94. }
  95. arrow::Status ExecScalarScalar(arrow::compute::KernelContext* kernelCtx,
  96. const arrow::compute::ExecBatch& batch, arrow::Datum* res) const
  97. {
  98. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  99. const auto& arg1 = batch.values[0];
  100. const auto& arg2 = batch.values[1];
  101. if (!arg1.scalar()->is_valid || !arg2.scalar()->is_valid) {
  102. *res = arrow::MakeNullScalar(GetPrimitiveDataType<NYql::NDecimal::TInt128>());
  103. } else {
  104. const auto val1Ptr = GetScalarValue<NYql::NDecimal::TInt128>(*arg1.scalar());
  105. const auto val2Ptr = GetScalarValue<TRight>(*arg2.scalar());
  106. std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(16, kernelCtx->memory_pool())));
  107. auto* mem = reinterpret_cast<NYql::NDecimal::TInt128*>(buffer->mutable_data());
  108. auto resDatum = arrow::Datum(std::make_shared<TPrimitiveDataType<NYql::NDecimal::TInt128>::TScalarResult>(buffer));
  109. *mem = Apply(*val1Ptr, *val2Ptr);
  110. *res = resDatum;
  111. }
  112. return arrow::Status::OK();
  113. }
  114. arrow::Status ExecScalarArray(const arrow::compute::ExecBatch& batch, arrow::Datum* res) const
  115. {
  116. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  117. const auto& arg1 = batch.values[0];
  118. const auto& arg2 = batch.values[1];
  119. auto& resArr = *res->array();
  120. if (arg1.scalar()->is_valid) {
  121. const auto val1Ptr = GetScalarValue<NYql::NDecimal::TInt128>(*arg1.scalar());
  122. const auto& arr2 = *arg2.array();
  123. auto length = arr2.length;
  124. const auto val2Ptr = reinterpret_cast<const TRight*>(arr2.buffers[1]->data());
  125. const auto nullCount2 = arr2.GetNullCount();
  126. const auto valid2 = (nullCount2 == 0) ? nullptr : arr2.GetValues<uint8_t>(0);
  127. auto resPtr = reinterpret_cast<NYql::NDecimal::TInt128*>(resArr.buffers[1]->mutable_data());
  128. auto resValid = res->array()->GetMutableValues<uint8_t>(0);
  129. ScalarArrayCore(val1Ptr, nullptr, val2Ptr, valid2, resPtr, resValid, length, 0, arr2.offset);
  130. } else {
  131. GetBitmap(resArr, 0).SetBitsTo(false);
  132. }
  133. return arrow::Status::OK();
  134. }
  135. arrow::Status ExecArrayScalar(const arrow::compute::ExecBatch& batch, arrow::Datum* res) const
  136. {
  137. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  138. const auto& arg1 = batch.values[0];
  139. const auto& arg2 = batch.values[1];
  140. auto& resArr = *res->array();
  141. if (arg2.scalar()->is_valid) {
  142. const auto& arr1 = *arg1.array();
  143. const auto val1Ptr = reinterpret_cast<const NYql::NDecimal::TInt128*>(arr1.buffers[1]->data());
  144. auto length = arr1.length;
  145. const auto nullCount1 = arr1.GetNullCount();
  146. const auto valid1 = (nullCount1 == 0) ? nullptr : arr1.GetValues<uint8_t>(0);
  147. const auto val2Ptr = GetScalarValue<TRight>(*arg2.scalar());
  148. auto resPtr = reinterpret_cast<NYql::NDecimal::TInt128*>(resArr.buffers[1]->mutable_data());
  149. auto resValid = res->array()->GetMutableValues<uint8_t>(0);
  150. ArrayScalarCore(val1Ptr, valid1, val2Ptr, nullptr, resPtr, resValid, length, arr1.offset, 0);
  151. } else {
  152. GetBitmap(resArr, 0).SetBitsTo(false);
  153. }
  154. return arrow::Status::OK();
  155. }
  156. arrow::Status ExecArrayArray(const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
  157. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  158. const auto& arg1 = batch.values[0];
  159. const auto& arg2 = batch.values[1];
  160. const auto& arr1 = *arg1.array();
  161. const auto& arr2 = *arg2.array();
  162. MKQL_ENSURE(arr1.length == arr2.length, "Expected same length");
  163. auto length = arr1.length;
  164. const auto val1Ptr = reinterpret_cast<const NYql::NDecimal::TInt128*>(arr1.buffers[1]->data());
  165. const auto nullCount1 = arr1.GetNullCount();
  166. const auto valid1 = (nullCount1 == 0) ? nullptr : arr1.GetValues<uint8_t>(0);
  167. const auto val2Ptr = reinterpret_cast<const TRight*>(arr2.buffers[1]->data());
  168. const auto nullCount2 = arr2.GetNullCount();
  169. const auto valid2 = (nullCount2 == 0) ? nullptr : arr2.GetValues<uint8_t>(0);
  170. auto& resArr = *res->array();
  171. auto resPtr = reinterpret_cast<NYql::NDecimal::TInt128*>(resArr.buffers[1]->mutable_data());
  172. auto resValid = res->array()->GetMutableValues<uint8_t>(0);
  173. ArrayArrayCore(val1Ptr, valid1, val2Ptr, valid2, resPtr, resValid, length, arr1.offset, arr2.offset);
  174. return arrow::Status::OK();
  175. }
  176. arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
  177. Y_UNUSED(ctx);
  178. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  179. const auto& arg1 = batch.values[0];
  180. const auto& arg2 = batch.values[1];
  181. if (arg1.is_scalar()) {
  182. if (arg2.is_scalar()) {
  183. return ExecScalarScalar(ctx, batch, res);
  184. } else {
  185. return ExecScalarArray(batch, res);
  186. }
  187. } else {
  188. if (arg2.is_scalar()) {
  189. return ExecArrayScalar(batch, res);
  190. } else {
  191. return ExecArrayArray(batch, res);
  192. }
  193. }
  194. return arrow::Status::OK();
  195. }
  196. };
  197. template<typename TRight>
  198. struct TDecimalMulBlockExec: NYql::NDecimal::TDecimalMultiplicator<TRight>, TDecimalBlockExec<TDecimalMulBlockExec<TRight>, TRight> {
  199. TDecimalMulBlockExec(
  200. ui8 precision,
  201. ui8 scale)
  202. : NYql::NDecimal::TDecimalMultiplicator<TRight>(precision, scale)
  203. { }
  204. };
  205. template<typename TRight>
  206. struct TDecimalDivBlockExec: NYql::NDecimal::TDecimalDivisor<TRight>, TDecimalBlockExec<TDecimalDivBlockExec<TRight>, TRight> {
  207. TDecimalDivBlockExec(
  208. ui8 precision,
  209. ui8 scale)
  210. : NYql::NDecimal::TDecimalDivisor<TRight>(precision, scale)
  211. { }
  212. };
  213. template<typename TRight>
  214. struct TDecimalModBlockExec: NYql::NDecimal::TDecimalRemainder<TRight>, TDecimalBlockExec<TDecimalModBlockExec<TRight>, TRight> {
  215. TDecimalModBlockExec(
  216. ui8 precision,
  217. ui8 scale)
  218. : NYql::NDecimal::TDecimalRemainder<TRight>(precision, scale)
  219. { }
  220. };
  221. template<template <typename> class TExec>
  222. std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockKernel(const TVector<TType*>& argTypes, TType* resultType) {
  223. MKQL_ENSURE(argTypes.size() == 2, "Require 2 arguments");
  224. MKQL_ENSURE(argTypes[0]->GetKind() == TType::EKind::Block, "Require block");
  225. MKQL_ENSURE(argTypes[1]->GetKind() == TType::EKind::Block, "Require block");
  226. MKQL_ENSURE(resultType->GetKind() == TType::EKind::Block, "Require block");
  227. bool isOptional = false;
  228. auto dataType1 = UnpackOptionalData(static_cast<TBlockType*>(argTypes[0])->GetItemType(), isOptional);
  229. auto dataType2 = UnpackOptionalData(static_cast<TBlockType*>(argTypes[1])->GetItemType(), isOptional);
  230. auto dataResultType = UnpackOptionalData(static_cast<TBlockType*>(resultType)->GetItemType(), isOptional);
  231. MKQL_ENSURE(*dataType1->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal");
  232. MKQL_ENSURE(*dataResultType->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal");
  233. auto decimalType1 = static_cast<TDataDecimalType*>(dataType1);
  234. auto decimalResultType = static_cast<TDataDecimalType*>(dataResultType);
  235. MKQL_ENSURE(decimalType1->GetParams() == decimalResultType->GetParams(), "Require same precision/scale");
  236. auto [precision, scale] = decimalType1->GetParams();
  237. MKQL_ENSURE(precision >= 1&& precision <= 35, TStringBuilder() << "Wrong precision: " << (int)precision);
  238. auto createKernel = [&](auto exec) {
  239. auto k = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
  240. [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  241. return exec->Exec(ctx, batch, res);
  242. });
  243. k->null_handling = arrow::compute::NullHandling::INTERSECTION;
  244. return k;
  245. };
  246. switch (dataType2->GetSchemeType()) {
  247. case NUdf::TDataType<NUdf::TDecimal>::Id: {
  248. return createKernel(std::make_shared<TExec<NYql::NDecimal::TInt128>>(precision, scale));
  249. }
  250. #define MAKE_PRIMITIVE_TYPE_MUL(type) \
  251. case NUdf::TDataType<type>::Id: { \
  252. return createKernel(std::make_shared<TExec<type>>(precision, scale)); \
  253. }
  254. INTEGRAL_VALUE_TYPES(MAKE_PRIMITIVE_TYPE_MUL)
  255. #undef MAKE_PRIMITIVE_TYPE_MUL
  256. default:
  257. Y_ABORT("Unupported type.");
  258. }
  259. }
  260. template<template <typename> class TExec>
  261. IComputationNode* WrapBlockDecimal(TStringBuf name, TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  262. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  263. auto first = callable.GetInput(0);
  264. auto second = callable.GetInput(1);
  265. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  266. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  267. auto firstCompute = LocateNode(ctx.NodeLocator, callable, 0);
  268. auto secondCompute = LocateNode(ctx.NodeLocator, callable, 1);
  269. TComputationNodePtrVector argsNodes = { firstCompute, secondCompute };
  270. TVector<TType*> argsTypes = { firstType, secondType };
  271. std::shared_ptr<arrow::compute::ScalarKernel> kernel = MakeBlockKernel<TExec>(argsTypes, callable.GetType()->GetReturnType());
  272. return new TBlockFuncNode(ctx.Mutables, name, std::move(argsNodes), argsTypes, *kernel, kernel);
  273. }
  274. }
  275. IComputationNode* WrapBlockDecimalMul(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  276. return WrapBlockDecimal<TDecimalMulBlockExec>("DecimalMul", callable, ctx);
  277. }
  278. IComputationNode* WrapBlockDecimalDiv(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  279. return WrapBlockDecimal<TDecimalDivBlockExec>("DecimalDiv", callable, ctx);
  280. }
  281. IComputationNode* WrapBlockDecimalMod(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  282. return WrapBlockDecimal<TDecimalModBlockExec>("DecimalMod", callable, ctx);
  283. }
  284. }
  285. }