mkql_apply.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #include "mkql_apply.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. #include <library/cpp/containers/stack_array/stack_array.h>
  7. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  8. #include <yql/essentials/minikql/computation/mkql_value_builder.h>
  9. namespace NKikimr {
  10. namespace NMiniKQL {
  11. namespace {
  12. class TApplyWrapper: public TMutableCodegeneratorPtrNode<TApplyWrapper> {
  13. typedef TMutableCodegeneratorPtrNode<TApplyWrapper> TBaseComputation;
  14. public:
  15. struct TKernelState : public arrow::compute::KernelState {
  16. TKernelState(ui32 argsCount)
  17. : Alloc(__LOCATION__)
  18. , MemInfo("Apply")
  19. , HolderFactory(Alloc.Ref(), MemInfo)
  20. , ValueBuilder(HolderFactory, NUdf::EValidatePolicy::Exception)
  21. , Args(argsCount)
  22. {
  23. Alloc.Ref().EnableArrowTracking = false;
  24. Alloc.Release();
  25. }
  26. ~TKernelState()
  27. {
  28. Alloc.Acquire();
  29. }
  30. TScopedAlloc Alloc;
  31. TMemoryUsageInfo MemInfo;
  32. THolderFactory HolderFactory;
  33. TDefaultValueBuilder ValueBuilder;
  34. TVector<NUdf::TUnboxedValue> Args;
  35. };
  36. class TArrowNode : public IArrowKernelComputationNode {
  37. public:
  38. TArrowNode(const TApplyWrapper* parent, const NUdf::TUnboxedValue& callable, TType* returnType, const TVector<TType*>& argsTypes)
  39. : Parent_(parent)
  40. , Callable_(callable)
  41. , ArgsValuesDescr_(ToValueDescr(argsTypes))
  42. , Kernel_(ConvertToInputTypes(argsTypes), ConvertToOutputType(returnType), [this](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  43. auto& state = dynamic_cast<TKernelState&>(*ctx->state());
  44. auto guard = Guard(state.Alloc);
  45. Y_ENSURE(batch.values.size() == state.Args.size());
  46. for (ui32 i = 0; i < batch.values.size(); ++i) {
  47. state.Args[i] = state.HolderFactory.CreateArrowBlock(arrow::Datum(batch.values[i]));
  48. }
  49. const auto& ret = Callable_.Run(&state.ValueBuilder, state.Args.data());
  50. *res = TArrowBlock::From(ret).GetDatum();
  51. return arrow::Status::OK();
  52. })
  53. {
  54. Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  55. Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
  56. Kernel_.init = [argsCount = argsTypes.size()](arrow::compute::KernelContext*, const arrow::compute::KernelInitArgs&) {
  57. auto state = std::make_unique<TKernelState>(argsCount);
  58. return arrow::Result(std::move(state));
  59. };
  60. }
  61. TStringBuf GetKernelName() const final {
  62. return "Apply";
  63. }
  64. const arrow::compute::ScalarKernel& GetArrowKernel() const {
  65. return Kernel_;
  66. }
  67. const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
  68. return ArgsValuesDescr_;
  69. }
  70. const IComputationNode* GetArgument(ui32 index) const {
  71. return Parent_->ArgNodes[index];
  72. }
  73. private:
  74. const TApplyWrapper* Parent_;
  75. const NUdf::TUnboxedValue Callable_;
  76. const std::vector<arrow::ValueDescr> ArgsValuesDescr_;
  77. arrow::compute::ScalarKernel Kernel_;
  78. };
  79. friend class TArrowNode;
  80. TApplyWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* callableNode,
  81. TComputationNodePtrVector&& argNodes, ui32 usedArgs, const NUdf::TSourcePosition& pos, TCallableType* callableType)
  82. : TBaseComputation(mutables, kind)
  83. , CallableNode(callableNode)
  84. , ArgNodes(std::move(argNodes))
  85. , UsedArgs(usedArgs)
  86. , Position(pos)
  87. , CallableType(callableType)
  88. {
  89. Stateless = false;
  90. }
  91. std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
  92. if (UsedArgs != CallableType->GetArgumentsCount()) {
  93. return {};
  94. }
  95. std::shared_ptr<arrow::DataType> t;
  96. if (!CallableType->GetReturnType()->IsBlock() ||
  97. !ConvertArrowType(AS_TYPE(TBlockType, CallableType->GetReturnType())->GetItemType(), t)) {
  98. return {};
  99. }
  100. TVector<TType*> argsTypes;
  101. for (ui32 i = 0; i < CallableType->GetArgumentsCount(); ++i) {
  102. argsTypes.push_back(CallableType->GetArgumentType(i));
  103. if (!CallableType->GetArgumentType(i)->IsBlock() ||
  104. !ConvertArrowType(AS_TYPE(TBlockType, CallableType->GetArgumentType(i))->GetItemType(), t)) {
  105. return {};
  106. }
  107. }
  108. const auto callable = CallableNode->GetValue(ctx);
  109. return std::make_unique<TArrowNode>(this, callable, CallableType->GetReturnType(), argsTypes);
  110. }
  111. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  112. NStackArray::TStackArray<NUdf::TUnboxedValue> values(ALLOC_ON_STACK(NUdf::TUnboxedValue, UsedArgs));
  113. for (size_t i = 0; i < UsedArgs; ++i) {
  114. if (const auto valueNode = ArgNodes[i]) {
  115. values[i] = valueNode->GetValue(ctx);
  116. }
  117. }
  118. const auto callable = CallableNode->GetValue(ctx);
  119. const auto prev = ctx.CalleePosition;
  120. ctx.CalleePosition = &Position;
  121. const auto ret = callable.Run(ctx.Builder, values.data());
  122. ctx.CalleePosition = prev;
  123. return ret;
  124. }
  125. #ifndef MKQL_DISABLE_CODEGEN
  126. void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
  127. auto& context = ctx.Codegen.GetContext();
  128. const auto idxType = Type::getInt32Ty(context);
  129. const auto valType = Type::getInt128Ty(context);
  130. const auto arrayType = ArrayType::get(valType, ArgNodes.size());
  131. const auto args = *Stateless || ctx.AlwaysInline ?
  132. new AllocaInst(arrayType, 0U, "args", &ctx.Func->getEntryBlock().back()):
  133. new AllocaInst(arrayType, 0U, "args", block);
  134. ui32 i = 0;
  135. std::vector<std::pair<Value*, EValueRepresentation>> argsv;
  136. argsv.reserve(ArgNodes.size());
  137. for (const auto node : ArgNodes) {
  138. const auto argPtr = GetElementPtrInst::CreateInBounds(arrayType, args, {ConstantInt::get(idxType, 0), ConstantInt::get(idxType, i++)}, "arg_ptr", block);
  139. if (node) {
  140. GetNodeValue(argPtr, node, ctx, block);
  141. argsv.emplace_back(argPtr, node->GetRepresentation());
  142. } else {
  143. new StoreInst(ConstantInt::get(valType, 0), argPtr, block);
  144. }
  145. }
  146. if (const auto codegen = dynamic_cast<ICodegeneratorRunNode*>(CallableNode)) {
  147. codegen->CreateRun(ctx, block, pointer, args);
  148. } else {
  149. const auto callable = GetNodeValue(CallableNode, ctx, block);
  150. const auto calleePtr = GetElementPtrInst::CreateInBounds(GetCompContextType(context), ctx.Ctx, {ConstantInt::get(idxType, 0), ConstantInt::get(idxType, 6)}, "callee_ptr", block);
  151. const auto previous = new LoadInst(PointerType::getUnqual(GetSourcePosType(context)), calleePtr, "previous", block);
  152. const auto callee = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), ui64(&Position)), previous->getType(), "callee", block);
  153. new StoreInst(callee, calleePtr, block);
  154. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, callable, ctx.Codegen, block, ctx.GetBuilder(), args);
  155. new StoreInst(previous, calleePtr, block);
  156. if (CallableNode->IsTemporaryValue()) {
  157. CleanupBoxed(callable, ctx, block);
  158. }
  159. }
  160. for (const auto& arg : argsv) {
  161. ValueUnRef(arg.second, arg.first, ctx, block);
  162. }
  163. }
  164. #endif
  165. private:
  166. void RegisterDependencies() const final {
  167. DependsOn(CallableNode);
  168. for (const auto node : ArgNodes) {
  169. if (node) {
  170. DependsOn(node);
  171. }
  172. }
  173. }
  174. IComputationNode *const CallableNode;
  175. const TComputationNodePtrVector ArgNodes;
  176. const ui32 UsedArgs;
  177. const NUdf::TSourcePosition Position;
  178. TCallableType* CallableType;
  179. };
  180. }
  181. IComputationNode* WrapApply(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  182. const bool withPos = callable.GetType()->GetName() == "Apply2";
  183. const ui32 deltaArgs = withPos ? 3 : 0;
  184. MKQL_ENSURE(callable.GetInputsCount() >= 2 + deltaArgs, "Expected at least " << (2 + deltaArgs) << " arguments");
  185. const auto function = callable.GetInput(0);
  186. MKQL_ENSURE(!function.IsImmediate() && function.GetNode()->GetType()->IsCallable(),
  187. "First argument of Apply must be a callable");
  188. const auto functionCallable = static_cast<TCallable*>(function.GetNode());
  189. const auto returnType = functionCallable->GetType()->GetReturnType();
  190. MKQL_ENSURE(returnType->IsCallable(), "Expected callable as return type");
  191. const TStringBuf file = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().AsStringRef() : NUdf::TStringRef();
  192. const ui32 row = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui32>() : 0;
  193. const ui32 column = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().Get<ui32>() : 0;
  194. const ui32 inputsCount = callable.GetInputsCount() - deltaArgs;
  195. const ui32 argsCount = inputsCount - 2;
  196. const ui32 dependentCount = AS_VALUE(TDataLiteral, callable.GetInput(1))->AsValue().Get<ui32>();
  197. MKQL_ENSURE(dependentCount <= argsCount, "Too many dependent nodes");
  198. const ui32 usedArgs = argsCount - dependentCount;
  199. auto callableType = static_cast<TCallableType*>(returnType);
  200. MKQL_ENSURE(usedArgs <= callableType->GetArgumentsCount(), "Too many arguments");
  201. MKQL_ENSURE(usedArgs >= callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount(), "Too few arguments");
  202. TComputationNodePtrVector argNodes(callableType->GetArgumentsCount() + dependentCount);
  203. for (ui32 i = 2; i < 2 + usedArgs; ++i) {
  204. argNodes[i - 2] = LocateNode(ctx.NodeLocator, callable, i + deltaArgs);
  205. }
  206. for (ui32 i = 2 + usedArgs; i < inputsCount; ++i) {
  207. argNodes[callableType->GetArgumentsCount() + i - 2 - usedArgs] = LocateNode(ctx.NodeLocator, callable, i + deltaArgs);
  208. }
  209. auto functionNode = LocateNode(ctx.NodeLocator, callable, 0);
  210. return new TApplyWrapper(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), functionNode, std::move(argNodes),
  211. callableType->GetArgumentsCount(), NUdf::TSourcePosition(row, column, file), callableType);
  212. }
  213. }
  214. }