#include "mkql_apply.h" #include #include // Y_IGNORE #include #include #include #include #include namespace NKikimr { namespace NMiniKQL { namespace { class TApplyWrapper: public TMutableCodegeneratorPtrNode { typedef TMutableCodegeneratorPtrNode TBaseComputation; public: struct TKernelState : public arrow::compute::KernelState { TKernelState(ui32 argsCount) : Alloc(__LOCATION__) , MemInfo("Apply") , HolderFactory(Alloc.Ref(), MemInfo) , ValueBuilder(HolderFactory, NUdf::EValidatePolicy::Exception) , Args(argsCount) { Alloc.Ref().EnableArrowTracking = false; Alloc.Release(); } ~TKernelState() { Alloc.Acquire(); } TScopedAlloc Alloc; TMemoryUsageInfo MemInfo; THolderFactory HolderFactory; TDefaultValueBuilder ValueBuilder; TVector Args; }; class TArrowNode : public IArrowKernelComputationNode { public: TArrowNode(const TApplyWrapper* parent, const NUdf::TUnboxedValue& callable, TType* returnType, const TVector& argsTypes) : Parent_(parent) , Callable_(callable) , ArgsValuesDescr_(ToValueDescr(argsTypes)) , Kernel_(ConvertToInputTypes(argsTypes), ConvertToOutputType(returnType), [this](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { auto& state = dynamic_cast(*ctx->state()); auto guard = Guard(state.Alloc); Y_ENSURE(batch.values.size() == state.Args.size()); for (ui32 i = 0; i < batch.values.size(); ++i) { state.Args[i] = state.HolderFactory.CreateArrowBlock(arrow::Datum(batch.values[i])); } const auto& ret = Callable_.Run(&state.ValueBuilder, state.Args.data()); *res = TArrowBlock::From(ret).GetDatum(); return arrow::Status::OK(); }) { Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; Kernel_.init = [argsCount = argsTypes.size()](arrow::compute::KernelContext*, const arrow::compute::KernelInitArgs&) { auto state = std::make_unique(argsCount); return arrow::Result(std::move(state)); }; } TStringBuf GetKernelName() const final { return "Apply"; } const arrow::compute::ScalarKernel& GetArrowKernel() const { return Kernel_; } const std::vector& GetArgsDesc() const { return ArgsValuesDescr_; } const IComputationNode* GetArgument(ui32 index) const { return Parent_->ArgNodes[index]; } private: const TApplyWrapper* Parent_; const NUdf::TUnboxedValue Callable_; const std::vector ArgsValuesDescr_; arrow::compute::ScalarKernel Kernel_; }; friend class TArrowNode; TApplyWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* callableNode, TComputationNodePtrVector&& argNodes, ui32 usedArgs, const NUdf::TSourcePosition& pos, TCallableType* callableType) : TBaseComputation(mutables, kind) , CallableNode(callableNode) , ArgNodes(std::move(argNodes)) , UsedArgs(usedArgs) , Position(pos) , CallableType(callableType) { Stateless = false; } std::unique_ptr PrepareArrowKernelComputationNode(TComputationContext& ctx) const final { if (UsedArgs != CallableType->GetArgumentsCount()) { return {}; } std::shared_ptr t; if (!CallableType->GetReturnType()->IsBlock() || !ConvertArrowType(AS_TYPE(TBlockType, CallableType->GetReturnType())->GetItemType(), t)) { return {}; } TVector argsTypes; for (ui32 i = 0; i < CallableType->GetArgumentsCount(); ++i) { argsTypes.push_back(CallableType->GetArgumentType(i)); if (!CallableType->GetArgumentType(i)->IsBlock() || !ConvertArrowType(AS_TYPE(TBlockType, CallableType->GetArgumentType(i))->GetItemType(), t)) { return {}; } } const auto callable = CallableNode->GetValue(ctx); return std::make_unique(this, callable, CallableType->GetReturnType(), argsTypes); } NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { NStackArray::TStackArray values(ALLOC_ON_STACK(NUdf::TUnboxedValue, UsedArgs)); for (size_t i = 0; i < UsedArgs; ++i) { if (const auto valueNode = ArgNodes[i]) { values[i] = valueNode->GetValue(ctx); } } const auto callable = CallableNode->GetValue(ctx); const auto prev = ctx.CalleePosition; ctx.CalleePosition = &Position; const auto ret = callable.Run(ctx.Builder, values.data()); ctx.CalleePosition = prev; return ret; } #ifndef MKQL_DISABLE_CODEGEN void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto idxType = Type::getInt32Ty(context); const auto valType = Type::getInt128Ty(context); const auto arrayType = ArrayType::get(valType, ArgNodes.size()); const auto args = *Stateless || ctx.AlwaysInline ? new AllocaInst(arrayType, 0U, "args", &ctx.Func->getEntryBlock().back()): new AllocaInst(arrayType, 0U, "args", block); ui32 i = 0; std::vector> argsv; argsv.reserve(ArgNodes.size()); for (const auto node : ArgNodes) { const auto argPtr = GetElementPtrInst::CreateInBounds(arrayType, args, {ConstantInt::get(idxType, 0), ConstantInt::get(idxType, i++)}, "arg_ptr", block); if (node) { GetNodeValue(argPtr, node, ctx, block); argsv.emplace_back(argPtr, node->GetRepresentation()); } else { new StoreInst(ConstantInt::get(valType, 0), argPtr, block); } } if (const auto codegen = dynamic_cast(CallableNode)) { codegen->CreateRun(ctx, block, pointer, args); } else { const auto callable = GetNodeValue(CallableNode, ctx, block); const auto calleePtr = GetElementPtrInst::CreateInBounds(GetCompContextType(context), ctx.Ctx, {ConstantInt::get(idxType, 0), ConstantInt::get(idxType, 6)}, "callee_ptr", block); const auto previous = new LoadInst(PointerType::getUnqual(GetSourcePosType(context)), calleePtr, "previous", block); const auto callee = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), ui64(&Position)), previous->getType(), "callee", block); new StoreInst(callee, calleePtr, block); CallBoxedValueVirtualMethod(pointer, callable, ctx.Codegen, block, ctx.GetBuilder(), args); new StoreInst(previous, calleePtr, block); if (CallableNode->IsTemporaryValue()) { CleanupBoxed(callable, ctx, block); } } for (const auto& arg : argsv) { ValueUnRef(arg.second, arg.first, ctx, block); } } #endif private: void RegisterDependencies() const final { DependsOn(CallableNode); for (const auto node : ArgNodes) { if (node) { DependsOn(node); } } } IComputationNode *const CallableNode; const TComputationNodePtrVector ArgNodes; const ui32 UsedArgs; const NUdf::TSourcePosition Position; TCallableType* CallableType; }; } IComputationNode* WrapApply(TCallable& callable, const TComputationNodeFactoryContext& ctx) { const bool withPos = callable.GetType()->GetName() == "Apply2"; const ui32 deltaArgs = withPos ? 3 : 0; MKQL_ENSURE(callable.GetInputsCount() >= 2 + deltaArgs, "Expected at least " << (2 + deltaArgs) << " arguments"); const auto function = callable.GetInput(0); MKQL_ENSURE(!function.IsImmediate() && function.GetNode()->GetType()->IsCallable(), "First argument of Apply must be a callable"); const auto functionCallable = static_cast(function.GetNode()); const auto returnType = functionCallable->GetType()->GetReturnType(); MKQL_ENSURE(returnType->IsCallable(), "Expected callable as return type"); const TStringBuf file = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().AsStringRef() : NUdf::TStringRef(); const ui32 row = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get() : 0; const ui32 column = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().Get() : 0; const ui32 inputsCount = callable.GetInputsCount() - deltaArgs; const ui32 argsCount = inputsCount - 2; const ui32 dependentCount = AS_VALUE(TDataLiteral, callable.GetInput(1))->AsValue().Get(); MKQL_ENSURE(dependentCount <= argsCount, "Too many dependent nodes"); const ui32 usedArgs = argsCount - dependentCount; auto callableType = static_cast(returnType); MKQL_ENSURE(usedArgs <= callableType->GetArgumentsCount(), "Too many arguments"); MKQL_ENSURE(usedArgs >= callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount(), "Too few arguments"); TComputationNodePtrVector argNodes(callableType->GetArgumentsCount() + dependentCount); for (ui32 i = 2; i < 2 + usedArgs; ++i) { argNodes[i - 2] = LocateNode(ctx.NodeLocator, callable, i + deltaArgs); } for (ui32 i = 2 + usedArgs; i < inputsCount; ++i) { argNodes[callableType->GetArgumentsCount() + i - 2 - usedArgs] = LocateNode(ctx.NodeLocator, callable, i + deltaArgs); } auto functionNode = LocateNode(ctx.NodeLocator, callable, 0); return new TApplyWrapper(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), functionNode, std::move(argNodes), callableType->GetArgumentsCount(), NUdf::TSourcePosition(row, column, file), callableType); } } }