123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465 |
- #include "mkql_block_skiptake.h"
- #include <yql/essentials/minikql/computation/mkql_block_impl.h>
- #include <yql/essentials/minikql/arrow/arrow_defs.h>
- #include <yql/essentials/minikql/arrow/arrow_util.h>
- #include <yql/essentials/minikql/mkql_type_builder.h>
- #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
- #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
- #include <yql/essentials/minikql/mkql_node_builder.h>
- #include <yql/essentials/minikql/mkql_node_cast.h>
- namespace NKikimr {
- namespace NMiniKQL {
- namespace {
- class TWideSkipBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TWideSkipBlocksWrapper> {
- using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideSkipBlocksWrapper>;
- public:
- TWideSkipBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, ui32 size)
- : TBaseComputation(mutables, flow, EValueRepresentation::Embedded), Flow(flow), Count(count), Width(size - 1U)
- {}
- EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
- if (state.IsInvalid()) {
- state = Count->GetValue(ctx);
- }
- if (auto count = state.Get<ui64>()) while (true) {
- if (const auto result = Flow->FetchValues(ctx, output); EFetchResult::One != result) {
- state = NUdf::TUnboxedValuePod(count);
- return result;
- }
- if (const auto blockSize = GetBlockCount(*output[Width]); count < blockSize) {
- state = NUdf::TUnboxedValuePod::Zero();
- *output[Width] = MakeBlockCount(ctx.HolderFactory, blockSize - count);
- for (auto i = 0U; i < Width; ++i)
- if (const auto out = output[i])
- *out = SliceBlock(ctx.HolderFactory, *out, count);
- return EFetchResult::One;
- } else
- count -= blockSize;
- }
- return Flow->FetchValues(ctx, output);
- }
- #ifndef MKQL_DISABLE_CODEGEN
- TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto indexType = Type::getInt64Ty(context);
- const auto valueType = Type::getInt128Ty(context);
- const auto atTop = &ctx.Func->getEntryBlock().back();
- const auto offsetPtr = new AllocaInst(indexType, 0U, "offset_ptr", atTop);
- const auto sizePtr = new AllocaInst(indexType, 0U, "size_ptr", atTop);
- const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideSkipBlocksWrapper::SliceBlock));
- const auto sliceType = FunctionType::get(valueType, {ctx.GetFactory()->getType(), valueType, indexType}, false);
- const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice", atTop);
- const auto name = "GetBlockCount";
- ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&GetBlockCount));
- const auto getCountType = FunctionType::get(indexType, { valueType }, false);
- const auto getCount = ctx.Codegen.GetModule().getOrInsertFunction(name, getCountType);
- const auto init = BasicBlock::Create(context, "init", ctx.Func);
- const auto main = BasicBlock::Create(context, "main", ctx.Func);
- const auto load = new LoadInst(valueType, statePtr, "load", block);
- const auto state = PHINode::Create(valueType, 2U, "state", main);
- state->addIncoming(load, block);
- BranchInst::Create(init, main, IsInvalid(load, block, context), block);
- block = init;
- GetNodeValue(statePtr, Count, ctx, block);
- const auto save = new LoadInst(valueType, statePtr, "save", block);
- state->addIncoming(save, block);
- BranchInst::Create(main, block);
- block = main;
- const auto work = BasicBlock::Create(context, "work", ctx.Func);
- const auto good = BasicBlock::Create(context, "good", ctx.Func);
- const auto test = BasicBlock::Create(context, "test", ctx.Func);
- const auto over = BasicBlock::Create(context, "over", ctx.Func);
- const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- const auto resultType = Type::getInt32Ty(context);
- const auto result = PHINode::Create(resultType, 2U, "result", done);
- const auto trunc = GetterFor<ui64>(state, context, block);
- const auto count = PHINode::Create(trunc->getType(), 2U, "count", work);
- count->addIncoming(trunc, block);
- BranchInst::Create(work, block);
- block = work;
- const auto getres = GetNodeValues(Flow, ctx, block);
- const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
- BranchInst::Create(pass, good, special, block);
- block = good;
- const auto more = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, count, ConstantInt::get(indexType, 0), "more", block);
- BranchInst::Create(test, pass, more, block);
- block = test;
- const auto countValue = getres.second.back()(ctx, block);
- const auto height = CallInst::Create(getCount, { countValue }, "height", block);
- ValueCleanup(EValueRepresentation::Any, countValue, ctx, block);
- const auto part = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, count, height, "part", block);
- const auto decr = BinaryOperator::CreateSub(count, height, "decr", block);
- count->addIncoming(decr, block);
- BranchInst::Create(over, work, part, block);
- block = over;
- const auto tail = BinaryOperator::CreateSub(height, count, "tail", block);
- new StoreInst(count, offsetPtr, block);
- new StoreInst(tail, sizePtr, block);
- new StoreInst(GetFalse(context), statePtr, block);
- result->addIncoming(getres.first, block);
- BranchInst::Create(done, block);
- block = pass;
- new StoreInst(ConstantInt::get(indexType, 0), offsetPtr, block);
- new StoreInst(ConstantInt::get(indexType, 0), sizePtr, block);
- new StoreInst(SetterFor<ui64>(count, context, block), statePtr, block);
- result->addIncoming(getres.first, block);
- BranchInst::Create(done, block);
- block = done;
- ICodegeneratorInlineWideNode::TGettersList getters(getres.second.size());
- getters.back() = [sizePtr, indexType, valueType, getSize = getres.second.back()](const TCodegenContext& ctx, BasicBlock*& block) {
- auto& context = ctx.Codegen.GetContext();
- const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
- const auto calc = BasicBlock::Create(context, "calc", ctx.Func);
- const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
- const auto height = PHINode::Create(valueType, 2U, "state", exit);
- const auto count = new LoadInst(indexType, sizePtr, "count", block);
- const auto work = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, count, ConstantInt::get(indexType, 0), "work", block);
- BranchInst::Create(calc, pass, work, block);
- block = calc;
- const auto makeCountFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&MakeBlockCount));
- const auto makeCountType = FunctionType::get(valueType, {ctx.GetFactory()->getType(), indexType}, false);
- const auto makeCountPtr = CastInst::Create(Instruction::IntToPtr, makeCountFunc, PointerType::getUnqual(makeCountType), "make_count_func", block);
- const auto slice = CallInst::Create(makeCountType, makeCountPtr, {ctx.GetFactory(), count}, "slice", block);
- height->addIncoming(slice, block);
- BranchInst::Create(exit, block);
- block = pass;
- const auto size = getSize(ctx, block);
- height->addIncoming(size, block);
- BranchInst::Create(exit, block);
- block = exit;
- return height;
- };
- for (auto idx = 0U; idx < Width; ++idx) {
- getters[idx] = [offsetPtr, indexType, valueType, sliceType, slicePtr, getBlock = getres.second[idx]](const TCodegenContext& ctx, BasicBlock*& block) {
- auto& context = ctx.Codegen.GetContext();
- const auto calc = BasicBlock::Create(context, "calc", ctx.Func);
- const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
- const auto output = PHINode::Create(valueType, 2U, "output", exit);
- const auto offset = new LoadInst(indexType, offsetPtr, "offset", block);
- const auto work = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, offset, ConstantInt::get(indexType, 0), "work", block);
- const auto value = getBlock(ctx, block);
- output->addIncoming(value, block);
- BranchInst::Create(calc, exit, work, block);
- block = calc;
- const auto slice = CallInst::Create(sliceType, slicePtr, {ctx.GetFactory(), value, offset}, "slice", block);
- ValueCleanup(EValueRepresentation::Any, value, ctx, block);
- output->addIncoming(slice, block);
- BranchInst::Create(exit, block);
- block = exit;
- return output;
- };
- }
- return {result, std::move(getters)};
- }
- #endif
- private:
- static NUdf::TUnboxedValuePod SliceBlock(const THolderFactory& holderFactory, NUdf::TUnboxedValuePod block, const uint64_t offset) {
- const auto& datum = TArrowBlock::From(block).GetDatum();
- return datum.is_scalar() ? block : holderFactory.CreateArrowBlock(DeepSlice(datum.array(), offset, datum.array()->length - offset));
- }
- void RegisterDependencies() const final {
- if (const auto flow = FlowDependsOn(Flow)) {
- DependsOn(flow, Count);
- }
- }
- IComputationWideFlowNode* const Flow;
- IComputationNode* const Count;
- const ui32 Width;
- };
- class TWideTakeBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TWideTakeBlocksWrapper> {
- using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTakeBlocksWrapper>;
- public:
- TWideTakeBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, ui32 size)
- : TBaseComputation(mutables, flow, EValueRepresentation::Embedded), Flow(flow), Count(count), Width(size - 1U)
- {}
- EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
- if (state.IsInvalid()) {
- state = Count->GetValue(ctx);
- }
- if (const auto count = state.Get<ui64>()) {
- if (const auto result = Flow->FetchValues(ctx, output); EFetchResult::One == result) {
- if (const auto blockSize = GetBlockCount(*output[Width]); count < blockSize) {
- state = NUdf::TUnboxedValuePod::Zero();
- *output[Width] = MakeBlockCount(ctx.HolderFactory, count);
- for (auto i = 0U; i < Width; ++i)
- if (const auto out = output[i])
- *out = SliceBlock(ctx.HolderFactory, *out, count);
- } else
- state = NUdf::TUnboxedValuePod(ui64(count - blockSize));
- return EFetchResult::One;
- } else {
- return result;
- }
- }
- return EFetchResult::Finish;
- }
- #ifndef MKQL_DISABLE_CODEGEN
- TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto indexType = Type::getInt64Ty(context);
- const auto valueType = Type::getInt128Ty(context);
- const auto atTop = &ctx.Func->getEntryBlock().back();
- const auto sizePtr = new AllocaInst(indexType, 0U, "size_ptr", atTop);
- new StoreInst(ConstantInt::get(indexType, 0), sizePtr, atTop);
- const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideTakeBlocksWrapper::SliceBlock));
- const auto sliceType = FunctionType::get(valueType, {ctx.GetFactory()->getType(), valueType, indexType}, false);
- const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice", atTop);
- const auto name = "GetBlockCount";
- ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&GetBlockCount));
- const auto getCountType = FunctionType::get(indexType, { valueType }, false);
- const auto getCount = ctx.Codegen.GetModule().getOrInsertFunction(name, getCountType);
- const auto init = BasicBlock::Create(context, "init", ctx.Func);
- const auto main = BasicBlock::Create(context, "main", ctx.Func);
- const auto load = new LoadInst(valueType, statePtr, "load", block);
- const auto state = PHINode::Create(valueType, 2U, "state", main);
- state->addIncoming(load, block);
- BranchInst::Create(init, main, IsInvalid(load, block, context), block);
- block = init;
- GetNodeValue(statePtr, Count, ctx, block);
- const auto save = new LoadInst(valueType, statePtr, "save", block);
- state->addIncoming(save, block);
- BranchInst::Create(main, block);
- block = main;
- const auto work = BasicBlock::Create(context, "work", ctx.Func);
- const auto good = BasicBlock::Create(context, "good", ctx.Func);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- const auto resultType = Type::getInt32Ty(context);
- const auto result = PHINode::Create(resultType, 3U, "result", done);
- result->addIncoming(ConstantInt::get(resultType, static_cast<i32>(EFetchResult::Finish)), block);
- const auto count = GetterFor<ui64>(state, context, block);
- const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, count, ConstantInt::get(count->getType(), 0ULL), "plus", block);
- BranchInst::Create(work, done, plus, block);
- block = work;
- const auto getres = GetNodeValues(Flow, ctx, block);
- const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
- result->addIncoming(getres.first, block);
- BranchInst::Create(done, good, special, block);
- block = good;
- const auto countValue = getres.second.back()(ctx, block);
- const auto height = CallInst::Create(getCount, { countValue }, "height", block);
- ValueCleanup(EValueRepresentation::Any, countValue, ctx, block);
- const auto part = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, count, height, "part", block);
- const auto decr = BinaryOperator::CreateSub(count, height, "decr", block);
- const auto next = SelectInst::Create(part, ConstantInt::get(indexType, 0), decr, "next", block);
- const auto size = SelectInst::Create(part, count, ConstantInt::get(indexType, 0), "size", block);
- new StoreInst(SetterFor<ui64>(next, context, block), statePtr, block);
- new StoreInst(size, sizePtr, block);
- result->addIncoming(getres.first, block);
- BranchInst::Create(done, block);
- block = done;
- ICodegeneratorInlineWideNode::TGettersList getters(getres.second.size());
- getters.back() = [sizePtr, indexType, valueType, getSize = getres.second.back()](const TCodegenContext& ctx, BasicBlock*& block) {
- auto& context = ctx.Codegen.GetContext();
- const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
- const auto calc = BasicBlock::Create(context, "calc", ctx.Func);
- const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
- const auto height = PHINode::Create(valueType, 2U, "state", exit);
- const auto count = new LoadInst(indexType, sizePtr, "count", block);
- const auto work = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, count, ConstantInt::get(indexType, 0), "work", block);
- BranchInst::Create(calc, pass, work, block);
- block = calc;
- const auto makeCountFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&MakeBlockCount));
- const auto makeCountType = FunctionType::get(valueType, {ctx.GetFactory()->getType(), indexType}, false);
- const auto makeCountPtr = CastInst::Create(Instruction::IntToPtr, makeCountFunc, PointerType::getUnqual(makeCountType), "make_count_func", block);
- const auto slice = CallInst::Create(makeCountType, makeCountPtr, {ctx.GetFactory(), count}, "slice", block);
- height->addIncoming(slice, block);
- BranchInst::Create(exit, block);
- block = pass;
- const auto size = getSize(ctx, block);
- height->addIncoming(size, block);
- BranchInst::Create(exit, block);
- block = exit;
- return height;
- };
- for (auto idx = 0U; idx < Width; ++idx) {
- getters[idx] = [sizePtr, indexType, valueType, sliceType, slicePtr, getBlock = getres.second[idx]](const TCodegenContext& ctx, BasicBlock*& block) {
- auto& context = ctx.Codegen.GetContext();
- const auto calc = BasicBlock::Create(context, "calc", ctx.Func);
- const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
- const auto output = PHINode::Create(valueType, 2U, "output", exit);
- const auto size = new LoadInst(indexType, sizePtr, "size", block);
- const auto work = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(indexType, 0), "work", block);
- const auto value = getBlock(ctx, block);
- output->addIncoming(value, block);
- BranchInst::Create(calc, exit, work, block);
- block = calc;
- const auto slice = CallInst::Create(sliceType, slicePtr, {ctx.GetFactory(), value, size}, "slice", block);
- ValueCleanup(EValueRepresentation::Any, value, ctx, block);
- output->addIncoming(slice, block);
- BranchInst::Create(exit, block);
- block = exit;
- return output;
- };
- }
- return {result, std::move(getters)};
- }
- #endif
- private:
- static NUdf::TUnboxedValuePod SliceBlock(const THolderFactory& holderFactory, NUdf::TUnboxedValuePod block, const uint64_t offset) {
- const auto& datum = TArrowBlock::From(block).GetDatum();
- return datum.is_scalar() ? block : holderFactory.CreateArrowBlock(DeepSlice(datum.array(), 0ULL, offset));
- }
- void RegisterDependencies() const final {
- if (const auto flow = FlowDependsOn(Flow)) {
- DependsOn(flow, Count);
- }
- }
- IComputationWideFlowNode* const Flow;
- IComputationNode* const Count;
- const ui32 Width;
- };
- IComputationNode* WrapSkipTake(bool skip, TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
- const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
- const auto flowWidth = GetWideComponentsCount(flowType);
- MKQL_ENSURE(flowWidth > 0, "Expected at least one column");
- auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
- MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
- const auto count = LocateNode(ctx.NodeLocator, callable, 1);
- const auto countType = AS_TYPE(TDataType, callable.GetInput(1).GetStaticType());
- MKQL_ENSURE(countType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
- if (skip) {
- return new TWideSkipBlocksWrapper(ctx.Mutables, wideFlow, count, flowWidth);
- }
- return new TWideTakeBlocksWrapper(ctx.Mutables, wideFlow, count, flowWidth);
- }
- } //namespace
- IComputationNode* WrapWideSkipBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- bool skip = true;
- return WrapSkipTake(skip, callable, ctx);
- }
- IComputationNode* WrapWideTakeBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- bool skip = false;
- return WrapSkipTake(skip, callable, ctx);
- }
- }
- }
|