mkql_scalar_apply.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. #include "mkql_scalar_apply.h"
  2. #include <yql/essentials/public/udf/arrow/memory_pool.h>
  3. #include <yql/essentials/minikql/computation/mkql_block_reader.h>
  4. #include <yql/essentials/minikql/computation/mkql_block_builder.h>
  5. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  6. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  7. #include <yql/essentials/minikql/arrow/arrow_util.h>
  8. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  9. #include <yql/essentials/minikql/computation/mkql_value_builder.h>
  10. #include <yql/essentials/minikql/mkql_node_cast.h>
  11. #include <yql/essentials/parser/pg_wrapper/interface/utils.h>
  12. namespace NKikimr {
  13. namespace NMiniKQL {
  14. namespace {
  15. class TScalarApplyWrapper : public TMutableComputationNode<TScalarApplyWrapper> {
  16. public:
  17. struct TAccessors {
  18. TAccessors(const TVector<TType*>& argsTypes, TType* returnType, const NUdf::IPgBuilder& pgBuilder)
  19. : PgBuilder(pgBuilder)
  20. {
  21. auto returnItemType = AS_TYPE(TBlockType, returnType)->GetItemType();
  22. ReturnConverter = MakeBlockItemConverter(TTypeInfoHelper(), returnItemType, pgBuilder);
  23. ArgsConverters.reserve(argsTypes.size());
  24. ArgsReaders.reserve(argsTypes.size());
  25. for (auto type : argsTypes) {
  26. ArgsConverters.emplace_back(MakeBlockItemConverter(TTypeInfoHelper(), AS_TYPE(TBlockType, type)->GetItemType(), pgBuilder));
  27. ArgsReaders.emplace_back(MakeBlockReader(TTypeInfoHelper(), AS_TYPE(TBlockType, type)->GetItemType()));
  28. }
  29. }
  30. const NUdf::IPgBuilder& PgBuilder;
  31. std::unique_ptr<IBlockItemConverter> ReturnConverter;
  32. TVector<std::unique_ptr<IBlockItemConverter>> ArgsConverters;
  33. TVector<std::unique_ptr<IBlockReader>> ArgsReaders;
  34. bool ScalarsProcessed = false;
  35. };
  36. struct TKernelState : public arrow::compute::KernelState {
  37. TKernelState(const TVector<TType*>& argsTypes, TType* returnType, const TComputationContext& originalContext)
  38. : Alloc(__LOCATION__)
  39. , TypeEnv(Alloc)
  40. , MemInfo("ScalarApply")
  41. , FunctionRegistry(originalContext.HolderFactory.GetFunctionRegistry()->Clone())
  42. , HolderFactory(Alloc.Ref(), MemInfo, FunctionRegistry.Get())
  43. , ValueBuilder(HolderFactory, NUdf::EValidatePolicy::Exception)
  44. , PgBuilder(NYql::CreatePgBuilder())
  45. , Accessors(argsTypes, returnType, *PgBuilder)
  46. , RandomProvider(CreateDefaultRandomProvider())
  47. , TimeProvider(CreateDefaultTimeProvider())
  48. , Ctx(HolderFactory, &ValueBuilder, TComputationOptsFull(
  49. nullptr, Alloc.Ref(), TypeEnv, *RandomProvider, *TimeProvider, NUdf::EValidatePolicy::Exception, originalContext.SecureParamsProvider, originalContext.CountersProvider),
  50. originalContext.Mutables, *NYql::NUdf::GetYqlMemoryPool())
  51. {
  52. Alloc.Ref().EnableArrowTracking = false;
  53. Alloc.Release();
  54. }
  55. ~TKernelState()
  56. {
  57. Alloc.Acquire();
  58. }
  59. TScopedAlloc Alloc;
  60. TTypeEnvironment TypeEnv;
  61. TMemoryUsageInfo MemInfo;
  62. const IFunctionRegistry::TPtr FunctionRegistry;
  63. THolderFactory HolderFactory;
  64. TDefaultValueBuilder ValueBuilder;
  65. std::unique_ptr<NUdf::IPgBuilder> PgBuilder;
  66. TAccessors Accessors;
  67. TIntrusivePtr<IRandomProvider> RandomProvider;
  68. TIntrusivePtr<ITimeProvider> TimeProvider;
  69. TComputationContext Ctx;
  70. };
  71. class TArrowNode : public IArrowKernelComputationNode {
  72. public:
  73. TArrowNode(const TScalarApplyWrapper* parent, TComputationContext& originalContext)
  74. : Parent_(parent)
  75. , OriginalContext_(originalContext)
  76. , ArgsValuesDescr_(ToValueDescr(parent->ArgsTypes_))
  77. , Kernel_(ConvertToInputTypes(parent->ArgsTypes_), ConvertToOutputType(parent->ReturnType_), [parent](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  78. auto& state = dynamic_cast<TKernelState&>(*ctx->state());
  79. auto guard = Guard(state.Alloc);
  80. TVector<TDatumProvider> providers;
  81. providers.reserve(batch.values.size());
  82. for (const auto& v : batch.values) {
  83. providers.emplace_back(MakeDatumProvider(v));
  84. }
  85. *res = parent->CalculateImpl(providers, state.Accessors, *NYql::NUdf::GetYqlMemoryPool(), state.Ctx);
  86. return arrow::Status::OK();
  87. })
  88. {
  89. Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  90. Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
  91. Kernel_.init = [parent, ctx = &OriginalContext_](arrow::compute::KernelContext*, const arrow::compute::KernelInitArgs&) {
  92. auto state = std::make_unique<TKernelState>(parent->ArgsTypes_, parent->ReturnType_, *ctx);
  93. return arrow::Result(std::move(state));
  94. };
  95. }
  96. TStringBuf GetKernelName() const final {
  97. return "ScalarApply";
  98. }
  99. const arrow::compute::ScalarKernel& GetArrowKernel() const {
  100. return Kernel_;
  101. }
  102. const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
  103. return ArgsValuesDescr_;
  104. }
  105. const IComputationNode* GetArgument(ui32 index) const {
  106. return Parent_->Args_[index];
  107. }
  108. private:
  109. const TScalarApplyWrapper* Parent_;
  110. const TComputationContext& OriginalContext_;
  111. const std::vector<arrow::ValueDescr> ArgsValuesDescr_;
  112. arrow::compute::ScalarKernel Kernel_;
  113. };
  114. friend class TArrowNode;
  115. TScalarApplyWrapper(TComputationMutables& mutables, const TVector<TType*>& argsTypes, TType* returnType,
  116. TVector<IComputationNode*>&& args, TVector<IComputationExternalNode*>&& lambdaArgs, IComputationNode* lambdaRoot)
  117. : TMutableComputationNode(mutables)
  118. , StateIndex_(mutables.CurValueIndex++)
  119. , ArgsTypes_(argsTypes)
  120. , ReturnType_(returnType)
  121. , Args_(std::move(args))
  122. , LambdaArgs_(std::move(lambdaArgs))
  123. , LambdaRoot_(lambdaRoot)
  124. {
  125. MKQL_ENSURE(Args_.size() == LambdaArgs_.size(), "Mismatch args count");
  126. }
  127. std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
  128. return std::make_unique<TArrowNode>(this, ctx);
  129. }
  130. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  131. TVector<TDatumProvider> providers;
  132. providers.reserve(Args_.size());
  133. for (auto arg : Args_) {
  134. providers.emplace_back(MakeDatumProvider(arg, ctx));
  135. }
  136. auto& state = GetState(ctx);
  137. return ctx.HolderFactory.CreateArrowBlock(CalculateImpl(providers, state.Accessors, ctx.ArrowMemoryPool, ctx));
  138. }
  139. arrow::Datum CalculateImpl(const TVector<TDatumProvider>& providers, TAccessors& accessors, arrow::MemoryPool& memoryPool,
  140. TComputationContext& ctx) const {
  141. TVector<arrow::Datum> args;
  142. args.reserve(providers.size());
  143. size_t length = 1;
  144. for (const auto& prov : providers) {
  145. args.emplace_back(prov());
  146. if (!args.back().is_scalar()) {
  147. length = args.back().array()->length;
  148. }
  149. }
  150. auto returnItemType = AS_TYPE(TBlockType, ReturnType_)->GetItemType();
  151. if (AS_TYPE(TBlockType, ReturnType_)->GetShape() == TBlockType::EShape::Scalar) {
  152. if (!accessors.ScalarsProcessed) {
  153. for (ui32 j = 0; j < Args_.size(); ++j) {
  154. if (!LambdaArgs_[j]) {
  155. continue;
  156. }
  157. auto item = accessors.ArgsReaders[j]->GetScalarItem(*args[j].scalar());
  158. auto value = accessors.ArgsConverters[j]->MakeValue(item, ctx.HolderFactory);
  159. LambdaArgs_[j]->SetValue(ctx, value);
  160. }
  161. accessors.ScalarsProcessed = true;
  162. }
  163. auto value = LambdaRoot_->GetValue(ctx);
  164. return ConvertScalar(returnItemType, value, memoryPool);
  165. } else {
  166. auto builder = MakeArrayBuilder(TTypeInfoHelper(), returnItemType, memoryPool, length, &accessors.PgBuilder);
  167. for (size_t i = 0; i < length; ++i) {
  168. for (ui32 j = 0; j < Args_.size(); ++j) {
  169. if (!LambdaArgs_[j]) {
  170. continue;
  171. }
  172. if (args[j].is_scalar() && accessors.ScalarsProcessed) {
  173. continue;
  174. }
  175. auto item = args[j].is_scalar() ?
  176. accessors.ArgsReaders[j]->GetScalarItem(*args[j].scalar()) :
  177. accessors.ArgsReaders[j]->GetItem(*args[j].array(), i);
  178. auto value = accessors.ArgsConverters[j]->MakeValue(item, ctx.HolderFactory);
  179. LambdaArgs_[j]->SetValue(ctx, value);
  180. }
  181. accessors.ScalarsProcessed = true;
  182. auto value = LambdaRoot_->GetValue(ctx);
  183. auto item = accessors.ReturnConverter->MakeItem(value);
  184. builder->Add(item);
  185. }
  186. return builder->Build(true);
  187. }
  188. }
  189. private:
  190. void RegisterDependencies() const final {
  191. for (auto arg : Args_) {
  192. this->DependsOn(arg);
  193. }
  194. for (ui32 i = 0; i < Args_.size(); ++i) {
  195. Args_[i]->AddDependence(LambdaArgs_[i]);
  196. this->Own(LambdaArgs_[i]);
  197. }
  198. this->DependsOn(LambdaRoot_);
  199. }
  200. struct TState : public TComputationValue<TState> {
  201. using TComputationValue::TComputationValue;
  202. TState(TMemoryUsageInfo* memInfo, const TVector<TType*>& argsTypes, TType* returnType, const NUdf::IPgBuilder& pgBuilder)
  203. : TComputationValue(memInfo)
  204. , Accessors(argsTypes, returnType, pgBuilder)
  205. {
  206. }
  207. TAccessors Accessors;
  208. };
  209. TState& GetState(TComputationContext& ctx) const {
  210. auto& result = ctx.MutableValues[StateIndex_];
  211. if (!result.HasValue()) {
  212. result = ctx.HolderFactory.Create<TState>(ArgsTypes_, ReturnType_, ctx.Builder->GetPgBuilder());
  213. }
  214. return *static_cast<TState*>(result.AsBoxed().Get());
  215. }
  216. const ui32 StateIndex_;
  217. const TVector<TType*> ArgsTypes_;
  218. TType* const ReturnType_;
  219. const TVector<IComputationNode*> Args_;
  220. const TVector<IComputationExternalNode*> LambdaArgs_;
  221. IComputationNode* const LambdaRoot_;
  222. };
  223. } // namespace
  224. IComputationNode* WrapScalarApply(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  225. MKQL_ENSURE((callable.GetInputsCount() >= 3) && (callable.GetInputsCount() % 2 == 1), "Bad args count");
  226. auto lambdaRoot = LocateNode(ctx.NodeLocator, callable, callable.GetInputsCount() - 1);
  227. auto argsCount = (callable.GetInputsCount() - 1) / 2;
  228. TVector<IComputationNode*> args(argsCount);
  229. TVector<IComputationExternalNode*> lambdaArgs(argsCount);
  230. TVector<TType*> argsTypes(argsCount);
  231. for (ui32 i = 0; i < argsCount; ++i) {
  232. args[i] = LocateNode(ctx.NodeLocator, callable, i);
  233. lambdaArgs[i] = LocateExternalNode(ctx.NodeLocator, callable, i + argsCount);
  234. argsTypes[i] = callable.GetType()->GetArgumentType(i);
  235. }
  236. return new TScalarApplyWrapper(ctx.Mutables, argsTypes, callable.GetType()->GetReturnType(),
  237. std::move(args), std::move(lambdaArgs), lambdaRoot);
  238. }
  239. }
  240. }