mkql_computation_node.h 17 KB


  1. #pragma once
  2. #include "mkql_computation_node_list.h"
  3. #include "mkql_spiller_factory.h"
  4. #include <yql/essentials/minikql/defs.h>
  5. #include <yql/essentials/minikql/mkql_node.h>
  6. #include <yql/essentials/minikql/mkql_node_visitor.h>
  7. #include <yql/essentials/minikql/mkql_function_registry.h>
  8. #include <yql/essentials/minikql/mkql_alloc.h>
  9. #include <yql/essentials/minikql/mkql_stats_registry.h>
  10. #include <yql/essentials/minikql/mkql_terminator.h>
  11. #include <yql/essentials/public/udf/udf_value.h>
  12. #include <yql/essentials/public/udf/udf_validate.h>
  13. #include <yql/essentials/public/udf/udf_value_builder.h>
  14. #include <library/cpp/cache/cache.h>
  15. #include <library/cpp/random_provider/random_provider.h>
  16. #include <library/cpp/time_provider/time_provider.h>
  17. #include <map>
  18. #include <unordered_set>
  19. #include <unordered_map>
  20. #include <vector>
  21. namespace NKikimr {
  22. namespace NMiniKQL {
  23. inline const TDefaultListRepresentation* GetDefaultListRepresentation(const NUdf::TUnboxedValuePod& value) {
  24. return reinterpret_cast<const TDefaultListRepresentation*>(NUdf::TBoxedValueAccessor::GetListRepresentation(*value.AsBoxed()));
  25. }
  26. enum class EGraphPerProcess {
  27. Multi,
  28. Single
  29. };
  30. struct TComputationOpts {
  31. TComputationOpts(IStatsRegistry* stats)
  32. : Stats(stats)
  33. {}
  34. IStatsRegistry *const Stats;
  35. };
  36. struct TComputationOptsFull: public TComputationOpts {
  37. TComputationOptsFull(IStatsRegistry* stats, TAllocState& allocState, const TTypeEnvironment& typeEnv, IRandomProvider& randomProvider,
  38. ITimeProvider& timeProvider, NUdf::EValidatePolicy validatePolicy, const NUdf::ISecureParamsProvider* secureParamsProvider, NUdf::ICountersProvider* countersProvider)
  39. : TComputationOpts(stats)
  40. , AllocState(allocState)
  41. , TypeEnv(typeEnv)
  42. , RandomProvider(randomProvider)
  43. , TimeProvider(timeProvider)
  44. , ValidatePolicy(validatePolicy)
  45. , SecureParamsProvider(secureParamsProvider)
  46. , CountersProvider(countersProvider)
  47. {}
  48. TAllocState& AllocState;
  49. const TTypeEnvironment& TypeEnv;
  50. IRandomProvider& RandomProvider;
  51. ITimeProvider& TimeProvider;
  52. NUdf::EValidatePolicy ValidatePolicy;
  53. const NUdf::ISecureParamsProvider *const SecureParamsProvider;
  54. NUdf::ICountersProvider *const CountersProvider;
  55. };
  56. struct TWideFieldsInitInfo {
  57. ui32 MutablesIndex = 0;
  58. ui32 WideFieldsIndex = 0;
  59. std::set<ui32> Used;
  60. };
  61. struct TComputationMutables {
  62. ui32 CurValueIndex = 0U;
  63. std::vector<ui32> SerializableValues; // Indices of values that need to be saved in IComputationGraph::SaveGraphState() and restored in IComputationGraph::LoadGraphState().
  64. ui32 CurWideFieldsIndex = 0U;
  65. std::vector<TWideFieldsInitInfo> WideFieldInitialize;
  66. std::vector<ui32> CachedValues; // Indices of values that holds temporary cached data and unreachable by dependencies
  67. void DeferWideFieldsInit(ui32 count, std::set<ui32> used) {
  68. Y_DEBUG_ABORT_UNLESS(AllOf(used, [count](ui32 i) { return i < count; }));
  69. WideFieldInitialize.push_back({CurValueIndex, CurWideFieldsIndex, std::move(used)});
  70. CurValueIndex += count;
  71. CurWideFieldsIndex += count;
  72. }
  73. ui32 IncrementWideFieldsIndex(ui32 addend) {
  74. auto cur = CurWideFieldsIndex;
  75. CurWideFieldsIndex += addend;
  76. return cur;
  77. }
  78. };
  79. class THolderFactory;
  80. // Do not reorder: used in LLVM!
  81. struct TComputationContextLLVM {
  82. const THolderFactory& HolderFactory;
  83. IStatsRegistry *const Stats;
  84. const std::unique_ptr<NUdf::TUnboxedValue[]> MutableValues;
  85. const NUdf::IValueBuilder *const Builder;
  86. float UsageAdjustor = 1.f;
  87. ui32 RssCounter = 0U;
  88. const NUdf::TSourcePosition* CalleePosition = nullptr;
  89. };
  90. struct TComputationContext : public TComputationContextLLVM {
  91. IRandomProvider& RandomProvider;
  92. ITimeProvider& TimeProvider;
  93. bool ExecuteLLVM = false;
  94. arrow::MemoryPool& ArrowMemoryPool;
  95. std::vector<NUdf::TUnboxedValue*> WideFields;
  96. const TTypeEnvironment& TypeEnv;
  97. const TComputationMutables Mutables;
  98. std::shared_ptr<ISpillerFactory> SpillerFactory;
  99. const NUdf::ITypeInfoHelper::TPtr TypeInfoHelper;
  100. NUdf::ICountersProvider *const CountersProvider;
  101. const NUdf::ISecureParamsProvider *const SecureParamsProvider;
  102. TComputationContext(const THolderFactory& holderFactory,
  103. const NUdf::IValueBuilder* builder,
  104. const TComputationOptsFull& opts,
  105. const TComputationMutables& mutables,
  106. arrow::MemoryPool& arrowMemoryPool);
  107. ~TComputationContext();
  108. // Returns true if current usage delta exceeds the memory limit
  109. // The function automatically adjusts memory limit taking into account RSS delta between calls
  110. template<bool TrackRss>
  111. inline bool CheckAdjustedMemLimit(ui64 memLimit, ui64 initMemUsage);
  112. void UpdateUsageAdjustor(ui64 memLimit);
  113. private:
  114. ui64 InitRss = 0ULL;
  115. ui64 LastRss = 0ULL;
  116. #ifndef NDEBUG
  117. TInstant LastPrintUsage;
  118. #endif
  119. };
  120. class IArrowKernelComputationNode;
  121. class IComputationNode {
  122. public:
  123. typedef TIntrusivePtr<IComputationNode> TPtr;
  124. typedef std::map<ui32, EValueRepresentation> TIndexesMap;
  125. virtual ~IComputationNode() {}
  126. virtual void InitNode(TComputationContext&) const = 0;
  127. virtual NUdf::TUnboxedValue GetValue(TComputationContext& compCtx) const = 0;
  128. virtual IComputationNode* AddDependence(const IComputationNode* node) = 0;
  129. virtual const IComputationNode* GetSource() const = 0;
  130. virtual void RegisterDependencies() const = 0;
  131. virtual ui32 GetIndex() const = 0;
  132. virtual void CollectDependentIndexes(const IComputationNode* owner, TIndexesMap& dependencies) const = 0;
  133. virtual ui32 GetDependencyWeight() const = 0;
  134. virtual ui32 GetDependencesCount() const = 0;
  135. virtual bool IsTemporaryValue() const = 0;
  136. virtual EValueRepresentation GetRepresentation() const = 0;
  137. virtual void PrepareStageOne() = 0;
  138. virtual void PrepareStageTwo() = 0;
  139. virtual TString DebugString() const = 0;
  140. virtual void Ref() = 0;
  141. virtual void UnRef() = 0;
  142. virtual ui32 RefCount() const = 0;
  143. virtual std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const;
  144. };
  145. class IComputationExternalNode : public IComputationNode {
  146. public:
  147. virtual NUdf::TUnboxedValue& RefValue(TComputationContext& compCtx) const = 0;
  148. virtual void SetValue(TComputationContext& compCtx, NUdf::TUnboxedValue&& newValue) const = 0;
  149. virtual void SetOwner(const IComputationNode* node) = 0;
  150. using TGetter = std::function<NUdf::TUnboxedValue(TComputationContext&)>;
  151. virtual void SetGetter(TGetter&& getter) = 0;
  152. virtual void InvalidateValue(TComputationContext& compCtx) const = 0;
  153. };
  154. enum class EFetchResult : i32 {
  155. Finish = -1,
  156. Yield = 0,
  157. One = 1
  158. };
  159. class IComputationWideFlowNode : public IComputationNode {
  160. public:
  161. virtual EFetchResult FetchValues(TComputationContext& compCtx, NUdf::TUnboxedValue*const* values) const = 0;
  162. };
  163. class IComputationWideFlowProxyNode : public IComputationWideFlowNode {
  164. public:
  165. using TFetcher = std::function<EFetchResult(TComputationContext&, NUdf::TUnboxedValue*const*)>;
  166. virtual void SetFetcher(TFetcher&& fetcher) = 0;
  167. virtual void SetOwner(const IComputationNode* node) = 0;
  168. virtual void InvalidateValue(TComputationContext& compCtx) const = 0;
  169. };
  170. using TDatumProvider = std::function<arrow::Datum()>;
  171. TDatumProvider MakeDatumProvider(const arrow::Datum& datum);
  172. TDatumProvider MakeDatumProvider(const IComputationNode* node, TComputationContext& ctx);
  173. class IArrowKernelComputationNode {
  174. public:
  175. virtual ~IArrowKernelComputationNode() = default;
  176. virtual TStringBuf GetKernelName() const = 0;
  177. virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0;
  178. virtual const std::vector<arrow::ValueDescr>& GetArgsDesc() const = 0;
  179. virtual const IComputationNode* GetArgument(ui32 index) const = 0;
  180. };
  181. struct TArrowKernelsTopologyItem {
  182. std::vector<ui32> Inputs;
  183. std::unique_ptr<IArrowKernelComputationNode> Node;
  184. };
  185. struct TArrowKernelsTopology {
  186. ui32 InputArgsCount = 0;
  187. std::vector<TArrowKernelsTopologyItem> Items;
  188. };
  189. using TComputationNodePtrVector = std::vector<IComputationNode*, TMKQLAllocator<IComputationNode*>>;
  190. using TComputationWideFlowNodePtrVector = std::vector<IComputationWideFlowNode*, TMKQLAllocator<IComputationWideFlowNode*>>;
  191. using TComputationExternalNodePtrVector = std::vector<IComputationExternalNode*, TMKQLAllocator<IComputationExternalNode*>>;
  192. using TConstComputationNodePtrVector = std::vector<const IComputationNode*, TMKQLAllocator<const IComputationNode*>>;
  193. using TComputationNodePtrDeque = std::deque<IComputationNode::TPtr, TMKQLAllocator<IComputationNode::TPtr>>;
  194. using TComputationNodeOnNodeMap = std::unordered_map<const IComputationNode*, IComputationNode*, std::hash<const IComputationNode*>, std::equal_to<const IComputationNode*>, TMKQLAllocator<std::pair<const IComputationNode *const, IComputationNode*>>>;
  195. class IComputationGraph {
  196. public:
  197. virtual ~IComputationGraph() {}
  198. virtual void Prepare() = 0;
  199. virtual NUdf::TUnboxedValue GetValue() = 0;
  200. virtual TComputationContext& GetContext() = 0;
  201. virtual IComputationExternalNode* GetEntryPoint(size_t index, bool require) = 0;
  202. virtual const TArrowKernelsTopology* GetKernelsTopology() = 0;
  203. virtual const TComputationNodePtrDeque& GetNodes() const = 0;
  204. virtual void Invalidate() = 0; // Invalidate all mutable values in graph (may lead to udf recreation)
  205. virtual void InvalidateCaches() = 0; // Invalidate only cached values
  206. virtual TMemoryUsageInfo& GetMemInfo() const = 0;
  207. virtual const THolderFactory& GetHolderFactory() const = 0;
  208. virtual ITerminator* GetTerminator() const = 0;
  209. virtual bool SetExecuteLLVM(bool value) = 0;
  210. virtual TString SaveGraphState() = 0;
  211. virtual void LoadGraphState(TStringBuf state) = 0;
  212. };
  213. class TNodeFactory;
  214. typedef std::function<IComputationNode* (TNode* node, bool pop)> TNodeLocator;
  215. typedef std::function<void (IComputationNode*)> TNodePushBack;
  216. struct TComputationNodeFactoryContext {
  217. TNodeLocator NodeLocator;
  218. const IFunctionRegistry& FunctionRegistry;
  219. const TTypeEnvironment& Env;
  220. NUdf::ITypeInfoHelper::TPtr TypeInfoHelper;
  221. NUdf::ICountersProvider* CountersProvider;
  222. const NUdf::ISecureParamsProvider* SecureParamsProvider;
  223. const TNodeFactory& NodeFactory;
  224. const THolderFactory& HolderFactory;
  225. const NUdf::IValueBuilder *const Builder;
  226. NUdf::EValidateMode ValidateMode;
  227. NUdf::EValidatePolicy ValidatePolicy;
  228. EGraphPerProcess GraphPerProcess;
  229. TComputationMutables& Mutables;
  230. TComputationNodeOnNodeMap& ElementsCache;
  231. const TNodePushBack NodePushBack;
  232. TComputationNodeFactoryContext(
  233. const TNodeLocator& nodeLocator,
  234. const IFunctionRegistry& functionRegistry,
  235. const TTypeEnvironment& env,
  236. NUdf::ITypeInfoHelper::TPtr typeInfoHelper,
  237. NUdf::ICountersProvider* countersProvider,
  238. const NUdf::ISecureParamsProvider* secureParamsProvider,
  239. const TNodeFactory& nodeFactory,
  240. const THolderFactory& holderFactory,
  241. const NUdf::IValueBuilder* builder,
  242. NUdf::EValidateMode validateMode,
  243. NUdf::EValidatePolicy validatePolicy,
  244. EGraphPerProcess graphPerProcess,
  245. TComputationMutables& mutables,
  246. TComputationNodeOnNodeMap& elementsCache,
  247. TNodePushBack&& nodePushBack
  248. )
  249. : NodeLocator(nodeLocator)
  250. , FunctionRegistry(functionRegistry)
  251. , Env(env)
  252. , TypeInfoHelper(typeInfoHelper)
  253. , CountersProvider(countersProvider)
  254. , SecureParamsProvider(secureParamsProvider)
  255. , NodeFactory(nodeFactory)
  256. , HolderFactory(holderFactory)
  257. , Builder(builder)
  258. , ValidateMode(validateMode)
  259. , ValidatePolicy(validatePolicy)
  260. , GraphPerProcess(graphPerProcess)
  261. , Mutables(mutables)
  262. , ElementsCache(elementsCache)
  263. , NodePushBack(std::move(nodePushBack))
  264. {}
  265. };
  266. using TComputationNodeFactory = std::function<IComputationNode* (TCallable&, const TComputationNodeFactoryContext&)>;
  267. using TStreamEmitter = std::function<void(NUdf::TUnboxedValue&&)>;
  268. struct TPatternCacheEntry;
  269. struct TComputationPatternOpts {
  270. TComputationPatternOpts(TAllocState& allocState, const TTypeEnvironment& env)
  271. : AllocState(allocState)
  272. , Env(env)
  273. {}
  274. TComputationPatternOpts(
  275. TAllocState& allocState,
  276. const TTypeEnvironment& env,
  277. TComputationNodeFactory factory,
  278. const IFunctionRegistry* functionRegistry,
  279. NUdf::EValidateMode validateMode,
  280. NUdf::EValidatePolicy validatePolicy,
  281. const TString& optLLVM,
  282. EGraphPerProcess graphPerProcess,
  283. IStatsRegistry* stats = nullptr,
  284. NUdf::ICountersProvider* countersProvider = nullptr,
  285. const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr)
  286. : AllocState(allocState)
  287. , Env(env)
  288. , Factory(factory)
  289. , FunctionRegistry(functionRegistry)
  290. , ValidateMode(validateMode)
  291. , ValidatePolicy(validatePolicy)
  292. , OptLLVM(optLLVM)
  293. , GraphPerProcess(graphPerProcess)
  294. , Stats(stats)
  295. , CountersProvider(countersProvider)
  296. , SecureParamsProvider(secureParamsProvider)
  297. {}
  298. void SetOptions(TComputationNodeFactory factory, const IFunctionRegistry* functionRegistry,
  299. NUdf::EValidateMode validateMode, NUdf::EValidatePolicy validatePolicy,
  300. const TString& optLLVM, EGraphPerProcess graphPerProcess, IStatsRegistry* stats = nullptr,
  301. NUdf::ICountersProvider* counters = nullptr,
  302. const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) {
  303. Factory = factory;
  304. FunctionRegistry = functionRegistry;
  305. ValidateMode = validateMode;
  306. ValidatePolicy = validatePolicy;
  307. OptLLVM = optLLVM;
  308. GraphPerProcess = graphPerProcess;
  309. Stats = stats;
  310. CountersProvider = counters;
  311. SecureParamsProvider = secureParamsProvider;
  312. }
  313. void SetPatternEnv(std::shared_ptr<TPatternCacheEntry> cacheEnv) {
  314. PatternEnv = std::move(cacheEnv);
  315. }
  316. mutable std::shared_ptr<TPatternCacheEntry> PatternEnv;
  317. TAllocState& AllocState;
  318. const TTypeEnvironment& Env;
  319. TComputationNodeFactory Factory;
  320. const IFunctionRegistry* FunctionRegistry = nullptr;
  321. NUdf::EValidateMode ValidateMode = NUdf::EValidateMode::None;
  322. NUdf::EValidatePolicy ValidatePolicy = NUdf::EValidatePolicy::Fail;
  323. TString OptLLVM;
  324. EGraphPerProcess GraphPerProcess = EGraphPerProcess::Multi;
  325. IStatsRegistry* Stats = nullptr;
  326. NUdf::ICountersProvider* CountersProvider = nullptr;
  327. const NUdf::ISecureParamsProvider* SecureParamsProvider = nullptr;
  328. TComputationOptsFull ToComputationOptions(IRandomProvider& randomProvider, ITimeProvider& timeProvider, TAllocState* allocStatePtr = nullptr) const {
  329. return TComputationOptsFull(Stats, allocStatePtr ? *allocStatePtr : AllocState, Env, randomProvider, timeProvider, ValidatePolicy, SecureParamsProvider, CountersProvider);
  330. }
  331. };
  332. class IComputationPattern: public TAtomicRefCount<IComputationPattern> {
  333. public:
  334. typedef TIntrusivePtr<IComputationPattern> TPtr;
  335. virtual ~IComputationPattern() = default;
  336. virtual void Compile(TString optLLVM, IStatsRegistry* stats) = 0;
  337. virtual bool IsCompiled() const = 0;
  338. virtual size_t CompiledCodeSize() const = 0;
  339. virtual void RemoveCompiledCode() = 0;
  340. virtual THolder<IComputationGraph> Clone(const TComputationOptsFull& compOpts) = 0;
  341. virtual bool GetSuitableForCache() const = 0;
  342. };
  343. // node cookie's will be clean up when graph will be destroyed, explorer must not be changed/destroyed until that time
  344. IComputationPattern::TPtr MakeComputationPattern(
  345. TExploringNodeVisitor& explorer,
  346. const TRuntimeNode& root,
  347. const std::vector<TNode*>& entryPoints,
  348. const TComputationPatternOpts& opts);
  349. std::unique_ptr<NUdf::ISecureParamsProvider> MakeSimpleSecureParamsProvider(const THashMap<TString, TString>& secureParams);
  350. using TCallableComputationNodeBuilder = std::function<IComputationNode* (TCallable&, const TComputationNodeFactoryContext& ctx)>;
  351. template<typename... Ts>
  352. TCallableComputationNodeBuilder WrapComputationBuilder(IComputationNode* (*f)(const TComputationNodeFactoryContext&, Ts...)){
  353. return [f](TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  354. MKQL_ENSURE(callable.GetInputsCount() == sizeof...(Ts), "Incorrect number of inputs");
  355. return CallComputationBuilderWithArgs(f, callable, ctx, std::make_index_sequence<sizeof...(Ts)>());
  356. };
  357. }
  358. template<typename F, size_t... Is>
  359. auto CallComputationBuilderWithArgs(F* f, TCallable& callable, const TComputationNodeFactoryContext& ctx,
  360. const std::integer_sequence<size_t, Is...> &) {
  361. return f(ctx, callable.GetInput(Is)...);
  362. }
  363. } // namespace NMiniKQL
  364. } // namespace NKikimr