#pragma once #include "mkql_computation_node_list.h" #include "mkql_spiller_factory.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace NKikimr { namespace NMiniKQL { inline const TDefaultListRepresentation* GetDefaultListRepresentation(const NUdf::TUnboxedValuePod& value) { return reinterpret_cast(NUdf::TBoxedValueAccessor::GetListRepresentation(*value.AsBoxed())); } enum class EGraphPerProcess { Multi, Single }; struct TComputationOpts { TComputationOpts(IStatsRegistry* stats) : Stats(stats) {} IStatsRegistry *const Stats; }; struct TComputationOptsFull: public TComputationOpts { TComputationOptsFull(IStatsRegistry* stats, TAllocState& allocState, const TTypeEnvironment& typeEnv, IRandomProvider& randomProvider, ITimeProvider& timeProvider, NUdf::EValidatePolicy validatePolicy, const NUdf::ISecureParamsProvider* secureParamsProvider, NUdf::ICountersProvider* countersProvider) : TComputationOpts(stats) , AllocState(allocState) , TypeEnv(typeEnv) , RandomProvider(randomProvider) , TimeProvider(timeProvider) , ValidatePolicy(validatePolicy) , SecureParamsProvider(secureParamsProvider) , CountersProvider(countersProvider) {} TAllocState& AllocState; const TTypeEnvironment& TypeEnv; IRandomProvider& RandomProvider; ITimeProvider& TimeProvider; NUdf::EValidatePolicy ValidatePolicy; const NUdf::ISecureParamsProvider *const SecureParamsProvider; NUdf::ICountersProvider *const CountersProvider; }; struct TWideFieldsInitInfo { ui32 MutablesIndex = 0; ui32 WideFieldsIndex = 0; std::set Used; }; struct TComputationMutables { ui32 CurValueIndex = 0U; std::vector SerializableValues; // Indices of values that need to be saved in IComputationGraph::SaveGraphState() and restored in IComputationGraph::LoadGraphState(). ui32 CurWideFieldsIndex = 0U; std::vector WideFieldInitialize; std::vector CachedValues; // Indices of values that holds temporary cached data and unreachable by dependencies void DeferWideFieldsInit(ui32 count, std::set used) { Y_DEBUG_ABORT_UNLESS(AllOf(used, [count](ui32 i) { return i < count; })); WideFieldInitialize.push_back({CurValueIndex, CurWideFieldsIndex, std::move(used)}); CurValueIndex += count; CurWideFieldsIndex += count; } ui32 IncrementWideFieldsIndex(ui32 addend) { auto cur = CurWideFieldsIndex; CurWideFieldsIndex += addend; return cur; } }; class THolderFactory; // Do not reorder: used in LLVM! struct TComputationContextLLVM { const THolderFactory& HolderFactory; IStatsRegistry *const Stats; const std::unique_ptr MutableValues; const NUdf::IValueBuilder *const Builder; float UsageAdjustor = 1.f; ui32 RssCounter = 0U; const NUdf::TSourcePosition* CalleePosition = nullptr; }; struct TComputationContext : public TComputationContextLLVM { IRandomProvider& RandomProvider; ITimeProvider& TimeProvider; bool ExecuteLLVM = false; arrow::MemoryPool& ArrowMemoryPool; std::vector WideFields; const TTypeEnvironment& TypeEnv; const TComputationMutables Mutables; std::shared_ptr SpillerFactory; const NUdf::ITypeInfoHelper::TPtr TypeInfoHelper; NUdf::ICountersProvider *const CountersProvider; const NUdf::ISecureParamsProvider *const SecureParamsProvider; TComputationContext(const THolderFactory& holderFactory, const NUdf::IValueBuilder* builder, const TComputationOptsFull& opts, const TComputationMutables& mutables, arrow::MemoryPool& arrowMemoryPool); ~TComputationContext(); // Returns true if current usage delta exceeds the memory limit // The function automatically adjusts memory limit taking into account RSS delta between calls template inline bool CheckAdjustedMemLimit(ui64 memLimit, ui64 initMemUsage); void UpdateUsageAdjustor(ui64 memLimit); private: ui64 InitRss = 0ULL; ui64 LastRss = 0ULL; #ifndef NDEBUG TInstant LastPrintUsage; #endif }; class IArrowKernelComputationNode; class IComputationNode { public: typedef TIntrusivePtr TPtr; typedef std::map TIndexesMap; virtual ~IComputationNode() {} virtual void InitNode(TComputationContext&) const = 0; virtual NUdf::TUnboxedValue GetValue(TComputationContext& compCtx) const = 0; virtual IComputationNode* AddDependence(const IComputationNode* node) = 0; virtual const IComputationNode* GetSource() const = 0; virtual void RegisterDependencies() const = 0; virtual ui32 GetIndex() const = 0; virtual void CollectDependentIndexes(const IComputationNode* owner, TIndexesMap& dependencies) const = 0; virtual ui32 GetDependencyWeight() const = 0; virtual ui32 GetDependencesCount() const = 0; virtual bool IsTemporaryValue() const = 0; virtual EValueRepresentation GetRepresentation() const = 0; virtual void PrepareStageOne() = 0; virtual void PrepareStageTwo() = 0; virtual TString DebugString() const = 0; virtual void Ref() = 0; virtual void UnRef() = 0; virtual ui32 RefCount() const = 0; virtual std::unique_ptr PrepareArrowKernelComputationNode(TComputationContext& ctx) const; }; class IComputationExternalNode : public IComputationNode { public: virtual NUdf::TUnboxedValue& RefValue(TComputationContext& compCtx) const = 0; virtual void SetValue(TComputationContext& compCtx, NUdf::TUnboxedValue&& newValue) const = 0; virtual void SetOwner(const IComputationNode* node) = 0; using TGetter = std::function; virtual void SetGetter(TGetter&& getter) = 0; virtual void InvalidateValue(TComputationContext& compCtx) const = 0; }; enum class EFetchResult : i32 { Finish = -1, Yield = 0, One = 1 }; class IComputationWideFlowNode : public IComputationNode { public: virtual EFetchResult FetchValues(TComputationContext& compCtx, NUdf::TUnboxedValue*const* values) const = 0; }; class IComputationWideFlowProxyNode : public IComputationWideFlowNode { public: using TFetcher = std::function; virtual void SetFetcher(TFetcher&& fetcher) = 0; virtual void SetOwner(const IComputationNode* node) = 0; virtual void InvalidateValue(TComputationContext& compCtx) const = 0; }; using TDatumProvider = std::function; TDatumProvider MakeDatumProvider(const arrow::Datum& datum); TDatumProvider MakeDatumProvider(const IComputationNode* node, TComputationContext& ctx); class IArrowKernelComputationNode { public: virtual ~IArrowKernelComputationNode() = default; virtual TStringBuf GetKernelName() const = 0; virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0; virtual const std::vector& GetArgsDesc() const = 0; virtual const IComputationNode* GetArgument(ui32 index) const = 0; }; struct TArrowKernelsTopologyItem { std::vector Inputs; std::unique_ptr Node; }; struct TArrowKernelsTopology { ui32 InputArgsCount = 0; std::vector Items; }; using TComputationNodePtrVector = std::vector>; using TComputationWideFlowNodePtrVector = std::vector>; using TComputationExternalNodePtrVector = std::vector>; using TConstComputationNodePtrVector = std::vector>; using TComputationNodePtrDeque = std::deque>; using TComputationNodeOnNodeMap = std::unordered_map, std::equal_to, TMKQLAllocator>>; class IComputationGraph { public: virtual ~IComputationGraph() {} virtual void Prepare() = 0; virtual NUdf::TUnboxedValue GetValue() = 0; virtual TComputationContext& GetContext() = 0; virtual IComputationExternalNode* GetEntryPoint(size_t index, bool require) = 0; virtual const TArrowKernelsTopology* GetKernelsTopology() = 0; virtual const TComputationNodePtrDeque& GetNodes() const = 0; virtual void Invalidate() = 0; // Invalidate all mutable values in graph (may lead to udf recreation) virtual void InvalidateCaches() = 0; // Invalidate only cached values virtual TMemoryUsageInfo& GetMemInfo() const = 0; virtual const THolderFactory& GetHolderFactory() const = 0; virtual ITerminator* GetTerminator() const = 0; virtual bool SetExecuteLLVM(bool value) = 0; virtual TString SaveGraphState() = 0; virtual void LoadGraphState(TStringBuf state) = 0; }; class TNodeFactory; typedef std::function TNodeLocator; typedef std::function TNodePushBack; struct TComputationNodeFactoryContext { TNodeLocator NodeLocator; const IFunctionRegistry& FunctionRegistry; const TTypeEnvironment& Env; NUdf::ITypeInfoHelper::TPtr TypeInfoHelper; NUdf::ICountersProvider* CountersProvider; const NUdf::ISecureParamsProvider* SecureParamsProvider; const TNodeFactory& NodeFactory; const THolderFactory& HolderFactory; const NUdf::IValueBuilder *const Builder; NUdf::EValidateMode ValidateMode; NUdf::EValidatePolicy ValidatePolicy; EGraphPerProcess GraphPerProcess; TComputationMutables& Mutables; TComputationNodeOnNodeMap& ElementsCache; const TNodePushBack NodePushBack; TComputationNodeFactoryContext( const TNodeLocator& nodeLocator, const IFunctionRegistry& functionRegistry, const TTypeEnvironment& env, NUdf::ITypeInfoHelper::TPtr typeInfoHelper, NUdf::ICountersProvider* countersProvider, const NUdf::ISecureParamsProvider* secureParamsProvider, const TNodeFactory& nodeFactory, const THolderFactory& holderFactory, const NUdf::IValueBuilder* builder, NUdf::EValidateMode validateMode, NUdf::EValidatePolicy validatePolicy, EGraphPerProcess graphPerProcess, TComputationMutables& mutables, TComputationNodeOnNodeMap& elementsCache, TNodePushBack&& nodePushBack ) : NodeLocator(nodeLocator) , FunctionRegistry(functionRegistry) , Env(env) , TypeInfoHelper(typeInfoHelper) , CountersProvider(countersProvider) , SecureParamsProvider(secureParamsProvider) , NodeFactory(nodeFactory) , HolderFactory(holderFactory) , Builder(builder) , ValidateMode(validateMode) , ValidatePolicy(validatePolicy) , GraphPerProcess(graphPerProcess) , Mutables(mutables) , ElementsCache(elementsCache) , NodePushBack(std::move(nodePushBack)) {} }; using TComputationNodeFactory = std::function; using TStreamEmitter = std::function; struct TPatternCacheEntry; struct TComputationPatternOpts { TComputationPatternOpts(TAllocState& allocState, const TTypeEnvironment& env) : AllocState(allocState) , Env(env) {} TComputationPatternOpts( TAllocState& allocState, const TTypeEnvironment& env, TComputationNodeFactory factory, const IFunctionRegistry* functionRegistry, NUdf::EValidateMode validateMode, NUdf::EValidatePolicy validatePolicy, const TString& optLLVM, EGraphPerProcess graphPerProcess, IStatsRegistry* stats = nullptr, NUdf::ICountersProvider* countersProvider = nullptr, const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) : AllocState(allocState) , Env(env) , Factory(factory) , FunctionRegistry(functionRegistry) , ValidateMode(validateMode) , ValidatePolicy(validatePolicy) , OptLLVM(optLLVM) , GraphPerProcess(graphPerProcess) , Stats(stats) , CountersProvider(countersProvider) , SecureParamsProvider(secureParamsProvider) {} void SetOptions(TComputationNodeFactory factory, const IFunctionRegistry* functionRegistry, NUdf::EValidateMode validateMode, NUdf::EValidatePolicy validatePolicy, const TString& optLLVM, EGraphPerProcess graphPerProcess, IStatsRegistry* stats = nullptr, NUdf::ICountersProvider* counters = nullptr, const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) { Factory = factory; FunctionRegistry = functionRegistry; ValidateMode = validateMode; ValidatePolicy = validatePolicy; OptLLVM = optLLVM; GraphPerProcess = graphPerProcess; Stats = stats; CountersProvider = counters; SecureParamsProvider = secureParamsProvider; } void SetPatternEnv(std::shared_ptr cacheEnv) { PatternEnv = std::move(cacheEnv); } mutable std::shared_ptr PatternEnv; TAllocState& AllocState; const TTypeEnvironment& Env; TComputationNodeFactory Factory; const IFunctionRegistry* FunctionRegistry = nullptr; NUdf::EValidateMode ValidateMode = NUdf::EValidateMode::None; NUdf::EValidatePolicy ValidatePolicy = NUdf::EValidatePolicy::Fail; TString OptLLVM; EGraphPerProcess GraphPerProcess = EGraphPerProcess::Multi; IStatsRegistry* Stats = nullptr; NUdf::ICountersProvider* CountersProvider = nullptr; const NUdf::ISecureParamsProvider* SecureParamsProvider = nullptr; TComputationOptsFull ToComputationOptions(IRandomProvider& randomProvider, ITimeProvider& timeProvider, TAllocState* allocStatePtr = nullptr) const { return TComputationOptsFull(Stats, allocStatePtr ? *allocStatePtr : AllocState, Env, randomProvider, timeProvider, ValidatePolicy, SecureParamsProvider, CountersProvider); } }; class IComputationPattern: public TAtomicRefCount { public: typedef TIntrusivePtr TPtr; virtual ~IComputationPattern() = default; virtual void Compile(TString optLLVM, IStatsRegistry* stats) = 0; virtual bool IsCompiled() const = 0; virtual size_t CompiledCodeSize() const = 0; virtual void RemoveCompiledCode() = 0; virtual THolder Clone(const TComputationOptsFull& compOpts) = 0; virtual bool GetSuitableForCache() const = 0; }; // node cookie's will be clean up when graph will be destroyed, explorer must not be changed/destroyed until that time IComputationPattern::TPtr MakeComputationPattern( TExploringNodeVisitor& explorer, const TRuntimeNode& root, const std::vector& entryPoints, const TComputationPatternOpts& opts); std::unique_ptr MakeSimpleSecureParamsProvider(const THashMap& secureParams); using TCallableComputationNodeBuilder = std::function; template TCallableComputationNodeBuilder WrapComputationBuilder(IComputationNode* (*f)(const TComputationNodeFactoryContext&, Ts...)){ return [f](TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == sizeof...(Ts), "Incorrect number of inputs"); return CallComputationBuilderWithArgs(f, callable, ctx, std::make_index_sequence()); }; } template auto CallComputationBuilderWithArgs(F* f, TCallable& callable, const TComputationNodeFactoryContext& ctx, const std::integer_sequence &) { return f(ctx, callable.GetInput(Is)...); } } // namespace NMiniKQL } // namespace NKikimr