udf_arrow_helpers.h 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747
  1. #pragma once
  2. #include <yql/essentials/public/udf/udf_type_builder.h>
  3. #include <yql/essentials/public/udf/udf_value.h>
  4. #include <yql/essentials/public/udf/udf_helpers.h>
  5. #include <yql/essentials/public/udf/udf_data_type.h>
  6. #include <yql/essentials/public/udf/udf_type_inspection.h>
  7. #include "defs.h"
  8. #include "util.h"
  9. #include "args_dechunker.h"
  10. #include "block_reader.h"
  11. #include "block_builder.h"
  12. #include "memory_pool.h"
  13. #include <arrow/array/array_base.h>
  14. #include <arrow/array/util.h>
  15. #include <arrow/c/bridge.h>
  16. #include <arrow/chunked_array.h>
  17. #include <arrow/compute/kernel.h>
  18. #include <arrow/compute/exec_internal.h>
  19. #include <arrow/util/bitmap_ops.h>
  20. namespace NYql {
  21. namespace NUdf {
  22. using TExec = arrow::Status(*)(arrow::compute::KernelContext*, const arrow::compute::ExecBatch&, arrow::Datum*);
  23. class TUdfKernelState : public arrow::compute::KernelState {
  24. public:
  25. TUdfKernelState(const TVector<const TType*>& argTypes, const TType* outputType, bool onlyScalars, const ITypeInfoHelper* typeInfoHelper, const IValueBuilder* valueBuilder)
  26. : ArgTypes_(argTypes)
  27. , OutputType_(outputType)
  28. , OnlyScalars_(onlyScalars)
  29. , TypeInfoHelper_(typeInfoHelper)
  30. , ValueBuilder_(valueBuilder)
  31. {
  32. Readers_.resize(ArgTypes_.size());
  33. }
  34. IBlockReader& GetReader(ui32 index) {
  35. if (!Readers_[index]) {
  36. Readers_[index] = MakeBlockReader(*TypeInfoHelper_, ArgTypes_[index]);
  37. }
  38. return *Readers_[index];
  39. }
  40. IArrayBuilder& GetArrayBuilder() {
  41. Y_ENSURE(!OnlyScalars_);
  42. if (!ArrayBuilder_) {
  43. ArrayBuilder_ = MakeArrayBuilder(*TypeInfoHelper_, OutputType_, *GetYqlMemoryPool(), TypeInfoHelper_->GetMaxBlockLength(OutputType_), &ValueBuilder_->GetPgBuilder());
  44. }
  45. return *ArrayBuilder_;
  46. }
  47. IScalarBuilder& GetScalarBuilder() {
  48. Y_ENSURE(OnlyScalars_);
  49. if (!ScalarBuilder_) {
  50. ScalarBuilder_ = MakeScalarBuilder(*TypeInfoHelper_, OutputType_);
  51. }
  52. return *ScalarBuilder_;
  53. }
  54. const IValueBuilder& GetValueBuilder() {
  55. Y_ENSURE(ValueBuilder_);
  56. return *ValueBuilder_;
  57. }
  58. private:
  59. const TVector<const TType*> ArgTypes_;
  60. const TType* OutputType_;
  61. const bool OnlyScalars_;
  62. const ITypeInfoHelper* TypeInfoHelper_;
  63. const IValueBuilder* ValueBuilder_;
  64. TVector<std::unique_ptr<IBlockReader>> Readers_;
  65. std::unique_ptr<IArrayBuilder> ArrayBuilder_;
  66. std::unique_ptr<IScalarBuilder> ScalarBuilder_;
  67. };
  68. class TSimpleArrowUdfImpl : public TBoxedValue {
  69. public:
  70. TSimpleArrowUdfImpl(const TVector<const TType*> argBlockTypes, const TType* outputType, bool onlyScalars,
  71. TExec exec, IFunctionTypeInfoBuilder& builder, const TString& name,
  72. arrow::compute::NullHandling::type nullHandling)
  73. : OnlyScalars_(onlyScalars)
  74. , Exec_(exec)
  75. , Pos_(GetSourcePosition(builder))
  76. , Name_(name)
  77. , OutputType_(outputType)
  78. , NullDatum_(arrow::Datum(std::make_shared<arrow::NullScalar>()))
  79. {
  80. TypeInfoHelper_ = builder.TypeInfoHelper();
  81. Kernel_.null_handling = nullHandling;
  82. Kernel_.exec = Exec_;
  83. std::vector<arrow::compute::InputType> inTypes;
  84. for (const auto& blockType : argBlockTypes) {
  85. TBlockTypeInspector blockInspector(*TypeInfoHelper_, blockType);
  86. Y_ENSURE(blockInspector);
  87. ArgTypes_.push_back(blockInspector.GetItemType());
  88. auto arrowTypeHandle = TypeInfoHelper_->MakeArrowType(blockInspector.GetItemType());
  89. Y_ENSURE(arrowTypeHandle);
  90. ArrowSchema s;
  91. arrowTypeHandle->Export(&s);
  92. auto type = ARROW_RESULT(arrow::ImportType(&s));
  93. ArgArrowTypes_.emplace_back(type);
  94. auto shape = blockInspector.IsScalar() ? arrow::ValueDescr::SCALAR : arrow::ValueDescr::ARRAY;
  95. inTypes.emplace_back(arrow::compute::InputType(type, shape));
  96. ArgsValuesDescr_.emplace_back(arrow::ValueDescr(type, shape));
  97. }
  98. ReturnArrowTypeHandle_ = TypeInfoHelper_->MakeArrowType(outputType);
  99. Y_ENSURE(ReturnArrowTypeHandle_);
  100. ArrowSchema s;
  101. ReturnArrowTypeHandle_->Export(&s);
  102. auto outputShape = onlyScalars ? arrow::ValueDescr::SCALAR : arrow::ValueDescr::ARRAY;
  103. arrow::compute::OutputType outType(arrow::ValueDescr(ARROW_RESULT(arrow::ImportType(&s)), outputShape));
  104. Kernel_.signature = arrow::compute::KernelSignature::Make(std::move(inTypes), std::move(outType));
  105. }
  106. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const final {
  107. try {
  108. TVector<arrow::Datum> argDatums(ArgArrowTypes_.size());
  109. for (ui32 i = 0; i < ArgArrowTypes_.size(); ++i) {
  110. bool isScalar;
  111. ui64 length;
  112. // If no value is given to the UDF, pass the Null datum as an
  113. // optional argument value to the fixed-arg kernel.
  114. // XXX: Use bool operator for TUnboxedValuePod object instead
  115. // of its HasValue method due to just(null) semantics.
  116. if (!args[i]) {
  117. argDatums[i] = NullDatum_;
  118. continue;
  119. }
  120. ui32 chunkCount = valueBuilder->GetArrowBlockChunks(args[i], isScalar, length);
  121. if (isScalar) {
  122. ArrowArray a;
  123. valueBuilder->ExportArrowBlock(args[i], 0, &a);
  124. auto arr = ARROW_RESULT(arrow::ImportArray(&a, ArgArrowTypes_[i]));
  125. auto scalar = ARROW_RESULT(arr->GetScalar(0));
  126. argDatums[i] = scalar;
  127. } else {
  128. TVector<std::shared_ptr<arrow::Array>> imported(chunkCount);
  129. for (ui32 k = 0; k < chunkCount; ++k) {
  130. ArrowArray a;
  131. valueBuilder->ExportArrowBlock(args[i], k, &a);
  132. auto arr = ARROW_RESULT(arrow::ImportArray(&a, ArgArrowTypes_[i]));
  133. imported[k] = arr;
  134. }
  135. if (chunkCount == 1) {
  136. argDatums[i] = imported.front();
  137. } else {
  138. argDatums[i] = ARROW_RESULT(arrow::ChunkedArray::Make(std::move(imported), ArgArrowTypes_[i]));
  139. }
  140. }
  141. }
  142. TUdfKernelState kernelState(ArgTypes_, OutputType_, OnlyScalars_, TypeInfoHelper_.Get(), valueBuilder);
  143. arrow::compute::ExecContext execContext(GetYqlMemoryPool());
  144. arrow::compute::KernelContext kernelContext(&execContext);
  145. kernelContext.SetState(&kernelState);
  146. arrow::Datum res;
  147. if (OnlyScalars_) {
  148. auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
  149. ARROW_OK(executor->Init(&kernelContext, { &Kernel_, ArgsValuesDescr_, nullptr }));
  150. auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
  151. ARROW_OK(executor->Execute(argDatums, listener.get()));
  152. res = executor->WrapResults(argDatums, listener->values());
  153. } else {
  154. TArgsDechunker dechunker(std::move(argDatums));
  155. std::vector<arrow::Datum> chunk;
  156. TVector<std::shared_ptr<arrow::ArrayData>> arrays;
  157. while (dechunker.Next(chunk)) {
  158. auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
  159. ARROW_OK(executor->Init(&kernelContext, { &Kernel_, ArgsValuesDescr_, nullptr }));
  160. arrow::compute::detail::DatumAccumulator listener;
  161. ARROW_OK(executor->Execute(chunk, &listener));
  162. auto output = executor->WrapResults(chunk, listener.values());
  163. ForEachArrayData(output, [&](const auto& arr) { arrays.push_back(arr); });
  164. }
  165. res = MakeArray(arrays);
  166. }
  167. if (OnlyScalars_) {
  168. auto arr = ARROW_RESULT(arrow::MakeArrayFromScalar(*res.scalar(), 1));
  169. ArrowArray a;
  170. ARROW_OK(arrow::ExportArray(*arr, &a));
  171. return valueBuilder->ImportArrowBlock(&a, 1, true, *ReturnArrowTypeHandle_);
  172. } else {
  173. TVector<ArrowArray> a;
  174. if (res.is_array()) {
  175. a.resize(1);
  176. ARROW_OK(arrow::ExportArray(*res.make_array(), &a[0]));
  177. } else {
  178. Y_ENSURE(res.is_arraylike());
  179. a.resize(res.chunks().size());
  180. for (ui32 i = 0; i < res.chunks().size(); ++i) {
  181. ARROW_OK(arrow::ExportArray(*res.chunks()[i], &a[i]));
  182. }
  183. }
  184. return valueBuilder->ImportArrowBlock(a.data(), a.size(), false, *ReturnArrowTypeHandle_);
  185. }
  186. } catch (const std::exception&) {
  187. TStringBuilder sb;
  188. sb << Pos_ << " ";
  189. sb << CurrentExceptionMessage();
  190. sb << Endl << "[" << Name_ << "]";
  191. UdfTerminate(sb.c_str());
  192. }
  193. }
  194. private:
  195. const bool OnlyScalars_;
  196. const TExec Exec_;
  197. TSourcePosition Pos_;
  198. const TString Name_;
  199. const TType* OutputType_;
  200. ITypeInfoHelper::TPtr TypeInfoHelper_;
  201. TVector<std::shared_ptr<arrow::DataType>> ArgArrowTypes_;
  202. IArrowType::TPtr ReturnArrowTypeHandle_;
  203. arrow::compute::ScalarKernel Kernel_;
  204. std::vector<arrow::ValueDescr> ArgsValuesDescr_;
  205. TVector<const TType*> ArgTypes_;
  206. const arrow::Datum NullDatum_;
  207. };
  208. inline void SetCallableArgumentAttributes(IFunctionArgTypesBuilder& argsBuilder,
  209. const TCallableTypeInspector& callableInspector, const ui32 index) {
  210. if (callableInspector.GetArgumentName(index).Size() > 0) {
  211. argsBuilder.Name(callableInspector.GetArgumentName(index));
  212. }
  213. if (callableInspector.GetArgumentFlags(index) != 0) {
  214. argsBuilder.Flags(callableInspector.GetArgumentFlags(index));
  215. }
  216. }
  217. inline void PrepareSimpleArrowUdf(IFunctionTypeInfoBuilder& builder, TType* signature, TType* userType, TExec exec, bool typesOnly,
  218. const TString& name, arrow::compute::NullHandling::type nullHandling = arrow::compute::NullHandling::type::COMPUTED_NO_PREALLOCATE) {
  219. auto typeInfoHelper = builder.TypeInfoHelper();
  220. TCallableTypeInspector callableInspector(*typeInfoHelper, signature);
  221. Y_ENSURE(callableInspector);
  222. Y_ENSURE(callableInspector.GetArgsCount() > 0);
  223. TTupleTypeInspector userTypeInspector(*typeInfoHelper, userType);
  224. Y_ENSURE(userTypeInspector);
  225. Y_ENSURE(userTypeInspector.GetElementsCount() == 3);
  226. TTupleTypeInspector argsInspector(*typeInfoHelper, userTypeInspector.GetElementType(0));
  227. Y_ENSURE(argsInspector);
  228. Y_ENSURE(argsInspector.GetElementsCount() <= callableInspector.GetArgsCount());
  229. const ui32 omitted = callableInspector.GetArgsCount() - argsInspector.GetElementsCount();
  230. Y_ENSURE(omitted <= callableInspector.GetOptionalArgsCount());
  231. bool hasBlocks = false;
  232. bool onlyScalars = true;
  233. for (ui32 i = 0; i < argsInspector.GetElementsCount(); ++i) {
  234. TBlockTypeInspector blockInspector(*typeInfoHelper, argsInspector.GetElementType(i));
  235. if (blockInspector) {
  236. if (i == 0) {
  237. hasBlocks = true;
  238. } else {
  239. Y_ENSURE(hasBlocks);
  240. }
  241. onlyScalars = onlyScalars && blockInspector.IsScalar();
  242. }
  243. }
  244. builder.SupportsBlocks();
  245. builder.UserType(userType);
  246. Y_ENSURE(hasBlocks);
  247. TVector<const TType*> argBlockTypes;
  248. auto argsBuilder = builder.Args(callableInspector.GetArgsCount());
  249. for (ui32 i = 0; i < argsInspector.GetElementsCount(); ++i) {
  250. TBlockTypeInspector blockInspector(*typeInfoHelper, argsInspector.GetElementType(i));
  251. auto type = callableInspector.GetArgType(i);
  252. auto argBlockType = builder.Block(blockInspector.IsScalar())->Item(type).Build();
  253. argsBuilder->Add(argBlockType);
  254. SetCallableArgumentAttributes(*argsBuilder, callableInspector, i);
  255. argBlockTypes.emplace_back(argBlockType);
  256. }
  257. // XXX: Append the Block types for the omitted arguments to preserve the
  258. // fixed-arg kernel signature. Unlikely to the required arguments,
  259. // initialized above, the type of the omitted argument has to be passed to
  260. // the specialized UDF signature builder (i.e. argsBuilder) as an original
  261. // TOptional parameter type. At the same time, all of omitted arguments have
  262. // to be substituted with Null datums, so all the original types of the
  263. // optional parameters are wrapped type with the Block type with the Scalar
  264. // shape in the UDFKernel signature (i.e. argBlockTypes).
  265. for (ui32 i = argsInspector.GetElementsCount(); i < callableInspector.GetArgsCount(); i++) {
  266. auto optType = callableInspector.GetArgType(i);
  267. argsBuilder->Add(optType);
  268. SetCallableArgumentAttributes(*argsBuilder, callableInspector, i);
  269. argBlockTypes.emplace_back(builder.Block(true)->Item(optType).Build());
  270. }
  271. builder.Returns(builder.Block(onlyScalars)->Item(callableInspector.GetReturnType()).Build());
  272. // XXX: Only the omitted parameters should be specified as optional
  273. // arguments in this context.
  274. if (omitted) {
  275. builder.OptionalArgs(omitted);
  276. }
  277. if (callableInspector.GetPayload().Size() > 0) {
  278. builder.PayloadImpl(callableInspector.GetPayload());
  279. }
  280. if (!typesOnly) {
  281. builder.Implementation(new TSimpleArrowUdfImpl(argBlockTypes, callableInspector.GetReturnType(),
  282. onlyScalars, exec, builder, name, nullHandling));
  283. }
  284. }
  285. template<typename TBuilder>
  286. TBuilder* CastToArrayBuilderImpl(IArrayBuilder& builder) {
  287. static_assert(std::is_base_of_v<IArrayBuilder, TBuilder>);
  288. auto* builderImpl = dynamic_cast<TBuilder*>(&builder);
  289. Y_ENSURE(builderImpl, TStringBuilder() << "Got " << typeid(builder).name() << " as ArrayBuilder");
  290. return builderImpl;
  291. }
  292. template<typename TScalarBuilderImpl>
  293. TScalarBuilderImpl* CastToScalarBuilderImpl(IScalarBuilder& builder) {
  294. static_assert(std::is_base_of_v<IScalarBuilder, TScalarBuilderImpl>);
  295. auto* builderImpl = dynamic_cast<TScalarBuilderImpl*>(&builder);
  296. Y_ENSURE(builderImpl, TStringBuilder() << "Got " << typeid(builder).name() << " as ArrayBuilder");
  297. return builderImpl;
  298. }
  299. template<typename TReader>
  300. TReader* CastToBlockReaderImpl(IBlockReader& reader) {
  301. static_assert(std::is_base_of_v<IBlockReader, TReader>);
  302. auto* readerImpl = dynamic_cast<TReader*>(&reader);
  303. Y_ENSURE(readerImpl, TStringBuilder() << "Got " << typeid(reader).name() << " as BlockReader");
  304. return readerImpl;
  305. }
  306. template <typename TDerived, typename TReader = IBlockReader, typename TArrayBuilderImpl = IArrayBuilder, typename TScalarBuilderImpl = IScalarBuilder>
  307. struct TUnaryKernelExec {
  308. static arrow::Status Do(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  309. auto& state = dynamic_cast<TUdfKernelState&>(*ctx->state());
  310. auto& reader = state.GetReader(0);
  311. auto* readerImpl = CastToBlockReaderImpl<TReader>(reader);
  312. const auto& arg = batch.values[0];
  313. if (arg.is_scalar()) {
  314. auto& builder = state.GetScalarBuilder();
  315. auto* builderImpl = CastToScalarBuilderImpl<TScalarBuilderImpl>(builder);
  316. auto item = readerImpl->GetScalarItem(*arg.scalar());
  317. TDerived::Process(&state.GetValueBuilder(), item, [&](TBlockItem out) {
  318. *res = builderImpl->Build(out);
  319. });
  320. }
  321. else {
  322. auto& array = *arg.array();
  323. auto& builder = state.GetArrayBuilder();
  324. auto* builderImpl = CastToArrayBuilderImpl<TArrayBuilderImpl>(builder);
  325. size_t maxBlockLength = builderImpl->MaxLength();
  326. Y_ENSURE(maxBlockLength > 0);
  327. TVector<std::shared_ptr<arrow::ArrayData>> outputArrays;
  328. for (int64_t i = 0; i < array.length;) {
  329. for (size_t j = 0; j < maxBlockLength && i < array.length; ++j, ++i) {
  330. auto item = readerImpl->GetItem(array, i);
  331. TDerived::Process(&state.GetValueBuilder(), item, [&](TBlockItem out) {
  332. builderImpl->Add(out);
  333. });
  334. }
  335. auto outputDatum = builderImpl->Build(false);
  336. ForEachArrayData(outputDatum, [&](const auto& arr) { outputArrays.push_back(arr); });
  337. }
  338. *res = MakeArray(outputArrays);
  339. }
  340. return arrow::Status::OK();
  341. }
  342. };
  343. template <typename TDerived, typename TReader1 = IBlockReader, typename TReader2 = IBlockReader, typename TArrayBuilderImpl = IArrayBuilder, typename TScalarBuilderImpl = IScalarBuilder>
  344. struct TBinaryKernelExec {
  345. static arrow::Status Do(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  346. auto& state = dynamic_cast<TUdfKernelState&>(*ctx->state());
  347. auto& reader1 = state.GetReader(0);
  348. auto* reader1Impl = CastToBlockReaderImpl<TReader1>(reader1);
  349. auto& reader2 = state.GetReader(1);
  350. auto* reader2Impl = CastToBlockReaderImpl<TReader2>(reader2);
  351. const auto& arg1 = batch.values[0];
  352. const auto& arg2 = batch.values[1];
  353. if (arg1.is_scalar() && arg2.is_scalar()) {
  354. auto& builder = state.GetScalarBuilder();
  355. auto* builderImpl = CastToScalarBuilderImpl<TScalarBuilderImpl>(builder);
  356. auto item1 = reader1Impl->GetScalarItem(*arg1.scalar());
  357. auto item2 = reader2Impl->GetScalarItem(*arg2.scalar());
  358. TDerived::Process(&state.GetValueBuilder(), item1, item2, [&](TBlockItem out) {
  359. *res = builderImpl->Build(out);
  360. });
  361. }
  362. else if (arg1.is_scalar() && arg2.is_array()) {
  363. auto item1 = reader1Impl->GetScalarItem(*arg1.scalar());
  364. auto& array2 = *arg2.array();
  365. auto& builder = state.GetArrayBuilder();
  366. auto* builderImpl = CastToArrayBuilderImpl<TArrayBuilderImpl>(builder);
  367. size_t maxBlockLength = builder.MaxLength();
  368. Y_ENSURE(maxBlockLength > 0);
  369. TVector<std::shared_ptr<arrow::ArrayData>> outputArrays;
  370. for (int64_t i = 0; i < array2.length;) {
  371. for (size_t j = 0; j < maxBlockLength && i < array2.length; ++j, ++i) {
  372. auto item2 = reader2Impl->GetItem(array2, i);
  373. TDerived::Process(&state.GetValueBuilder(), item1, item2, [&](TBlockItem out) {
  374. builderImpl->Add(out);
  375. });
  376. }
  377. auto outputDatum = builder.Build(false);
  378. ForEachArrayData(outputDatum, [&](const auto& arr) { outputArrays.push_back(arr); });
  379. }
  380. *res = MakeArray(outputArrays);
  381. } else if (arg1.is_array() && arg2.is_scalar()) {
  382. auto& array1 = *arg1.array();
  383. auto item2 = reader2Impl->GetScalarItem(*arg2.scalar());
  384. auto& builder = state.GetArrayBuilder();
  385. auto* builderImpl = CastToArrayBuilderImpl<TArrayBuilderImpl>(builder);
  386. size_t maxBlockLength = builder.MaxLength();
  387. Y_ENSURE(maxBlockLength > 0);
  388. TVector<std::shared_ptr<arrow::ArrayData>> outputArrays;
  389. for (int64_t i = 0; i < array1.length;) {
  390. for (size_t j = 0; j < maxBlockLength && i < array1.length; ++j, ++i) {
  391. auto item1 = reader1Impl->GetItem(array1, i);
  392. TDerived::Process(&state.GetValueBuilder(), item1, item2, [&](TBlockItem out) {
  393. builderImpl->Add(out);
  394. });
  395. }
  396. auto outputDatum = builder.Build(false);
  397. ForEachArrayData(outputDatum, [&](const auto& arr) { outputArrays.push_back(arr); });
  398. }
  399. *res = MakeArray(outputArrays);
  400. } else {
  401. Y_ENSURE(arg1.is_array() && arg2.is_array());
  402. auto& array1 = *arg1.array();
  403. auto& array2 = *arg2.array();
  404. auto& builder = state.GetArrayBuilder();
  405. auto* builderImpl = CastToArrayBuilderImpl<TArrayBuilderImpl>(builder);
  406. size_t maxBlockLength = builder.MaxLength();
  407. Y_ENSURE(maxBlockLength > 0);
  408. TVector<std::shared_ptr<arrow::ArrayData>> outputArrays;
  409. Y_ENSURE(array1.length == array2.length);
  410. for (int64_t i = 0; i < array1.length;) {
  411. for (size_t j = 0; j < maxBlockLength && i < array1.length; ++j, ++i) {
  412. auto item1 = reader1Impl->GetItem(array1, i);
  413. auto item2 = reader2Impl->GetItem(array2, i);
  414. TDerived::Process(&state.GetValueBuilder(), item1, item2, [&](TBlockItem out) {
  415. builderImpl->Add(out);
  416. });
  417. }
  418. auto outputDatum = builder.Build(false);
  419. ForEachArrayData(outputDatum, [&](const auto& arr) { outputArrays.push_back(arr); });
  420. }
  421. *res = MakeArray(outputArrays);
  422. }
  423. return arrow::Status::OK();
  424. }
  425. };
  426. template <typename TDerived, size_t Argc, typename TArrayBuilderImpl = IArrayBuilder, typename TScalarBuilderImpl = IScalarBuilder>
  427. struct TGenericKernelExec {
  428. static arrow::Status Do(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  429. auto& state = dynamic_cast<TUdfKernelState&>(*ctx->state());
  430. Y_ENSURE(batch.num_values() == Argc);
  431. // XXX: Since Arrow arrays ought to have the valid length value, use
  432. // this constant to check whether all the arrays in the given batch have
  433. // the same length and also as an indicator whether there is no array
  434. // arguments in the given batch.
  435. int64_t alength = arrow::Datum::kUnknownLength;
  436. // XXX: Allocate fixed-size buffer to pass the parameters into the
  437. // Process routine (stored into BlockItem), since only the content
  438. // of the particular cells will be updated in the main "process" loop.
  439. std::array<TBlockItem, Argc> args;
  440. const TBlockItem items(args.data());
  441. // XXX: Introduce scalar/array mapping to avoid excess scalar copy ops
  442. // in the main "process" loop.
  443. std::array<bool, Argc> needUpdate;
  444. needUpdate.fill(false);
  445. for (size_t k = 0; k < Argc; k++) {
  446. auto& arg = batch[k];
  447. Y_ENSURE(arg.is_scalar() || arg.is_array());
  448. if (arg.is_scalar()) {
  449. continue;
  450. }
  451. if (alength == arrow::Datum::kUnknownLength) {
  452. alength = arg.length();
  453. } else {
  454. Y_ENSURE(arg.length() == alength);
  455. }
  456. needUpdate[k] = true;
  457. }
  458. // Specialize the case, when all given arguments are scalar.
  459. if (alength == arrow::Datum::kUnknownLength) {
  460. auto& builder = state.GetScalarBuilder();
  461. auto* builderImpl = CastToScalarBuilderImpl<TScalarBuilderImpl>(builder);
  462. for (size_t k = 0; k < Argc; k++) {
  463. auto& reader = state.GetReader(k);
  464. args[k] = reader.GetScalarItem(*batch[k].scalar());
  465. }
  466. TDerived::Process(&state.GetValueBuilder(), items, [&](TBlockItem out) {
  467. *res = builderImpl->Build(out);
  468. });
  469. } else {
  470. auto& builder = state.GetArrayBuilder();
  471. auto* builderImpl = CastToArrayBuilderImpl<TArrayBuilderImpl>(builder);
  472. size_t maxBlockLength = builder.MaxLength();
  473. Y_ENSURE(maxBlockLength > 0);
  474. TVector<std::shared_ptr<arrow::ArrayData>> outputArrays;
  475. // Initialize all scalar arguments before the main "process" loop.
  476. for (size_t k = 0; k < Argc; k++) {
  477. if (needUpdate[k]) {
  478. continue;
  479. }
  480. auto& reader = state.GetReader(k);
  481. args[k] = reader.GetScalarItem(*batch[k].scalar());
  482. }
  483. for (int64_t i = 0; i < alength;) {
  484. for (size_t j = 0; j < maxBlockLength && i < alength; ++j, ++i) {
  485. // Update array arguments and call the Process routine.
  486. for (size_t k = 0; k < Argc; k++) {
  487. if (!needUpdate[k]) {
  488. continue;
  489. }
  490. auto& reader = state.GetReader(k);
  491. args[k] = reader.GetItem(*batch[k].array(), i);
  492. }
  493. TDerived::Process(&state.GetValueBuilder(), items, [&](TBlockItem out) {
  494. builderImpl->Add(out);
  495. });
  496. }
  497. auto outputDatum = builderImpl->Build(false);
  498. ForEachArrayData(outputDatum, [&](const auto& arr) { outputArrays.push_back(arr); });
  499. }
  500. *res = MakeArray(outputArrays);
  501. }
  502. return arrow::Status::OK();
  503. }
  504. };
  505. template <typename TInput, typename TOutput, TOutput(*Core)(TInput)>
  506. arrow::Status UnaryPreallocatedExecImpl(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  507. Y_UNUSED(ctx);
  508. auto& inArray = batch.values[0].array();
  509. auto& outArray = res->array();
  510. const TInput* inValues = inArray->GetValues<TInput>(1);
  511. TOutput* outValues = outArray->GetMutableValues<TOutput>(1);
  512. auto length = inArray->length;
  513. for (int64_t i = 0; i < length; ++i) {
  514. outValues[i] = Core(inValues[i]);
  515. }
  516. return arrow::Status::OK();
  517. }
  518. template <typename TReader, typename TOutput, TOutput(*Core)(TBlockItem)>
  519. arrow::Status UnaryPreallocatedReaderExecImpl(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  520. Y_UNUSED(ctx);
  521. static_assert(std::is_base_of_v<IBlockReader, TReader>);
  522. TReader reader;
  523. auto& inArray = batch.values[0].array();
  524. auto& outArray = res->array();
  525. TOutput* outValues = outArray->GetMutableValues<TOutput>(1);
  526. auto length = inArray->length;
  527. for (int64_t i = 0; i < length; ++i) {
  528. auto item = reader.GetItem(*inArray, i);
  529. outValues[i] = Core(item);
  530. }
  531. return arrow::Status::OK();
  532. }
  533. template<typename TInput, typename TOutput, std::pair<TOutput, bool> Core(TInput)>
  534. struct TUnaryUnsafeFixedSizeFilterKernel {
  535. static arrow::Status Do(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  536. static_assert(std::is_arithmetic<TInput>::value);
  537. Y_UNUSED(ctx);
  538. const auto& inArray = batch.values.front().array();
  539. const auto* inValues = inArray->GetValues<TInput>(1);
  540. const auto length = inArray->length;
  541. auto& outArray = res->array();
  542. auto* outValues = outArray->GetMutableValues<TOutput>(1);
  543. TTypedBufferBuilder<uint8_t> nullBuilder(GetYqlMemoryPool());
  544. nullBuilder.Reserve(length);
  545. bool isAllNull = inArray->GetNullCount() == length;
  546. if (!isAllNull) {
  547. for (i64 i = 0; i < length; ++i) {
  548. auto [output, isValid] = Core(inValues[i]);
  549. outValues[i] = output;
  550. nullBuilder.UnsafeAppend(isValid);
  551. }
  552. } else {
  553. nullBuilder.UnsafeAppend(length, 0);
  554. }
  555. auto validMask = nullBuilder.Finish();
  556. validMask = MakeDenseBitmap(validMask->data(), length, GetYqlMemoryPool());
  557. auto inMask = inArray->buffers[0];
  558. if (inMask) {
  559. outArray->buffers[0] = AllocateBitmapWithReserve(length, GetYqlMemoryPool());
  560. arrow::internal::BitmapAnd(validMask->data(), 0, inArray->buffers[0]->data(), inArray->offset, outArray->length, outArray->offset, outArray->buffers[0]->mutable_data());
  561. } else {
  562. outArray->buffers[0] = std::move(validMask);
  563. }
  564. return arrow::Status::OK();
  565. }
  566. };
  567. template <typename TInput, typename TOutput, TOutput(*Core)(TInput)>
  568. class TUnaryOverOptionalImpl : public TBoxedValue {
  569. public:
  570. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const final {
  571. Y_UNUSED(valueBuilder);
  572. if (!args[0]) {
  573. return {};
  574. }
  575. return TUnboxedValuePod(Core(args[0].Get<TInput>()));
  576. }
  577. };
  578. }
  579. }
  580. #define BEGIN_ARROW_UDF_IMPL(udfNameBlocks, signatureFunc, optArgc, isStrict) \
  581. class udfNameBlocks { \
  582. public: \
  583. typedef bool TTypeAwareMarker; \
  584. static const ::NYql::NUdf::TStringRef& Name() { \
  585. static auto name = ::NYql::NUdf::TStringRef::Of(#udfNameBlocks).Substring(1, 256); \
  586. return name; \
  587. } \
  588. static bool IsStrict() { \
  589. return isStrict; \
  590. } \
  591. static ::NYql::NUdf::TType* GetSignatureType(::NYql::NUdf::IFunctionTypeInfoBuilder& builder) { \
  592. return builder.SimpleSignatureType<signatureFunc>(optArgc); \
  593. } \
  594. static bool DeclareSignature(\
  595. const ::NYql::NUdf::TStringRef& name, \
  596. ::NYql::NUdf::TType* userType, \
  597. ::NYql::NUdf::IFunctionTypeInfoBuilder& builder, \
  598. bool typesOnly); \
  599. };
  600. #define BEGIN_SIMPLE_ARROW_UDF(udfName, signatureFunc) \
  601. BEGIN_ARROW_UDF_IMPL(udfName##_BlocksImpl, signatureFunc, 0, false) \
  602. UDF_IMPL(udfName, builder.SimpleSignature<signatureFunc>().SupportsBlocks();, ;, ;, "", "", udfName##_BlocksImpl)
  603. #define BEGIN_SIMPLE_STRICT_ARROW_UDF(udfName, signatureFunc) \
  604. BEGIN_ARROW_UDF_IMPL(udfName##_BlocksImpl, signatureFunc, 0, true) \
  605. UDF_IMPL(udfName, builder.SimpleSignature<signatureFunc>().SupportsBlocks().IsStrict();, ;, ;, "", "", udfName##_BlocksImpl)
  606. #define BEGIN_SIMPLE_ARROW_UDF_WITH_OPTIONAL_ARGS(udfName, signatureFunc, optArgc) \
  607. BEGIN_ARROW_UDF_IMPL(udfName##_BlocksImpl, signatureFunc, optArgc, false) \
  608. UDF_IMPL(udfName, builder.SimpleSignature<signatureFunc>().SupportsBlocks().OptionalArgs(optArgc);, ;, ;, "", "", udfName##_BlocksImpl)
  609. #define BEGIN_SIMPLE_STRICT_ARROW_UDF_WITH_OPTIONAL_ARGS(udfName, signatureFunc, optArgc) \
  610. BEGIN_ARROW_UDF_IMPL(udfName##_BlocksImpl, signatureFunc, optArgc, true) \
  611. UDF_IMPL(udfName, builder.SimpleSignature<signatureFunc>().SupportsBlocks().IsStrict().OptionalArgs(optArgc);, ;, ;, "", "", udfName##_BlocksImpl)
  612. #define END_ARROW_UDF(udfNameBlocks, exec) \
  613. inline bool udfNameBlocks::DeclareSignature(\
  614. const ::NYql::NUdf::TStringRef& name, \
  615. ::NYql::NUdf::TType* userType, \
  616. ::NYql::NUdf::IFunctionTypeInfoBuilder& builder, \
  617. bool typesOnly) { \
  618. if (Name() == name) { \
  619. if (IsStrict()) { \
  620. builder.IsStrict(); \
  621. } \
  622. PrepareSimpleArrowUdf(builder, GetSignatureType(builder), userType, exec, typesOnly, TString(name)); \
  623. return true; \
  624. } \
  625. return false; \
  626. }
  627. #define END_ARROW_UDF_WITH_NULL_HANDLING(udfNameBlocks, exec, nullHandling) \
  628. inline bool udfNameBlocks::DeclareSignature(\
  629. const ::NYql::NUdf::TStringRef& name, \
  630. ::NYql::NUdf::TType* userType, \
  631. ::NYql::NUdf::IFunctionTypeInfoBuilder& builder, \
  632. bool typesOnly) { \
  633. if (Name() == name) { \
  634. PrepareSimpleArrowUdf(builder, GetSignatureType(builder), userType, exec, typesOnly, TString(name), nullHandling); \
  635. return true; \
  636. } \
  637. return false; \
  638. }
  639. #define END_SIMPLE_ARROW_UDF(udfName, exec) \
  640. END_ARROW_UDF(udfName##_BlocksImpl, exec)
  641. #define END_SIMPLE_ARROW_UDF_WITH_NULL_HANDLING(udfName, exec, nullHandling) \
  642. END_ARROW_UDF_WITH_NULL_HANDLING(udfName##_BlocksImpl, exec, nullHandling)