#include "mkql_element.h" #include // Y_IGNORE #include #include namespace NKikimr { namespace NMiniKQL { namespace { template class TElementsWrapper : public TMutableCodegeneratorNode> { typedef TMutableCodegeneratorNode> TBaseComputation; public: TElementsWrapper(TComputationMutables& mutables, IComputationNode* array) : TBaseComputation(mutables, EValueRepresentation::Embedded), Array(array) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { const auto& array = Array->GetValue(compCtx); if constexpr (IsOptional) { return array ? NUdf::TUnboxedValuePod(reinterpret_cast(array.GetElements())) : NUdf::TUnboxedValuePod(); } else { return NUdf::TUnboxedValuePod(reinterpret_cast(array.GetElements())); } } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto array = GetNodeValue(Array, ctx, block); const auto elementsType = PointerType::getUnqual(array->getType()); if constexpr (IsOptional) { const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto result = PHINode::Create(array->getType(), 2U, "result", done); result->addIncoming(ConstantInt::get(array->getType(), 0ULL), block); BranchInst::Create(done, good, IsEmpty(array, block, context), block); block = good; const auto elements = CallBoxedValueVirtualMethod(elementsType, array, ctx.Codegen, block); const auto cast = CastInst::Create(Instruction::PtrToInt, elements, Type::getInt64Ty(context), "cast", block); const auto wide = SetterFor(cast, context, block); result->addIncoming(wide, block); BranchInst::Create(done, block); block = done; return result; } else { const auto elements = CallBoxedValueVirtualMethod(elementsType, array, ctx.Codegen, block); const auto cast = CastInst::Create(Instruction::PtrToInt, elements, Type::getInt64Ty(context), "cast", block); return SetterFor(cast, context, block); } } #endif private: void RegisterDependencies() const final { this->DependsOn(Array); } IComputationNode* const Array; }; template class TElementWrapper : public TMutableCodegeneratorPtrNode> { typedef TMutableCodegeneratorPtrNode> TBaseComputation; public: TElementWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* cache, IComputationNode* array, ui32 index) : TBaseComputation(mutables, kind), Cache(cache), Array(array), Index(index) {} NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { if (Cache->GetDependencesCount() > 1U) { const auto cache = Cache->GetValue(ctx); if (IsOptional && !cache) { return NUdf::TUnboxedValue(); } if (const auto elements = cache.Get()) { return reinterpret_cast(elements)[Index]; } } const auto& array = Array->GetValue(ctx); if constexpr (IsOptional) { return array ? array.GetElement(Index) : NUdf::TUnboxedValue(); } else { return array.GetElement(Index); } } #ifndef MKQL_DISABLE_CODEGEN void DoGenerateGetElement(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto array = GetNodeValue(Array, ctx, block); const auto index = ConstantInt::get(Type::getInt32Ty(context), Index); if constexpr (IsOptional) { const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto zero = BasicBlock::Create(context, "zero", ctx.Func); const auto exit = BasicBlock::Create(context, "exit", ctx.Func); BranchInst::Create(zero, good, IsEmpty(array, block, context), block); block = zero; new StoreInst(ConstantInt::get(array->getType(), 0ULL), pointer, block); BranchInst::Create(exit, block); block = good; CallBoxedValueVirtualMethod(pointer, array, ctx.Codegen, block, index); if (Array->IsTemporaryValue()) CleanupBoxed(array, ctx, block); BranchInst::Create(exit, block); block = exit; } else { CallBoxedValueVirtualMethod(pointer, array, ctx.Codegen, block, index); if (Array->IsTemporaryValue()) CleanupBoxed(array, ctx, block); } } void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const { if (Cache->GetDependencesCount() <= 1U) { return DoGenerateGetElement(ctx, pointer, block); } auto& context = ctx.Codegen.GetContext(); const auto cache = GetNodeValue(Cache, ctx, block); const auto fast = BasicBlock::Create(context, "fast", ctx.Func); const auto slow = BasicBlock::Create(context, "slow", ctx.Func); const auto done = BasicBlock::Create(context, "done", ctx.Func); if constexpr (IsOptional) { const auto zero = ConstantInt::get(cache->getType(), 0ULL); const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cache, zero, "check", block); const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto none = BasicBlock::Create(context, "none", ctx.Func); BranchInst::Create(none, good, check, block); block = none; new StoreInst(zero, pointer, block); BranchInst::Create(done, block); block = good; } const auto trunc = CastInst::Create(Instruction::Trunc, cache, Type::getInt64Ty(context), "trunc", block); const auto type = PointerType::getUnqual(cache->getType()); const auto elements = CastInst::Create(Instruction::IntToPtr, trunc, type, "elements", block); const auto fill = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, elements, ConstantPointerNull::get(type), "fill", block); BranchInst::Create(fast, slow, fill, block); block = fast; const auto index = ConstantInt::get(Type::getInt32Ty(context), this->Index); const auto ptr = GetElementPtrInst::CreateInBounds(cache->getType(), elements, {index}, "ptr", block); const auto item = new LoadInst(cache->getType(), ptr, "item", block); ValueAddRef(this->GetRepresentation(), item, ctx, block); new StoreInst(item, pointer, block); BranchInst::Create(done, block); block = slow; DoGenerateGetElement(ctx, pointer, block); BranchInst::Create(done, block); block = done; } #endif private: void RegisterDependencies() const final { this->DependsOn(Array); this->DependsOn(Cache); } IComputationNode *const Cache; IComputationNode *const Array; const ui32 Index; }; IComputationNode* WrapElements(IComputationNode* array, const TComputationNodeFactoryContext& ctx, bool isOptional) { if (isOptional) { return new TElementsWrapper(ctx.Mutables, array); } else { return new TElementsWrapper(ctx.Mutables, array); } } } IComputationNode* WrapNth(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 2U, "Expected two args."); const auto input = callable.GetInput(0U); bool isOptional; const auto tupleType = AS_TYPE(TTupleType, UnpackOptional(input.GetStaticType(), isOptional)); const auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1U)); const auto index = indexData->AsValue().Get(); MKQL_ENSURE(index < tupleType->GetElementsCount(), "Bad tuple index"); const auto tuple = LocateNode(ctx.NodeLocator, callable, 0); const auto ins = ctx.ElementsCache.emplace(tuple, nullptr); if (ins.second) { ctx.NodePushBack(ins.first->second = WrapElements(tuple, ctx, isOptional)); } if (isOptional) { return new TElementWrapper(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index); } else { return new TElementWrapper(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index); } } IComputationNode* WrapMember(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 2U, "Expected two args."); const auto input = callable.GetInput(0U); bool isOptional; const auto structType = AS_TYPE(TStructType, UnpackOptional(input.GetStaticType(), isOptional)); const auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1U)); const auto index = indexData->AsValue().Get(); MKQL_ENSURE(index < structType->GetMembersCount(), "Bad member index"); const auto structObj = LocateNode(ctx.NodeLocator, callable, 0U); const auto ins = ctx.ElementsCache.emplace(structObj, nullptr); if (ins.second) { ctx.NodePushBack(ins.first->second = WrapElements(structObj, ctx, isOptional)); } if (isOptional) { return new TElementWrapper(ctx.Mutables, GetValueRepresentation(structType->GetMemberType(index)), ins.first->second, structObj, index); } else { return new TElementWrapper(ctx.Mutables, GetValueRepresentation(structType->GetMemberType(index)), ins.first->second, structObj, index); } } } }