#include "mkql_tooptional.h" #include // Y_IGNORE #include namespace NKikimr { namespace NMiniKQL { namespace { template class THeadWrapper : public TMutableCodegeneratorPtrNode> { typedef TMutableCodegeneratorPtrNode> TBaseComputation; public: THeadWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* list) : TBaseComputation(mutables, kind), List(list) {} NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { const auto& value = List->GetValue(ctx); if (const auto ptr = value.GetElements()) { if (value.GetListLength() > 0ULL) { return NUdf::TUnboxedValuePod(*ptr).MakeOptionalIf(); } } else if (const auto iter = value.GetListIterator()) { NUdf::TUnboxedValue result; if (iter.Next(result)) { return result.Release().MakeOptionalIf(); } } return NUdf::TUnboxedValue(); } #ifndef MKQL_DISABLE_CODEGEN void DoGenerateGetValue(const TCodegenContext& ctx, Value* result, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto valueType = Type::getInt128Ty(context); const auto ptrType = PointerType::getUnqual(valueType); const auto list = GetNodeValue(List, ctx, block); const auto elements = CallBoxedValueVirtualMethod(ptrType, list, ctx.Codegen, block); const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block); const auto fast = BasicBlock::Create(context, "fast", ctx.Func); const auto slow = BasicBlock::Create(context, "slow", ctx.Func); const auto many = BasicBlock::Create(context, "many", ctx.Func); const auto none = BasicBlock::Create(context, "none", ctx.Func); const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto good = IsOptional ? BasicBlock::Create(context, "good", ctx.Func) : done; BranchInst::Create(slow, fast, null, block); { block = fast; const auto size = CallBoxedValueVirtualMethod(Type::getInt64Ty(context), list, ctx.Codegen, block); const auto test = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(size->getType(), 0ULL), "test", block); BranchInst::Create(many, none, test, block); block = many; const auto item = new LoadInst(valueType, elements, "item", block); ValueAddRef(this->GetRepresentation(), item, ctx, block); new StoreInst(IsOptional ? MakeOptional(context, item, block) : item, result, block); BranchInst::Create(done, block); } { block = slow; CallBoxedValueVirtualMethod(result, list, ctx.Codegen, block); const auto iter = new LoadInst(valueType, result, "iter", block); new StoreInst(ConstantInt::get(valueType, 0ULL), result, block); const auto status = CallBoxedValueVirtualMethod(Type::getInt1Ty(context), iter, ctx.Codegen, block, result); UnRefBoxed(iter, ctx, block); BranchInst::Create(good, none, status, block); if constexpr (IsOptional) { block = good; const auto item = new LoadInst(valueType, result, "item", block); new StoreInst(MakeOptional(context, item, block), result, block); BranchInst::Create(done, block); } } block = none; new StoreInst(ConstantInt::get(valueType, 0ULL), result, block); BranchInst::Create(done, block); block = done; if (List->IsTemporaryValue()) CleanupBoxed(list, ctx, block); } #endif private: void RegisterDependencies() const final { this->DependsOn(List); } IComputationNode* const List; }; template class TLastWrapper : public TMutableCodegeneratorPtrNode> { typedef TMutableCodegeneratorPtrNode> TBaseComputation; public: TLastWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* list) : TBaseComputation(mutables, kind), List(list) {} NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { const auto& value = List->GetValue(ctx); if (const auto ptr = value.GetElements()) { if (const auto size = value.GetListLength()) { return NUdf::TUnboxedValuePod(ptr[size - 1U]).MakeOptionalIf(); } } else if (const auto iter = value.GetListIterator()) { NUdf::TUnboxedValue result; if (iter.Next(result)) { while (iter.Next(result)) continue; return result.Release().MakeOptionalIf(); } } return NUdf::TUnboxedValue(); } #ifndef MKQL_DISABLE_CODEGEN void DoGenerateGetValue(const TCodegenContext& ctx, Value* result, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto valueType = Type::getInt128Ty(context); const auto ptrType = PointerType::getUnqual(valueType); const auto list = GetNodeValue(List, ctx, block); const auto elements = CallBoxedValueVirtualMethod(ptrType, list, ctx.Codegen, block); const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block); const auto fast = BasicBlock::Create(context, "fast", ctx.Func); const auto slow = BasicBlock::Create(context, "slow", ctx.Func); const auto nope = BasicBlock::Create(context, "nope", ctx.Func); const auto loop = BasicBlock::Create(context, "loop", ctx.Func); const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto many = BasicBlock::Create(context, "many", ctx.Func); const auto none = BasicBlock::Create(context, "none", ctx.Func); const auto done = BasicBlock::Create(context, "done", ctx.Func); BranchInst::Create(slow, fast, null, block); { block = fast; const auto size = CallBoxedValueVirtualMethod(Type::getInt64Ty(context), list, ctx.Codegen, block); const auto test = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(size->getType(), 0ULL), "test", block); BranchInst::Create(many, none, test, block); block = many; const auto index = BinaryOperator::CreateSub(size, ConstantInt::get(size->getType(), 1), "index", block); const auto last = GetElementPtrInst::CreateInBounds(valueType, elements, {index}, "last", block); const auto item = new LoadInst(valueType, last, "item", block); ValueAddRef(this->GetRepresentation(), item, ctx, block); new StoreInst(IsOptional ? MakeOptional(context, item, block) : item, result, block); BranchInst::Create(done, block); } { block = slow; CallBoxedValueVirtualMethod(result, list, ctx.Codegen, block); const auto iter = new LoadInst(valueType, result, "iter", block); new StoreInst(ConstantInt::get(valueType, 0ULL), result, block); const auto first = CallBoxedValueVirtualMethod(Type::getInt1Ty(context), iter, ctx.Codegen, block, result); BranchInst::Create(loop, nope, first, block); block = nope; UnRefBoxed(iter, ctx, block); BranchInst::Create(none, block); block = loop; const auto status = CallBoxedValueVirtualMethod(Type::getInt1Ty(context), iter, ctx.Codegen, block, result); BranchInst::Create(loop, good, status, block); block = good; UnRefBoxed(iter, ctx, block); if constexpr (IsOptional) { const auto item = new LoadInst(valueType, result, "item", block); new StoreInst(MakeOptional(context, item, block), result, block); } BranchInst::Create(done, block); } block = none; new StoreInst(ConstantInt::get(valueType, 0ULL), result, block); BranchInst::Create(done, block); block = done; if (List->IsTemporaryValue()) CleanupBoxed(list, ctx, block); } #endif private: void RegisterDependencies() const final { this->DependsOn(List); } IComputationNode* const List; }; } IComputationNode* WrapHead(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args"); if (AS_TYPE(TOptionalType, callable.GetType()->GetReturnType())->IsOptional()) { return new THeadWrapper(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), LocateNode(ctx.NodeLocator, callable, 0)); } else { return new THeadWrapper(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), LocateNode(ctx.NodeLocator, callable, 0)); } } IComputationNode* WrapLast(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args"); if (AS_TYPE(TOptionalType, callable.GetType()->GetReturnType())->IsOptional()) { return new TLastWrapper(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), LocateNode(ctx.NodeLocator, callable, 0)); } else { return new TLastWrapper(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), LocateNode(ctx.NodeLocator, callable, 0)); } } } }