123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- #pragma once
- #include "mkql_computation_node_list.h"
- #include "mkql_spiller_factory.h"
- #include <yql/essentials/minikql/defs.h>
- #include <yql/essentials/minikql/mkql_node.h>
- #include <yql/essentials/minikql/mkql_node_visitor.h>
- #include <yql/essentials/minikql/mkql_function_registry.h>
- #include <yql/essentials/minikql/mkql_alloc.h>
- #include <yql/essentials/minikql/mkql_stats_registry.h>
- #include <yql/essentials/minikql/mkql_terminator.h>
- #include <yql/essentials/public/udf/udf_value.h>
- #include <yql/essentials/public/udf/udf_validate.h>
- #include <yql/essentials/public/udf/udf_value_builder.h>
- #include <library/cpp/cache/cache.h>
- #include <library/cpp/random_provider/random_provider.h>
- #include <library/cpp/time_provider/time_provider.h>
- #include <map>
- #include <unordered_set>
- #include <unordered_map>
- #include <vector>
- namespace NKikimr {
- namespace NMiniKQL {
- inline const TDefaultListRepresentation* GetDefaultListRepresentation(const NUdf::TUnboxedValuePod& value) {
- return reinterpret_cast<const TDefaultListRepresentation*>(NUdf::TBoxedValueAccessor::GetListRepresentation(*value.AsBoxed()));
- }
- enum class EGraphPerProcess {
- Multi,
- Single
- };
- struct TComputationOpts {
- TComputationOpts(IStatsRegistry* stats)
- : Stats(stats)
- {}
- IStatsRegistry *const Stats;
- };
- struct TComputationOptsFull: public TComputationOpts {
- TComputationOptsFull(IStatsRegistry* stats, TAllocState& allocState, const TTypeEnvironment& typeEnv, IRandomProvider& randomProvider,
- ITimeProvider& timeProvider, NUdf::EValidatePolicy validatePolicy, const NUdf::ISecureParamsProvider* secureParamsProvider, NUdf::ICountersProvider* countersProvider)
- : TComputationOpts(stats)
- , AllocState(allocState)
- , TypeEnv(typeEnv)
- , RandomProvider(randomProvider)
- , TimeProvider(timeProvider)
- , ValidatePolicy(validatePolicy)
- , SecureParamsProvider(secureParamsProvider)
- , CountersProvider(countersProvider)
- {}
- TAllocState& AllocState;
- const TTypeEnvironment& TypeEnv;
- IRandomProvider& RandomProvider;
- ITimeProvider& TimeProvider;
- NUdf::EValidatePolicy ValidatePolicy;
- const NUdf::ISecureParamsProvider *const SecureParamsProvider;
- NUdf::ICountersProvider *const CountersProvider;
- };
- struct TWideFieldsInitInfo {
- ui32 MutablesIndex = 0;
- ui32 WideFieldsIndex = 0;
- std::set<ui32> Used;
- };
- struct TComputationMutables {
- ui32 CurValueIndex = 0U;
- std::vector<ui32> SerializableValues; // Indices of values that need to be saved in IComputationGraph::SaveGraphState() and restored in IComputationGraph::LoadGraphState().
- ui32 CurWideFieldsIndex = 0U;
- std::vector<TWideFieldsInitInfo> WideFieldInitialize;
- std::vector<ui32> CachedValues; // Indices of values that holds temporary cached data and unreachable by dependencies
- void DeferWideFieldsInit(ui32 count, std::set<ui32> used) {
- Y_DEBUG_ABORT_UNLESS(AllOf(used, [count](ui32 i) { return i < count; }));
- WideFieldInitialize.push_back({CurValueIndex, CurWideFieldsIndex, std::move(used)});
- CurValueIndex += count;
- CurWideFieldsIndex += count;
- }
- ui32 IncrementWideFieldsIndex(ui32 addend) {
- auto cur = CurWideFieldsIndex;
- CurWideFieldsIndex += addend;
- return cur;
- }
- };
- class THolderFactory;
- // Do not reorder: used in LLVM!
- struct TComputationContextLLVM {
- const THolderFactory& HolderFactory;
- IStatsRegistry *const Stats;
- const std::unique_ptr<NUdf::TUnboxedValue[]> MutableValues;
- const NUdf::IValueBuilder *const Builder;
- float UsageAdjustor = 1.f;
- ui32 RssCounter = 0U;
- const NUdf::TSourcePosition* CalleePosition = nullptr;
- };
- struct TComputationContext : public TComputationContextLLVM {
- IRandomProvider& RandomProvider;
- ITimeProvider& TimeProvider;
- bool ExecuteLLVM = false;
- arrow::MemoryPool& ArrowMemoryPool;
- std::vector<NUdf::TUnboxedValue*> WideFields;
- const TTypeEnvironment& TypeEnv;
- const TComputationMutables Mutables;
- std::shared_ptr<ISpillerFactory> SpillerFactory;
- const NUdf::ITypeInfoHelper::TPtr TypeInfoHelper;
- NUdf::ICountersProvider *const CountersProvider;
- const NUdf::ISecureParamsProvider *const SecureParamsProvider;
- TComputationContext(const THolderFactory& holderFactory,
- const NUdf::IValueBuilder* builder,
- const TComputationOptsFull& opts,
- const TComputationMutables& mutables,
- arrow::MemoryPool& arrowMemoryPool);
- ~TComputationContext();
- // Returns true if current usage delta exceeds the memory limit
- // The function automatically adjusts memory limit taking into account RSS delta between calls
- template<bool TrackRss>
- inline bool CheckAdjustedMemLimit(ui64 memLimit, ui64 initMemUsage);
- void UpdateUsageAdjustor(ui64 memLimit);
- private:
- ui64 InitRss = 0ULL;
- ui64 LastRss = 0ULL;
- #ifndef NDEBUG
- TInstant LastPrintUsage;
- #endif
- };
- class IArrowKernelComputationNode;
- class IComputationNode {
- public:
- typedef TIntrusivePtr<IComputationNode> TPtr;
- typedef std::map<ui32, EValueRepresentation> TIndexesMap;
- virtual ~IComputationNode() {}
- virtual void InitNode(TComputationContext&) const = 0;
- virtual NUdf::TUnboxedValue GetValue(TComputationContext& compCtx) const = 0;
- virtual IComputationNode* AddDependence(const IComputationNode* node) = 0;
- virtual const IComputationNode* GetSource() const = 0;
- virtual void RegisterDependencies() const = 0;
- virtual ui32 GetIndex() const = 0;
- virtual void CollectDependentIndexes(const IComputationNode* owner, TIndexesMap& dependencies) const = 0;
- virtual ui32 GetDependencyWeight() const = 0;
- virtual ui32 GetDependencesCount() const = 0;
- virtual bool IsTemporaryValue() const = 0;
- virtual EValueRepresentation GetRepresentation() const = 0;
- virtual void PrepareStageOne() = 0;
- virtual void PrepareStageTwo() = 0;
- virtual TString DebugString() const = 0;
- virtual void Ref() = 0;
- virtual void UnRef() = 0;
- virtual ui32 RefCount() const = 0;
- virtual std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const;
- };
- class IComputationExternalNode : public IComputationNode {
- public:
- virtual NUdf::TUnboxedValue& RefValue(TComputationContext& compCtx) const = 0;
- virtual void SetValue(TComputationContext& compCtx, NUdf::TUnboxedValue&& newValue) const = 0;
- virtual void SetOwner(const IComputationNode* node) = 0;
- using TGetter = std::function<NUdf::TUnboxedValue(TComputationContext&)>;
- virtual void SetGetter(TGetter&& getter) = 0;
- virtual void InvalidateValue(TComputationContext& compCtx) const = 0;
- };
- enum class EFetchResult : i32 {
- Finish = -1,
- Yield = 0,
- One = 1
- };
- class IComputationWideFlowNode : public IComputationNode {
- public:
- virtual EFetchResult FetchValues(TComputationContext& compCtx, NUdf::TUnboxedValue*const* values) const = 0;
- };
- class IComputationWideFlowProxyNode : public IComputationWideFlowNode {
- public:
- using TFetcher = std::function<EFetchResult(TComputationContext&, NUdf::TUnboxedValue*const*)>;
- virtual void SetFetcher(TFetcher&& fetcher) = 0;
- virtual void SetOwner(const IComputationNode* node) = 0;
- virtual void InvalidateValue(TComputationContext& compCtx) const = 0;
- };
- using TDatumProvider = std::function<arrow::Datum()>;
- TDatumProvider MakeDatumProvider(const arrow::Datum& datum);
- TDatumProvider MakeDatumProvider(const IComputationNode* node, TComputationContext& ctx);
- class IArrowKernelComputationNode {
- public:
- virtual ~IArrowKernelComputationNode() = default;
- virtual TStringBuf GetKernelName() const = 0;
- virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0;
- virtual const std::vector<arrow::ValueDescr>& GetArgsDesc() const = 0;
- virtual const IComputationNode* GetArgument(ui32 index) const = 0;
- };
- struct TArrowKernelsTopologyItem {
- std::vector<ui32> Inputs;
- std::unique_ptr<IArrowKernelComputationNode> Node;
- };
- struct TArrowKernelsTopology {
- ui32 InputArgsCount = 0;
- std::vector<TArrowKernelsTopologyItem> Items;
- };
- using TComputationNodePtrVector = std::vector<IComputationNode*, TMKQLAllocator<IComputationNode*>>;
- using TComputationWideFlowNodePtrVector = std::vector<IComputationWideFlowNode*, TMKQLAllocator<IComputationWideFlowNode*>>;
- using TComputationExternalNodePtrVector = std::vector<IComputationExternalNode*, TMKQLAllocator<IComputationExternalNode*>>;
- using TConstComputationNodePtrVector = std::vector<const IComputationNode*, TMKQLAllocator<const IComputationNode*>>;
- using TComputationNodePtrDeque = std::deque<IComputationNode::TPtr, TMKQLAllocator<IComputationNode::TPtr>>;
- using TComputationNodeOnNodeMap = std::unordered_map<const IComputationNode*, IComputationNode*, std::hash<const IComputationNode*>, std::equal_to<const IComputationNode*>, TMKQLAllocator<std::pair<const IComputationNode *const, IComputationNode*>>>;
- class IComputationGraph {
- public:
- virtual ~IComputationGraph() {}
- virtual void Prepare() = 0;
- virtual NUdf::TUnboxedValue GetValue() = 0;
- virtual TComputationContext& GetContext() = 0;
- virtual IComputationExternalNode* GetEntryPoint(size_t index, bool require) = 0;
- virtual const TArrowKernelsTopology* GetKernelsTopology() = 0;
- virtual const TComputationNodePtrDeque& GetNodes() const = 0;
- virtual void Invalidate() = 0; // Invalidate all mutable values in graph (may lead to udf recreation)
- virtual void InvalidateCaches() = 0; // Invalidate only cached values
- virtual TMemoryUsageInfo& GetMemInfo() const = 0;
- virtual const THolderFactory& GetHolderFactory() const = 0;
- virtual ITerminator* GetTerminator() const = 0;
- virtual bool SetExecuteLLVM(bool value) = 0;
- virtual TString SaveGraphState() = 0;
- virtual void LoadGraphState(TStringBuf state) = 0;
- };
- class TNodeFactory;
- typedef std::function<IComputationNode* (TNode* node, bool pop)> TNodeLocator;
- typedef std::function<void (IComputationNode*)> TNodePushBack;
- struct TComputationNodeFactoryContext {
- TNodeLocator NodeLocator;
- const IFunctionRegistry& FunctionRegistry;
- const TTypeEnvironment& Env;
- NUdf::ITypeInfoHelper::TPtr TypeInfoHelper;
- NUdf::ICountersProvider* CountersProvider;
- const NUdf::ISecureParamsProvider* SecureParamsProvider;
- const TNodeFactory& NodeFactory;
- const THolderFactory& HolderFactory;
- const NUdf::IValueBuilder *const Builder;
- NUdf::EValidateMode ValidateMode;
- NUdf::EValidatePolicy ValidatePolicy;
- EGraphPerProcess GraphPerProcess;
- TComputationMutables& Mutables;
- TComputationNodeOnNodeMap& ElementsCache;
- const TNodePushBack NodePushBack;
- TComputationNodeFactoryContext(
- const TNodeLocator& nodeLocator,
- const IFunctionRegistry& functionRegistry,
- const TTypeEnvironment& env,
- NUdf::ITypeInfoHelper::TPtr typeInfoHelper,
- NUdf::ICountersProvider* countersProvider,
- const NUdf::ISecureParamsProvider* secureParamsProvider,
- const TNodeFactory& nodeFactory,
- const THolderFactory& holderFactory,
- const NUdf::IValueBuilder* builder,
- NUdf::EValidateMode validateMode,
- NUdf::EValidatePolicy validatePolicy,
- EGraphPerProcess graphPerProcess,
- TComputationMutables& mutables,
- TComputationNodeOnNodeMap& elementsCache,
- TNodePushBack&& nodePushBack
- )
- : NodeLocator(nodeLocator)
- , FunctionRegistry(functionRegistry)
- , Env(env)
- , TypeInfoHelper(typeInfoHelper)
- , CountersProvider(countersProvider)
- , SecureParamsProvider(secureParamsProvider)
- , NodeFactory(nodeFactory)
- , HolderFactory(holderFactory)
- , Builder(builder)
- , ValidateMode(validateMode)
- , ValidatePolicy(validatePolicy)
- , GraphPerProcess(graphPerProcess)
- , Mutables(mutables)
- , ElementsCache(elementsCache)
- , NodePushBack(std::move(nodePushBack))
- {}
- };
- using TComputationNodeFactory = std::function<IComputationNode* (TCallable&, const TComputationNodeFactoryContext&)>;
- using TStreamEmitter = std::function<void(NUdf::TUnboxedValue&&)>;
- struct TPatternCacheEntry;
- struct TComputationPatternOpts {
- TComputationPatternOpts(TAllocState& allocState, const TTypeEnvironment& env)
- : AllocState(allocState)
- , Env(env)
- {}
- TComputationPatternOpts(
- TAllocState& allocState,
- const TTypeEnvironment& env,
- TComputationNodeFactory factory,
- const IFunctionRegistry* functionRegistry,
- NUdf::EValidateMode validateMode,
- NUdf::EValidatePolicy validatePolicy,
- const TString& optLLVM,
- EGraphPerProcess graphPerProcess,
- IStatsRegistry* stats = nullptr,
- NUdf::ICountersProvider* countersProvider = nullptr,
- const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr)
- : AllocState(allocState)
- , Env(env)
- , Factory(factory)
- , FunctionRegistry(functionRegistry)
- , ValidateMode(validateMode)
- , ValidatePolicy(validatePolicy)
- , OptLLVM(optLLVM)
- , GraphPerProcess(graphPerProcess)
- , Stats(stats)
- , CountersProvider(countersProvider)
- , SecureParamsProvider(secureParamsProvider)
- {}
- void SetOptions(TComputationNodeFactory factory, const IFunctionRegistry* functionRegistry,
- NUdf::EValidateMode validateMode, NUdf::EValidatePolicy validatePolicy,
- const TString& optLLVM, EGraphPerProcess graphPerProcess, IStatsRegistry* stats = nullptr,
- NUdf::ICountersProvider* counters = nullptr,
- const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) {
- Factory = factory;
- FunctionRegistry = functionRegistry;
- ValidateMode = validateMode;
- ValidatePolicy = validatePolicy;
- OptLLVM = optLLVM;
- GraphPerProcess = graphPerProcess;
- Stats = stats;
- CountersProvider = counters;
- SecureParamsProvider = secureParamsProvider;
- }
- void SetPatternEnv(std::shared_ptr<TPatternCacheEntry> cacheEnv) {
- PatternEnv = std::move(cacheEnv);
- }
- mutable std::shared_ptr<TPatternCacheEntry> PatternEnv;
- TAllocState& AllocState;
- const TTypeEnvironment& Env;
- TComputationNodeFactory Factory;
- const IFunctionRegistry* FunctionRegistry = nullptr;
- NUdf::EValidateMode ValidateMode = NUdf::EValidateMode::None;
- NUdf::EValidatePolicy ValidatePolicy = NUdf::EValidatePolicy::Fail;
- TString OptLLVM;
- EGraphPerProcess GraphPerProcess = EGraphPerProcess::Multi;
- IStatsRegistry* Stats = nullptr;
- NUdf::ICountersProvider* CountersProvider = nullptr;
- const NUdf::ISecureParamsProvider* SecureParamsProvider = nullptr;
- TComputationOptsFull ToComputationOptions(IRandomProvider& randomProvider, ITimeProvider& timeProvider, TAllocState* allocStatePtr = nullptr) const {
- return TComputationOptsFull(Stats, allocStatePtr ? *allocStatePtr : AllocState, Env, randomProvider, timeProvider, ValidatePolicy, SecureParamsProvider, CountersProvider);
- }
- };
- class IComputationPattern: public TAtomicRefCount<IComputationPattern> {
- public:
- typedef TIntrusivePtr<IComputationPattern> TPtr;
- virtual ~IComputationPattern() = default;
- virtual void Compile(TString optLLVM, IStatsRegistry* stats) = 0;
- virtual bool IsCompiled() const = 0;
- virtual size_t CompiledCodeSize() const = 0;
- virtual void RemoveCompiledCode() = 0;
- virtual THolder<IComputationGraph> Clone(const TComputationOptsFull& compOpts) = 0;
- virtual bool GetSuitableForCache() const = 0;
- };
- // node cookie's will be clean up when graph will be destroyed, explorer must not be changed/destroyed until that time
- IComputationPattern::TPtr MakeComputationPattern(
- TExploringNodeVisitor& explorer,
- const TRuntimeNode& root,
- const std::vector<TNode*>& entryPoints,
- const TComputationPatternOpts& opts);
- std::unique_ptr<NUdf::ISecureParamsProvider> MakeSimpleSecureParamsProvider(const THashMap<TString, TString>& secureParams);
- using TCallableComputationNodeBuilder = std::function<IComputationNode* (TCallable&, const TComputationNodeFactoryContext& ctx)>;
- template<typename... Ts>
- TCallableComputationNodeBuilder WrapComputationBuilder(IComputationNode* (*f)(const TComputationNodeFactoryContext&, Ts...)){
- return [f](TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- MKQL_ENSURE(callable.GetInputsCount() == sizeof...(Ts), "Incorrect number of inputs");
- return CallComputationBuilderWithArgs(f, callable, ctx, std::make_index_sequence<sizeof...(Ts)>());
- };
- }
- template<typename F, size_t... Is>
- auto CallComputationBuilderWithArgs(F* f, TCallable& callable, const TComputationNodeFactoryContext& ctx,
- const std::integer_sequence<size_t, Is...> &) {
- return f(ctx, callable.GetInput(Is)...);
- }
- } // namespace NMiniKQL
- } // namespace NKikimr
|