mkql_builtins_impl.cpp 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. #include "mkql_builtins_impl.h" // Y_IGNORE
  2. #include <yql/essentials/minikql/mkql_node_builder.h> // UnpackOptionalData
  3. namespace NKikimr {
  4. namespace NMiniKQL {
  5. template <typename T>
  6. arrow::compute::InputType GetPrimitiveInputArrowType(bool tz) {
  7. return arrow::compute::InputType(AddTzType(tz, GetPrimitiveDataType<T>()), arrow::ValueDescr::ANY);
  8. }
  9. template <typename T>
  10. arrow::compute::OutputType GetPrimitiveOutputArrowType(bool tz) {
  11. return arrow::compute::OutputType(AddTzType(tz, GetPrimitiveDataType<T>()));
  12. }
  13. template arrow::compute::InputType GetPrimitiveInputArrowType<bool>(bool tz);
  14. template arrow::compute::InputType GetPrimitiveInputArrowType<i8>(bool tz);
  15. template arrow::compute::InputType GetPrimitiveInputArrowType<ui8>(bool tz);
  16. template arrow::compute::InputType GetPrimitiveInputArrowType<i16>(bool tz);
  17. template arrow::compute::InputType GetPrimitiveInputArrowType<ui16>(bool tz);
  18. template arrow::compute::InputType GetPrimitiveInputArrowType<i32>(bool tz);
  19. template arrow::compute::InputType GetPrimitiveInputArrowType<ui32>(bool tz);
  20. template arrow::compute::InputType GetPrimitiveInputArrowType<i64>(bool tz);
  21. template arrow::compute::InputType GetPrimitiveInputArrowType<ui64>(bool tz);
  22. template arrow::compute::InputType GetPrimitiveInputArrowType<float>(bool tz);
  23. template arrow::compute::InputType GetPrimitiveInputArrowType<double>(bool tz);
  24. template arrow::compute::InputType GetPrimitiveInputArrowType<char*>(bool tz);
  25. template arrow::compute::InputType GetPrimitiveInputArrowType<NYql::NUdf::TUtf8>(bool tz);
  26. template arrow::compute::OutputType GetPrimitiveOutputArrowType<bool>(bool tz);
  27. template arrow::compute::OutputType GetPrimitiveOutputArrowType<i8>(bool tz);
  28. template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui8>(bool tz);
  29. template arrow::compute::OutputType GetPrimitiveOutputArrowType<i16>(bool tz);
  30. template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui16>(bool tz);
  31. template arrow::compute::OutputType GetPrimitiveOutputArrowType<i32>(bool tz);
  32. template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui32>(bool tz);
  33. template arrow::compute::OutputType GetPrimitiveOutputArrowType<i64>(bool tz);
  34. template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui64>(bool tz);
  35. template arrow::compute::OutputType GetPrimitiveOutputArrowType<float>(bool tz);
  36. template arrow::compute::OutputType GetPrimitiveOutputArrowType<double>(bool tz);
  37. template arrow::compute::OutputType GetPrimitiveOutputArrowType<char*>(bool tz);
  38. template arrow::compute::OutputType GetPrimitiveOutputArrowType<NYql::NUdf::TUtf8>(bool tz);
  39. arrow::compute::InputType GetPrimitiveInputArrowType(NUdf::EDataSlot slot) {
  40. switch (slot) {
  41. case NUdf::EDataSlot::Bool: return GetPrimitiveInputArrowType<bool>();
  42. case NUdf::EDataSlot::Int8: return GetPrimitiveInputArrowType<i8>();
  43. case NUdf::EDataSlot::Uint8: return GetPrimitiveInputArrowType<ui8>();
  44. case NUdf::EDataSlot::Int16: return GetPrimitiveInputArrowType<i16>();
  45. case NUdf::EDataSlot::Uint16: return GetPrimitiveInputArrowType<ui16>();
  46. case NUdf::EDataSlot::Int32: return GetPrimitiveInputArrowType<i32>();
  47. case NUdf::EDataSlot::Uint32: return GetPrimitiveInputArrowType<ui32>();
  48. case NUdf::EDataSlot::Int64: return GetPrimitiveInputArrowType<i64>();
  49. case NUdf::EDataSlot::Uint64: return GetPrimitiveInputArrowType<ui64>();
  50. case NUdf::EDataSlot::Float: return GetPrimitiveInputArrowType<float>();
  51. case NUdf::EDataSlot::Double: return GetPrimitiveInputArrowType<double>();
  52. case NUdf::EDataSlot::String: return GetPrimitiveInputArrowType<char*>();
  53. case NUdf::EDataSlot::Utf8: return GetPrimitiveInputArrowType<NYql::NUdf::TUtf8>();
  54. case NUdf::EDataSlot::Date: return GetPrimitiveInputArrowType<ui16>();
  55. case NUdf::EDataSlot::TzDate: return GetPrimitiveInputArrowType<ui16>(true);
  56. case NUdf::EDataSlot::Datetime: return GetPrimitiveInputArrowType<ui32>();
  57. case NUdf::EDataSlot::TzDatetime: return GetPrimitiveInputArrowType<ui32>(true);
  58. case NUdf::EDataSlot::Timestamp: return GetPrimitiveInputArrowType<ui64>();
  59. case NUdf::EDataSlot::TzTimestamp: return GetPrimitiveInputArrowType<ui64>(true);
  60. case NUdf::EDataSlot::Interval: return GetPrimitiveInputArrowType<i64>();
  61. case NUdf::EDataSlot::Date32: return GetPrimitiveInputArrowType<i32>();
  62. case NUdf::EDataSlot::TzDate32: return GetPrimitiveInputArrowType<i32>(true);
  63. case NUdf::EDataSlot::Datetime64: return GetPrimitiveInputArrowType<i64>();
  64. case NUdf::EDataSlot::TzDatetime64: return GetPrimitiveInputArrowType<i64>(true);
  65. case NUdf::EDataSlot::Timestamp64: return GetPrimitiveInputArrowType<i64>();
  66. case NUdf::EDataSlot::TzTimestamp64: return GetPrimitiveInputArrowType<i64>(true);
  67. case NUdf::EDataSlot::Interval64: return GetPrimitiveInputArrowType<i64>();
  68. case NUdf::EDataSlot::Decimal: return GetPrimitiveInputArrowType<NYql::NDecimal::TInt128>();
  69. default:
  70. ythrow yexception() << "Unexpected data slot: " << slot;
  71. }
  72. }
  73. arrow::compute::OutputType GetPrimitiveOutputArrowType(NUdf::EDataSlot slot) {
  74. switch (slot) {
  75. case NUdf::EDataSlot::Bool: return GetPrimitiveOutputArrowType<bool>();
  76. case NUdf::EDataSlot::Int8: return GetPrimitiveOutputArrowType<i8>();
  77. case NUdf::EDataSlot::Uint8: return GetPrimitiveOutputArrowType<ui8>();
  78. case NUdf::EDataSlot::Int16: return GetPrimitiveOutputArrowType<i16>();
  79. case NUdf::EDataSlot::Uint16: return GetPrimitiveOutputArrowType<ui16>();
  80. case NUdf::EDataSlot::Int32: return GetPrimitiveOutputArrowType<i32>();
  81. case NUdf::EDataSlot::Uint32: return GetPrimitiveOutputArrowType<ui32>();
  82. case NUdf::EDataSlot::Int64: return GetPrimitiveOutputArrowType<i64>();
  83. case NUdf::EDataSlot::Uint64: return GetPrimitiveOutputArrowType<ui64>();
  84. case NUdf::EDataSlot::Float: return GetPrimitiveOutputArrowType<float>();
  85. case NUdf::EDataSlot::Double: return GetPrimitiveOutputArrowType<double>();
  86. case NUdf::EDataSlot::String: return GetPrimitiveOutputArrowType<char*>();
  87. case NUdf::EDataSlot::Utf8: return GetPrimitiveOutputArrowType<NYql::NUdf::TUtf8>();
  88. case NUdf::EDataSlot::Date: return GetPrimitiveOutputArrowType<ui16>();
  89. case NUdf::EDataSlot::TzDate: return GetPrimitiveOutputArrowType<ui16>(true);
  90. case NUdf::EDataSlot::Datetime: return GetPrimitiveOutputArrowType<ui32>();
  91. case NUdf::EDataSlot::TzDatetime: return GetPrimitiveOutputArrowType<ui32>(true);
  92. case NUdf::EDataSlot::Timestamp: return GetPrimitiveOutputArrowType<ui64>();
  93. case NUdf::EDataSlot::TzTimestamp: return GetPrimitiveOutputArrowType<ui64>(true);
  94. case NUdf::EDataSlot::Interval: return GetPrimitiveOutputArrowType<i64>();
  95. case NUdf::EDataSlot::Date32: return GetPrimitiveOutputArrowType<i32>();
  96. case NUdf::EDataSlot::TzDate32: return GetPrimitiveOutputArrowType<i32>(true);
  97. case NUdf::EDataSlot::Datetime64: return GetPrimitiveOutputArrowType<i64>();
  98. case NUdf::EDataSlot::TzDatetime64: return GetPrimitiveOutputArrowType<i64>(true);
  99. case NUdf::EDataSlot::Timestamp64: return GetPrimitiveOutputArrowType<i64>();
  100. case NUdf::EDataSlot::TzTimestamp64: return GetPrimitiveOutputArrowType<i64>(true);
  101. case NUdf::EDataSlot::Interval64: return GetPrimitiveOutputArrowType<i64>();
  102. case NUdf::EDataSlot::Decimal: return GetPrimitiveOutputArrowType<NYql::NDecimal::TInt128>();
  103. default:
  104. ythrow yexception() << "Unexpected data slot: " << slot;
  105. }
  106. }
  107. std::shared_ptr<arrow::DataType> AddTzType(bool addTz, const std::shared_ptr<arrow::DataType>& type) {
  108. if (!addTz) {
  109. return type;
  110. }
  111. std::vector<std::shared_ptr<arrow::Field>> fields {
  112. std::make_shared<arrow::Field>("datetime", type, false),
  113. std::make_shared<arrow::Field>("timezoneId", arrow::uint16(), false)
  114. };
  115. return std::make_shared<arrow::StructType>(fields);
  116. }
  117. std::shared_ptr<arrow::DataType> AddTzType(EPropagateTz propagateTz, const std::shared_ptr<arrow::DataType>& type) {
  118. return AddTzType(propagateTz != EPropagateTz::None, type);
  119. }
  120. std::shared_ptr<arrow::Scalar> ExtractTz(bool isTz, const std::shared_ptr<arrow::Scalar>& value) {
  121. if (!isTz) {
  122. return value;
  123. }
  124. const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*value);
  125. return structScalar.value[0];
  126. }
  127. std::shared_ptr<arrow::ArrayData> ExtractTz(bool isTz, const std::shared_ptr<arrow::ArrayData>& value) {
  128. if (!isTz) {
  129. return value;
  130. }
  131. return value->child_data[0];
  132. }
  133. std::shared_ptr<arrow::Scalar> WithTz(bool propagateTz, const std::shared_ptr<arrow::Scalar>& input,
  134. const std::shared_ptr<arrow::Scalar>& value) {
  135. if (!propagateTz) {
  136. return value;
  137. }
  138. const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input);
  139. auto tzId = structScalar.value[1];
  140. return std::make_shared<arrow::StructScalar>(arrow::StructScalar::ValueType{value, tzId}, input->type);
  141. }
  142. std::shared_ptr<arrow::Scalar> WithTz(EPropagateTz propagateTz,
  143. const std::shared_ptr<arrow::Scalar>& input1,
  144. const std::shared_ptr<arrow::Scalar>& input2,
  145. const std::shared_ptr<arrow::Scalar>& value) {
  146. if (propagateTz == EPropagateTz::None) {
  147. return value;
  148. }
  149. const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(propagateTz == EPropagateTz::FromLeft ? *input1 : *input2);
  150. const auto tzId = structScalar.value[1];
  151. return std::make_shared<arrow::StructScalar>(arrow::StructScalar::ValueType{value,tzId}, propagateTz == EPropagateTz::FromLeft ? input1->type : input2->type);
  152. }
  153. std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, bool propagateTz,
  154. const std::shared_ptr<arrow::ArrayData>& input, arrow::MemoryPool* pool,
  155. size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) {
  156. if (!propagateTz) {
  157. return res;
  158. }
  159. Y_ENSURE(res->child_data.empty());
  160. std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool));
  161. res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer }));
  162. res->child_data.push_back(input->child_data[1]);
  163. return res->child_data[0];
  164. }
  165. std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz,
  166. const std::shared_ptr<arrow::ArrayData>& input1,
  167. const std::shared_ptr<arrow::Scalar>& input2,
  168. arrow::MemoryPool* pool,
  169. size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) {
  170. if (propagateTz == EPropagateTz::None) {
  171. return res;
  172. }
  173. Y_ENSURE(res->child_data.empty());
  174. std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool));
  175. res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer }));
  176. if (propagateTz == EPropagateTz::FromLeft) {
  177. res->child_data.push_back(input1->child_data[1]);
  178. } else {
  179. const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input2);
  180. auto tzId = ARROW_RESULT(arrow::MakeArrayFromScalar(*structScalar.value[1], res->length, pool))->data();
  181. res->child_data.push_back(tzId);
  182. }
  183. return res->child_data[0];
  184. }
  185. std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz,
  186. const std::shared_ptr<arrow::Scalar>& input1,
  187. const std::shared_ptr<arrow::ArrayData>& input2,
  188. arrow::MemoryPool* pool,
  189. size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) {
  190. if (propagateTz == EPropagateTz::None) {
  191. return res;
  192. }
  193. Y_ENSURE(res->child_data.empty());
  194. std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool));
  195. res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer }));
  196. if (propagateTz == EPropagateTz::FromLeft) {
  197. const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input1);
  198. auto tzId = ARROW_RESULT(arrow::MakeArrayFromScalar(*structScalar.value[1], res->length, pool))->data();
  199. res->child_data.push_back(tzId);
  200. } else {
  201. res->child_data.push_back(input2->child_data[1]);
  202. }
  203. return res->child_data[0];
  204. }
  205. std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz,
  206. const std::shared_ptr<arrow::ArrayData>& input1,
  207. const std::shared_ptr<arrow::ArrayData>& input2,
  208. arrow::MemoryPool* pool,
  209. size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) {
  210. if (propagateTz == EPropagateTz::None) {
  211. return res;
  212. }
  213. Y_ENSURE(res->child_data.empty());
  214. std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool));
  215. res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer }));
  216. if (propagateTz == EPropagateTz::FromLeft) {
  217. res->child_data.push_back(input1->child_data[1]);
  218. } else {
  219. res->child_data.push_back(input2->child_data[1]);
  220. }
  221. return res->child_data[0];
  222. }
  223. TPlainKernel::TPlainKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes,
  224. NUdf::TDataTypeId returnType, std::unique_ptr<arrow::compute::ScalarKernel>&& arrowKernel,
  225. TKernel::ENullMode nullMode)
  226. : TKernel(family, argTypes, returnType, nullMode)
  227. , ArrowKernel(std::move(arrowKernel))
  228. {
  229. }
  230. const arrow::compute::ScalarKernel& TPlainKernel::GetArrowKernel() const {
  231. return *ArrowKernel;
  232. }
  233. std::shared_ptr<arrow::compute::ScalarKernel> TPlainKernel::MakeArrowKernel(const TVector<TType*>&, TType*) const {
  234. ythrow yexception() << "Unsupported kernel";
  235. }
  236. bool TPlainKernel::IsPolymorphic() const {
  237. return false;
  238. }
  239. TDecimalKernel::TDecimalKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes,
  240. NUdf::TDataTypeId returnType, TStatelessArrayKernelExec exec,
  241. TKernel::ENullMode nullMode)
  242. : TKernel(family, argTypes, returnType, nullMode)
  243. , Exec(exec)
  244. {
  245. }
  246. const arrow::compute::ScalarKernel& TDecimalKernel::GetArrowKernel() const {
  247. ythrow yexception() << "Unsupported kernel";
  248. }
  249. std::shared_ptr<arrow::compute::ScalarKernel> TDecimalKernel::MakeArrowKernel(const TVector<TType*>& argTypes, TType* resultType) const {
  250. MKQL_ENSURE(argTypes.size() == 2, "Require 2 arguments");
  251. MKQL_ENSURE(argTypes[0]->GetKind() == TType::EKind::Block, "Require block");
  252. MKQL_ENSURE(argTypes[1]->GetKind() == TType::EKind::Block, "Require block");
  253. MKQL_ENSURE(resultType->GetKind() == TType::EKind::Block, "Require block");
  254. bool isOptional = false;
  255. auto dataType1 = UnpackOptionalData(static_cast<TBlockType*>(argTypes[0])->GetItemType(), isOptional);
  256. auto dataType2 = UnpackOptionalData(static_cast<TBlockType*>(argTypes[1])->GetItemType(), isOptional);
  257. auto dataResultType = UnpackOptionalData(static_cast<TBlockType*>(resultType)->GetItemType(), isOptional);
  258. MKQL_ENSURE(*dataType1->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal");
  259. MKQL_ENSURE(*dataType2->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal");
  260. auto decimalType1 = static_cast<TDataDecimalType*>(dataType1);
  261. auto decimalType2 = static_cast<TDataDecimalType*>(dataType2);
  262. MKQL_ENSURE(decimalType1->GetParams() == decimalType2->GetParams(), "Require same precision/scale");
  263. ui8 precision = decimalType1->GetParams().first;
  264. MKQL_ENSURE(precision >= 1&& precision <= 35, TStringBuilder() << "Wrong precision: " << (int)precision);
  265. auto k = std::make_shared<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{
  266. GetPrimitiveInputArrowType(NUdf::EDataSlot::Decimal), GetPrimitiveInputArrowType(NUdf::EDataSlot::Decimal)
  267. }, GetPrimitiveOutputArrowType(*dataResultType->GetDataSlot()), Exec);
  268. k->null_handling = arrow::compute::NullHandling::INTERSECTION;
  269. k->init = [precision](arrow::compute::KernelContext*, const arrow::compute::KernelInitArgs&) {
  270. auto state = std::make_unique<TDecimalKernel::TKernelState>();
  271. state->Precision = precision;
  272. return arrow::Result(std::move(state));
  273. };
  274. return k;
  275. }
  276. bool TDecimalKernel::IsPolymorphic() const {
  277. return true;
  278. }
  279. void AddUnaryKernelImpl(TKernelFamilyBase& owner, NUdf::EDataSlot arg1, NUdf::EDataSlot res,
  280. TStatelessArrayKernelExec exec, TKernel::ENullMode nullMode) {
  281. auto type1 = NUdf::GetDataTypeInfo(arg1).TypeId;
  282. auto returnType = NUdf::GetDataTypeInfo(res).TypeId;
  283. std::vector<NUdf::TDataTypeId> argTypes({ type1 });
  284. auto k = std::make_unique<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{
  285. GetPrimitiveInputArrowType(arg1)
  286. }, GetPrimitiveOutputArrowType(res), exec);
  287. switch (nullMode) {
  288. case TKernel::ENullMode::Default:
  289. k->null_handling = arrow::compute::NullHandling::INTERSECTION;
  290. break;
  291. case TKernel::ENullMode::AlwaysNull:
  292. k->null_handling = arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
  293. break;
  294. case TKernel::ENullMode::AlwaysNotNull:
  295. k->null_handling = arrow::compute::NullHandling::OUTPUT_NOT_NULL;
  296. break;
  297. }
  298. owner.Adopt(argTypes, returnType, std::make_unique<TPlainKernel>(owner, argTypes, returnType, std::move(k), nullMode));
  299. }
  300. void AddBinaryKernelImpl(TKernelFamilyBase& owner, NUdf::EDataSlot arg1, NUdf::EDataSlot arg2, NUdf::EDataSlot res,
  301. TStatelessArrayKernelExec exec, TKernel::ENullMode nullMode) {
  302. auto type1 = NUdf::GetDataTypeInfo(arg1).TypeId;
  303. auto type2 = NUdf::GetDataTypeInfo(arg2).TypeId;
  304. auto returnType = NUdf::GetDataTypeInfo(res).TypeId;
  305. std::vector<NUdf::TDataTypeId> argTypes({ type1, type2 });
  306. auto k = std::make_unique<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{
  307. GetPrimitiveInputArrowType(arg1), GetPrimitiveInputArrowType(arg2)
  308. }, GetPrimitiveOutputArrowType(res), exec);
  309. switch (nullMode) {
  310. case TKernel::ENullMode::Default:
  311. k->null_handling = arrow::compute::NullHandling::INTERSECTION;
  312. break;
  313. case TKernel::ENullMode::AlwaysNull:
  314. k->null_handling = arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
  315. break;
  316. case TKernel::ENullMode::AlwaysNotNull:
  317. k->null_handling = arrow::compute::NullHandling::OUTPUT_NOT_NULL;
  318. break;
  319. }
  320. owner.Adopt(argTypes, returnType, std::make_unique<TPlainKernel>(owner, argTypes, returnType, std::move(k), nullMode));
  321. }
  322. arrow::Status ExecScalarImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  323. TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter, TUntypedUnaryScalarFuncPtr func,
  324. bool tz, bool propagateTz) {
  325. if (const auto& arg = batch.values.front(); !arg.scalar()->is_valid) {
  326. *res = arrow::MakeNullScalar(AddTzType(propagateTz, typeGetter()));
  327. } else {
  328. const auto argTz = ExtractTz(tz, arg.scalar());
  329. const auto valPtr = GetPrimitiveScalarValuePtr(*argTz);
  330. auto resDatum = scalarGetter();
  331. const auto resPtr = GetPrimitiveScalarValueMutablePtr(*resDatum.scalar());
  332. func(valPtr, resPtr);
  333. *res = WithTz(propagateTz, arg.scalar(), resDatum.scalar());
  334. }
  335. return arrow::Status::OK();
  336. }
  337. arrow::Status ExecArrayImpl(arrow::compute::KernelContext* kernelCtx,
  338. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  339. TUntypedUnaryArrayFuncPtr func, size_t outputSizeOf, TPrimitiveDataTypeGetter outputTypeGetter,
  340. bool tz, bool propagateTz) {
  341. const auto& arg = batch.values.front();
  342. auto& resArr = *CopyTzImpl(res->array(), propagateTz, arg.array(), kernelCtx->memory_pool(),
  343. outputSizeOf, outputTypeGetter());
  344. const auto& arr = *ExtractTz(tz, arg.array());
  345. auto length = arr.length;
  346. const auto valPtr = arr.buffers[1]->data();
  347. auto resPtr = resArr.buffers[1]->mutable_data();
  348. func(valPtr, resPtr, length, arr.offset);
  349. return arrow::Status::OK();
  350. }
  351. arrow::Status ExecUnaryImpl(arrow::compute::KernelContext* kernelCtx,
  352. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  353. TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter,
  354. bool tz, bool propagateTz, size_t outputSizeOf,
  355. TUntypedUnaryScalarFuncPtr scalarFunc, TUntypedUnaryArrayFuncPtr arrayFunc) {
  356. MKQL_ENSURE(batch.values.size() == 1, "Expected single argument");
  357. const auto& arg = batch.values[0];
  358. if (arg.is_scalar()) {
  359. return ExecScalarImpl(batch, res, typeGetter, scalarGetter, scalarFunc, tz, propagateTz);
  360. } else {
  361. return ExecArrayImpl(kernelCtx, batch, res, arrayFunc, outputSizeOf, typeGetter, tz, propagateTz);
  362. }
  363. }
  364. arrow::Status ExecScalarScalarImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  365. TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter, TUntypedBinaryScalarFuncPtr func,
  366. bool tz1, bool tz2, EPropagateTz propagateTz) {
  367. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  368. const auto& arg1 = batch.values[0];
  369. const auto& arg2 = batch.values[1];
  370. if (!arg1.scalar()->is_valid || !arg2.scalar()->is_valid) {
  371. *res = arrow::MakeNullScalar(AddTzType(propagateTz, typeGetter()));
  372. } else {
  373. const auto arg1tz = ExtractTz(tz1, arg1.scalar());
  374. const auto arg2tz = ExtractTz(tz2, arg2.scalar());
  375. const auto val1Ptr = GetPrimitiveScalarValuePtr(*arg1tz);
  376. const auto val2Ptr = GetPrimitiveScalarValuePtr(*arg2tz);
  377. auto resDatum = scalarGetter();
  378. const auto resPtr = GetPrimitiveScalarValueMutablePtr(*resDatum.scalar());
  379. func(val1Ptr, val2Ptr, resPtr);
  380. *res = WithTz(propagateTz, arg1.scalar(), arg2.scalar(), resDatum.scalar());
  381. }
  382. return arrow::Status::OK();
  383. }
  384. arrow::Status ExecScalarArrayImpl(arrow::compute::KernelContext* kernelCtx,
  385. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  386. TUntypedBinaryArrayFuncPtr func, size_t outputSizeOf, TPrimitiveDataTypeGetter outputTypeGetter,
  387. bool tz1, bool tz2, EPropagateTz propagateTz) {
  388. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  389. const auto& arg1 = batch.values[0];
  390. const auto& arg2 = batch.values[1];
  391. auto& resArr = *CopyTzImpl(res->array(), propagateTz, arg1.scalar(), arg2.array(), kernelCtx->memory_pool(),
  392. outputSizeOf, outputTypeGetter());
  393. if (arg1.scalar()->is_valid) {
  394. const auto arg1tz = ExtractTz(tz1, arg1.scalar());
  395. const auto val1Ptr = GetPrimitiveScalarValuePtr(*arg1tz);
  396. const auto& arr2 = *ExtractTz(tz2, arg2.array());
  397. auto length = arr2.length;
  398. const auto val2Ptr = arr2.buffers[1]->data();
  399. auto resPtr = resArr.buffers[1]->mutable_data();
  400. func(val1Ptr, val2Ptr, resPtr, length, 0, arr2.offset);
  401. }
  402. return arrow::Status::OK();
  403. }
  404. arrow::Status ExecArrayScalarImpl(arrow::compute::KernelContext* kernelCtx,
  405. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  406. TUntypedBinaryArrayFuncPtr func, size_t outputSizeOf, TPrimitiveDataTypeGetter outputTypeGetter,
  407. bool tz1, bool tz2, EPropagateTz propagateTz) {
  408. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  409. const auto& arg1 = batch.values[0];
  410. const auto& arg2 = batch.values[1];
  411. auto& resArr = *CopyTzImpl(res->array(), propagateTz, arg1.array(), arg2.scalar(), kernelCtx->memory_pool(),
  412. outputSizeOf, outputTypeGetter());
  413. if (arg2.scalar()->is_valid) {
  414. const auto& arr1 = *ExtractTz(tz1, arg1.array());
  415. auto length = arr1.length;
  416. const auto val1Ptr = arr1.buffers[1]->data();
  417. const auto arg2tz = ExtractTz(tz2, arg2.scalar());
  418. const auto val2Ptr = GetPrimitiveScalarValuePtr(*arg2tz);
  419. auto resPtr = resArr.buffers[1]->mutable_data();
  420. func(val1Ptr, val2Ptr, resPtr, length, arr1.offset, 0);
  421. }
  422. return arrow::Status::OK();
  423. }
  424. arrow::Status ExecArrayArrayImpl(arrow::compute::KernelContext* kernelCtx,
  425. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  426. TUntypedBinaryArrayFuncPtr func, size_t outputSizeOf, TPrimitiveDataTypeGetter outputTypeGetter,
  427. bool tz1, bool tz2, EPropagateTz propagateTz) {
  428. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  429. const auto& arg1 = batch.values[0];
  430. const auto& arg2 = batch.values[1];
  431. const auto& arr1 = *ExtractTz(tz1, arg1.array());
  432. const auto& arr2 = *ExtractTz(tz2, arg2.array());
  433. auto& resArr = *CopyTzImpl(res->array(), propagateTz, arg1.array(), arg2.array(), kernelCtx->memory_pool(),
  434. outputSizeOf, outputTypeGetter());
  435. MKQL_ENSURE(arr1.length == arr2.length, "Expected same length");
  436. auto length = arr1.length;
  437. const auto val1Ptr = arr1.buffers[1]->data();
  438. const auto val2Ptr = arr2.buffers[1]->data();
  439. auto resPtr = resArr.buffers[1]->mutable_data();
  440. func(val1Ptr, val2Ptr, resPtr, length, arr1.offset, arr2.offset);
  441. return arrow::Status::OK();
  442. }
  443. arrow::Status ExecBinaryImpl(arrow::compute::KernelContext* kernelCtx,
  444. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  445. TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter,
  446. bool tz1, bool tz2, EPropagateTz propagateTz, size_t outputSizeOf,
  447. TUntypedBinaryScalarFuncPtr scalarScalarFunc,
  448. TUntypedBinaryArrayFuncPtr scalarArrayFunc,
  449. TUntypedBinaryArrayFuncPtr arrayScalarFunc,
  450. TUntypedBinaryArrayFuncPtr arrayArrayFunc) {
  451. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  452. const auto& arg1 = batch.values[0];
  453. const auto& arg2 = batch.values[1];
  454. if (arg1.is_scalar()) {
  455. if (arg2.is_scalar()) {
  456. return ExecScalarScalarImpl(batch, res, typeGetter, scalarGetter, scalarScalarFunc, tz1, tz2, propagateTz);
  457. } else {
  458. return ExecScalarArrayImpl(kernelCtx, batch, res, scalarArrayFunc, outputSizeOf, typeGetter, tz1, tz2, propagateTz);
  459. }
  460. } else {
  461. if (arg2.is_scalar()) {
  462. return ExecArrayScalarImpl(kernelCtx, batch, res, arrayScalarFunc, outputSizeOf, typeGetter, tz1, tz2, propagateTz);
  463. } else {
  464. return ExecArrayArrayImpl(kernelCtx, batch, res, arrayArrayFunc, outputSizeOf, typeGetter, tz1, tz2, propagateTz);
  465. }
  466. }
  467. }
  468. arrow::Status ExecScalarScalarOptImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  469. TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter, TUntypedBinaryScalarOptFuncPtr func,
  470. bool tz1, bool tz2, EPropagateTz propagateTz) {
  471. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  472. const auto& arg1 = batch.values[0];
  473. const auto& arg2 = batch.values[1];
  474. if (!arg1.scalar()->is_valid || !arg2.scalar()->is_valid) {
  475. *res = arrow::MakeNullScalar(AddTzType(propagateTz, typeGetter()));
  476. } else {
  477. const auto arg1tz = ExtractTz(tz1, arg1.scalar());
  478. const auto arg2tz = ExtractTz(tz2, arg2.scalar());
  479. const auto val1Ptr = GetPrimitiveScalarValuePtr(*arg1tz);
  480. const auto val2Ptr = GetPrimitiveScalarValuePtr(*arg2tz);
  481. auto resDatum = scalarGetter();
  482. const auto resPtr = GetPrimitiveScalarValueMutablePtr(*resDatum.scalar());
  483. if (!func(val1Ptr, val2Ptr, resPtr)) {
  484. *res = arrow::MakeNullScalar(AddTzType(propagateTz, typeGetter()));
  485. } else {
  486. *res = WithTz(propagateTz, arg1.scalar(), arg2.scalar(), resDatum.scalar());
  487. }
  488. }
  489. return arrow::Status::OK();
  490. }
  491. arrow::Status ExecScalarArrayOptImpl(arrow::compute::KernelContext* kernelCtx,
  492. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  493. TUntypedBinaryArrayOptFuncPtr func, size_t outputSizeOf, TPrimitiveDataTypeGetter outputTypeGetter,
  494. bool tz1, bool tz2, EPropagateTz propagateTz) {
  495. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  496. const auto& arg1 = batch.values[0];
  497. const auto& arg2 = batch.values[1];
  498. auto& resArr = *CopyTzImpl(res->array(), propagateTz, arg1.scalar(), arg2.array(), kernelCtx->memory_pool(),
  499. outputSizeOf, outputTypeGetter());
  500. if (arg1.scalar()->is_valid) {
  501. const auto arg1tz = ExtractTz(tz1, arg1.scalar());
  502. const auto val1Ptr = GetPrimitiveScalarValuePtr(*arg1tz);
  503. const auto& arr2 = *ExtractTz(tz2, arg2.array());
  504. auto length = arr2.length;
  505. const auto val2Ptr = arr2.buffers[1]->data();
  506. const auto nullCount2 = arr2.GetNullCount();
  507. const auto valid2 = (nullCount2 == 0) ? nullptr : arr2.GetValues<uint8_t>(0);
  508. auto resPtr = resArr.buffers[1]->mutable_data();
  509. auto resValid = res->array()->GetMutableValues<uint8_t>(0);
  510. func(val1Ptr, nullptr, val2Ptr, valid2, resPtr, resValid, length, 0, arr2.offset);
  511. } else {
  512. GetBitmap(resArr, 0).SetBitsTo(false);
  513. }
  514. return arrow::Status::OK();
  515. }
  516. arrow::Status ExecArrayScalarOptImpl(arrow::compute::KernelContext* kernelCtx,
  517. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  518. TUntypedBinaryArrayOptFuncPtr func, size_t outputSizeOf, TPrimitiveDataTypeGetter outputTypeGetter,
  519. bool tz1, bool tz2, EPropagateTz propagateTz) {
  520. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  521. const auto& arg1 = batch.values[0];
  522. const auto& arg2 = batch.values[1];
  523. auto& resArr = *CopyTzImpl(res->array(), propagateTz, arg1.array(), arg2.scalar(), kernelCtx->memory_pool(),
  524. outputSizeOf, outputTypeGetter());
  525. if (arg2.scalar()->is_valid) {
  526. const auto& arr1 = *ExtractTz(tz1, arg1.array());
  527. const auto val1Ptr = arr1.buffers[1]->data();
  528. auto length = arr1.length;
  529. const auto nullCount1 = arr1.GetNullCount();
  530. const auto valid1 = (nullCount1 == 0) ? nullptr : arr1.GetValues<uint8_t>(0);
  531. const auto arg2tz = ExtractTz(tz2, arg2.scalar());
  532. const auto val2Ptr = GetPrimitiveScalarValuePtr(*arg2tz);
  533. auto resPtr = resArr.buffers[1]->mutable_data();
  534. auto resValid = res->array()->GetMutableValues<uint8_t>(0);
  535. func(val1Ptr, valid1, val2Ptr, nullptr, resPtr, resValid, length, arr1.offset, 0);
  536. } else {
  537. GetBitmap(resArr, 0).SetBitsTo(false);
  538. }
  539. return arrow::Status::OK();
  540. }
  541. arrow::Status ExecArrayArrayOptImpl(arrow::compute::KernelContext* kernelCtx,
  542. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  543. TUntypedBinaryArrayOptFuncPtr func, size_t outputSizeOf, TPrimitiveDataTypeGetter outputTypeGetter,
  544. bool tz1, bool tz2, EPropagateTz propagateTz) {
  545. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  546. const auto& arg1 = batch.values[0];
  547. const auto& arg2 = batch.values[1];
  548. const auto& arr1 = *ExtractTz(tz1, arg1.array());
  549. const auto& arr2 = *ExtractTz(tz2, arg2.array());
  550. MKQL_ENSURE(arr1.length == arr2.length, "Expected same length");
  551. auto length = arr1.length;
  552. const auto val1Ptr = arr1.buffers[1]->data();
  553. const auto nullCount1 = arr1.GetNullCount();
  554. const auto valid1 = (nullCount1 == 0) ? nullptr : arr1.GetValues<uint8_t>(0);
  555. const auto val2Ptr = arr2.buffers[1]->data();
  556. const auto nullCount2 = arr2.GetNullCount();
  557. const auto valid2 = (nullCount2 == 0) ? nullptr : arr2.GetValues<uint8_t>(0);
  558. auto& resArr = *CopyTzImpl(res->array(), propagateTz, arg1.array(), arg2.array(), kernelCtx->memory_pool(),
  559. outputSizeOf, outputTypeGetter());
  560. auto resPtr = resArr.buffers[1]->mutable_data();
  561. auto resValid = res->array()->GetMutableValues<uint8_t>(0);
  562. func(val1Ptr, valid1, val2Ptr, valid2, resPtr, resValid, length, arr1.offset, arr2.offset);
  563. return arrow::Status::OK();
  564. }
  565. arrow::Status ExecBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
  566. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  567. TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetter scalarGetter,
  568. bool tz1, bool tz2, EPropagateTz propagateTz, size_t outputSizeOf,
  569. TUntypedBinaryScalarOptFuncPtr scalarScalarFunc,
  570. TUntypedBinaryArrayOptFuncPtr scalarArrayFunc,
  571. TUntypedBinaryArrayOptFuncPtr arrayScalarFunc,
  572. TUntypedBinaryArrayOptFuncPtr arrayArrayFunc) {
  573. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  574. const auto& arg1 = batch.values[0];
  575. const auto& arg2 = batch.values[1];
  576. if (arg1.is_scalar()) {
  577. if (arg2.is_scalar()) {
  578. return ExecScalarScalarOptImpl(batch, res, typeGetter, scalarGetter, scalarScalarFunc, tz1, tz2, propagateTz);
  579. } else {
  580. return ExecScalarArrayOptImpl(kernelCtx, batch, res, scalarArrayFunc, outputSizeOf, typeGetter, tz1, tz2, propagateTz);
  581. }
  582. } else {
  583. if (arg2.is_scalar()) {
  584. return ExecArrayScalarOptImpl(kernelCtx, batch, res, arrayScalarFunc, outputSizeOf, typeGetter, tz1, tz2, propagateTz);
  585. } else {
  586. return ExecArrayArrayOptImpl(kernelCtx, batch, res, arrayArrayFunc, outputSizeOf, typeGetter, tz1, tz2, propagateTz);
  587. }
  588. }
  589. }
  590. arrow::Status ExecDecimalArrayScalarOptImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  591. TUntypedBinaryArrayOptFuncPtr func) {
  592. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  593. const auto& arg1 = batch.values[0];
  594. const auto& arg2 = batch.values[1];
  595. auto& resArr = *res->array();
  596. if (arg2.scalar()->is_valid) {
  597. const auto& arr1 = *arg1.array();
  598. const auto val1Ptr = arr1.buffers[1]->data();
  599. auto length = arr1.length;
  600. const auto nullCount1 = arr1.GetNullCount();
  601. const auto valid1 = (nullCount1 == 0) ? nullptr : arr1.GetValues<uint8_t>(0);
  602. const auto val2Ptr = GetStringScalarValue(*arg2.scalar());
  603. auto resPtr = resArr.buffers[1]->mutable_data();
  604. auto resValid = res->array()->GetMutableValues<uint8_t>(0);
  605. func(val1Ptr, valid1, val2Ptr.data(), nullptr, resPtr, resValid, length, arr1.offset, 0);
  606. } else {
  607. GetBitmap(resArr, 0).SetBitsTo(false);
  608. }
  609. return arrow::Status::OK();
  610. }
  611. arrow::Status ExecDecimalScalarArrayOptImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  612. TUntypedBinaryArrayOptFuncPtr func) {
  613. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  614. const auto& arg1 = batch.values[0];
  615. const auto& arg2 = batch.values[1];
  616. auto& resArr = *res->array();
  617. if (arg1.scalar()->is_valid) {
  618. const auto val1Ptr = GetStringScalarValue(*arg1.scalar());
  619. const auto& arr2 = *arg2.array();
  620. auto length = arr2.length;
  621. const auto val2Ptr = arr2.buffers[1]->data();
  622. const auto nullCount2 = arr2.GetNullCount();
  623. const auto valid2 = (nullCount2 == 0) ? nullptr : arr2.GetValues<uint8_t>(0);
  624. auto resPtr = resArr.buffers[1]->mutable_data();
  625. auto resValid = res->array()->GetMutableValues<uint8_t>(0);
  626. func(val1Ptr.data(), nullptr, val2Ptr, valid2, resPtr, resValid, length, 0, arr2.offset);
  627. } else {
  628. GetBitmap(resArr, 0).SetBitsTo(false);
  629. }
  630. return arrow::Status::OK();
  631. }
  632. arrow::Status ExecDecimalScalarScalarOptImpl(arrow::compute::KernelContext* kernelCtx,
  633. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  634. TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetterWithMemPool scalarGetter,
  635. TUntypedBinaryScalarOptFuncPtr func) {
  636. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  637. const auto& arg1 = batch.values[0];
  638. const auto& arg2 = batch.values[1];
  639. if (!arg1.scalar()->is_valid || !arg2.scalar()->is_valid) {
  640. *res = arrow::MakeNullScalar(typeGetter());
  641. } else {
  642. const auto val1Ptr = GetStringScalarValue(*arg1.scalar());
  643. const auto val2Ptr = GetStringScalarValue(*arg2.scalar());
  644. void* resMem;
  645. auto resDatum = scalarGetter(&resMem, kernelCtx->memory_pool());
  646. if (!func(val1Ptr.data(), val2Ptr.data(), resMem)) {
  647. *res = arrow::MakeNullScalar(typeGetter());
  648. } else {
  649. *res = resDatum.scalar();
  650. }
  651. }
  652. return arrow::Status::OK();
  653. }
  654. arrow::Status ExecDecimalBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
  655. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  656. TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetterWithMemPool scalarGetter,
  657. size_t outputSizeOf,
  658. TUntypedBinaryScalarOptFuncPtr scalarScalarFunc,
  659. TUntypedBinaryArrayOptFuncPtr scalarArrayFunc,
  660. TUntypedBinaryArrayOptFuncPtr arrayScalarFunc,
  661. TUntypedBinaryArrayOptFuncPtr arrayArrayFunc) {
  662. MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
  663. const auto& arg1 = batch.values[0];
  664. const auto& arg2 = batch.values[1];
  665. if (arg1.is_scalar()) {
  666. if (arg2.is_scalar()) {
  667. return ExecDecimalScalarScalarOptImpl(kernelCtx, batch, res, typeGetter, scalarGetter, scalarScalarFunc);
  668. } else {
  669. return ExecDecimalScalarArrayOptImpl(batch, res, scalarArrayFunc);
  670. }
  671. } else {
  672. if (arg2.is_scalar()) {
  673. return ExecDecimalArrayScalarOptImpl(batch, res, arrayScalarFunc);
  674. } else {
  675. return ExecArrayArrayOptImpl(kernelCtx, batch, res, arrayArrayFunc, outputSizeOf, typeGetter, false, false, EPropagateTz::None);
  676. }
  677. }
  678. }
  679. arrow::Status ExecDecimalScalarImpl(arrow::compute::KernelContext* kernelCtx,
  680. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  681. TPrimitiveDataTypeGetter typeGetter, TUntypedUnaryScalarFuncPtr func) {
  682. if (const auto& arg = batch.values.front(); !arg.scalar()->is_valid) {
  683. *res = arrow::MakeNullScalar(typeGetter());
  684. } else {
  685. const auto valPtr = GetPrimitiveScalarValuePtr(*arg.scalar());
  686. std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(16, kernelCtx->memory_pool())));
  687. auto resDatum = arrow::Datum(std::make_shared<TPrimitiveDataType<NYql::NDecimal::TInt128>::TScalarResult>(buffer));
  688. func(valPtr, buffer->mutable_data());
  689. *res = resDatum.scalar();
  690. }
  691. return arrow::Status::OK();
  692. }
  693. arrow::Status ExecDecimalArrayImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  694. TUntypedUnaryArrayFuncPtr func) {
  695. const auto& arg = batch.values.front();
  696. auto& resArr = *res->array();
  697. const auto& arr = *arg.array();
  698. auto length = arr.length;
  699. const auto valPtr = arr.buffers[1]->data();
  700. auto resPtr = resArr.buffers[1]->mutable_data();
  701. func(valPtr, resPtr, length, arr.offset);
  702. return arrow::Status::OK();
  703. }
  704. arrow::Status ExecDecimalUnaryImpl(arrow::compute::KernelContext* kernelCtx,
  705. const arrow::compute::ExecBatch& batch, arrow::Datum* res,
  706. TPrimitiveDataTypeGetter typeGetter,
  707. TUntypedUnaryScalarFuncPtr scalarFunc, TUntypedUnaryArrayFuncPtr arrayFunc) {
  708. MKQL_ENSURE(batch.values.size() == 1, "Expected single argument");
  709. const auto& arg = batch.values[0];
  710. if (arg.is_scalar()) {
  711. return ExecDecimalScalarImpl(kernelCtx, batch, res, typeGetter, scalarFunc);
  712. } else {
  713. return ExecDecimalArrayImpl(batch, res, arrayFunc);
  714. }
  715. }
  716. } // namespace NMiniKQL
  717. } // namespace NKikimr