#include "mkql_wide_chain_map.h" #include #include // Y_IGNORE #include #include #include namespace NKikimr { namespace NMiniKQL { using NYql::EnsureDynamicCast; namespace { class TWideChain1MapWrapper : public TStatefulWideFlowCodegeneratorNode { using TBaseComputation = TStatefulWideFlowCodegeneratorNode; public: TWideChain1MapWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& inputs, TComputationNodePtrVector&& initItems, TComputationExternalNodePtrVector&& outputs, TComputationNodePtrVector&& updateItems) : TBaseComputation(mutables, flow, EValueRepresentation::Embedded) , Flow(flow) , Inputs(std::move(inputs)) , InitItems(std::move(initItems)) , Outputs(std::move(outputs)) , UpdateItems(std::move(updateItems)) , InputsOnInit(GetPasstroughtMapOneToOne(Inputs, InitItems)) , InputsOnUpdate(GetPasstroughtMapOneToOne(Inputs, UpdateItems)) , InitOnInputs(GetPasstroughtMapOneToOne(InitItems, Inputs)) , UpdateOnInputs(GetPasstroughtMapOneToOne(UpdateItems, Inputs)) , OutputsOnUpdate(GetPasstroughtMapOneToOne(Outputs, UpdateItems)) , UpdateOnOutputs(GetPasstroughtMapOneToOne(UpdateItems, Outputs)) , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Inputs.size())) , TempStateIndex(std::exchange(mutables.CurValueIndex, mutables.CurValueIndex + Outputs.size())) {} EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { if (state.IsInvalid()) { state = NUdf::TUnboxedValuePod(); return CalculateFirst(ctx, output); } return CalculateOther(ctx, output); } #ifndef MKQL_DISABLE_CODEGEN ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto flagType = Type::getInt1Ty(context); const auto flagPtr = new AllocaInst(flagType, 0U, "flag_ptr", &ctx.Func->getEntryBlock().back()); const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto getres = GetNodeValues(Flow, ctx, block); const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), 0), "special", block); BranchInst::Create(done, good, special, block); block = good; for (auto i = 0U; i < Inputs.size(); ++i) if (Inputs[i]->GetDependencesCount() > 0U || !InputsOnInit[i] || !InputsOnUpdate[i]) EnsureDynamicCast(Inputs[i])->CreateSetValue(ctx, block, getres.second[i](ctx, block)); const auto init = BasicBlock::Create(context, "init", ctx.Func); const auto next = BasicBlock::Create(context, "next", ctx.Func); const auto flag = IsInvalid(statePtr, block, context); new StoreInst(flag, flagPtr, block); BranchInst::Create(init, next, flag, block); block = init; for (auto i = 0U; i < Outputs.size(); ++i) { if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) { const auto& map = InitOnInputs[i]; const auto value = map ? getres.second[*map](ctx, block) : GetNodeValue(InitItems[i], ctx, block); EnsureDynamicCast(Outputs[i])->CreateSetValue(ctx, block, value); } } new StoreInst(GetEmpty(context), statePtr, block); BranchInst::Create(done, block); block = next; std::vector outputs(Outputs.size(), nullptr); for (auto i = 0U; i < outputs.size(); ++i) { if (const auto& dep = OutputsOnUpdate[i]; Outputs[i]->GetDependencesCount() > 0U || (dep && *dep != i)) { const auto& map = UpdateOnInputs[i]; outputs[i] = map ? getres.second[*map](ctx, block) : GetNodeValue(UpdateItems[i], ctx, block); } } for (auto i = 0U; i < outputs.size(); ++i) if (const auto out = outputs[i]) EnsureDynamicCast(Outputs[i])->CreateSetValue(ctx, block, out); BranchInst::Create(done, block); block = done; ICodegeneratorInlineWideNode::TGettersList result; result.reserve(Outputs.size()); for (auto i = 0U; i < Outputs.size(); ++i) { if (const auto& one = InitOnInputs[i], two = UpdateOnInputs[i]; one && two && *one == *two) result.emplace_back(getres.second[*two]); else if (Outputs[i]->GetDependencesCount() > 0 || OutputsOnUpdate[i]) result.emplace_back([output = Outputs[i]] (const TCodegenContext& ctx, BasicBlock*& block) { return GetNodeValue(output, ctx, block); }); else result.emplace_back([this, i, source = getres.second, flagPtr, flagType] (const TCodegenContext& ctx, BasicBlock*& block) { auto& context = ctx.Codegen.GetContext(); const auto init = BasicBlock::Create(context, "init", ctx.Func); const auto next = BasicBlock::Create(context, "next", ctx.Func); const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto result = PHINode::Create(Type::getInt128Ty(context), 2U, "result", done); const auto flag = new LoadInst(flagType, flagPtr, "flag", block); BranchInst::Create(init, next, flag, block); block = init; if (const auto& map = InitOnInputs[i]) result->addIncoming(source[*map](ctx, block), block); else result->addIncoming(GetNodeValue(InitItems[i], ctx, block), block); BranchInst::Create(done, block); block = next; if (const auto& map = UpdateOnInputs[i]) result->addIncoming(source[*map](ctx, block), block); else result->addIncoming(GetNodeValue(UpdateItems[i], ctx, block), block); BranchInst::Create(done, block); block = done; return result; }); }; return {getres.first, std::move(result)}; } #endif private: EFetchResult CalculateFirst(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { auto** fields = ctx.WideFields.data() + WideFieldsIndex; for (auto i = 0U; i < Inputs.size(); ++i) { if (const auto& map = InputsOnInit[i]; map && !Inputs[i]->GetDependencesCount()) { if (const auto& to = UpdateOnOutputs[*map]) { fields[i] = &Outputs[*to]->RefValue(ctx); continue; } else if (const auto out = output[*map]) { fields[i] = out; continue; } } else { fields[i] = &Inputs[i]->RefValue(ctx); continue; } fields[i] = nullptr; } if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result) return result; for (auto i = 0U; i < Outputs.size(); ++i) { if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) { if (const auto& map = InitOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) { Outputs[i]->SetValue(ctx, InitItems[i]->GetValue(ctx)); } } } for (auto i = 0U; i < Outputs.size(); ++i) { if (const auto out = output[i]) { if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) *out = Outputs[i]->GetValue(ctx); else { if (const auto& map = InitOnInputs[i]) { if (const auto from = *map; !Inputs[from]->GetDependencesCount()) { if (const auto first = *InputsOnInit[from]; first != i) *out = *output[first]; continue; } } *out = InitItems[i]->GetValue(ctx); } } } return EFetchResult::One; } EFetchResult CalculateOther(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { auto** fields = ctx.WideFields.data() + WideFieldsIndex; for (auto i = 0U; i < Inputs.size(); ++i) { if (const auto& map = InputsOnUpdate[i]; map && !Inputs[i]->GetDependencesCount()) { if (const auto out = output[*map]) { fields[i] = out; continue; } } else { fields[i] = &Inputs[i]->RefValue(ctx); continue; } fields[i] = nullptr; } if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result) return result; for (auto i = 0U; i < Outputs.size(); ++i) { if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) { if (const auto& map = UpdateOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) { ctx.MutableValues[TempStateIndex + i] = UpdateItems[i]->GetValue(ctx); } } } for (auto i = 0U; i < Outputs.size(); ++i) { if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) { if (const auto& map = UpdateOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) { Outputs[i]->SetValue(ctx, std::move(ctx.MutableValues[TempStateIndex + i])); } } } for (auto i = 0U; i < Outputs.size(); ++i) { if (const auto out = output[i]) { if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) *out = Outputs[i]->GetValue(ctx); else { if (const auto& map = UpdateOnInputs[i]) { if (const auto from = *map; !Inputs[from]->GetDependencesCount()) { if (const auto first = *InputsOnUpdate[from]; first != i) *out = *output[first]; continue; } } *out = UpdateItems[i]->GetValue(ctx); } } } return EFetchResult::One; } void RegisterDependencies() const final { if (const auto flow = FlowDependsOn(Flow)) { std::for_each(Inputs.cbegin(), Inputs.cend(), std::bind(&TWideChain1MapWrapper::Own, flow, std::placeholders::_1)); std::for_each(Outputs.cbegin(), Outputs.cend(), std::bind(&TWideChain1MapWrapper::Own, flow, std::placeholders::_1)); std::for_each(InitItems.cbegin(), InitItems.cend(), std::bind(&TWideChain1MapWrapper::DependsOn, flow, std::placeholders::_1)); std::for_each(UpdateItems.cbegin(), UpdateItems.cend(), std::bind(&TWideChain1MapWrapper::DependsOn, flow, std::placeholders::_1)); } } IComputationWideFlowNode* const Flow; const TComputationExternalNodePtrVector Inputs; const TComputationNodePtrVector InitItems; const TComputationExternalNodePtrVector Outputs; const TComputationNodePtrVector UpdateItems; const TPasstroughtMap InputsOnInit, InputsOnUpdate, InitOnInputs, UpdateOnInputs, OutputsOnUpdate, UpdateOnOutputs; const ui32 WideFieldsIndex; const ui32 TempStateIndex; }; } IComputationNode* WrapWideChain1Map(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() > 0U, "Expected argument."); const auto inputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType())); const auto outputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType())); MKQL_ENSURE(callable.GetInputsCount() == inputWidth + outputWidth * 3U + 1U, "Wrong signature."); const auto flow = LocateNode(ctx.NodeLocator, callable, 0U); if (const auto wide = dynamic_cast(flow)) { TComputationNodePtrVector initOutput(outputWidth, nullptr), updateOutput(outputWidth, nullptr); auto index = inputWidth; std::generate(initOutput.begin(), initOutput.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); }); index += outputWidth; std::generate(updateOutput.begin(), updateOutput.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); }); TComputationExternalNodePtrVector inputs(inputWidth, nullptr), outputs(outputWidth, nullptr); index = 0U; std::generate(inputs.begin(), inputs.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); }); index += outputWidth; std::generate(outputs.begin(), outputs.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); }); return new TWideChain1MapWrapper(ctx.Mutables, wide, std::move(inputs), std::move(initOutput), std::move(outputs), std::move(updateOutput)); } THROW yexception() << "Expected wide flow."; } } }