mkql_combine.cpp 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046
  1. #include "mkql_combine.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/computation/mkql_llvm_base.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_stats_registry.h>
  6. #include <yql/essentials/minikql/defs.h>
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. TStatKey Combine_FlushesCount("Combine_FlushesCount", true);
  10. TStatKey Combine_MaxRowsCount("Combine_MaxRowsCount", false);
  11. namespace {
  12. using TEqualsPtr = bool(*)(NUdf::TUnboxedValuePod, NUdf::TUnboxedValuePod);
  13. using THashPtr = NYql::NUdf::THashType(*)(NUdf::TUnboxedValuePod);
  14. using TEqualsFunc = std::function<bool(NUdf::TUnboxedValuePod, NUdf::TUnboxedValuePod)>;
  15. using THashFunc = std::function<NYql::NUdf::THashType(NUdf::TUnboxedValuePod)>;
  16. using TDependsOn = std::function<void(IComputationNode*)>;
  17. using TOwn = std::function<void(IComputationExternalNode*)>;
  18. struct TCombineCoreNodes {
  19. IComputationExternalNode* ItemNode;
  20. IComputationExternalNode* KeyNode;
  21. IComputationExternalNode* StateNode;
  22. IComputationNode* KeyResultNode;
  23. IComputationNode* InitResultNode;
  24. IComputationNode* UpdateResultNode;
  25. IComputationNode* FinishResultNode;
  26. NUdf::TUnboxedValuePod ExtractKey(TComputationContext& compCtx, NUdf::TUnboxedValue value) const {
  27. ItemNode->SetValue(compCtx, std::move(value));
  28. auto key = KeyResultNode->GetValue(compCtx);
  29. const auto result = static_cast<const NUdf::TUnboxedValuePod&>(key);
  30. KeyNode->SetValue(compCtx, std::move(key));
  31. return result;
  32. }
  33. void ProcessItem(TComputationContext& compCtx, NUdf::TUnboxedValuePod& state) const {
  34. if (auto& st = static_cast<NUdf::TUnboxedValue&>(state); state.IsInvalid()) {
  35. st = InitResultNode->GetValue(compCtx);
  36. } else {
  37. StateNode->SetValue(compCtx, std::move(st));
  38. st = UpdateResultNode->GetValue(compCtx);
  39. }
  40. }
  41. NUdf::TUnboxedValuePod FinishItem(TComputationContext& compCtx, NUdf::TUnboxedValue& key, NUdf::TUnboxedValue& state) const {
  42. KeyNode->SetValue(compCtx, std::move(key));
  43. StateNode->SetValue(compCtx, std::move(state));
  44. return FinishResultNode->GetValue(compCtx).Release();
  45. }
  46. void RegisterDependencies(const TDependsOn& dependsOn, const TOwn& own) const {
  47. own(ItemNode);
  48. own(KeyNode);
  49. own(StateNode);
  50. dependsOn(KeyResultNode);
  51. dependsOn(InitResultNode);
  52. dependsOn(UpdateResultNode);
  53. dependsOn(FinishResultNode);
  54. }
  55. };
  56. class TState: public TComputationValue<TState> {
  57. using TBase = TComputationValue<TState>;
  58. using TStateMap = std::unordered_map<
  59. NUdf::TUnboxedValuePod, NUdf::TUnboxedValuePod,
  60. THashFunc, TEqualsFunc,
  61. TMKQLAllocator<std::pair<const NUdf::TUnboxedValuePod, NUdf::TUnboxedValuePod>>>;
  62. public:
  63. TState(TMemoryUsageInfo* memInfo, const THashFunc& hash, const TEqualsFunc& equal)
  64. : TBase(memInfo), States(0, hash, equal) {
  65. States.max_load_factor(1.2f);
  66. }
  67. NUdf::TUnboxedValuePod& At(const NUdf::TUnboxedValuePod key) {
  68. const auto ins = States.emplace(key, NUdf::TUnboxedValuePod::Invalid());
  69. if (ins.second) {
  70. key.Ref();
  71. }
  72. return ins.first->second;
  73. }
  74. bool IsEmpty() const {
  75. if (!States.empty()) {
  76. return false;
  77. }
  78. CleanupCurrentContext();
  79. return true;
  80. }
  81. void PushStat(IStatsRegistry* stats) const {
  82. if (!States.empty()) {
  83. MKQL_SET_MAX_STAT(stats, Combine_MaxRowsCount, static_cast<i64>(States.size()));
  84. MKQL_INC_STAT(stats, Combine_FlushesCount);
  85. }
  86. }
  87. bool Extract(NUdf::TUnboxedValue& key, NUdf::TUnboxedValue& state) {
  88. if (States.empty()) {
  89. return false;
  90. }
  91. const auto& node = States.extract(States.cbegin());
  92. static_cast<NUdf::TUnboxedValuePod&>(key) = node.key();
  93. static_cast<NUdf::TUnboxedValuePod&>(state) = node.mapped();
  94. return true;
  95. }
  96. NUdf::EFetchStatus InputStatus = NUdf::EFetchStatus::Ok;
  97. private:
  98. TStateMap States;
  99. };
  100. #ifndef MKQL_DISABLE_CODEGEN
  101. class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState>> {
  102. private:
  103. using TBase = TLLVMFieldsStructure<TComputationValue<TState>>;
  104. llvm::PointerType* StructPtrType;
  105. llvm::IntegerType* StatusType;
  106. protected:
  107. using TBase::Context;
  108. public:
  109. std::vector<llvm::Type*> GetFieldsArray() {
  110. std::vector<llvm::Type*> result = TBase::GetFields();
  111. result.emplace_back(StatusType); // status
  112. result.emplace_back(StructPtrType); // map
  113. return result;
  114. }
  115. llvm::Constant* GetStatus() {
  116. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
  117. }
  118. llvm::Constant* GetMap() {
  119. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
  120. }
  121. TLLVMFieldsStructureState(llvm::LLVMContext& context)
  122. : TBase(context)
  123. , StructPtrType(PointerType::getUnqual(StructType::get(context)))
  124. , StatusType(Type::getInt32Ty(context)) {
  125. }
  126. };
  127. #endif
  128. template <bool IsMultiRowState, bool StateContainerOpt, bool TrackRss>
  129. class TCombineCoreFlowWrapper: public std::conditional_t<IsMultiRowState,
  130. TPairStateFlowCodegeneratorNode<TCombineCoreFlowWrapper<IsMultiRowState, StateContainerOpt, TrackRss>>,
  131. TStatefulFlowCodegeneratorNode<TCombineCoreFlowWrapper<IsMultiRowState, StateContainerOpt, TrackRss>>>
  132. #ifndef MKQL_DISABLE_CODEGEN
  133. , public ICodegeneratorRootNode
  134. #endif
  135. {
  136. using TBaseComputation = std::conditional_t<IsMultiRowState,
  137. TPairStateFlowCodegeneratorNode<TCombineCoreFlowWrapper<IsMultiRowState, StateContainerOpt, TrackRss>>,
  138. TStatefulFlowCodegeneratorNode<TCombineCoreFlowWrapper<IsMultiRowState, StateContainerOpt, TrackRss>>>;
  139. public:
  140. TCombineCoreFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow, const TCombineCoreNodes& nodes, TKeyTypes&& keyTypes, bool isTuple, ui64 memLimit)
  141. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Any)
  142. , Flow(flow)
  143. , Nodes(nodes)
  144. , KeyTypes(std::move(keyTypes))
  145. , IsTuple(isTuple)
  146. , MemLimit(memLimit)
  147. {}
  148. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  149. if (state.IsInvalid()) {
  150. MakeState(ctx, state);
  151. }
  152. while (const auto ptr = static_cast<TState*>(state.AsBoxed().Get())) {
  153. if (ptr->IsEmpty()) {
  154. switch (ptr->InputStatus) {
  155. case NUdf::EFetchStatus::Ok: break;
  156. case NUdf::EFetchStatus::Finish:
  157. return NUdf::TUnboxedValuePod::MakeFinish();
  158. case NUdf::EFetchStatus::Yield:
  159. ptr->InputStatus = NUdf::EFetchStatus::Ok;
  160. return NUdf::TUnboxedValuePod::MakeYield();
  161. }
  162. const auto initUsage = MemLimit ? ctx.HolderFactory.GetMemoryUsed() : 0ULL;
  163. do {
  164. auto item = Flow->GetValue(ctx);
  165. if (item.IsYield()) {
  166. ptr->InputStatus = NUdf::EFetchStatus::Yield;
  167. break;
  168. } else if (item.IsFinish()) {
  169. ptr->InputStatus = NUdf::EFetchStatus::Finish;
  170. break;
  171. }
  172. const auto key = Nodes.ExtractKey(ctx, item);
  173. Nodes.ProcessItem(ctx, ptr->At(key));
  174. } while (!ctx.template CheckAdjustedMemLimit<TrackRss>(MemLimit, initUsage));
  175. ptr->PushStat(ctx.Stats);
  176. }
  177. if (NUdf::TUnboxedValue key, state; ptr->Extract(key, state)) {
  178. if (const auto out = Nodes.FinishItem(ctx, key, state)) {
  179. return out.template GetOptionalValueIf<!IsMultiRowState && StateContainerOpt>();
  180. }
  181. }
  182. }
  183. Y_UNREACHABLE();
  184. }
  185. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, NUdf::TUnboxedValue& current, TComputationContext& ctx) const {
  186. while (true) {
  187. if (current.HasValue()) {
  188. if constexpr (StateContainerOpt) {
  189. NUdf::TUnboxedValue result;
  190. switch (const auto status = current.Fetch(result)) {
  191. case NUdf::EFetchStatus::Ok: return result.Release();
  192. case NUdf::EFetchStatus::Yield: return NUdf::TUnboxedValuePod::MakeYield();
  193. case NUdf::EFetchStatus::Finish: break;
  194. }
  195. } else if (NUdf::TUnboxedValue result; current.Next(result)) {
  196. return result.Release();
  197. }
  198. current.Clear();
  199. }
  200. if (NUdf::TUnboxedValue output = DoCalculate(state, ctx); output.IsSpecial()) {
  201. return output.Release();
  202. } else {
  203. current = StateContainerOpt ? std::move(output) : output.GetListIterator();
  204. }
  205. }
  206. }
  207. #ifndef MKQL_DISABLE_CODEGEN
  208. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  209. auto& context = ctx.Codegen.GetContext();
  210. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(Nodes.ItemNode);
  211. const auto codegenKeyArg = dynamic_cast<ICodegeneratorExternalNode*>(Nodes.KeyNode);
  212. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(Nodes.StateNode);
  213. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  214. MKQL_ENSURE(codegenKeyArg, "Key arg must be codegenerator node.");
  215. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  216. const auto valueType = Type::getInt128Ty(context);
  217. const auto ptrValueType = PointerType::getUnqual(valueType);
  218. const auto statusType = Type::getInt32Ty(context);
  219. TLLVMFieldsStructureState fieldsStruct(context);
  220. const auto stateType = StructType::get(context, fieldsStruct.GetFieldsArray());
  221. const auto statePtrType = PointerType::getUnqual(stateType);
  222. const auto onePtr = new AllocaInst(valueType, 0U, "one_ptr", &ctx.Func->getEntryBlock().back());
  223. const auto twoPtr = new AllocaInst(valueType, 0U, "two_ptr", &ctx.Func->getEntryBlock().back());
  224. new StoreInst(ConstantInt::get(valueType, 0), onePtr, block);
  225. new StoreInst(ConstantInt::get(valueType, 0), twoPtr, block);
  226. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  227. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  228. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  229. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  230. block = make;
  231. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  232. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  233. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TCombineCoreFlowWrapper::MakeState));
  234. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  235. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  236. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  237. BranchInst::Create(main, block);
  238. block = main;
  239. const auto state = new LoadInst(valueType, statePtr, "state", block);
  240. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  241. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  242. BranchInst::Create(more, block);
  243. block = more;
  244. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  245. const auto result = PHINode::Create(valueType, 3U, "result", over);
  246. const auto isEmptyFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::IsEmpty));
  247. const auto isEmptyFuncType = FunctionType::get(Type::getInt1Ty(context), { statePtrType }, false);
  248. const auto isEmptyFuncPtr = CastInst::Create(Instruction::IntToPtr, isEmptyFunc, PointerType::getUnqual(isEmptyFuncType), "cast", block);
  249. const auto empty = CallInst::Create(isEmptyFuncType, isEmptyFuncPtr, { stateArg }, "empty", block);
  250. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  251. const auto full = BasicBlock::Create(context, "full", ctx.Func);
  252. BranchInst::Create(next, full, empty, block);
  253. {
  254. block = next;
  255. const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
  256. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  257. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  258. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  259. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  260. const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { fieldsStruct.This(), fieldsStruct.GetStatus() }, "last", block);
  261. const auto last = new LoadInst(statusType, statusPtr, "last", block);
  262. result->addIncoming(GetFinish(context), block);
  263. const auto choise = SwitchInst::Create(last, pull, 2U, block);
  264. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Yield)), rest);
  265. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), over);
  266. block = rest;
  267. new StoreInst(ConstantInt::get(last->getType(), static_cast<ui32>(NUdf::EFetchStatus::Ok)), statusPtr, block);
  268. result->addIncoming(GetYield(context), block);
  269. BranchInst::Create(over, block);
  270. block = pull;
  271. const auto used = GetMemoryUsed(MemLimit, ctx, block);
  272. BranchInst::Create(loop, block);
  273. block = loop;
  274. const auto item = GetNodeValue(Flow, ctx, block);
  275. const auto finsh = IsFinish(item, block, context);
  276. const auto yield = IsYield(item, block, context);
  277. const auto special = BinaryOperator::CreateOr(finsh, yield, "special", block);
  278. const auto fin = SelectInst::Create(finsh, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), "fin", block);
  279. const auto save = SelectInst::Create(yield, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Yield)), fin, "save", block);
  280. new StoreInst(save, statusPtr, block);
  281. BranchInst::Create(done, good, special, block);
  282. block = good;
  283. codegenItemArg->CreateSetValue(ctx, block, item);
  284. const auto key = GetNodeValue(Nodes.KeyResultNode, ctx, block);
  285. codegenKeyArg->CreateSetValue(ctx, block, key);
  286. const auto keyParam = key;
  287. const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::At));
  288. const auto atType = FunctionType::get(ptrValueType, {stateArg->getType(), keyParam->getType()}, false);
  289. const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
  290. const auto place = CallInst::Create(atType, atPtr, {stateArg, keyParam}, "place", block);
  291. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  292. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  293. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  294. BranchInst::Create(init, next, IsInvalid(place, block, context), block);
  295. block = init;
  296. GetNodeValue(place, Nodes.InitResultNode, ctx, block);
  297. BranchInst::Create(test, block);
  298. block = next;
  299. codegenStateArg->CreateSetValue(ctx, block, place);
  300. GetNodeValue(place, Nodes.UpdateResultNode, ctx, block);
  301. BranchInst::Create(test, block);
  302. block = test;
  303. const auto check = CheckAdjustedMemLimit<TrackRss>(MemLimit, used, ctx, block);
  304. BranchInst::Create(done, loop, check, block);
  305. block = done;
  306. const auto stat = ctx.GetStat();
  307. const auto statFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::PushStat));
  308. const auto statType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType(), stat->getType()}, false);
  309. const auto statPtr = CastInst::Create(Instruction::IntToPtr, statFunc, PointerType::getUnqual(statType), "stat", block);
  310. CallInst::Create(statType, statPtr, {stateArg, stat}, "", block);
  311. BranchInst::Create(full, block);
  312. }
  313. {
  314. block = full;
  315. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  316. const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract));
  317. const auto extractType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType(), onePtr->getType(), twoPtr->getType()}, false);
  318. const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
  319. const auto has = CallInst::Create(extractType, extractPtr, {stateArg, onePtr, twoPtr}, "has", block);
  320. BranchInst::Create(good, more, has, block);
  321. block = good;
  322. codegenKeyArg->CreateSetValue(ctx, block, onePtr);
  323. codegenStateArg->CreateSetValue(ctx, block, twoPtr);
  324. const auto value = GetNodeValue(Nodes.FinishResultNode, ctx, block);
  325. if constexpr (IsMultiRowState) {
  326. result->addIncoming(value, block);
  327. BranchInst::Create(over, block);
  328. } else if constexpr (StateContainerOpt) {
  329. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  330. BranchInst::Create(more, exit, IsEmpty(value, block, context), block);
  331. block = exit;
  332. const auto strip = GetOptionalValue(context, value, block);
  333. result->addIncoming(strip, block);
  334. BranchInst::Create(over, block);
  335. } else {
  336. result->addIncoming(value, block);
  337. BranchInst::Create(more, over, IsEmpty(value, block, context), block);
  338. }
  339. }
  340. block = over;
  341. return result;
  342. }
  343. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, Value* currentPtr, BasicBlock*& block) const {
  344. auto& context = ctx.Codegen.GetContext();
  345. const auto statusType = Type::getInt32Ty(context);
  346. const auto valueType = Type::getInt128Ty(context);
  347. const auto valuePtr = new AllocaInst(valueType, 0U, "value_ptr", &ctx.Func->getEntryBlock().back());
  348. new StoreInst(ConstantInt::get(valueType, 0), valuePtr, block);
  349. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  350. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  351. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  352. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  353. const auto result = PHINode::Create(valueType, StateContainerOpt ? 3U : 2U, "result", over);
  354. BranchInst::Create(more, block);
  355. block = more;
  356. const auto current = new LoadInst(valueType, currentPtr, "current", block);
  357. BranchInst::Create(pull, skip, HasValue(current, block, context), block);
  358. {
  359. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  360. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  361. block = pull;
  362. if constexpr (StateContainerOpt) {
  363. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, current, ctx.Codegen, block, valuePtr);
  364. result->addIncoming(GetYield(context), block);
  365. const auto choise = SwitchInst::Create(status, good, 2U, block);
  366. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Yield)), over);
  367. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), next);
  368. } else {
  369. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), current, ctx.Codegen, block, valuePtr);
  370. BranchInst::Create(good, next, status, block);
  371. }
  372. block = good;
  373. const auto value = new LoadInst(valueType, valuePtr, "value", block);
  374. ValueRelease(static_cast<const IComputationNode*>(this)->GetRepresentation(), value, ctx, block);
  375. result->addIncoming(value, block);
  376. BranchInst::Create(over, block);
  377. block = next;
  378. UnRefBoxed(current, ctx, block);
  379. new StoreInst(ConstantInt::get(current->getType(), 0), currentPtr, block);
  380. BranchInst::Create(skip, block);
  381. }
  382. {
  383. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  384. block = skip;
  385. const auto list = DoGenerateGetValue(ctx, statePtr, block);
  386. result->addIncoming(list, block);
  387. BranchInst::Create(over, good, IsSpecial(list, block, context), block);
  388. block = good;
  389. if constexpr (StateContainerOpt) {
  390. new StoreInst(list, currentPtr, block);
  391. AddRefBoxed(list, ctx, block);
  392. } else {
  393. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(currentPtr, list, ctx.Codegen, block);
  394. CleanupBoxed(list, ctx, block);
  395. }
  396. BranchInst::Create(more, block);
  397. }
  398. block = over;
  399. return result;
  400. }
  401. #endif
  402. private:
  403. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  404. #ifdef MKQL_DISABLE_CODEGEN
  405. state = ctx.HolderFactory.Create<TState>(TValueHasher(KeyTypes, IsTuple, nullptr), TValueEqual(KeyTypes, IsTuple, nullptr));
  406. #else
  407. state = ctx.HolderFactory.Create<TState>(
  408. ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TValueHasher(KeyTypes, IsTuple, nullptr)),
  409. ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TValueEqual(KeyTypes, IsTuple, nullptr))
  410. );
  411. #endif
  412. }
  413. void RegisterDependencies() const final {
  414. if (const auto flow = this->FlowDependsOn(Flow)) {
  415. Nodes.RegisterDependencies(
  416. [this, flow](IComputationNode* node){ this->DependsOn(flow, node); },
  417. [this, flow](IComputationExternalNode* node){ this->Own(flow, node); }
  418. );
  419. }
  420. }
  421. IComputationNode* const Flow;
  422. const TCombineCoreNodes Nodes;
  423. const TKeyTypes KeyTypes;
  424. const bool IsTuple;
  425. const ui64 MemLimit;
  426. #ifndef MKQL_DISABLE_CODEGEN
  427. TEqualsPtr Equals = nullptr;
  428. THashPtr Hash = nullptr;
  429. Function* EqualsFunc = nullptr;
  430. Function* HashFunc = nullptr;
  431. template <bool EqualsOrHash>
  432. TString MakeName() const {
  433. TStringStream out;
  434. out << this->DebugString() << "::" << (EqualsOrHash ? "Equals" : "Hash") << "_(" << static_cast<const void*>(this) << ").";
  435. return out.Str();
  436. }
  437. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  438. if (EqualsFunc) {
  439. Equals = reinterpret_cast<TEqualsPtr>(codegen.GetPointerToFunction(EqualsFunc));
  440. }
  441. if (HashFunc) {
  442. Hash = reinterpret_cast<THashPtr>(codegen.GetPointerToFunction(HashFunc));
  443. }
  444. }
  445. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  446. codegen.ExportSymbol(HashFunc = GenerateHashFunction(codegen, MakeName<false>(), IsTuple, KeyTypes));
  447. codegen.ExportSymbol(EqualsFunc = GenerateEqualsFunction(codegen, MakeName<true>(), IsTuple, KeyTypes));
  448. }
  449. #endif
  450. };
  451. template <bool IsMultiRowState, bool StateContainerOpt, bool TrackRss>
  452. class TCombineCoreWrapper: public TCustomValueCodegeneratorNode<TCombineCoreWrapper<IsMultiRowState, StateContainerOpt, TrackRss>> {
  453. typedef TCustomValueCodegeneratorNode<TCombineCoreWrapper<IsMultiRowState, StateContainerOpt, TrackRss>> TBaseComputation;
  454. #ifndef MKQL_DISABLE_CODEGEN
  455. using TCodegenValue = std::conditional_t<IsMultiRowState, TStreamCodegenSelfStatePlusValue<TState>, TStreamCodegenSelfStateValue<TState>>;
  456. #endif
  457. public:
  458. class TStreamValue : public TState {
  459. public:
  460. TStreamValue(TMemoryUsageInfo* memInfo, NUdf::TUnboxedValue&& stream, const TCombineCoreNodes& nodes, ui64 memLimit, TComputationContext& compCtx, const THashFunc& hash, const TEqualsFunc& equal)
  461. : TState(memInfo, hash, equal)
  462. , Stream(std::move(stream))
  463. , Nodes(nodes)
  464. , MemLimit(memLimit)
  465. , CompCtx(compCtx)
  466. {}
  467. private:
  468. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
  469. for (;;) {
  470. if (IsMultiRowState && Iterator) {
  471. if constexpr (StateContainerOpt) {
  472. const auto status = Iterator.Fetch(result);
  473. if (status != NUdf::EFetchStatus::Finish) {
  474. return status;
  475. }
  476. Iterator.Clear();
  477. } else if (Iterator.Next(result)) {
  478. return NUdf::EFetchStatus::Ok;
  479. }
  480. Iterator.Clear();
  481. }
  482. if (IsEmpty()) {
  483. switch (InputStatus) {
  484. case NUdf::EFetchStatus::Ok: break;
  485. case NUdf::EFetchStatus::Finish:
  486. return NUdf::EFetchStatus::Finish;
  487. case NUdf::EFetchStatus::Yield:
  488. InputStatus = NUdf::EFetchStatus::Ok;
  489. return NUdf::EFetchStatus::Yield;
  490. }
  491. const auto initUsage = MemLimit ? CompCtx.HolderFactory.GetMemoryUsed() : 0ULL;
  492. do {
  493. NUdf::TUnboxedValue item;
  494. InputStatus = Stream.Fetch(item);
  495. if (NUdf::EFetchStatus::Ok != InputStatus) {
  496. break;
  497. }
  498. const auto key = Nodes.ExtractKey(CompCtx, item);
  499. Nodes.ProcessItem(CompCtx, At(key));
  500. } while (!CompCtx.template CheckAdjustedMemLimit<TrackRss>(MemLimit, initUsage));
  501. PushStat(CompCtx.Stats);
  502. }
  503. if (NUdf::TUnboxedValue key, state; Extract(key, state)) {
  504. NUdf::TUnboxedValue finishItem = Nodes.FinishItem(CompCtx, key, state);
  505. if constexpr (IsMultiRowState) {
  506. Iterator = StateContainerOpt ? std::move(finishItem) : finishItem.GetListIterator();
  507. } else {
  508. result = finishItem.Release().GetOptionalValueIf<StateContainerOpt>();
  509. return NUdf::EFetchStatus::Ok;
  510. }
  511. }
  512. }
  513. }
  514. const NUdf::TUnboxedValue Stream;
  515. NUdf::TUnboxedValue Iterator;
  516. const TCombineCoreNodes Nodes;
  517. const ui64 MemLimit;
  518. TComputationContext& CompCtx;
  519. };
  520. TCombineCoreWrapper(TComputationMutables& mutables, IComputationNode* stream, const TCombineCoreNodes& nodes, TKeyTypes&& keyTypes, bool isTuple, ui64 memLimit)
  521. : TBaseComputation(mutables)
  522. , Stream(stream)
  523. , Nodes(nodes)
  524. , KeyTypes(std::move(keyTypes))
  525. , IsTuple(isTuple)
  526. , MemLimit(memLimit)
  527. {}
  528. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  529. #ifndef MKQL_DISABLE_CODEGEN
  530. if (ctx.ExecuteLLVM && Combine)
  531. return ctx.HolderFactory.Create<TCodegenValue>(Combine, &ctx, Stream->GetValue(ctx),
  532. ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TValueHasher(KeyTypes, IsTuple, nullptr)),
  533. ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TValueEqual(KeyTypes, IsTuple, nullptr))
  534. );
  535. #endif
  536. return ctx.HolderFactory.Create<TStreamValue>(Stream->GetValue(ctx), Nodes, MemLimit, ctx,
  537. TValueHasher(KeyTypes, IsTuple, nullptr), TValueEqual(KeyTypes, IsTuple, nullptr));
  538. }
  539. private:
  540. void RegisterDependencies() const final {
  541. this->DependsOn(Stream);
  542. Nodes.RegisterDependencies(
  543. [this](IComputationNode* node){ this->DependsOn(node); },
  544. [this](IComputationExternalNode* node){ this->Own(node); }
  545. );
  546. }
  547. #ifndef MKQL_DISABLE_CODEGEN
  548. template <bool EqualsOrHash>
  549. TString MakeFuncName() const {
  550. TStringStream out;
  551. out << this->DebugString() << "::" << (EqualsOrHash ? "Equals" : "Hash") << "_(" << static_cast<const void*>(this) << ").";
  552. return out.Str();
  553. }
  554. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  555. codegen.ExportSymbol(CombineFunc = GenerateCombine(codegen));
  556. codegen.ExportSymbol(EqualsFunc = GenerateEqualsFunction(codegen, MakeFuncName<true>(), IsTuple, KeyTypes));
  557. codegen.ExportSymbol(HashFunc = GenerateHashFunction(codegen, MakeFuncName<false>(), IsTuple, KeyTypes));
  558. }
  559. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  560. if (CombineFunc) {
  561. Combine = reinterpret_cast<TCombinePtr>(codegen.GetPointerToFunction(CombineFunc));
  562. }
  563. if (EqualsFunc) {
  564. Equals = reinterpret_cast<TEqualsPtr>(codegen.GetPointerToFunction(EqualsFunc));
  565. }
  566. if (HashFunc) {
  567. Hash = reinterpret_cast<THashPtr>(codegen.GetPointerToFunction(HashFunc));
  568. }
  569. }
  570. Function* GenerateCombine(NYql::NCodegen::ICodegen& codegen) const {
  571. auto& module = codegen.GetModule();
  572. auto& context = codegen.GetContext();
  573. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(Nodes.ItemNode);
  574. const auto codegenKeyArg = dynamic_cast<ICodegeneratorExternalNode*>(Nodes.KeyNode);
  575. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(Nodes.StateNode);
  576. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  577. MKQL_ENSURE(codegenKeyArg, "Key arg must be codegenerator node.");
  578. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  579. const auto& name = this->MakeName("Fetch");
  580. if (const auto f = module.getFunction(name.c_str()))
  581. return f;
  582. const auto valueType = Type::getInt128Ty(context);
  583. const auto ptrValueType = PointerType::getUnqual(valueType);
  584. const auto containerType = static_cast<Type*>(valueType);
  585. const auto contextType = GetCompContextType(context);
  586. const auto statusType = Type::getInt32Ty(context);
  587. TLLVMFieldsStructureState fieldsStruct(context);
  588. const auto stateType = StructType::get(context, fieldsStruct.GetFieldsArray());
  589. const auto statePtrType = PointerType::getUnqual(stateType);
  590. const auto funcType = IsMultiRowState ?
  591. FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, statePtrType, ptrValueType, ptrValueType}, false):
  592. FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, statePtrType, ptrValueType}, false);
  593. TCodegenContext ctx(codegen);
  594. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  595. DISubprogramAnnotator annotator(ctx, ctx.Func);
  596. auto args = ctx.Func->arg_begin();
  597. ctx.Ctx = &*args;
  598. const auto containerArg = &*++args;
  599. const auto stateArg = &*++args;
  600. const auto currArg = IsMultiRowState ? &*++args : nullptr;
  601. const auto valuePtr = &*++args;
  602. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  603. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  604. auto block = main;
  605. const auto onePtr = new AllocaInst(valueType, 0U, "one_ptr", block);
  606. const auto twoPtr = new AllocaInst(valueType, 0U, "two_ptr", block);
  607. new StoreInst(ConstantInt::get(valueType, 0), onePtr, block);
  608. new StoreInst(ConstantInt::get(valueType, 0), twoPtr, block);
  609. BranchInst::Create(more, block);
  610. block = more;
  611. if constexpr (IsMultiRowState) {
  612. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  613. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  614. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  615. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  616. const auto current = new LoadInst(valueType, currArg, "current", block);
  617. BranchInst::Create(skip, pull, IsEmpty(current, block, context), block);
  618. block = pull;
  619. const auto status = StateContainerOpt ?
  620. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, current, codegen, block, valuePtr):
  621. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), current, codegen, block, valuePtr);
  622. const auto icmp = StateContainerOpt ?
  623. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, status, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Finish)), "cond", block): status;
  624. BranchInst::Create(good, next, icmp, block);
  625. block = good;
  626. ReturnInst::Create(context, StateContainerOpt ? status : ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), block);
  627. block = next;
  628. UnRefBoxed(current, ctx, block);
  629. new StoreInst(ConstantInt::get(current->getType(), 0), currArg, block);
  630. BranchInst::Create(skip, block);
  631. block = skip;
  632. }
  633. const auto isEmptyFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::IsEmpty));
  634. const auto isEmptyFuncType = FunctionType::get(Type::getInt1Ty(context), { statePtrType }, false);
  635. const auto isEmptyFuncPtr = CastInst::Create(Instruction::IntToPtr, isEmptyFunc, PointerType::getUnqual(isEmptyFuncType), "cast", block);
  636. const auto empty = CallInst::Create(isEmptyFuncType, isEmptyFuncPtr, { stateArg }, "empty", block);
  637. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  638. const auto full = BasicBlock::Create(context, "full", ctx.Func);
  639. BranchInst::Create(next, full, empty, block);
  640. {
  641. block = next;
  642. const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
  643. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  644. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  645. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  646. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  647. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  648. const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { fieldsStruct.This(), fieldsStruct.GetStatus() }, "last", block);
  649. const auto last = new LoadInst(statusType, statusPtr, "last", block);
  650. const auto choise = SwitchInst::Create(last, pull, 2U, block);
  651. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Yield)), rest);
  652. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), exit);
  653. block = rest;
  654. new StoreInst(ConstantInt::get(last->getType(), static_cast<ui32>(NUdf::EFetchStatus::Ok)), statusPtr, block);
  655. BranchInst::Create(exit, block);
  656. block = exit;
  657. ReturnInst::Create(context, last, block);
  658. block = pull;
  659. const auto used = GetMemoryUsed(MemLimit, ctx, block);
  660. const auto stream = static_cast<Value*>(containerArg);
  661. BranchInst::Create(loop, block);
  662. block = loop;
  663. const auto fetch = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, stream, codegen, block, onePtr);
  664. const auto ok = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, fetch, ConstantInt::get(fetch->getType(), static_cast<ui32>(NUdf::EFetchStatus::Ok)), "ok", block);
  665. new StoreInst(fetch, statusPtr, block);
  666. BranchInst::Create(good, done, ok, block);
  667. block = good;
  668. codegenItemArg->CreateSetValue(ctx, block, onePtr);
  669. const auto key = GetNodeValue(Nodes.KeyResultNode, ctx, block);
  670. codegenKeyArg->CreateSetValue(ctx, block, key);
  671. const auto keyParam = key;
  672. const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::At));
  673. const auto atType = FunctionType::get(ptrValueType, {stateArg->getType(), keyParam->getType()}, false);
  674. const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
  675. const auto place = CallInst::Create(atType, atPtr, {stateArg, keyParam}, "place", block);
  676. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  677. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  678. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  679. BranchInst::Create(init, next, IsInvalid(place, block, context), block);
  680. block = init;
  681. GetNodeValue(place, Nodes.InitResultNode, ctx, block);
  682. BranchInst::Create(test, block);
  683. block = next;
  684. codegenStateArg->CreateSetValue(ctx, block, place);
  685. GetNodeValue(place, Nodes.UpdateResultNode, ctx, block);
  686. BranchInst::Create(test, block);
  687. block = test;
  688. const auto check = CheckAdjustedMemLimit<TrackRss>(MemLimit, used, ctx, block);
  689. BranchInst::Create(done, loop, check, block);
  690. block = done;
  691. const auto stat = ctx.GetStat();
  692. const auto statFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::PushStat));
  693. const auto statType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType(), stat->getType()}, false);
  694. const auto statPtr = CastInst::Create(Instruction::IntToPtr, statFunc, PointerType::getUnqual(statType), "stat", block);
  695. CallInst::Create(statType, statPtr, {stateArg, stat}, "", block);
  696. BranchInst::Create(full, block);
  697. }
  698. {
  699. block = full;
  700. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  701. const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract));
  702. const auto extractType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType(), onePtr->getType(), twoPtr->getType()}, false);
  703. const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
  704. const auto has = CallInst::Create(extractType, extractPtr, {stateArg, onePtr, twoPtr}, "has", block);
  705. BranchInst::Create(good, more, has, block);
  706. block = good;
  707. codegenKeyArg->CreateSetValue(ctx, block, onePtr);
  708. codegenStateArg->CreateSetValue(ctx, block, twoPtr);
  709. if constexpr (IsMultiRowState) {
  710. if constexpr (StateContainerOpt) {
  711. GetNodeValue(currArg, Nodes.FinishResultNode, ctx, block);
  712. BranchInst::Create(more, block);
  713. } else {
  714. const auto list = GetNodeValue(Nodes.FinishResultNode, ctx, block);
  715. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(currArg, list, codegen, block);
  716. if (Nodes.FinishResultNode->IsTemporaryValue())
  717. CleanupBoxed(list, ctx, block);
  718. BranchInst::Create(more, block);
  719. }
  720. } else {
  721. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  722. GetNodeValue(valuePtr, Nodes.FinishResultNode, ctx, block);
  723. const auto value = new LoadInst(valueType, valuePtr, "value", block);
  724. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  725. BranchInst::Create(more, exit, IsEmpty(value, block, context), block);
  726. block = exit;
  727. if constexpr (StateContainerOpt) {
  728. const auto strip = GetOptionalValue(context, value, block);
  729. new StoreInst(strip, valuePtr, block);
  730. }
  731. ReturnInst::Create(context, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), block);
  732. }
  733. }
  734. return ctx.Func;
  735. }
  736. using TCombinePtr = typename TCodegenValue::TFetchPtr;
  737. Function* CombineFunc = nullptr;
  738. Function* EqualsFunc = nullptr;
  739. Function* HashFunc = nullptr;
  740. TCombinePtr Combine = nullptr;
  741. TEqualsPtr Equals = nullptr;
  742. THashPtr Hash = nullptr;
  743. #endif
  744. IComputationNode* const Stream;
  745. const TCombineCoreNodes Nodes;
  746. const TKeyTypes KeyTypes;
  747. const bool IsTuple;
  748. const ui64 MemLimit;
  749. };
  750. }
  751. IComputationNode* WrapCombineCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  752. MKQL_ENSURE(callable.GetInputsCount() == 9U, "Expected 9 args");
  753. const auto type = callable.GetType()->GetReturnType();
  754. const auto finishResultType = callable.GetInput(7).GetStaticType();
  755. MKQL_ENSURE(finishResultType->IsList() || finishResultType->IsOptional() || finishResultType->IsStream(), "Expected list, stream or optional");
  756. const auto keyType = callable.GetInput(2).GetStaticType();
  757. TKeyTypes keyTypes;
  758. bool isTuple;
  759. bool encoded;
  760. bool useIHash;
  761. GetDictionaryKeyTypes(keyType, keyTypes, isTuple, encoded, useIHash);
  762. Y_ENSURE(!encoded, "TODO");
  763. const auto memLimit = AS_VALUE(TDataLiteral, callable.GetInput(8))->AsValue().Get<ui64>();
  764. const bool trackRss = EGraphPerProcess::Single == ctx.GraphPerProcess;
  765. const auto stream = LocateNode(ctx.NodeLocator, callable, 0);
  766. const auto keyExtractorResultNode = LocateNode(ctx.NodeLocator, callable, 2);
  767. const auto initResultNode = LocateNode(ctx.NodeLocator, callable, 4);
  768. const auto updateResultNode = LocateNode(ctx.NodeLocator, callable, 6);
  769. const auto finishResultNode = LocateNode(ctx.NodeLocator, callable, 7);
  770. const auto itemNode = LocateExternalNode(ctx.NodeLocator, callable, 1);
  771. const auto keyNode = LocateExternalNode(ctx.NodeLocator, callable, 3);
  772. const auto stateNode = LocateExternalNode(ctx.NodeLocator, callable, 5);
  773. const TCombineCoreNodes nodes = {
  774. itemNode,
  775. keyNode,
  776. stateNode,
  777. keyExtractorResultNode,
  778. initResultNode,
  779. updateResultNode,
  780. finishResultNode
  781. };
  782. if (type->IsFlow()) {
  783. const auto kind = GetValueRepresentation(AS_TYPE(TFlowType, type)->GetItemType());
  784. if (finishResultType->IsStream()) {
  785. if (trackRss)
  786. return new TCombineCoreFlowWrapper<true, true, true>(ctx.Mutables, kind, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  787. else
  788. return new TCombineCoreFlowWrapper<true, true, false>(ctx.Mutables, kind, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  789. } else if (finishResultType->IsList()) {
  790. if (trackRss)
  791. return new TCombineCoreFlowWrapper<true, false, true>(ctx.Mutables, kind, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  792. else
  793. return new TCombineCoreFlowWrapper<true, false, false>(ctx.Mutables, kind, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  794. } else if (finishResultType->IsOptional()) {
  795. if (AS_TYPE(TOptionalType, finishResultType)->GetItemType()->IsOptional()) {
  796. if (trackRss)
  797. return new TCombineCoreFlowWrapper<false, true, true>(ctx.Mutables, kind, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  798. else
  799. return new TCombineCoreFlowWrapper<false, true, false>(ctx.Mutables, kind, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  800. } else {
  801. if (trackRss)
  802. return new TCombineCoreFlowWrapper<false, false, true>(ctx.Mutables, kind, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  803. else
  804. return new TCombineCoreFlowWrapper<false, false, false>(ctx.Mutables, kind, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  805. }
  806. }
  807. } else if (type->IsStream()) {
  808. if (finishResultType->IsStream()) {
  809. if (trackRss)
  810. return new TCombineCoreWrapper<true, true, true>(ctx.Mutables, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  811. else
  812. return new TCombineCoreWrapper<true, true, false>(ctx.Mutables, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  813. } else if (finishResultType->IsList()) {
  814. if (trackRss)
  815. return new TCombineCoreWrapper<true, false, true>(ctx.Mutables, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  816. else
  817. return new TCombineCoreWrapper<true, false, false>(ctx.Mutables, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  818. } else if (finishResultType->IsOptional()) {
  819. if (AS_TYPE(TOptionalType, finishResultType)->GetItemType()->IsOptional()) {
  820. if (trackRss)
  821. return new TCombineCoreWrapper<false, true, true>(ctx.Mutables, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  822. else
  823. return new TCombineCoreWrapper<false, true, false>(ctx.Mutables, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  824. } else {
  825. if (trackRss)
  826. return new TCombineCoreWrapper<false, false, true>(ctx.Mutables, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  827. else
  828. return new TCombineCoreWrapper<false, false, false>(ctx.Mutables, stream, nodes, std::move(keyTypes), isTuple, memLimit);
  829. }
  830. }
  831. }
  832. THROW yexception() << "Expected flow or stream.";
  833. }
  834. }
  835. }