mkql_wide_combine.cpp 87 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996
  1. #include "mkql_counters.h"
  2. #include "mkql_rh_hash.h"
  3. #include "mkql_wide_combine.h"
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  5. #include <yql/essentials/minikql/computation/mkql_llvm_base.h> // Y_IGNORE
  6. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  7. #include <yql/essentials/minikql/computation/mkql_spiller_adapter.h>
  8. #include <yql/essentials/minikql/computation/mkql_spiller.h>
  9. #include <yql/essentials/minikql/mkql_node_builder.h>
  10. #include <yql/essentials/minikql/mkql_node_cast.h>
  11. #include <yql/essentials/minikql/mkql_runtime_version.h>
  12. #include <yql/essentials/minikql/mkql_stats_registry.h>
  13. #include <yql/essentials/minikql/defs.h>
  14. #include <yql/essentials/utils/cast.h>
  15. #include <yql/essentials/utils/log/log.h>
  16. #include <util/string/cast.h>
  17. #include <contrib/libs/xxhash/xxhash.h>
  18. namespace NKikimr {
  19. namespace NMiniKQL {
  20. using NYql::EnsureDynamicCast;
  21. using NYql::TChunkedBuffer;
  22. extern TStatKey Combine_FlushesCount;
  23. extern TStatKey Combine_MaxRowsCount;
  24. namespace {
  25. struct TMyValueEqual {
  26. TMyValueEqual(const TKeyTypes& types)
  27. : Types(types)
  28. {}
  29. bool operator()(const NUdf::TUnboxedValuePod* left, const NUdf::TUnboxedValuePod* right) const {
  30. for (ui32 i = 0U; i < Types.size(); ++i)
  31. if (CompareValues(Types[i].first, true, Types[i].second, left[i], right[i]))
  32. return false;
  33. return true;
  34. }
  35. const TKeyTypes& Types;
  36. };
  37. struct TMyValueHasher {
  38. TMyValueHasher(const TKeyTypes& types)
  39. : Types(types)
  40. {}
  41. NUdf::THashType operator()(const NUdf::TUnboxedValuePod* values) const {
  42. if (Types.size() == 1U)
  43. if (const auto v = *values)
  44. return NUdf::GetValueHash(Types.front().first, v);
  45. else
  46. return HashOfNull;
  47. NUdf::THashType hash = 0ULL;
  48. for (const auto& type : Types) {
  49. if (const auto v = *values++)
  50. hash = CombineHashes(hash, NUdf::GetValueHash(type.first, v));
  51. else
  52. hash = CombineHashes(hash, HashOfNull);
  53. }
  54. return hash;
  55. }
  56. const TKeyTypes& Types;
  57. };
  58. using TEqualsPtr = bool(*)(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*);
  59. using THashPtr = NUdf::THashType(*)(const NUdf::TUnboxedValuePod*);
  60. using TEqualsFunc = std::function<bool(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)>;
  61. using THashFunc = std::function<NUdf::THashType(const NUdf::TUnboxedValuePod*)>;
  62. using TDependsOn = std::function<void(IComputationNode*)>;
  63. using TOwn = std::function<void(IComputationExternalNode*)>;
  64. struct TCombinerNodes {
  65. TComputationExternalNodePtrVector ItemNodes, KeyNodes, StateNodes, FinishNodes;
  66. TComputationNodePtrVector KeyResultNodes, InitResultNodes, UpdateResultNodes, FinishResultNodes;
  67. TPasstroughtMap
  68. KeysOnItems,
  69. InitOnKeys,
  70. InitOnItems,
  71. UpdateOnKeys,
  72. UpdateOnItems,
  73. UpdateOnState,
  74. StateOnUpdate,
  75. ItemsOnResult,
  76. ResultOnItems;
  77. std::vector<bool> PasstroughtItems;
  78. void BuildMaps() {
  79. KeysOnItems = GetPasstroughtMap(KeyResultNodes, ItemNodes);
  80. InitOnKeys = GetPasstroughtMap(InitResultNodes, KeyNodes);
  81. InitOnItems = GetPasstroughtMap(InitResultNodes, ItemNodes);
  82. UpdateOnKeys = GetPasstroughtMap(UpdateResultNodes, KeyNodes);
  83. UpdateOnItems = GetPasstroughtMap(UpdateResultNodes, ItemNodes);
  84. UpdateOnState = GetPasstroughtMap(UpdateResultNodes, StateNodes);
  85. StateOnUpdate = GetPasstroughtMap(StateNodes, UpdateResultNodes);
  86. ItemsOnResult = GetPasstroughtMap(FinishNodes, FinishResultNodes);
  87. ResultOnItems = GetPasstroughtMap(FinishResultNodes, FinishNodes);
  88. PasstroughtItems.resize(ItemNodes.size());
  89. auto anyResults = KeyResultNodes;
  90. anyResults.insert(anyResults.cend(), InitResultNodes.cbegin(), InitResultNodes.cend());
  91. anyResults.insert(anyResults.cend(), UpdateResultNodes.cbegin(), UpdateResultNodes.cend());
  92. const auto itemsOnResults = GetPasstroughtMap(ItemNodes, anyResults);
  93. std::transform(itemsOnResults.cbegin(), itemsOnResults.cend(), PasstroughtItems.begin(), [](const TPasstroughtMap::value_type& v) { return v.has_value(); });
  94. }
  95. bool IsInputItemNodeUsed(size_t i) const {
  96. return (ItemNodes[i]->GetDependencesCount() > 0U || PasstroughtItems[i]);
  97. }
  98. NUdf::TUnboxedValue* GetUsedInputItemNodePtrOrNull(TComputationContext& ctx, size_t i) const {
  99. return IsInputItemNodeUsed(i) ?
  100. &ItemNodes[i]->RefValue(ctx) :
  101. nullptr;
  102. }
  103. void ExtractKey(TComputationContext& ctx, NUdf::TUnboxedValue** values, NUdf::TUnboxedValue* keys) const {
  104. std::for_each(ItemNodes.cbegin(), ItemNodes.cend(), [&](IComputationExternalNode* item) {
  105. if (const auto pointer = *values++)
  106. item->SetValue(ctx, std::move(*pointer));
  107. });
  108. for (ui32 i = 0U; i < KeyNodes.size(); ++i) {
  109. auto& key = KeyNodes[i]->RefValue(ctx);
  110. *keys++ = key = KeyResultNodes[i]->GetValue(ctx);
  111. }
  112. }
  113. void ConsumeRawData(TComputationContext& /*ctx*/, NUdf::TUnboxedValue* keys, NUdf::TUnboxedValue** from, NUdf::TUnboxedValue* to) const {
  114. std::fill_n(keys, KeyResultNodes.size(), NUdf::TUnboxedValuePod());
  115. for (ui32 i = 0U; i < ItemNodes.size(); ++i) {
  116. if (from[i] && IsInputItemNodeUsed(i)) {
  117. to[i] = std::move(*(from[i]));
  118. }
  119. }
  120. }
  121. void ExtractRawData(TComputationContext& ctx, NUdf::TUnboxedValue* from, NUdf::TUnboxedValue* keys) const {
  122. for (ui32 i = 0U; i != ItemNodes.size(); ++i) {
  123. if (IsInputItemNodeUsed(i)) {
  124. ItemNodes[i]->SetValue(ctx, std::move(from[i]));
  125. }
  126. }
  127. for (ui32 i = 0U; i < KeyNodes.size(); ++i) {
  128. auto& key = KeyNodes[i]->RefValue(ctx);
  129. *keys++ = key = KeyResultNodes[i]->GetValue(ctx);
  130. }
  131. }
  132. void ProcessItem(TComputationContext& ctx, NUdf::TUnboxedValue* keys, NUdf::TUnboxedValue* state) const {
  133. if (keys) {
  134. std::fill_n(keys, KeyResultNodes.size(), NUdf::TUnboxedValuePod());
  135. auto source = state;
  136. std::for_each(StateNodes.cbegin(), StateNodes.cend(), [&](IComputationExternalNode* item){ item->SetValue(ctx, std::move(*source++)); });
  137. std::transform(UpdateResultNodes.cbegin(), UpdateResultNodes.cend(), state, [&](IComputationNode* node) { return node->GetValue(ctx); });
  138. } else {
  139. std::transform(InitResultNodes.cbegin(), InitResultNodes.cend(), state, [&](IComputationNode* node) { return node->GetValue(ctx); });
  140. }
  141. }
  142. void FinishItem(TComputationContext& ctx, NUdf::TUnboxedValue* state, NUdf::TUnboxedValue*const* output) const {
  143. std::for_each(FinishNodes.cbegin(), FinishNodes.cend(), [&](IComputationExternalNode* item) { item->SetValue(ctx, std::move(*state++)); });
  144. for (const auto node : FinishResultNodes)
  145. if (const auto out = *output++)
  146. *out = node->GetValue(ctx);
  147. }
  148. void RegisterDependencies(const TDependsOn& dependsOn, const TOwn& own) const {
  149. std::for_each(ItemNodes.cbegin(), ItemNodes.cend(), own);
  150. std::for_each(KeyNodes.cbegin(), KeyNodes.cend(), own);
  151. std::for_each(StateNodes.cbegin(), StateNodes.cend(), own);
  152. std::for_each(FinishNodes.cbegin(), FinishNodes.cend(), own);
  153. std::for_each(KeyResultNodes.cbegin(), KeyResultNodes.cend(), dependsOn);
  154. std::for_each(InitResultNodes.cbegin(), InitResultNodes.cend(), dependsOn);
  155. std::for_each(UpdateResultNodes.cbegin(), UpdateResultNodes.cend(), dependsOn);
  156. std::for_each(FinishResultNodes.cbegin(), FinishResultNodes.cend(), dependsOn);
  157. }
  158. };
  159. class TState : public TComputationValue<TState> {
  160. typedef TComputationValue<TState> TBase;
  161. private:
  162. using TStates = TRobinHoodHashSet<NUdf::TUnboxedValuePod*, TEqualsFunc, THashFunc, TMKQLAllocator<char, EMemorySubPool::Temporary>>;
  163. using TRow = std::vector<NUdf::TUnboxedValuePod, TMKQLAllocator<NUdf::TUnboxedValuePod>>;
  164. using TStorage = std::deque<TRow, TMKQLAllocator<TRow>>;
  165. class TStorageIterator {
  166. private:
  167. TStorage& Storage;
  168. const ui32 RowSize = 0;
  169. const ui64 Count = 0;
  170. ui64 Ready = 0;
  171. TStorage::iterator ItStorage;
  172. TRow::iterator ItRow;
  173. public:
  174. TStorageIterator(TStorage& storage, const ui32 rowSize, const ui64 count)
  175. : Storage(storage)
  176. , RowSize(rowSize)
  177. , Count(count)
  178. {
  179. ItStorage = Storage.begin();
  180. if (ItStorage != Storage.end()) {
  181. ItRow = ItStorage->begin();
  182. }
  183. }
  184. bool IsValid() {
  185. return Ready < Count;
  186. }
  187. bool Next() {
  188. if (++Ready >= Count) {
  189. return false;
  190. }
  191. ItRow += RowSize;
  192. if (ItRow == ItStorage->end()) {
  193. ++ItStorage;
  194. ItRow = ItStorage->begin();
  195. }
  196. return true;
  197. }
  198. NUdf::TUnboxedValuePod* GetValuePtr() const {
  199. return &*ItRow;
  200. }
  201. };
  202. static constexpr ui32 CountRowsOnPage = 128;
  203. ui32 RowSize() const {
  204. return KeyWidth + StateWidth;
  205. }
  206. public:
  207. TState(TMemoryUsageInfo* memInfo, ui32 keyWidth, ui32 stateWidth, const THashFunc& hash, const TEqualsFunc& equal, bool allowOutOfMemory = false)
  208. : TBase(memInfo), KeyWidth(keyWidth), StateWidth(stateWidth), AllowOutOfMemory(allowOutOfMemory), States(hash, equal, CountRowsOnPage) {
  209. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  210. CurrentPosition = 0;
  211. Tongue = CurrentPage->data();
  212. }
  213. ~TState() {
  214. //Workaround for YQL-16663, consider to rework this class in a safe manner
  215. while (auto row = Extract()) {
  216. for (size_t i = 0; i != RowSize(); ++i) {
  217. row[i].UnRef();
  218. }
  219. }
  220. ExtractIt.reset();
  221. Storage.clear();
  222. States.Clear();
  223. CleanupCurrentContext();
  224. }
  225. bool TasteIt() {
  226. Y_ABORT_UNLESS(!ExtractIt);
  227. bool isNew = false;
  228. auto itInsert = States.Insert(Tongue, isNew);
  229. if (isNew) {
  230. CurrentPosition += RowSize();
  231. if (CurrentPosition == CurrentPage->size()) {
  232. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  233. CurrentPosition = 0;
  234. }
  235. Tongue = CurrentPage->data() + CurrentPosition;
  236. }
  237. Throat = States.GetKey(itInsert) + KeyWidth;
  238. if (isNew) {
  239. GrowStates();
  240. }
  241. return isNew;
  242. }
  243. void GrowStates() {
  244. try {
  245. States.CheckGrow();
  246. } catch (TMemoryLimitExceededException) {
  247. YQL_LOG(INFO) << "State failed to grow";
  248. if (IsOutOfMemory || !AllowOutOfMemory) {
  249. throw;
  250. } else {
  251. IsOutOfMemory = true;
  252. }
  253. }
  254. }
  255. bool CheckIsOutOfMemory() const {
  256. return IsOutOfMemory;
  257. }
  258. template<bool SkipYields>
  259. bool ReadMore() {
  260. if constexpr (SkipYields) {
  261. if (EFetchResult::Yield == InputStatus)
  262. return true;
  263. }
  264. if (!States.Empty())
  265. return false;
  266. {
  267. TStorage localStorage;
  268. std::swap(localStorage, Storage);
  269. }
  270. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  271. CurrentPosition = 0;
  272. Tongue = CurrentPage->data();
  273. StoredDataSize = 0;
  274. CleanupCurrentContext();
  275. return true;
  276. }
  277. void PushStat(IStatsRegistry* stats) const {
  278. if (!States.Empty()) {
  279. MKQL_SET_MAX_STAT(stats, Combine_MaxRowsCount, static_cast<i64>(States.GetSize()));
  280. MKQL_INC_STAT(stats, Combine_FlushesCount);
  281. }
  282. }
  283. NUdf::TUnboxedValuePod* Extract() {
  284. if (!ExtractIt) {
  285. ExtractIt.emplace(Storage, RowSize(), States.GetSize());
  286. } else {
  287. ExtractIt->Next();
  288. }
  289. if (!ExtractIt->IsValid()) {
  290. ExtractIt.reset();
  291. States.Clear();
  292. return nullptr;
  293. }
  294. NUdf::TUnboxedValuePod* result = ExtractIt->GetValuePtr();
  295. CounterOutputRows_.Inc();
  296. return result;
  297. }
  298. EFetchResult InputStatus = EFetchResult::One;
  299. NUdf::TUnboxedValuePod* Tongue = nullptr;
  300. NUdf::TUnboxedValuePod* Throat = nullptr;
  301. i64 StoredDataSize = 0;
  302. NYql::NUdf::TCounter CounterOutputRows_;
  303. private:
  304. std::optional<TStorageIterator> ExtractIt;
  305. const ui32 KeyWidth, StateWidth;
  306. const bool AllowOutOfMemory;
  307. bool IsOutOfMemory = false;
  308. ui64 CurrentPosition = 0;
  309. TRow* CurrentPage = nullptr;
  310. TStorage Storage;
  311. TStates States;
  312. };
  313. class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
  314. typedef TComputationValue<TSpillingSupportState> TBase;
  315. typedef std::optional<NThreading::TFuture<ISpiller::TKey>> TAsyncWriteOperation;
  316. typedef std::optional<NThreading::TFuture<std::optional<TChunkedBuffer>>> TAsyncReadOperation;
  317. struct TSpilledBucket {
  318. std::unique_ptr<TWideUnboxedValuesSpillerAdapter> SpilledState; //state collected before switching to spilling mode
  319. std::unique_ptr<TWideUnboxedValuesSpillerAdapter> SpilledData; //data collected in spilling mode
  320. std::unique_ptr<TState> InMemoryProcessingState;
  321. TAsyncWriteOperation AsyncWriteOperation;
  322. enum class EBucketState {
  323. InMemory,
  324. SpillingState,
  325. SpillingData
  326. };
  327. EBucketState BucketState = EBucketState::InMemory;
  328. ui64 LineCount = 0;
  329. };
  330. enum class EOperatingMode {
  331. InMemory,
  332. SplittingState,
  333. Spilling,
  334. ProcessSpilled
  335. };
  336. public:
  337. enum class ETasteResult: i8 {
  338. Init = -1,
  339. Update,
  340. ConsumeRawData
  341. };
  342. enum class EUpdateResult: i8 {
  343. Yield = -1,
  344. ExtractRawData,
  345. ReadInput,
  346. Extract,
  347. Finish
  348. };
  349. TSpillingSupportState(
  350. TMemoryUsageInfo* memInfo,
  351. const TMultiType* usedInputItemType, const TMultiType* keyAndStateType, ui32 keyWidth, size_t itemNodesSize,
  352. const THashFunc& hash, const TEqualsFunc& equal, bool allowSpilling, TComputationContext& ctx
  353. )
  354. : TBase(memInfo)
  355. , InMemoryProcessingState(memInfo, keyWidth, keyAndStateType->GetElementsCount() - keyWidth, hash, equal, allowSpilling && ctx.SpillerFactory)
  356. , UsedInputItemType(usedInputItemType)
  357. , KeyAndStateType(keyAndStateType)
  358. , KeyWidth(keyWidth)
  359. , ItemNodesSize(itemNodesSize)
  360. , Hasher(hash)
  361. , Mode(EOperatingMode::InMemory)
  362. , ViewForKeyAndState(keyAndStateType->GetElementsCount())
  363. , MemInfo(memInfo)
  364. , Equal(equal)
  365. , AllowSpilling(allowSpilling)
  366. , Ctx(ctx)
  367. {
  368. BufferForUsedInputItems.reserve(usedInputItemType->GetElementsCount());
  369. Tongue = InMemoryProcessingState.Tongue;
  370. Throat = InMemoryProcessingState.Throat;
  371. if (ctx.CountersProvider) {
  372. // id will be assigned externally in future versions
  373. TString id = TString(Operator_Aggregation) + "0";
  374. CounterOutputRows_ = ctx.CountersProvider->GetCounter(id, Counter_OutputRows, false);
  375. }
  376. }
  377. EUpdateResult Update() {
  378. if (IsEverythingExtracted) {
  379. return EUpdateResult::Finish;
  380. }
  381. switch (GetMode()) {
  382. case EOperatingMode::InMemory: {
  383. Tongue = InMemoryProcessingState.Tongue;
  384. if (CheckMemoryAndSwitchToSpilling()) {
  385. return Update();
  386. }
  387. if (InputStatus == EFetchResult::Finish) return EUpdateResult::Extract;
  388. return EUpdateResult::ReadInput;
  389. }
  390. case EOperatingMode::SplittingState: {
  391. if (SplitStateIntoBucketsAndWait()) return EUpdateResult::Yield;
  392. return Update();
  393. }
  394. case EOperatingMode::Spilling: {
  395. UpdateSpillingBuckets();
  396. if (!HasMemoryForProcessing() && InputStatus != EFetchResult::Finish && TryToReduceMemoryAndWait()) {
  397. return EUpdateResult::Yield;
  398. }
  399. if (BufferForUsedInputItems.size()) {
  400. auto& bucket = SpilledBuckets[BufferForUsedInputItemsBucketId];
  401. if (bucket.AsyncWriteOperation.has_value()) return EUpdateResult::Yield;
  402. bucket.AsyncWriteOperation = bucket.SpilledData->WriteWideItem(BufferForUsedInputItems);
  403. BufferForUsedInputItems.resize(0); //for freeing allocated key value asap
  404. }
  405. if (InputStatus == EFetchResult::Finish) return FlushSpillingBuffersAndWait();
  406. return EUpdateResult::ReadInput;
  407. }
  408. case EOperatingMode::ProcessSpilled:
  409. return ProcessSpilledData();
  410. }
  411. }
  412. ETasteResult TasteIt() {
  413. if (GetMode() == EOperatingMode::InMemory) {
  414. bool isNew = InMemoryProcessingState.TasteIt();
  415. if (InMemoryProcessingState.CheckIsOutOfMemory()) {
  416. StateWantsToSpill = true;
  417. }
  418. Throat = InMemoryProcessingState.Throat;
  419. return isNew ? ETasteResult::Init : ETasteResult::Update;
  420. }
  421. if (GetMode() == EOperatingMode::ProcessSpilled) {
  422. // while restoration we process buckets one by one starting from the first in a queue
  423. bool isNew = SpilledBuckets.front().InMemoryProcessingState->TasteIt();
  424. Throat = SpilledBuckets.front().InMemoryProcessingState->Throat;
  425. BufferForUsedInputItems.resize(0);
  426. return isNew ? ETasteResult::Init : ETasteResult::Update;
  427. }
  428. auto bucketId = ChooseBucket(ViewForKeyAndState.data());
  429. auto& bucket = SpilledBuckets[bucketId];
  430. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) {
  431. std::copy_n(ViewForKeyAndState.data(), KeyWidth, static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Tongue));
  432. bool isNew = bucket.InMemoryProcessingState->TasteIt();
  433. Throat = bucket.InMemoryProcessingState->Throat;
  434. bucket.LineCount += isNew;
  435. return isNew ? ETasteResult::Init : ETasteResult::Update;
  436. }
  437. bucket.LineCount++;
  438. // Prepare space for raw data
  439. MKQL_ENSURE(BufferForUsedInputItems.size() == 0, "Internal logic error");
  440. BufferForUsedInputItems.resize(ItemNodesSize);
  441. BufferForUsedInputItemsBucketId = bucketId;
  442. Throat = BufferForUsedInputItems.data();
  443. return ETasteResult::ConsumeRawData;
  444. }
  445. NUdf::TUnboxedValuePod* Extract() {
  446. NUdf::TUnboxedValue* value = nullptr;
  447. if (GetMode() == EOperatingMode::InMemory) {
  448. value = static_cast<NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract());
  449. if (value) {
  450. CounterOutputRows_.Inc();
  451. } else {
  452. IsEverythingExtracted = true;
  453. }
  454. return value;
  455. }
  456. MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error");
  457. MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error");
  458. value = static_cast<NUdf::TUnboxedValue*>(SpilledBuckets.front().InMemoryProcessingState->Extract());
  459. if (value) {
  460. CounterOutputRows_.Inc();
  461. } else {
  462. SpilledBuckets.front().InMemoryProcessingState->ReadMore<false>();
  463. SpilledBuckets.pop_front();
  464. if (SpilledBuckets.empty()) IsEverythingExtracted = true;
  465. }
  466. return value;
  467. }
  468. private:
  469. ui64 ChooseBucket(const NUdf::TUnboxedValuePod *const key) {
  470. auto provided_hash = Hasher(key);
  471. XXH64_hash_t bucket = XXH64(&provided_hash, sizeof(provided_hash), 0) % SpilledBucketCount;
  472. return bucket;
  473. }
  474. EUpdateResult FlushSpillingBuffersAndWait() {
  475. UpdateSpillingBuckets();
  476. ui64 finishedCount = 0;
  477. for (auto& bucket : SpilledBuckets) {
  478. MKQL_ENSURE(bucket.BucketState != TSpilledBucket::EBucketState::SpillingState, "Internal logic error");
  479. if (!bucket.AsyncWriteOperation.has_value()) {
  480. auto writeOperation = bucket.SpilledData->FinishWriting();
  481. if (!writeOperation) {
  482. ++finishedCount;
  483. } else {
  484. bucket.AsyncWriteOperation = writeOperation;
  485. }
  486. }
  487. }
  488. if (finishedCount != SpilledBuckets.size()) return EUpdateResult::Yield;
  489. SwitchMode(EOperatingMode::ProcessSpilled);
  490. return ProcessSpilledData();
  491. }
  492. ui32 GetLargestInMemoryBucketNumber() const {
  493. ui64 maxSize = 0;
  494. ui32 largestInMemoryBucketNum = (ui32)-1;
  495. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  496. if (SpilledBuckets[i].BucketState == TSpilledBucket::EBucketState::InMemory) {
  497. if (SpilledBuckets[i].LineCount >= maxSize) {
  498. largestInMemoryBucketNum = i;
  499. maxSize = SpilledBuckets[i].LineCount;
  500. }
  501. }
  502. }
  503. return largestInMemoryBucketNum;
  504. }
  505. bool IsSpillingWhileStateSplitAllowed() const {
  506. // TODO: Write better condition here. For example: InMemorybuckets > 64
  507. return true;
  508. }
  509. bool SplitStateIntoBucketsAndWait() {
  510. if (SplitStateSpillingBucket != -1) {
  511. auto& bucket = SpilledBuckets[SplitStateSpillingBucket];
  512. MKQL_ENSURE(bucket.AsyncWriteOperation.has_value(), "Internal logic error");
  513. if (!bucket.AsyncWriteOperation->HasValue()) return true;
  514. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  515. bucket.AsyncWriteOperation = std::nullopt;
  516. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  517. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  518. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  519. //releasing values stored in unsafe TUnboxedValue buffer
  520. keyAndState[i].UnRef();
  521. }
  522. if (bucket.AsyncWriteOperation) return true;
  523. }
  524. SplitStateSpillingBucket = -1;
  525. }
  526. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue *>(InMemoryProcessingState.Extract())) {
  527. auto bucketId = ChooseBucket(keyAndState); // This uses only key for hashing
  528. auto& bucket = SpilledBuckets[bucketId];
  529. bucket.LineCount++;
  530. if (bucket.BucketState != TSpilledBucket::EBucketState::InMemory) {
  531. if (bucket.BucketState != TSpilledBucket::EBucketState::SpillingState) {
  532. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  533. SpillingBucketsCount++;
  534. }
  535. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  536. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  537. //releasing values stored in unsafe TUnboxedValue buffer
  538. keyAndState[i].UnRef();
  539. }
  540. if (bucket.AsyncWriteOperation) {
  541. SplitStateSpillingBucket = bucketId;
  542. return true;
  543. }
  544. continue;
  545. }
  546. auto& processingState = *bucket.InMemoryProcessingState;
  547. for (size_t i = 0; i < KeyWidth; ++i) {
  548. //jumping into unsafe world, refusing ownership
  549. static_cast<NUdf::TUnboxedValue&>(processingState.Tongue[i]) = std::move(keyAndState[i]);
  550. }
  551. processingState.TasteIt();
  552. for (size_t i = KeyWidth; i < KeyAndStateType->GetElementsCount(); ++i) {
  553. //jumping into unsafe world, refusing ownership
  554. static_cast<NUdf::TUnboxedValue&>(processingState.Throat[i - KeyWidth]) = std::move(keyAndState[i]);
  555. }
  556. if (InMemoryBucketsCount && !HasMemoryForProcessing() && IsSpillingWhileStateSplitAllowed()) {
  557. ui32 bucketNumToSpill = GetLargestInMemoryBucketNumber();
  558. SplitStateSpillingBucket = bucketNumToSpill;
  559. auto& bucket = SpilledBuckets[bucketNumToSpill];
  560. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  561. SpillingBucketsCount++;
  562. InMemoryBucketsCount--;
  563. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  564. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  565. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  566. //releasing values stored in unsafe TUnboxedValue buffer
  567. keyAndState[i].UnRef();
  568. }
  569. if (bucket.AsyncWriteOperation) return true;
  570. }
  571. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  572. if (bucket.AsyncWriteOperation) return true;
  573. }
  574. }
  575. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  576. auto& bucket = SpilledBuckets[i];
  577. if (bucket.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  578. if (bucket.AsyncWriteOperation.has_value()) {
  579. if (!bucket.AsyncWriteOperation->HasValue()) return true;
  580. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  581. bucket.AsyncWriteOperation = std::nullopt;
  582. }
  583. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  584. if (bucket.AsyncWriteOperation) return true;
  585. bucket.InMemoryProcessingState->ReadMore<false>();
  586. bucket.BucketState = TSpilledBucket::EBucketState::SpillingData;
  587. SpillingBucketsCount--;
  588. }
  589. }
  590. InMemoryProcessingState.ReadMore<false>();
  591. IsInMemoryProcessingStateSplitted = true;
  592. SwitchMode(EOperatingMode::Spilling);
  593. return false;
  594. }
  595. bool CheckMemoryAndSwitchToSpilling() {
  596. if (!(AllowSpilling && Ctx.SpillerFactory)) {
  597. return false;
  598. }
  599. if (StateWantsToSpill || IsSwitchToSpillingModeCondition()) {
  600. StateWantsToSpill = false;
  601. LogMemoryUsage();
  602. SwitchMode(EOperatingMode::SplittingState);
  603. return true;
  604. }
  605. return false;
  606. }
  607. void LogMemoryUsage() const {
  608. const auto used = TlsAllocState->GetUsed();
  609. const auto limit = TlsAllocState->GetLimit();
  610. TStringBuilder logmsg;
  611. logmsg << "Memory usage: ";
  612. if (limit) {
  613. logmsg << (used*100/limit) << "%=";
  614. }
  615. logmsg << (used/1_MB) << "MB/" << (limit/1_MB) << "MB";
  616. YQL_LOG(INFO) << logmsg;
  617. }
  618. void SpillMoreStateFromBucket(TSpilledBucket& bucket) {
  619. MKQL_ENSURE(!bucket.AsyncWriteOperation.has_value(), "Internal logic error");
  620. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) {
  621. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  622. SpillingBucketsCount++;
  623. InMemoryBucketsCount--;
  624. }
  625. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  626. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  627. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  628. //releasing values stored in unsafe TUnboxedValue buffer
  629. keyAndState[i].UnRef();
  630. }
  631. if (bucket.AsyncWriteOperation) return;
  632. }
  633. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  634. if (bucket.AsyncWriteOperation) return;
  635. bucket.InMemoryProcessingState->ReadMore<false>();
  636. bucket.BucketState = TSpilledBucket::EBucketState::SpillingData;
  637. SpillingBucketsCount--;
  638. }
  639. void UpdateSpillingBuckets() {
  640. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  641. auto& bucket = SpilledBuckets[i];
  642. if (bucket.AsyncWriteOperation.has_value() && bucket.AsyncWriteOperation->HasValue()) {
  643. if (bucket.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  644. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  645. bucket.AsyncWriteOperation = std::nullopt;
  646. SpillMoreStateFromBucket(bucket);
  647. } else {
  648. bucket.SpilledData->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  649. bucket.AsyncWriteOperation = std::nullopt;
  650. }
  651. }
  652. }
  653. }
  654. bool TryToReduceMemoryAndWait() {
  655. if (SpillingBucketsCount > 0) {
  656. return true;
  657. }
  658. while (InMemoryBucketsCount > 0) {
  659. ui32 maxLineBucketInd = GetLargestInMemoryBucketNumber();
  660. MKQL_ENSURE(maxLineBucketInd != (ui32)-1, "Internal logic error");
  661. auto& bucketToSpill = SpilledBuckets[maxLineBucketInd];
  662. SpillMoreStateFromBucket(bucketToSpill);
  663. if (bucketToSpill.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  664. return true;
  665. }
  666. }
  667. return false;
  668. }
  669. EUpdateResult ProcessSpilledData() {
  670. if (AsyncReadOperation) {
  671. if (!AsyncReadOperation->HasValue()) return EUpdateResult::Yield;
  672. if (RecoverState) {
  673. SpilledBuckets[0].SpilledState->AsyncReadCompleted(AsyncReadOperation->ExtractValue().value(), Ctx.HolderFactory);
  674. } else {
  675. SpilledBuckets[0].SpilledData->AsyncReadCompleted(AsyncReadOperation->ExtractValue().value(), Ctx.HolderFactory);
  676. }
  677. AsyncReadOperation = std::nullopt;
  678. }
  679. auto& bucket = SpilledBuckets.front();
  680. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) return EUpdateResult::Extract;
  681. //recover spilled state
  682. while(!bucket.SpilledState->Empty()) {
  683. RecoverState = true;
  684. TTemporaryUnboxedValueVector bufferForKeyAndState(KeyAndStateType->GetElementsCount());
  685. AsyncReadOperation = bucket.SpilledState->ExtractWideItem(bufferForKeyAndState);
  686. if (AsyncReadOperation) {
  687. return EUpdateResult::Yield;
  688. }
  689. for (size_t i = 0; i< KeyWidth; ++i) {
  690. //jumping into unsafe world, refusing ownership
  691. static_cast<NUdf::TUnboxedValue&>(bucket.InMemoryProcessingState->Tongue[i]) = std::move(bufferForKeyAndState[i]);
  692. }
  693. auto isNew = bucket.InMemoryProcessingState->TasteIt();
  694. MKQL_ENSURE(isNew, "Internal logic error");
  695. for (size_t i = KeyWidth; i < KeyAndStateType->GetElementsCount(); ++i) {
  696. //jumping into unsafe world, refusing ownership
  697. static_cast<NUdf::TUnboxedValue&>(bucket.InMemoryProcessingState->Throat[i - KeyWidth]) = std::move(bufferForKeyAndState[i]);
  698. }
  699. }
  700. //process spilled data
  701. if (!bucket.SpilledData->Empty()) {
  702. RecoverState = false;
  703. BufferForUsedInputItems.resize(UsedInputItemType->GetElementsCount());
  704. AsyncReadOperation = bucket.SpilledData->ExtractWideItem(BufferForUsedInputItems);
  705. if (AsyncReadOperation) {
  706. return EUpdateResult::Yield;
  707. }
  708. Throat = BufferForUsedInputItems.data();
  709. Tongue = bucket.InMemoryProcessingState->Tongue;
  710. return EUpdateResult::ExtractRawData;
  711. }
  712. bucket.BucketState = TSpilledBucket::EBucketState::InMemory;
  713. return EUpdateResult::Extract;
  714. }
  715. EOperatingMode GetMode() const {
  716. return Mode;
  717. }
  718. void SwitchMode(EOperatingMode mode) {
  719. switch(mode) {
  720. case EOperatingMode::InMemory: {
  721. YQL_LOG(INFO) << "switching Memory mode to InMemory";
  722. MKQL_ENSURE(false, "Internal logic error");
  723. break;
  724. }
  725. case EOperatingMode::SplittingState: {
  726. YQL_LOG(INFO) << "switching Memory mode to SplittingState";
  727. MKQL_ENSURE(EOperatingMode::InMemory == Mode, "Internal logic error");
  728. SpilledBuckets.resize(SpilledBucketCount);
  729. auto spiller = Ctx.SpillerFactory->CreateSpiller();
  730. for (auto &b: SpilledBuckets) {
  731. b.SpilledState = std::make_unique<TWideUnboxedValuesSpillerAdapter>(spiller, KeyAndStateType, 5_MB);
  732. b.SpilledData = std::make_unique<TWideUnboxedValuesSpillerAdapter>(spiller, UsedInputItemType, 5_MB);
  733. b.InMemoryProcessingState = std::make_unique<TState>(MemInfo, KeyWidth, KeyAndStateType->GetElementsCount() - KeyWidth, Hasher, Equal);
  734. }
  735. break;
  736. }
  737. case EOperatingMode::Spilling: {
  738. YQL_LOG(INFO) << "switching Memory mode to Spilling";
  739. MKQL_ENSURE(EOperatingMode::SplittingState == Mode || EOperatingMode::InMemory == Mode, "Internal logic error");
  740. Tongue = ViewForKeyAndState.data();
  741. break;
  742. }
  743. case EOperatingMode::ProcessSpilled: {
  744. YQL_LOG(INFO) << "switching Memory mode to ProcessSpilled";
  745. MKQL_ENSURE(EOperatingMode::Spilling == Mode, "Internal logic error");
  746. MKQL_ENSURE(SpilledBuckets.size() == SpilledBucketCount, "Internal logic error");
  747. std::sort(SpilledBuckets.begin(), SpilledBuckets.end(), [](const TSpilledBucket& lhs, const TSpilledBucket& rhs) {
  748. bool lhs_in_memory = lhs.BucketState == TSpilledBucket::EBucketState::InMemory;
  749. bool rhs_in_memory = rhs.BucketState == TSpilledBucket::EBucketState::InMemory;
  750. return lhs_in_memory > rhs_in_memory;
  751. });
  752. break;
  753. }
  754. }
  755. Mode = mode;
  756. }
  757. bool HasMemoryForProcessing() const {
  758. return !TlsAllocState->IsMemoryYellowZoneEnabled();
  759. }
  760. bool IsSwitchToSpillingModeCondition() const {
  761. return !HasMemoryForProcessing() || TlsAllocState->GetMaximumLimitValueReached();
  762. }
  763. public:
  764. EFetchResult InputStatus = EFetchResult::One;
  765. NUdf::TUnboxedValuePod* Tongue = nullptr;
  766. NUdf::TUnboxedValuePod* Throat = nullptr;
  767. private:
  768. bool StateWantsToSpill = false;
  769. bool IsEverythingExtracted = false;
  770. TState InMemoryProcessingState;
  771. bool IsInMemoryProcessingStateSplitted = false;
  772. const TMultiType* const UsedInputItemType;
  773. const TMultiType* const KeyAndStateType;
  774. const size_t KeyWidth;
  775. const size_t ItemNodesSize;
  776. THashFunc const Hasher;
  777. EOperatingMode Mode;
  778. bool RecoverState; //sub mode for ProcessSpilledData
  779. TAsyncReadOperation AsyncReadOperation = std::nullopt;
  780. static constexpr size_t SpilledBucketCount = 128;
  781. std::deque<TSpilledBucket> SpilledBuckets;
  782. ui32 SpillingBucketsCount = 0;
  783. ui32 InMemoryBucketsCount = SpilledBucketCount;
  784. ui64 BufferForUsedInputItemsBucketId;
  785. TUnboxedValueVector BufferForUsedInputItems;
  786. std::vector<NUdf::TUnboxedValuePod, TMKQLAllocator<NUdf::TUnboxedValuePod>> ViewForKeyAndState;
  787. i64 SplitStateSpillingBucket = -1;
  788. TMemoryUsageInfo* MemInfo = nullptr;
  789. TEqualsFunc const Equal;
  790. const bool AllowSpilling;
  791. TComputationContext& Ctx;
  792. NYql::NUdf::TCounter CounterOutputRows_;
  793. };
  794. #ifndef MKQL_DISABLE_CODEGEN
  795. class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState>> {
  796. private:
  797. using TBase = TLLVMFieldsStructure<TComputationValue<TState>>;
  798. llvm::IntegerType* ValueType;
  799. llvm::PointerType* PtrValueType;
  800. llvm::IntegerType* StatusType;
  801. llvm::IntegerType* StoredType;
  802. protected:
  803. using TBase::Context;
  804. public:
  805. std::vector<llvm::Type*> GetFieldsArray() {
  806. std::vector<llvm::Type*> result = TBase::GetFields();
  807. result.emplace_back(StatusType); //status
  808. result.emplace_back(PtrValueType); //tongue
  809. result.emplace_back(PtrValueType); //throat
  810. result.emplace_back(StoredType); //StoredDataSize
  811. result.emplace_back(Type::getInt32Ty(Context)); //size
  812. result.emplace_back(Type::getInt32Ty(Context)); //size
  813. return result;
  814. }
  815. llvm::Constant* GetStatus() {
  816. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
  817. }
  818. llvm::Constant* GetTongue() {
  819. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
  820. }
  821. llvm::Constant* GetThroat() {
  822. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 2);
  823. }
  824. llvm::Constant* GetStored() {
  825. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3);
  826. }
  827. TLLVMFieldsStructureState(llvm::LLVMContext& context)
  828. : TBase(context)
  829. , ValueType(Type::getInt128Ty(Context))
  830. , PtrValueType(PointerType::getUnqual(ValueType))
  831. , StatusType(Type::getInt32Ty(Context))
  832. , StoredType(Type::getInt64Ty(Context)) {
  833. }
  834. };
  835. #endif
  836. template <bool TrackRss, bool SkipYields>
  837. class TWideCombinerWrapper: public TStatefulWideFlowCodegeneratorNode<TWideCombinerWrapper<TrackRss, SkipYields>>
  838. #ifndef MKQL_DISABLE_CODEGEN
  839. , public ICodegeneratorRootNode
  840. #endif
  841. {
  842. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideCombinerWrapper<TrackRss, SkipYields>>;
  843. public:
  844. TWideCombinerWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TCombinerNodes&& nodes, TKeyTypes&& keyTypes, ui64 memLimit)
  845. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  846. , Flow(flow)
  847. , Nodes(std::move(nodes))
  848. , KeyTypes(std::move(keyTypes))
  849. , MemLimit(memLimit)
  850. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Nodes.ItemNodes.size()))
  851. {}
  852. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  853. if (state.IsInvalid()) {
  854. MakeState(ctx, state);
  855. }
  856. while (const auto ptr = static_cast<TState*>(state.AsBoxed().Get())) {
  857. if (ptr->ReadMore<SkipYields>()) {
  858. switch (ptr->InputStatus) {
  859. case EFetchResult::One:
  860. break;
  861. case EFetchResult::Yield:
  862. ptr->InputStatus = EFetchResult::One;
  863. if constexpr (SkipYields)
  864. break;
  865. else
  866. return EFetchResult::Yield;
  867. case EFetchResult::Finish:
  868. return EFetchResult::Finish;
  869. }
  870. const auto initUsage = MemLimit ? ctx.HolderFactory.GetMemoryUsed() : 0ULL;
  871. auto **fields = ctx.WideFields.data() + WideFieldsIndex;
  872. do {
  873. for (auto i = 0U; i < Nodes.ItemNodes.size(); ++i)
  874. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U || Nodes.PasstroughtItems[i])
  875. fields[i] = &Nodes.ItemNodes[i]->RefValue(ctx);
  876. ptr->InputStatus = Flow->FetchValues(ctx, fields);
  877. if constexpr (SkipYields) {
  878. if (EFetchResult::Yield == ptr->InputStatus) {
  879. if (MemLimit) {
  880. const auto currentUsage = ctx.HolderFactory.GetMemoryUsed();
  881. ptr->StoredDataSize += currentUsage > initUsage ? currentUsage - initUsage : 0;
  882. }
  883. return EFetchResult::Yield;
  884. } else if (EFetchResult::Finish == ptr->InputStatus) {
  885. break;
  886. }
  887. } else {
  888. if (EFetchResult::One != ptr->InputStatus) {
  889. break;
  890. }
  891. }
  892. Nodes.ExtractKey(ctx, fields, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  893. Nodes.ProcessItem(ctx, ptr->TasteIt() ? nullptr : static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  894. } while (!ctx.template CheckAdjustedMemLimit<TrackRss>(MemLimit, initUsage - ptr->StoredDataSize));
  895. ptr->PushStat(ctx.Stats);
  896. }
  897. if (const auto values = static_cast<NUdf::TUnboxedValue*>(ptr->Extract())) {
  898. Nodes.FinishItem(ctx, values, output);
  899. return EFetchResult::One;
  900. }
  901. }
  902. Y_UNREACHABLE();
  903. }
  904. #ifndef MKQL_DISABLE_CODEGEN
  905. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  906. auto& context = ctx.Codegen.GetContext();
  907. const auto valueType = Type::getInt128Ty(context);
  908. const auto ptrValueType = PointerType::getUnqual(valueType);
  909. const auto statusType = Type::getInt32Ty(context);
  910. const auto storedType = Type::getInt64Ty(context);
  911. TLLVMFieldsStructureState stateFields(context);
  912. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  913. const auto statePtrType = PointerType::getUnqual(stateType);
  914. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  915. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  916. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  917. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  918. block = make;
  919. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  920. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  921. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideCombinerWrapper::MakeState));
  922. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  923. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  924. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  925. BranchInst::Create(main, block);
  926. block = main;
  927. const auto state = new LoadInst(valueType, statePtr, "state", block);
  928. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  929. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  930. BranchInst::Create(more, block);
  931. block = more;
  932. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  933. const auto result = PHINode::Create(statusType, 3U, "result", over);
  934. const auto readMoreFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ReadMore<SkipYields>));
  935. const auto readMoreFuncType = FunctionType::get(Type::getInt1Ty(context), { statePtrType }, false);
  936. const auto readMoreFuncPtr = CastInst::Create(Instruction::IntToPtr, readMoreFunc, PointerType::getUnqual(readMoreFuncType), "read_more_func", block);
  937. const auto readMore = CallInst::Create(readMoreFuncType, readMoreFuncPtr, { stateArg }, "read_more", block);
  938. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  939. const auto full = BasicBlock::Create(context, "full", ctx.Func);
  940. BranchInst::Create(next, full, readMore, block);
  941. {
  942. block = next;
  943. const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
  944. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  945. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  946. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  947. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  948. const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
  949. const auto last = new LoadInst(statusType, statusPtr, "last", block);
  950. result->addIncoming(last, block);
  951. const auto choise = SwitchInst::Create(last, pull, 2U, block);
  952. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), rest);
  953. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), over);
  954. block = rest;
  955. new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::One)), statusPtr, block);
  956. if constexpr (SkipYields) {
  957. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), statusPtr, block);
  958. BranchInst::Create(pull, block);
  959. } else {
  960. result->addIncoming(last, block);
  961. BranchInst::Create(over, block);
  962. }
  963. block = pull;
  964. const auto used = GetMemoryUsed(MemLimit, ctx, block);
  965. BranchInst::Create(loop, block);
  966. block = loop;
  967. const auto getres = GetNodeValues(Flow, ctx, block);
  968. if constexpr (SkipYields) {
  969. const auto save = BasicBlock::Create(context, "save", ctx.Func);
  970. const auto way = SwitchInst::Create(getres.first, good, 2U, block);
  971. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), save);
  972. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), done);
  973. block = save;
  974. if (MemLimit) {
  975. const auto storedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStored() }, "stored_ptr", block);
  976. const auto lastStored = new LoadInst(storedType, storedPtr, "last_stored", block);
  977. const auto currentUsage = GetMemoryUsed(MemLimit, ctx, block);
  978. const auto skipSavingUsed = BasicBlock::Create(context, "skip_saving_used", ctx.Func);
  979. const auto saveUsed = BasicBlock::Create(context, "save_used", ctx.Func);
  980. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, currentUsage, used, "check", block);
  981. BranchInst::Create(saveUsed, skipSavingUsed, check, block);
  982. block = saveUsed;
  983. const auto usedMemory = BinaryOperator::CreateSub(GetMemoryUsed(MemLimit, ctx, block), used, "used_memory", block);
  984. const auto inc = BinaryOperator::CreateAdd(lastStored, usedMemory, "inc", block);
  985. new StoreInst(inc, storedPtr, block);
  986. BranchInst::Create(skipSavingUsed, block);
  987. block = skipSavingUsed;
  988. }
  989. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), statusPtr, block);
  990. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  991. BranchInst::Create(over, block);
  992. } else {
  993. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  994. BranchInst::Create(done, good, special, block);
  995. }
  996. block = good;
  997. std::vector<Value*> items(Nodes.ItemNodes.size(), nullptr);
  998. for (ui32 i = 0U; i < items.size(); ++i) {
  999. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1000. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i] = getres.second[i](ctx, block));
  1001. else if (Nodes.PasstroughtItems[i])
  1002. items[i] = getres.second[i](ctx, block);
  1003. }
  1004. const auto tonguePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetTongue() }, "tongue_ptr", block);
  1005. const auto tongue = new LoadInst(ptrValueType, tonguePtr, "tongue", block);
  1006. std::vector<Value*> keyPointers(Nodes.KeyResultNodes.size(), nullptr), keys(Nodes.KeyResultNodes.size(), nullptr);
  1007. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1008. auto& key = keys[i];
  1009. const auto keyPtr = keyPointers[i] = GetElementPtrInst::CreateInBounds(valueType, tongue, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("key_") += ToString(i)).c_str(), block);
  1010. if (const auto map = Nodes.KeysOnItems[i]) {
  1011. auto& it = items[*map];
  1012. if (!it)
  1013. it = getres.second[*map](ctx, block);
  1014. key = it;
  1015. } else {
  1016. key = GetNodeValue(Nodes.KeyResultNodes[i], ctx, block);
  1017. }
  1018. if (Nodes.KeyNodes[i]->GetDependencesCount() > 0U)
  1019. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.KeyNodes[i])->CreateSetValue(ctx, block, key);
  1020. new StoreInst(key, keyPtr, block);
  1021. }
  1022. const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::TasteIt));
  1023. const auto atType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
  1024. const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
  1025. const auto newKey = CallInst::Create(atType, atPtr, {stateArg}, "new_key", block);
  1026. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  1027. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  1028. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  1029. const auto throatPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "throat_ptr", block);
  1030. const auto throat = new LoadInst(ptrValueType, throatPtr, "throat", block);
  1031. std::vector<Value*> pointers;
  1032. pointers.reserve(Nodes.StateNodes.size());
  1033. for (ui32 i = 0U; i < Nodes.StateNodes.size(); ++i) {
  1034. pointers.emplace_back(GetElementPtrInst::CreateInBounds(valueType, throat, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("state_") += ToString(i)).c_str(), block));
  1035. }
  1036. BranchInst::Create(init, next, newKey, block);
  1037. block = init;
  1038. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1039. ValueAddRef(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1040. }
  1041. for (ui32 i = 0U; i < Nodes.InitResultNodes.size(); ++i) {
  1042. if (const auto map = Nodes.InitOnItems[i]) {
  1043. auto& it = items[*map];
  1044. if (!it)
  1045. it = getres.second[*map](ctx, block);
  1046. new StoreInst(it, pointers[i], block);
  1047. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), it, ctx, block);
  1048. } else if (const auto map = Nodes.InitOnKeys[i]) {
  1049. const auto key = keys[*map];
  1050. new StoreInst(key, pointers[i], block);
  1051. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), key, ctx, block);
  1052. } else {
  1053. GetNodeValue(pointers[i], Nodes.InitResultNodes[i], ctx, block);
  1054. }
  1055. }
  1056. BranchInst::Create(test, block);
  1057. block = next;
  1058. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1059. if (Nodes.KeysOnItems[i] || Nodes.KeyResultNodes[i]->IsTemporaryValue())
  1060. ValueCleanup(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1061. }
  1062. std::vector<Value*> stored(Nodes.StateNodes.size(), nullptr);
  1063. for (ui32 i = 0U; i < stored.size(); ++i) {
  1064. const bool hasDependency = Nodes.StateNodes[i]->GetDependencesCount() > 0U;
  1065. if (const auto map = Nodes.StateOnUpdate[i]) {
  1066. if (hasDependency || i != *map) {
  1067. stored[i] = new LoadInst(valueType, pointers[i], (TString("state_") += ToString(i)).c_str(), block);
  1068. if (hasDependency)
  1069. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, stored[i]);
  1070. }
  1071. } else if (hasDependency) {
  1072. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, pointers[i]);
  1073. } else {
  1074. ValueUnRef(Nodes.StateNodes[i]->GetRepresentation(), pointers[i], ctx, block);
  1075. }
  1076. }
  1077. for (ui32 i = 0U; i < Nodes.UpdateResultNodes.size(); ++i) {
  1078. if (const auto map = Nodes.UpdateOnState[i]) {
  1079. if (const auto j = *map; i != j) {
  1080. auto& it = stored[j];
  1081. if (!it)
  1082. it = new LoadInst(valueType, pointers[j], (TString("state_") += ToString(j)).c_str(), block);
  1083. new StoreInst(it, pointers[i], block);
  1084. if (i != *Nodes.StateOnUpdate[j])
  1085. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1086. }
  1087. } else if (const auto map = Nodes.UpdateOnItems[i]) {
  1088. auto& it = items[*map];
  1089. if (!it)
  1090. it = getres.second[*map](ctx, block);
  1091. new StoreInst(it, pointers[i], block);
  1092. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1093. } else if (const auto map = Nodes.UpdateOnKeys[i]) {
  1094. const auto key = keys[*map];
  1095. new StoreInst(key, pointers[i], block);
  1096. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), key, ctx, block);
  1097. } else {
  1098. GetNodeValue(pointers[i], Nodes.UpdateResultNodes[i], ctx, block);
  1099. }
  1100. }
  1101. BranchInst::Create(test, block);
  1102. block = test;
  1103. auto totalUsed = used;
  1104. if (MemLimit) {
  1105. const auto storedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStored() }, "stored_ptr", block);
  1106. const auto lastStored = new LoadInst(storedType, storedPtr, "last_stored", block);
  1107. totalUsed = BinaryOperator::CreateSub(used, lastStored, "decr", block);
  1108. }
  1109. const auto check = CheckAdjustedMemLimit<TrackRss>(MemLimit, totalUsed, ctx, block);
  1110. BranchInst::Create(done, loop, check, block);
  1111. block = done;
  1112. new StoreInst(getres.first, statusPtr, block);
  1113. const auto stat = ctx.GetStat();
  1114. const auto statFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::PushStat));
  1115. const auto statType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType(), stat->getType()}, false);
  1116. const auto statPtr = CastInst::Create(Instruction::IntToPtr, statFunc, PointerType::getUnqual(statType), "stat", block);
  1117. CallInst::Create(statType, statPtr, {stateArg, stat}, "", block);
  1118. BranchInst::Create(full, block);
  1119. }
  1120. {
  1121. block = full;
  1122. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  1123. const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract));
  1124. const auto extractType = FunctionType::get(ptrValueType, {stateArg->getType()}, false);
  1125. const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
  1126. const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
  1127. const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block);
  1128. BranchInst::Create(good, more, has, block);
  1129. block = good;
  1130. for (ui32 i = 0U; i < Nodes.FinishNodes.size(); ++i) {
  1131. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, out, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("out_key_") += ToString(i)).c_str(), block);
  1132. if (Nodes.FinishNodes[i]->GetDependencesCount() > 0 || Nodes.ItemsOnResult[i])
  1133. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.FinishNodes[i])->CreateSetValue(ctx, block, ptr);
  1134. else
  1135. ValueUnRef(Nodes.FinishNodes[i]->GetRepresentation(), ptr, ctx, block);
  1136. }
  1137. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  1138. BranchInst::Create(over, block);
  1139. }
  1140. block = over;
  1141. ICodegeneratorInlineWideNode::TGettersList getters;
  1142. getters.reserve(Nodes.FinishResultNodes.size());
  1143. std::transform(Nodes.FinishResultNodes.cbegin(), Nodes.FinishResultNodes.cend(), std::back_inserter(getters), [&](IComputationNode* node) {
  1144. return [node](const TCodegenContext& ctx, BasicBlock*& block){ return GetNodeValue(node, ctx, block); };
  1145. });
  1146. return {result, std::move(getters)};
  1147. }
  1148. #endif
  1149. private:
  1150. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  1151. #ifdef MKQL_DISABLE_CODEGEN
  1152. state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(), TMyValueHasher(KeyTypes), TMyValueEqual(KeyTypes));
  1153. #else
  1154. state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(),
  1155. ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
  1156. ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes))
  1157. );
  1158. #endif
  1159. if (ctx.CountersProvider) {
  1160. const auto ptr = static_cast<TState*>(state.AsBoxed().Get());
  1161. // id will be assigned externally in future versions
  1162. TString id = TString(Operator_Aggregation) + "0";
  1163. ptr->CounterOutputRows_ = ctx.CountersProvider->GetCounter(id, Counter_OutputRows, false);
  1164. }
  1165. }
  1166. void RegisterDependencies() const final {
  1167. if (const auto flow = this->FlowDependsOn(Flow)) {
  1168. Nodes.RegisterDependencies(
  1169. [this, flow](IComputationNode* node){ this->DependsOn(flow, node); },
  1170. [this, flow](IComputationExternalNode* node){ this->Own(flow, node); }
  1171. );
  1172. }
  1173. }
  1174. IComputationWideFlowNode *const Flow;
  1175. const TCombinerNodes Nodes;
  1176. const TKeyTypes KeyTypes;
  1177. const ui64 MemLimit;
  1178. const ui32 WideFieldsIndex;
  1179. #ifndef MKQL_DISABLE_CODEGEN
  1180. TEqualsPtr Equals = nullptr;
  1181. THashPtr Hash = nullptr;
  1182. Function* EqualsFunc = nullptr;
  1183. Function* HashFunc = nullptr;
  1184. template <bool EqualsOrHash>
  1185. TString MakeName() const {
  1186. TStringStream out;
  1187. out << this->DebugString() << "::" << (EqualsOrHash ? "Equals" : "Hash") << "_(" << static_cast<const void*>(this) << ").";
  1188. return out.Str();
  1189. }
  1190. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1191. if (EqualsFunc) {
  1192. Equals = reinterpret_cast<TEqualsPtr>(codegen.GetPointerToFunction(EqualsFunc));
  1193. }
  1194. if (HashFunc) {
  1195. Hash = reinterpret_cast<THashPtr>(codegen.GetPointerToFunction(HashFunc));
  1196. }
  1197. }
  1198. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1199. codegen.ExportSymbol(HashFunc = GenerateHashFunction(codegen, MakeName<false>(), KeyTypes));
  1200. codegen.ExportSymbol(EqualsFunc = GenerateEqualsFunction(codegen, MakeName<true>(), KeyTypes));
  1201. }
  1202. #endif
  1203. };
  1204. class TWideLastCombinerWrapper: public TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWrapper>
  1205. #ifndef MKQL_DISABLE_CODEGEN
  1206. , public ICodegeneratorRootNode
  1207. #endif
  1208. {
  1209. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWrapper>;
  1210. public:
  1211. TWideLastCombinerWrapper(
  1212. TComputationMutables& mutables,
  1213. IComputationWideFlowNode* flow,
  1214. TCombinerNodes&& nodes,
  1215. const TMultiType* usedInputItemType,
  1216. TKeyTypes&& keyTypes,
  1217. const TMultiType* keyAndStateType,
  1218. bool allowSpilling)
  1219. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  1220. , Flow(flow)
  1221. , Nodes(std::move(nodes))
  1222. , KeyTypes(std::move(keyTypes))
  1223. , UsedInputItemType(usedInputItemType)
  1224. , KeyAndStateType(keyAndStateType)
  1225. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Nodes.ItemNodes.size()))
  1226. , AllowSpilling(allowSpilling)
  1227. {}
  1228. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  1229. if (state.IsInvalid()) {
  1230. MakeState(ctx, state);
  1231. }
  1232. if (const auto ptr = static_cast<TSpillingSupportState*>(state.AsBoxed().Get())) {
  1233. auto **fields = ctx.WideFields.data() + WideFieldsIndex;
  1234. while (true) {
  1235. switch(ptr->Update()) {
  1236. case TSpillingSupportState::EUpdateResult::ReadInput: {
  1237. for (auto i = 0U; i < Nodes.ItemNodes.size(); ++i)
  1238. fields[i] = Nodes.GetUsedInputItemNodePtrOrNull(ctx, i);
  1239. switch (ptr->InputStatus = Flow->FetchValues(ctx, fields)) {
  1240. case EFetchResult::One:
  1241. break;
  1242. case EFetchResult::Finish:
  1243. continue;
  1244. case EFetchResult::Yield:
  1245. return EFetchResult::Yield;
  1246. }
  1247. Nodes.ExtractKey(ctx, fields, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  1248. break;
  1249. }
  1250. case TSpillingSupportState::EUpdateResult::Yield:
  1251. return EFetchResult::Yield;
  1252. case TSpillingSupportState::EUpdateResult::ExtractRawData:
  1253. Nodes.ExtractRawData(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Throat), static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  1254. break;
  1255. case TSpillingSupportState::EUpdateResult::Extract:
  1256. if (const auto values = static_cast<NUdf::TUnboxedValue*>(ptr->Extract())) {
  1257. Nodes.FinishItem(ctx, values, output);
  1258. return EFetchResult::One;
  1259. }
  1260. continue;
  1261. case TSpillingSupportState::EUpdateResult::Finish:
  1262. return EFetchResult::Finish;
  1263. }
  1264. switch(ptr->TasteIt()) {
  1265. case TSpillingSupportState::ETasteResult::Init:
  1266. Nodes.ProcessItem(ctx, nullptr, static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1267. break;
  1268. case TSpillingSupportState::ETasteResult::Update:
  1269. Nodes.ProcessItem(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1270. break;
  1271. case TSpillingSupportState::ETasteResult::ConsumeRawData:
  1272. Nodes.ConsumeRawData(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), fields, static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1273. break;
  1274. }
  1275. }
  1276. }
  1277. Y_UNREACHABLE();
  1278. }
  1279. #ifndef MKQL_DISABLE_CODEGEN
  1280. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  1281. auto& context = ctx.Codegen.GetContext();
  1282. const auto valueType = Type::getInt128Ty(context);
  1283. const auto ptrValueType = PointerType::getUnqual(valueType);
  1284. const auto statusType = Type::getInt32Ty(context);
  1285. const auto wayType = Type::getInt8Ty(context);
  1286. TLLVMFieldsStructureState stateFields(context);
  1287. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  1288. const auto statePtrType = PointerType::getUnqual(stateType);
  1289. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  1290. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  1291. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  1292. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  1293. block = make;
  1294. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  1295. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  1296. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideLastCombinerWrapper::MakeState));
  1297. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  1298. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  1299. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  1300. BranchInst::Create(main, block);
  1301. block = main;
  1302. const auto state = new LoadInst(valueType, statePtr, "state", block);
  1303. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  1304. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  1305. BranchInst::Create(more, block);
  1306. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  1307. const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
  1308. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  1309. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  1310. const auto load = BasicBlock::Create(context, "load", ctx.Func);
  1311. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  1312. const auto data = BasicBlock::Create(context, "data", ctx.Func);
  1313. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  1314. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  1315. const auto stub = BasicBlock::Create(context, "stub", ctx.Func);
  1316. new UnreachableInst(context, stub);
  1317. const auto result = PHINode::Create(statusType, 4U, "result", over);
  1318. std::vector<PHINode*> phis(Nodes.ItemNodes.size(), nullptr);
  1319. auto j = 0U;
  1320. std::generate(phis.begin(), phis.end(), [&]() {
  1321. return Nodes.IsInputItemNodeUsed(j++) ?
  1322. PHINode::Create(valueType, 2U, (TString("item_") += ToString(j)).c_str(), test) : nullptr;
  1323. });
  1324. block = more;
  1325. const auto updateFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::Update));
  1326. const auto updateType = FunctionType::get(wayType, {stateArg->getType()}, false);
  1327. const auto updateFuncPtr = CastInst::Create(Instruction::IntToPtr, updateFunc, PointerType::getUnqual(updateType), "update_func", block);
  1328. const auto update = CallInst::Create(updateType, updateFuncPtr, { stateArg }, "update", block);
  1329. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  1330. const auto updateWay = SwitchInst::Create(update, stub, 5U, block);
  1331. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Yield)), over);
  1332. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Extract)), fill);
  1333. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Finish)), done);
  1334. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::ReadInput)), pull);
  1335. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::ExtractRawData)), load);
  1336. block = load;
  1337. const auto extractorPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "extractor_ptr", block);
  1338. const auto extractor = new LoadInst(ptrValueType, extractorPtr, "extractor", block);
  1339. std::vector<Value*> items(phis.size(), nullptr);
  1340. for (ui32 i = 0U; i < items.size(); ++i) {
  1341. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, extractor, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("load_ptr_") += ToString(i)).c_str(), block);
  1342. if (phis[i])
  1343. items[i] = new LoadInst(valueType, ptr, (TString("load_") += ToString(i)).c_str(), block);
  1344. if (i < Nodes.ItemNodes.size() && Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1345. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i]);
  1346. }
  1347. for (ui32 i = 0U; i < phis.size(); ++i) {
  1348. if (const auto phi = phis[i]) {
  1349. phi->addIncoming(items[i], block);
  1350. }
  1351. }
  1352. BranchInst::Create(test, block);
  1353. block = pull;
  1354. const auto getres = GetNodeValues(Flow, ctx, block);
  1355. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  1356. const auto choise = SwitchInst::Create(getres.first, good, 2U, block);
  1357. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), over);
  1358. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), rest);
  1359. block = rest;
  1360. const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
  1361. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), statusPtr, block);
  1362. BranchInst::Create(more, block);
  1363. block = good;
  1364. for (ui32 i = 0U; i < items.size(); ++i) {
  1365. if (phis[i])
  1366. items[i] = getres.second[i](ctx, block);
  1367. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1368. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i]);
  1369. }
  1370. for (ui32 i = 0U; i < phis.size(); ++i) {
  1371. if (const auto phi = phis[i]) {
  1372. phi->addIncoming(items[i], block);
  1373. }
  1374. }
  1375. BranchInst::Create(test, block);
  1376. block = test;
  1377. const auto tonguePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetTongue() }, "tongue_ptr", block);
  1378. const auto tongue = new LoadInst(ptrValueType, tonguePtr, "tongue", block);
  1379. std::vector<Value*> keyPointers(Nodes.KeyResultNodes.size(), nullptr), keys(Nodes.KeyResultNodes.size(), nullptr);
  1380. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1381. auto& key = keys[i];
  1382. const auto keyPtr = keyPointers[i] = GetElementPtrInst::CreateInBounds(valueType, tongue, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("key_") += ToString(i)).c_str(), block);
  1383. if (const auto map = Nodes.KeysOnItems[i]) {
  1384. key = phis[*map];
  1385. } else {
  1386. key = GetNodeValue(Nodes.KeyResultNodes[i], ctx, block);
  1387. }
  1388. if (Nodes.KeyNodes[i]->GetDependencesCount() > 0U)
  1389. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.KeyNodes[i])->CreateSetValue(ctx, block, key);
  1390. new StoreInst(key, keyPtr, block);
  1391. }
  1392. const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::TasteIt));
  1393. const auto atType = FunctionType::get(wayType, {stateArg->getType()}, false);
  1394. const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
  1395. const auto taste= CallInst::Create(atType, atPtr, {stateArg}, "taste", block);
  1396. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  1397. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  1398. const auto save = BasicBlock::Create(context, "save", ctx.Func);
  1399. const auto throatPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "throat_ptr", block);
  1400. const auto throat = new LoadInst(ptrValueType, throatPtr, "throat", block);
  1401. std::vector<Value*> pointers;
  1402. const auto width = std::max(Nodes.StateNodes.size(), phis.size());
  1403. pointers.reserve(width);
  1404. for (ui32 i = 0U; i < width; ++i) {
  1405. pointers.emplace_back(GetElementPtrInst::CreateInBounds(valueType, throat, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("state_") += ToString(i)).c_str(), block));
  1406. }
  1407. const auto way = SwitchInst::Create(taste, stub, 3U, block);
  1408. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Init)), init);
  1409. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Update)), next);
  1410. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::ConsumeRawData)), save);
  1411. block = init;
  1412. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1413. ValueAddRef(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1414. }
  1415. for (ui32 i = 0U; i < Nodes.InitResultNodes.size(); ++i) {
  1416. if (const auto map = Nodes.InitOnItems[i]) {
  1417. const auto item = phis[*map];
  1418. new StoreInst(item, pointers[i], block);
  1419. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), item, ctx, block);
  1420. } else if (const auto map = Nodes.InitOnKeys[i]) {
  1421. const auto key = keys[*map];
  1422. new StoreInst(key, pointers[i], block);
  1423. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), key, ctx, block);
  1424. } else {
  1425. GetNodeValue(pointers[i], Nodes.InitResultNodes[i], ctx, block);
  1426. }
  1427. }
  1428. BranchInst::Create(more, block);
  1429. block = next;
  1430. std::vector<Value*> stored(Nodes.StateNodes.size(), nullptr);
  1431. for (ui32 i = 0U; i < stored.size(); ++i) {
  1432. const bool hasDependency = Nodes.StateNodes[i]->GetDependencesCount() > 0U;
  1433. if (const auto map = Nodes.StateOnUpdate[i]) {
  1434. if (hasDependency || i != *map) {
  1435. stored[i] = new LoadInst(valueType, pointers[i], (TString("state_") += ToString(i)).c_str(), block);
  1436. if (hasDependency)
  1437. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, stored[i]);
  1438. }
  1439. } else if (hasDependency) {
  1440. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, pointers[i]);
  1441. } else {
  1442. ValueUnRef(Nodes.StateNodes[i]->GetRepresentation(), pointers[i], ctx, block);
  1443. }
  1444. }
  1445. for (ui32 i = 0U; i < Nodes.UpdateResultNodes.size(); ++i) {
  1446. if (const auto map = Nodes.UpdateOnState[i]) {
  1447. if (const auto j = *map; i != j) {
  1448. const auto it = stored[j];
  1449. new StoreInst(it, pointers[i], block);
  1450. if (i != *Nodes.StateOnUpdate[j])
  1451. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1452. }
  1453. } else if (const auto map = Nodes.UpdateOnItems[i]) {
  1454. const auto item = phis[*map];
  1455. new StoreInst(item, pointers[i], block);
  1456. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), item, ctx, block);
  1457. } else if (const auto map = Nodes.UpdateOnKeys[i]) {
  1458. const auto key = keys[*map];
  1459. new StoreInst(key, pointers[i], block);
  1460. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), key, ctx, block);
  1461. } else {
  1462. GetNodeValue(pointers[i], Nodes.UpdateResultNodes[i], ctx, block);
  1463. }
  1464. }
  1465. BranchInst::Create(more, block);
  1466. block = save;
  1467. for (ui32 i = 0U; i < phis.size(); ++i) {
  1468. if (const auto item = phis[i]) {
  1469. new StoreInst(item, pointers[i], block);
  1470. ValueAddRef(Nodes.ItemNodes[i]->GetRepresentation(), item, ctx, block);
  1471. }
  1472. }
  1473. BranchInst::Create(more, block);
  1474. block = fill;
  1475. const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::Extract));
  1476. const auto extractType = FunctionType::get(ptrValueType, {stateArg->getType()}, false);
  1477. const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
  1478. const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
  1479. const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block);
  1480. BranchInst::Create(data, more, has, block);
  1481. block = data;
  1482. for (ui32 i = 0U; i < Nodes.FinishNodes.size(); ++i) {
  1483. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, out, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("out_key_") += ToString(i)).c_str(), block);
  1484. if (Nodes.FinishNodes[i]->GetDependencesCount() > 0 || Nodes.ItemsOnResult[i])
  1485. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.FinishNodes[i])->CreateSetValue(ctx, block, ptr);
  1486. else
  1487. ValueUnRef(Nodes.FinishNodes[i]->GetRepresentation(), ptr, ctx, block);
  1488. }
  1489. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  1490. BranchInst::Create(over, block);
  1491. block = done;
  1492. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  1493. BranchInst::Create(over, block);
  1494. block = over;
  1495. ICodegeneratorInlineWideNode::TGettersList getters;
  1496. getters.reserve(Nodes.FinishResultNodes.size());
  1497. std::transform(Nodes.FinishResultNodes.cbegin(), Nodes.FinishResultNodes.cend(), std::back_inserter(getters), [&](IComputationNode* node) {
  1498. return [node](const TCodegenContext& ctx, BasicBlock*& block){ return GetNodeValue(node, ctx, block); };
  1499. });
  1500. return {result, std::move(getters)};
  1501. }
  1502. #endif
  1503. private:
  1504. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  1505. state = ctx.HolderFactory.Create<TSpillingSupportState>(UsedInputItemType, KeyAndStateType,
  1506. Nodes.KeyNodes.size(),
  1507. Nodes.ItemNodes.size(),
  1508. #ifdef MKQL_DISABLE_CODEGEN
  1509. TMyValueHasher(KeyTypes),
  1510. TMyValueEqual(KeyTypes),
  1511. #else
  1512. ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
  1513. ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes)),
  1514. #endif
  1515. AllowSpilling,
  1516. ctx
  1517. );
  1518. }
  1519. void RegisterDependencies() const final {
  1520. if (const auto flow = this->FlowDependsOn(Flow)) {
  1521. Nodes.RegisterDependencies(
  1522. [this, flow](IComputationNode* node){ this->DependsOn(flow, node); },
  1523. [this, flow](IComputationExternalNode* node){ this->Own(flow, node); }
  1524. );
  1525. }
  1526. }
  1527. IComputationWideFlowNode *const Flow;
  1528. const TCombinerNodes Nodes;
  1529. const TKeyTypes KeyTypes;
  1530. const TMultiType* const UsedInputItemType;
  1531. const TMultiType* const KeyAndStateType;
  1532. const ui32 WideFieldsIndex;
  1533. const bool AllowSpilling;
  1534. #ifndef MKQL_DISABLE_CODEGEN
  1535. TEqualsPtr Equals = nullptr;
  1536. THashPtr Hash = nullptr;
  1537. Function* EqualsFunc = nullptr;
  1538. Function* HashFunc = nullptr;
  1539. template <bool EqualsOrHash>
  1540. TString MakeName() const {
  1541. TStringStream out;
  1542. out << this->DebugString() << "::" << (EqualsOrHash ? "Equals" : "Hash") << "_(" << static_cast<const void*>(this) << ").";
  1543. return out.Str();
  1544. }
  1545. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1546. if (EqualsFunc) {
  1547. Equals = reinterpret_cast<TEqualsPtr>(codegen.GetPointerToFunction(EqualsFunc));
  1548. }
  1549. if (HashFunc) {
  1550. Hash = reinterpret_cast<THashPtr>(codegen.GetPointerToFunction(HashFunc));
  1551. }
  1552. }
  1553. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1554. codegen.ExportSymbol(HashFunc = GenerateHashFunction(codegen, MakeName<false>(), KeyTypes));
  1555. codegen.ExportSymbol(EqualsFunc = GenerateEqualsFunction(codegen, MakeName<true>(), KeyTypes));
  1556. }
  1557. #endif
  1558. };
  1559. }
  1560. template<bool Last>
  1561. IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeFactoryContext& ctx, bool allowSpilling) {
  1562. MKQL_ENSURE(callable.GetInputsCount() >= (Last ? 3U : 4U), "Expected more arguments.");
  1563. const auto inputType = AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType());
  1564. const auto inputWidth = GetWideComponentsCount(inputType);
  1565. const auto outputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType()));
  1566. const auto flow = LocateNode(ctx.NodeLocator, callable, 0U);
  1567. auto index = Last ? 0U : 1U;
  1568. const auto keysSize = AS_VALUE(TDataLiteral, callable.GetInput(++index))->AsValue().Get<ui32>();
  1569. const auto stateSize = AS_VALUE(TDataLiteral, callable.GetInput(++index))->AsValue().Get<ui32>();
  1570. ++index += inputWidth;
  1571. std::vector<TType*> keyAndStateItemTypes;
  1572. keyAndStateItemTypes.reserve(keysSize + stateSize);
  1573. TKeyTypes keyTypes;
  1574. keyTypes.reserve(keysSize);
  1575. for (ui32 i = index; i < index + keysSize; ++i) {
  1576. TType *type = callable.GetInput(i).GetStaticType();
  1577. keyAndStateItemTypes.push_back(type);
  1578. bool optional;
  1579. keyTypes.emplace_back(*UnpackOptionalData(callable.GetInput(i).GetStaticType(), optional)->GetDataSlot(), optional);
  1580. }
  1581. TCombinerNodes nodes;
  1582. nodes.KeyResultNodes.reserve(keysSize);
  1583. std::generate_n(std::back_inserter(nodes.KeyResultNodes), keysSize, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1584. index += keysSize;
  1585. nodes.InitResultNodes.reserve(stateSize);
  1586. for (size_t i = 0; i != stateSize; ++i) {
  1587. TType *type = callable.GetInput(index).GetStaticType();
  1588. keyAndStateItemTypes.push_back(type);
  1589. nodes.InitResultNodes.push_back(LocateNode(ctx.NodeLocator, callable, index++));
  1590. }
  1591. index += stateSize;
  1592. nodes.UpdateResultNodes.reserve(stateSize);
  1593. std::generate_n(std::back_inserter(nodes.UpdateResultNodes), stateSize, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1594. index += keysSize + stateSize;
  1595. nodes.FinishResultNodes.reserve(outputWidth);
  1596. std::generate_n(std::back_inserter(nodes.FinishResultNodes), outputWidth, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1597. index = Last ? 3U : 4U;
  1598. nodes.ItemNodes.reserve(inputWidth);
  1599. std::generate_n(std::back_inserter(nodes.ItemNodes), inputWidth, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1600. index += keysSize;
  1601. nodes.KeyNodes.reserve(keysSize);
  1602. std::generate_n(std::back_inserter(nodes.KeyNodes), keysSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1603. index += stateSize;
  1604. nodes.StateNodes.reserve(stateSize);
  1605. std::generate_n(std::back_inserter(nodes.StateNodes), stateSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1606. index += stateSize;
  1607. nodes.FinishNodes.reserve(keysSize + stateSize);
  1608. std::generate_n(std::back_inserter(nodes.FinishNodes), keysSize + stateSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1609. nodes.BuildMaps();
  1610. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
  1611. if constexpr (Last) {
  1612. const auto inputItemTypes = GetWideComponents(inputType);
  1613. return new TWideLastCombinerWrapper(ctx.Mutables, wide, std::move(nodes),
  1614. TMultiType::Create(inputItemTypes.size(), inputItemTypes.data(), ctx.Env),
  1615. std::move(keyTypes),
  1616. TMultiType::Create(keyAndStateItemTypes.size(),keyAndStateItemTypes.data(), ctx.Env),
  1617. allowSpilling
  1618. );
  1619. } else {
  1620. if constexpr (RuntimeVersion < 46U) {
  1621. const auto memLimit = AS_VALUE(TDataLiteral, callable.GetInput(1U))->AsValue().Get<ui64>();
  1622. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1623. return new TWideCombinerWrapper<true, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), memLimit);
  1624. else
  1625. return new TWideCombinerWrapper<false, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), memLimit);
  1626. } else {
  1627. if (const auto memLimit = AS_VALUE(TDataLiteral, callable.GetInput(1U))->AsValue().Get<i64>(); memLimit >= 0)
  1628. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1629. return new TWideCombinerWrapper<true, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(memLimit));
  1630. else
  1631. return new TWideCombinerWrapper<false, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(memLimit));
  1632. else
  1633. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1634. return new TWideCombinerWrapper<true, true>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(-memLimit));
  1635. else
  1636. return new TWideCombinerWrapper<false, true>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(-memLimit));
  1637. }
  1638. }
  1639. }
  1640. THROW yexception() << "Expected wide flow.";
  1641. }
  1642. IComputationNode* WrapWideCombiner(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1643. return WrapWideCombinerT<false>(callable, ctx, false);
  1644. }
  1645. IComputationNode* WrapWideLastCombiner(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1646. YQL_LOG(INFO) << "Found non-serializable type, spilling is disabled";
  1647. return WrapWideCombinerT<true>(callable, ctx, false);
  1648. }
  1649. IComputationNode* WrapWideLastCombinerWithSpilling(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1650. return WrapWideCombinerT<true>(callable, ctx, true);
  1651. }
  1652. }
  1653. }