123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- #include "mkql_apply.h"
- #include <yql/essentials/minikql/computation/mkql_block_impl.h>
- #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
- #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
- #include <yql/essentials/minikql/mkql_node_cast.h>
- #include <library/cpp/containers/stack_array/stack_array.h>
- #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
- #include <yql/essentials/minikql/computation/mkql_value_builder.h>
- namespace NKikimr {
- namespace NMiniKQL {
- namespace {
- class TApplyWrapper: public TMutableCodegeneratorPtrNode<TApplyWrapper> {
- typedef TMutableCodegeneratorPtrNode<TApplyWrapper> TBaseComputation;
- public:
- struct TKernelState : public arrow::compute::KernelState {
- TKernelState(ui32 argsCount)
- : Alloc(__LOCATION__)
- , MemInfo("Apply")
- , HolderFactory(Alloc.Ref(), MemInfo)
- , ValueBuilder(HolderFactory, NUdf::EValidatePolicy::Exception)
- , Args(argsCount)
- {
- Alloc.Ref().EnableArrowTracking = false;
- Alloc.Release();
- }
- ~TKernelState()
- {
- Alloc.Acquire();
- }
- TScopedAlloc Alloc;
- TMemoryUsageInfo MemInfo;
- THolderFactory HolderFactory;
- TDefaultValueBuilder ValueBuilder;
- TVector<NUdf::TUnboxedValue> Args;
- };
- class TArrowNode : public IArrowKernelComputationNode {
- public:
- TArrowNode(const TApplyWrapper* parent, const NUdf::TUnboxedValue& callable, TType* returnType, const TVector<TType*>& argsTypes)
- : Parent_(parent)
- , Callable_(callable)
- , ArgsValuesDescr_(ToValueDescr(argsTypes))
- , Kernel_(ConvertToInputTypes(argsTypes), ConvertToOutputType(returnType), [this](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
- auto& state = dynamic_cast<TKernelState&>(*ctx->state());
- auto guard = Guard(state.Alloc);
- Y_ENSURE(batch.values.size() == state.Args.size());
- for (ui32 i = 0; i < batch.values.size(); ++i) {
- state.Args[i] = state.HolderFactory.CreateArrowBlock(arrow::Datum(batch.values[i]));
- }
- const auto& ret = Callable_.Run(&state.ValueBuilder, state.Args.data());
- *res = TArrowBlock::From(ret).GetDatum();
- return arrow::Status::OK();
- })
- {
- Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
- Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
- Kernel_.init = [argsCount = argsTypes.size()](arrow::compute::KernelContext*, const arrow::compute::KernelInitArgs&) {
- auto state = std::make_unique<TKernelState>(argsCount);
- return arrow::Result(std::move(state));
- };
- }
- TStringBuf GetKernelName() const final {
- return "Apply";
- }
- const arrow::compute::ScalarKernel& GetArrowKernel() const {
- return Kernel_;
- }
- const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
- return ArgsValuesDescr_;
- }
- const IComputationNode* GetArgument(ui32 index) const {
- return Parent_->ArgNodes[index];
- }
- private:
- const TApplyWrapper* Parent_;
- const NUdf::TUnboxedValue Callable_;
- const std::vector<arrow::ValueDescr> ArgsValuesDescr_;
- arrow::compute::ScalarKernel Kernel_;
- };
- friend class TArrowNode;
- TApplyWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* callableNode,
- TComputationNodePtrVector&& argNodes, ui32 usedArgs, const NUdf::TSourcePosition& pos, TCallableType* callableType)
- : TBaseComputation(mutables, kind)
- , CallableNode(callableNode)
- , ArgNodes(std::move(argNodes))
- , UsedArgs(usedArgs)
- , Position(pos)
- , CallableType(callableType)
- {
- Stateless = false;
- }
- std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
- if (UsedArgs != CallableType->GetArgumentsCount()) {
- return {};
- }
- std::shared_ptr<arrow::DataType> t;
- if (!CallableType->GetReturnType()->IsBlock() ||
- !ConvertArrowType(AS_TYPE(TBlockType, CallableType->GetReturnType())->GetItemType(), t)) {
- return {};
- }
- TVector<TType*> argsTypes;
- for (ui32 i = 0; i < CallableType->GetArgumentsCount(); ++i) {
- argsTypes.push_back(CallableType->GetArgumentType(i));
- if (!CallableType->GetArgumentType(i)->IsBlock() ||
- !ConvertArrowType(AS_TYPE(TBlockType, CallableType->GetArgumentType(i))->GetItemType(), t)) {
- return {};
- }
- }
- const auto callable = CallableNode->GetValue(ctx);
- return std::make_unique<TArrowNode>(this, callable, CallableType->GetReturnType(), argsTypes);
- }
- NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
- NStackArray::TStackArray<NUdf::TUnboxedValue> values(ALLOC_ON_STACK(NUdf::TUnboxedValue, UsedArgs));
- for (size_t i = 0; i < UsedArgs; ++i) {
- if (const auto valueNode = ArgNodes[i]) {
- values[i] = valueNode->GetValue(ctx);
- }
- }
- const auto callable = CallableNode->GetValue(ctx);
- const auto prev = ctx.CalleePosition;
- ctx.CalleePosition = &Position;
- const auto ret = callable.Run(ctx.Builder, values.data());
- ctx.CalleePosition = prev;
- return ret;
- }
- #ifndef MKQL_DISABLE_CODEGEN
- void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto idxType = Type::getInt32Ty(context);
- const auto valType = Type::getInt128Ty(context);
- const auto arrayType = ArrayType::get(valType, ArgNodes.size());
- const auto args = *Stateless || ctx.AlwaysInline ?
- new AllocaInst(arrayType, 0U, "args", &ctx.Func->getEntryBlock().back()):
- new AllocaInst(arrayType, 0U, "args", block);
- ui32 i = 0;
- std::vector<std::pair<Value*, EValueRepresentation>> argsv;
- argsv.reserve(ArgNodes.size());
- for (const auto node : ArgNodes) {
- const auto argPtr = GetElementPtrInst::CreateInBounds(arrayType, args, {ConstantInt::get(idxType, 0), ConstantInt::get(idxType, i++)}, "arg_ptr", block);
- if (node) {
- GetNodeValue(argPtr, node, ctx, block);
- argsv.emplace_back(argPtr, node->GetRepresentation());
- } else {
- new StoreInst(ConstantInt::get(valType, 0), argPtr, block);
- }
- }
- if (const auto codegen = dynamic_cast<ICodegeneratorRunNode*>(CallableNode)) {
- codegen->CreateRun(ctx, block, pointer, args);
- } else {
- const auto callable = GetNodeValue(CallableNode, ctx, block);
- const auto calleePtr = GetElementPtrInst::CreateInBounds(GetCompContextType(context), ctx.Ctx, {ConstantInt::get(idxType, 0), ConstantInt::get(idxType, 6)}, "callee_ptr", block);
- const auto previous = new LoadInst(PointerType::getUnqual(GetSourcePosType(context)), calleePtr, "previous", block);
- const auto callee = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), ui64(&Position)), previous->getType(), "callee", block);
- new StoreInst(callee, calleePtr, block);
- CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, callable, ctx.Codegen, block, ctx.GetBuilder(), args);
- new StoreInst(previous, calleePtr, block);
- if (CallableNode->IsTemporaryValue()) {
- CleanupBoxed(callable, ctx, block);
- }
- }
- for (const auto& arg : argsv) {
- ValueUnRef(arg.second, arg.first, ctx, block);
- }
- }
- #endif
- private:
- void RegisterDependencies() const final {
- DependsOn(CallableNode);
- for (const auto node : ArgNodes) {
- if (node) {
- DependsOn(node);
- }
- }
- }
- IComputationNode *const CallableNode;
- const TComputationNodePtrVector ArgNodes;
- const ui32 UsedArgs;
- const NUdf::TSourcePosition Position;
- TCallableType* CallableType;
- };
- }
- IComputationNode* WrapApply(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- const bool withPos = callable.GetType()->GetName() == "Apply2";
- const ui32 deltaArgs = withPos ? 3 : 0;
- MKQL_ENSURE(callable.GetInputsCount() >= 2 + deltaArgs, "Expected at least " << (2 + deltaArgs) << " arguments");
- const auto function = callable.GetInput(0);
- MKQL_ENSURE(!function.IsImmediate() && function.GetNode()->GetType()->IsCallable(),
- "First argument of Apply must be a callable");
- const auto functionCallable = static_cast<TCallable*>(function.GetNode());
- const auto returnType = functionCallable->GetType()->GetReturnType();
- MKQL_ENSURE(returnType->IsCallable(), "Expected callable as return type");
- const TStringBuf file = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().AsStringRef() : NUdf::TStringRef();
- const ui32 row = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui32>() : 0;
- const ui32 column = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().Get<ui32>() : 0;
- const ui32 inputsCount = callable.GetInputsCount() - deltaArgs;
- const ui32 argsCount = inputsCount - 2;
- const ui32 dependentCount = AS_VALUE(TDataLiteral, callable.GetInput(1))->AsValue().Get<ui32>();
- MKQL_ENSURE(dependentCount <= argsCount, "Too many dependent nodes");
- const ui32 usedArgs = argsCount - dependentCount;
- auto callableType = static_cast<TCallableType*>(returnType);
- MKQL_ENSURE(usedArgs <= callableType->GetArgumentsCount(), "Too many arguments");
- MKQL_ENSURE(usedArgs >= callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount(), "Too few arguments");
- TComputationNodePtrVector argNodes(callableType->GetArgumentsCount() + dependentCount);
- for (ui32 i = 2; i < 2 + usedArgs; ++i) {
- argNodes[i - 2] = LocateNode(ctx.NodeLocator, callable, i + deltaArgs);
- }
- for (ui32 i = 2 + usedArgs; i < inputsCount; ++i) {
- argNodes[callableType->GetArgumentsCount() + i - 2 - usedArgs] = LocateNode(ctx.NodeLocator, callable, i + deltaArgs);
- }
- auto functionNode = LocateNode(ctx.NodeLocator, callable, 0);
- return new TApplyWrapper(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), functionNode, std::move(argNodes),
- callableType->GetArgumentsCount(), NUdf::TSourcePosition(row, column, file), callableType);
- }
- }
- }
|