mkql_wide_combine.cpp 88 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014
  1. #include "mkql_counters.h"
  2. #include "mkql_rh_hash.h"
  3. #include "mkql_wide_combine.h"
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  5. #include <yql/essentials/minikql/computation/mkql_llvm_base.h> // Y_IGNORE
  6. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  7. #include <yql/essentials/minikql/computation/mkql_spiller_adapter.h>
  8. #include <yql/essentials/minikql/computation/mkql_spiller.h>
  9. #include <yql/essentials/minikql/mkql_node_builder.h>
  10. #include <yql/essentials/minikql/mkql_node_cast.h>
  11. #include <yql/essentials/minikql/mkql_runtime_version.h>
  12. #include <yql/essentials/minikql/mkql_stats_registry.h>
  13. #include <yql/essentials/minikql/defs.h>
  14. #include <yql/essentials/utils/cast.h>
  15. #include <yql/essentials/utils/log/log.h>
  16. #include <util/string/cast.h>
  17. #include <contrib/libs/xxhash/xxhash.h>
  18. namespace NKikimr {
  19. namespace NMiniKQL {
  20. using NYql::EnsureDynamicCast;
  21. using NYql::TChunkedBuffer;
  22. extern TStatKey Combine_FlushesCount;
  23. extern TStatKey Combine_MaxRowsCount;
  24. namespace {
  25. bool HasMemoryForProcessing() {
  26. return !TlsAllocState->IsMemoryYellowZoneEnabled();
  27. }
  28. struct TMyValueEqual {
  29. TMyValueEqual(const TKeyTypes& types)
  30. : Types(types)
  31. {}
  32. bool operator()(const NUdf::TUnboxedValuePod* left, const NUdf::TUnboxedValuePod* right) const {
  33. for (ui32 i = 0U; i < Types.size(); ++i)
  34. if (CompareValues(Types[i].first, true, Types[i].second, left[i], right[i]))
  35. return false;
  36. return true;
  37. }
  38. const TKeyTypes& Types;
  39. };
  40. struct TMyValueHasher {
  41. TMyValueHasher(const TKeyTypes& types)
  42. : Types(types)
  43. {}
  44. NUdf::THashType operator()(const NUdf::TUnboxedValuePod* values) const {
  45. if (Types.size() == 1U)
  46. if (const auto v = *values)
  47. return NUdf::GetValueHash(Types.front().first, v);
  48. else
  49. return HashOfNull;
  50. NUdf::THashType hash = 0ULL;
  51. for (const auto& type : Types) {
  52. if (const auto v = *values++)
  53. hash = CombineHashes(hash, NUdf::GetValueHash(type.first, v));
  54. else
  55. hash = CombineHashes(hash, HashOfNull);
  56. }
  57. return hash;
  58. }
  59. const TKeyTypes& Types;
  60. };
  61. using TEqualsPtr = bool(*)(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*);
  62. using THashPtr = NUdf::THashType(*)(const NUdf::TUnboxedValuePod*);
  63. using TEqualsFunc = std::function<bool(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)>;
  64. using THashFunc = std::function<NUdf::THashType(const NUdf::TUnboxedValuePod*)>;
  65. using TDependsOn = std::function<void(IComputationNode*)>;
  66. using TOwn = std::function<void(IComputationExternalNode*)>;
  67. struct TCombinerNodes {
  68. TComputationExternalNodePtrVector ItemNodes, KeyNodes, StateNodes, FinishNodes;
  69. TComputationNodePtrVector KeyResultNodes, InitResultNodes, UpdateResultNodes, FinishResultNodes;
  70. TPasstroughtMap
  71. KeysOnItems,
  72. InitOnKeys,
  73. InitOnItems,
  74. UpdateOnKeys,
  75. UpdateOnItems,
  76. UpdateOnState,
  77. StateOnUpdate,
  78. ItemsOnResult,
  79. ResultOnItems;
  80. std::vector<bool> PasstroughtItems;
  81. void BuildMaps() {
  82. KeysOnItems = GetPasstroughtMap(KeyResultNodes, ItemNodes);
  83. InitOnKeys = GetPasstroughtMap(InitResultNodes, KeyNodes);
  84. InitOnItems = GetPasstroughtMap(InitResultNodes, ItemNodes);
  85. UpdateOnKeys = GetPasstroughtMap(UpdateResultNodes, KeyNodes);
  86. UpdateOnItems = GetPasstroughtMap(UpdateResultNodes, ItemNodes);
  87. UpdateOnState = GetPasstroughtMap(UpdateResultNodes, StateNodes);
  88. StateOnUpdate = GetPasstroughtMap(StateNodes, UpdateResultNodes);
  89. ItemsOnResult = GetPasstroughtMap(FinishNodes, FinishResultNodes);
  90. ResultOnItems = GetPasstroughtMap(FinishResultNodes, FinishNodes);
  91. PasstroughtItems.resize(ItemNodes.size());
  92. auto anyResults = KeyResultNodes;
  93. anyResults.insert(anyResults.cend(), InitResultNodes.cbegin(), InitResultNodes.cend());
  94. anyResults.insert(anyResults.cend(), UpdateResultNodes.cbegin(), UpdateResultNodes.cend());
  95. const auto itemsOnResults = GetPasstroughtMap(ItemNodes, anyResults);
  96. std::transform(itemsOnResults.cbegin(), itemsOnResults.cend(), PasstroughtItems.begin(), [](const TPasstroughtMap::value_type& v) { return v.has_value(); });
  97. }
  98. bool IsInputItemNodeUsed(size_t i) const {
  99. return (ItemNodes[i]->GetDependencesCount() > 0U || PasstroughtItems[i]);
  100. }
  101. NUdf::TUnboxedValue* GetUsedInputItemNodePtrOrNull(TComputationContext& ctx, size_t i) const {
  102. return IsInputItemNodeUsed(i) ?
  103. &ItemNodes[i]->RefValue(ctx) :
  104. nullptr;
  105. }
  106. void ExtractKey(TComputationContext& ctx, NUdf::TUnboxedValue** values, NUdf::TUnboxedValue* keys) const {
  107. std::for_each(ItemNodes.cbegin(), ItemNodes.cend(), [&](IComputationExternalNode* item) {
  108. if (const auto pointer = *values++)
  109. item->SetValue(ctx, std::move(*pointer));
  110. });
  111. for (ui32 i = 0U; i < KeyNodes.size(); ++i) {
  112. auto& key = KeyNodes[i]->RefValue(ctx);
  113. *keys++ = key = KeyResultNodes[i]->GetValue(ctx);
  114. }
  115. }
  116. void ConsumeRawData(TComputationContext& /*ctx*/, NUdf::TUnboxedValue* keys, NUdf::TUnboxedValue** from, NUdf::TUnboxedValue* to) const {
  117. std::fill_n(keys, KeyResultNodes.size(), NUdf::TUnboxedValuePod());
  118. for (ui32 i = 0U; i < ItemNodes.size(); ++i) {
  119. if (from[i] && IsInputItemNodeUsed(i)) {
  120. to[i] = std::move(*(from[i]));
  121. }
  122. }
  123. }
  124. void ExtractRawData(TComputationContext& ctx, NUdf::TUnboxedValue* from, NUdf::TUnboxedValue* keys) const {
  125. for (ui32 i = 0U; i != ItemNodes.size(); ++i) {
  126. if (IsInputItemNodeUsed(i)) {
  127. ItemNodes[i]->SetValue(ctx, std::move(from[i]));
  128. }
  129. }
  130. for (ui32 i = 0U; i < KeyNodes.size(); ++i) {
  131. auto& key = KeyNodes[i]->RefValue(ctx);
  132. *keys++ = key = KeyResultNodes[i]->GetValue(ctx);
  133. }
  134. }
  135. void ProcessItem(TComputationContext& ctx, NUdf::TUnboxedValue* keys, NUdf::TUnboxedValue* state) const {
  136. if (keys) {
  137. std::fill_n(keys, KeyResultNodes.size(), NUdf::TUnboxedValuePod());
  138. auto source = state;
  139. std::for_each(StateNodes.cbegin(), StateNodes.cend(), [&](IComputationExternalNode* item){ item->SetValue(ctx, std::move(*source++)); });
  140. std::transform(UpdateResultNodes.cbegin(), UpdateResultNodes.cend(), state, [&](IComputationNode* node) { return node->GetValue(ctx); });
  141. } else {
  142. std::transform(InitResultNodes.cbegin(), InitResultNodes.cend(), state, [&](IComputationNode* node) { return node->GetValue(ctx); });
  143. }
  144. }
  145. void FinishItem(TComputationContext& ctx, NUdf::TUnboxedValue* state, NUdf::TUnboxedValue*const* output) const {
  146. std::for_each(FinishNodes.cbegin(), FinishNodes.cend(), [&](IComputationExternalNode* item) { item->SetValue(ctx, std::move(*state++)); });
  147. for (const auto node : FinishResultNodes)
  148. if (const auto out = *output++)
  149. *out = node->GetValue(ctx);
  150. }
  151. void RegisterDependencies(const TDependsOn& dependsOn, const TOwn& own) const {
  152. std::for_each(ItemNodes.cbegin(), ItemNodes.cend(), own);
  153. std::for_each(KeyNodes.cbegin(), KeyNodes.cend(), own);
  154. std::for_each(StateNodes.cbegin(), StateNodes.cend(), own);
  155. std::for_each(FinishNodes.cbegin(), FinishNodes.cend(), own);
  156. std::for_each(KeyResultNodes.cbegin(), KeyResultNodes.cend(), dependsOn);
  157. std::for_each(InitResultNodes.cbegin(), InitResultNodes.cend(), dependsOn);
  158. std::for_each(UpdateResultNodes.cbegin(), UpdateResultNodes.cend(), dependsOn);
  159. std::for_each(FinishResultNodes.cbegin(), FinishResultNodes.cend(), dependsOn);
  160. }
  161. };
  162. class TState : public TComputationValue<TState> {
  163. typedef TComputationValue<TState> TBase;
  164. private:
  165. using TStates = TRobinHoodHashSet<NUdf::TUnboxedValuePod*, TEqualsFunc, THashFunc, TMKQLAllocator<char, EMemorySubPool::Temporary>>;
  166. using TRow = std::vector<NUdf::TUnboxedValuePod, TMKQLAllocator<NUdf::TUnboxedValuePod>>;
  167. using TStorage = std::deque<TRow, TMKQLAllocator<TRow>>;
  168. class TStorageIterator {
  169. private:
  170. TStorage& Storage;
  171. const ui32 RowSize = 0;
  172. const ui64 Count = 0;
  173. ui64 Ready = 0;
  174. TStorage::iterator ItStorage;
  175. TRow::iterator ItRow;
  176. public:
  177. TStorageIterator(TStorage& storage, const ui32 rowSize, const ui64 count)
  178. : Storage(storage)
  179. , RowSize(rowSize)
  180. , Count(count)
  181. {
  182. ItStorage = Storage.begin();
  183. if (ItStorage != Storage.end()) {
  184. ItRow = ItStorage->begin();
  185. }
  186. }
  187. bool IsValid() {
  188. return Ready < Count;
  189. }
  190. bool Next() {
  191. if (++Ready >= Count) {
  192. return false;
  193. }
  194. ItRow += RowSize;
  195. if (ItRow == ItStorage->end()) {
  196. ++ItStorage;
  197. ItRow = ItStorage->begin();
  198. }
  199. return true;
  200. }
  201. NUdf::TUnboxedValuePod* GetValuePtr() const {
  202. return &*ItRow;
  203. }
  204. };
  205. static constexpr ui32 CountRowsOnPage = 128;
  206. ui32 RowSize() const {
  207. return KeyWidth + StateWidth;
  208. }
  209. public:
  210. TState(TMemoryUsageInfo* memInfo, ui32 keyWidth, ui32 stateWidth, const THashFunc& hash, const TEqualsFunc& equal, bool allowOutOfMemory = true)
  211. : TBase(memInfo), KeyWidth(keyWidth), StateWidth(stateWidth), AllowOutOfMemory(allowOutOfMemory), Hash(hash), Equal(equal) {
  212. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  213. CurrentPosition = 0;
  214. Tongue = CurrentPage->data();
  215. States = std::make_unique<TStates>(Hash, Equal, CountRowsOnPage);
  216. }
  217. ~TState() {
  218. //Workaround for YQL-16663, consider to rework this class in a safe manner
  219. while (auto row = Extract()) {
  220. for (size_t i = 0; i != RowSize(); ++i) {
  221. row[i].UnRef();
  222. }
  223. }
  224. ExtractIt.reset();
  225. Storage.clear();
  226. States->Clear();
  227. CleanupCurrentContext();
  228. }
  229. bool TasteIt() {
  230. Y_ABORT_UNLESS(!ExtractIt);
  231. bool isNew = false;
  232. auto itInsert = States->Insert(Tongue, isNew);
  233. if (isNew) {
  234. CurrentPosition += RowSize();
  235. if (CurrentPosition == CurrentPage->size()) {
  236. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  237. CurrentPosition = 0;
  238. }
  239. Tongue = CurrentPage->data() + CurrentPosition;
  240. }
  241. Throat = States->GetKey(itInsert) + KeyWidth;
  242. if (isNew) {
  243. GrowStates();
  244. }
  245. IsOutOfMemory = IsOutOfMemory || (!HasMemoryForProcessing() && States->GetSize() > 1000);
  246. return isNew;
  247. }
  248. void GrowStates() {
  249. try {
  250. States->CheckGrow();
  251. } catch (TMemoryLimitExceededException) {
  252. YQL_LOG(INFO) << "State failed to grow";
  253. if (IsOutOfMemory || !AllowOutOfMemory) {
  254. throw;
  255. } else {
  256. IsOutOfMemory = true;
  257. }
  258. }
  259. }
  260. template<bool SkipYields>
  261. bool ReadMore() {
  262. if constexpr (SkipYields) {
  263. if (EFetchResult::Yield == InputStatus)
  264. return true;
  265. }
  266. if (!States->Empty())
  267. return false;
  268. {
  269. TStorage localStorage;
  270. std::swap(localStorage, Storage);
  271. }
  272. if (IsOutOfMemory) {
  273. States = std::make_unique<TStates>(Hash, Equal, CountRowsOnPage);
  274. }
  275. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  276. CurrentPosition = 0;
  277. Tongue = CurrentPage->data();
  278. StoredDataSize = 0;
  279. IsOutOfMemory = false;
  280. CleanupCurrentContext();
  281. return true;
  282. }
  283. void PushStat(IStatsRegistry* stats) const {
  284. if (!States->Empty()) {
  285. MKQL_SET_MAX_STAT(stats, Combine_MaxRowsCount, static_cast<i64>(States->GetSize()));
  286. MKQL_INC_STAT(stats, Combine_FlushesCount);
  287. }
  288. }
  289. NUdf::TUnboxedValuePod* Extract() {
  290. if (!ExtractIt) {
  291. ExtractIt.emplace(Storage, RowSize(), States->GetSize());
  292. } else {
  293. ExtractIt->Next();
  294. }
  295. if (!ExtractIt->IsValid()) {
  296. ExtractIt.reset();
  297. States->Clear();
  298. return nullptr;
  299. }
  300. NUdf::TUnboxedValuePod* result = ExtractIt->GetValuePtr();
  301. CounterOutputRows_.Inc();
  302. return result;
  303. }
  304. EFetchResult InputStatus = EFetchResult::One;
  305. NUdf::TUnboxedValuePod* Tongue = nullptr;
  306. NUdf::TUnboxedValuePod* Throat = nullptr;
  307. i64 StoredDataSize = 0;
  308. bool IsOutOfMemory = false;
  309. NYql::NUdf::TCounter CounterOutputRows_;
  310. private:
  311. std::optional<TStorageIterator> ExtractIt;
  312. const ui32 KeyWidth, StateWidth;
  313. const bool AllowOutOfMemory;
  314. ui64 CurrentPosition = 0;
  315. TRow* CurrentPage = nullptr;
  316. TStorage Storage;
  317. std::unique_ptr<TStates> States;
  318. const THashFunc Hash;
  319. const TEqualsFunc Equal;
  320. };
  321. class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
  322. typedef TComputationValue<TSpillingSupportState> TBase;
  323. typedef std::optional<NThreading::TFuture<ISpiller::TKey>> TAsyncWriteOperation;
  324. typedef std::optional<NThreading::TFuture<std::optional<TChunkedBuffer>>> TAsyncReadOperation;
  325. struct TSpilledBucket {
  326. std::unique_ptr<TWideUnboxedValuesSpillerAdapter> SpilledState; //state collected before switching to spilling mode
  327. std::unique_ptr<TWideUnboxedValuesSpillerAdapter> SpilledData; //data collected in spilling mode
  328. std::unique_ptr<TState> InMemoryProcessingState;
  329. TAsyncWriteOperation AsyncWriteOperation;
  330. enum class EBucketState {
  331. InMemory,
  332. SpillingState,
  333. SpillingData
  334. };
  335. EBucketState BucketState = EBucketState::InMemory;
  336. ui64 LineCount = 0;
  337. };
  338. enum class EOperatingMode {
  339. InMemory,
  340. SplittingState,
  341. Spilling,
  342. ProcessSpilled
  343. };
  344. public:
  345. enum class ETasteResult: i8 {
  346. Init = -1,
  347. Update,
  348. ConsumeRawData
  349. };
  350. enum class EUpdateResult: i8 {
  351. Yield = -1,
  352. ExtractRawData,
  353. ReadInput,
  354. Extract,
  355. Finish
  356. };
  357. TSpillingSupportState(
  358. TMemoryUsageInfo* memInfo,
  359. const TMultiType* usedInputItemType, const TMultiType* keyAndStateType, ui32 keyWidth, size_t itemNodesSize,
  360. const THashFunc& hash, const TEqualsFunc& equal, bool allowSpilling, TComputationContext& ctx
  361. )
  362. : TBase(memInfo)
  363. , InMemoryProcessingState(memInfo, keyWidth, keyAndStateType->GetElementsCount() - keyWidth, hash, equal, allowSpilling && ctx.SpillerFactory)
  364. , UsedInputItemType(usedInputItemType)
  365. , KeyAndStateType(keyAndStateType)
  366. , KeyWidth(keyWidth)
  367. , ItemNodesSize(itemNodesSize)
  368. , Hasher(hash)
  369. , Mode(EOperatingMode::InMemory)
  370. , ViewForKeyAndState(keyAndStateType->GetElementsCount())
  371. , MemInfo(memInfo)
  372. , Equal(equal)
  373. , AllowSpilling(allowSpilling)
  374. , Ctx(ctx)
  375. {
  376. BufferForUsedInputItems.reserve(usedInputItemType->GetElementsCount());
  377. Tongue = InMemoryProcessingState.Tongue;
  378. Throat = InMemoryProcessingState.Throat;
  379. if (ctx.CountersProvider) {
  380. // id will be assigned externally in future versions
  381. TString id = TString(Operator_Aggregation) + "0";
  382. CounterOutputRows_ = ctx.CountersProvider->GetCounter(id, Counter_OutputRows, false);
  383. }
  384. }
  385. EUpdateResult Update() {
  386. if (IsEverythingExtracted) {
  387. return EUpdateResult::Finish;
  388. }
  389. switch (GetMode()) {
  390. case EOperatingMode::InMemory: {
  391. Tongue = InMemoryProcessingState.Tongue;
  392. if (CheckMemoryAndSwitchToSpilling()) {
  393. return Update();
  394. }
  395. if (InputStatus == EFetchResult::Finish) return EUpdateResult::Extract;
  396. return EUpdateResult::ReadInput;
  397. }
  398. case EOperatingMode::SplittingState: {
  399. if (SplitStateIntoBucketsAndWait()) return EUpdateResult::Yield;
  400. return Update();
  401. }
  402. case EOperatingMode::Spilling: {
  403. UpdateSpillingBuckets();
  404. if (!HasMemoryForProcessing() && InputStatus != EFetchResult::Finish && TryToReduceMemoryAndWait()) {
  405. return EUpdateResult::Yield;
  406. }
  407. if (BufferForUsedInputItems.size()) {
  408. auto& bucket = SpilledBuckets[BufferForUsedInputItemsBucketId];
  409. if (bucket.AsyncWriteOperation.has_value()) return EUpdateResult::Yield;
  410. bucket.AsyncWriteOperation = bucket.SpilledData->WriteWideItem(BufferForUsedInputItems);
  411. BufferForUsedInputItems.resize(0); //for freeing allocated key value asap
  412. }
  413. if (InputStatus == EFetchResult::Finish) return FlushSpillingBuffersAndWait();
  414. return EUpdateResult::ReadInput;
  415. }
  416. case EOperatingMode::ProcessSpilled:
  417. return ProcessSpilledData();
  418. }
  419. }
  420. ETasteResult TasteIt() {
  421. if (GetMode() == EOperatingMode::InMemory) {
  422. bool isNew = InMemoryProcessingState.TasteIt();
  423. if (InMemoryProcessingState.IsOutOfMemory) {
  424. StateWantsToSpill = true;
  425. }
  426. Throat = InMemoryProcessingState.Throat;
  427. return isNew ? ETasteResult::Init : ETasteResult::Update;
  428. }
  429. if (GetMode() == EOperatingMode::ProcessSpilled) {
  430. // while restoration we process buckets one by one starting from the first in a queue
  431. bool isNew = SpilledBuckets.front().InMemoryProcessingState->TasteIt();
  432. Throat = SpilledBuckets.front().InMemoryProcessingState->Throat;
  433. BufferForUsedInputItems.resize(0);
  434. return isNew ? ETasteResult::Init : ETasteResult::Update;
  435. }
  436. auto bucketId = ChooseBucket(ViewForKeyAndState.data());
  437. auto& bucket = SpilledBuckets[bucketId];
  438. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) {
  439. std::copy_n(ViewForKeyAndState.data(), KeyWidth, static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Tongue));
  440. bool isNew = bucket.InMemoryProcessingState->TasteIt();
  441. Throat = bucket.InMemoryProcessingState->Throat;
  442. bucket.LineCount += isNew;
  443. return isNew ? ETasteResult::Init : ETasteResult::Update;
  444. }
  445. bucket.LineCount++;
  446. // Prepare space for raw data
  447. MKQL_ENSURE(BufferForUsedInputItems.size() == 0, "Internal logic error");
  448. BufferForUsedInputItems.resize(ItemNodesSize);
  449. BufferForUsedInputItemsBucketId = bucketId;
  450. Throat = BufferForUsedInputItems.data();
  451. return ETasteResult::ConsumeRawData;
  452. }
  453. NUdf::TUnboxedValuePod* Extract() {
  454. NUdf::TUnboxedValue* value = nullptr;
  455. if (GetMode() == EOperatingMode::InMemory) {
  456. value = static_cast<NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract());
  457. if (value) {
  458. CounterOutputRows_.Inc();
  459. } else {
  460. IsEverythingExtracted = true;
  461. }
  462. return value;
  463. }
  464. MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error");
  465. MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error");
  466. value = static_cast<NUdf::TUnboxedValue*>(SpilledBuckets.front().InMemoryProcessingState->Extract());
  467. if (value) {
  468. CounterOutputRows_.Inc();
  469. } else {
  470. SpilledBuckets.front().InMemoryProcessingState->ReadMore<false>();
  471. SpilledBuckets.pop_front();
  472. if (SpilledBuckets.empty()) IsEverythingExtracted = true;
  473. }
  474. return value;
  475. }
  476. private:
  477. ui64 ChooseBucket(const NUdf::TUnboxedValuePod *const key) {
  478. auto provided_hash = Hasher(key);
  479. XXH64_hash_t bucket = XXH64(&provided_hash, sizeof(provided_hash), 0) % SpilledBucketCount;
  480. return bucket;
  481. }
  482. EUpdateResult FlushSpillingBuffersAndWait() {
  483. UpdateSpillingBuckets();
  484. ui64 finishedCount = 0;
  485. for (auto& bucket : SpilledBuckets) {
  486. MKQL_ENSURE(bucket.BucketState != TSpilledBucket::EBucketState::SpillingState, "Internal logic error");
  487. if (!bucket.AsyncWriteOperation.has_value()) {
  488. auto writeOperation = bucket.SpilledData->FinishWriting();
  489. if (!writeOperation) {
  490. ++finishedCount;
  491. } else {
  492. bucket.AsyncWriteOperation = writeOperation;
  493. }
  494. }
  495. }
  496. if (finishedCount != SpilledBuckets.size()) return EUpdateResult::Yield;
  497. SwitchMode(EOperatingMode::ProcessSpilled);
  498. return ProcessSpilledData();
  499. }
  500. ui32 GetLargestInMemoryBucketNumber() const {
  501. ui64 maxSize = 0;
  502. ui32 largestInMemoryBucketNum = (ui32)-1;
  503. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  504. if (SpilledBuckets[i].BucketState == TSpilledBucket::EBucketState::InMemory) {
  505. if (SpilledBuckets[i].LineCount >= maxSize) {
  506. largestInMemoryBucketNum = i;
  507. maxSize = SpilledBuckets[i].LineCount;
  508. }
  509. }
  510. }
  511. return largestInMemoryBucketNum;
  512. }
  513. bool IsSpillingWhileStateSplitAllowed() const {
  514. // TODO: Write better condition here. For example: InMemorybuckets > 64
  515. return true;
  516. }
  517. bool SplitStateIntoBucketsAndWait() {
  518. if (SplitStateSpillingBucket != -1) {
  519. auto& bucket = SpilledBuckets[SplitStateSpillingBucket];
  520. MKQL_ENSURE(bucket.AsyncWriteOperation.has_value(), "Internal logic error");
  521. if (!bucket.AsyncWriteOperation->HasValue()) return true;
  522. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  523. bucket.AsyncWriteOperation = std::nullopt;
  524. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  525. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  526. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  527. //releasing values stored in unsafe TUnboxedValue buffer
  528. keyAndState[i].UnRef();
  529. }
  530. if (bucket.AsyncWriteOperation) return true;
  531. }
  532. SplitStateSpillingBucket = -1;
  533. }
  534. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue *>(InMemoryProcessingState.Extract())) {
  535. auto bucketId = ChooseBucket(keyAndState); // This uses only key for hashing
  536. auto& bucket = SpilledBuckets[bucketId];
  537. bucket.LineCount++;
  538. if (bucket.BucketState != TSpilledBucket::EBucketState::InMemory) {
  539. if (bucket.BucketState != TSpilledBucket::EBucketState::SpillingState) {
  540. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  541. SpillingBucketsCount++;
  542. }
  543. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  544. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  545. //releasing values stored in unsafe TUnboxedValue buffer
  546. keyAndState[i].UnRef();
  547. }
  548. if (bucket.AsyncWriteOperation) {
  549. SplitStateSpillingBucket = bucketId;
  550. return true;
  551. }
  552. continue;
  553. }
  554. auto& processingState = *bucket.InMemoryProcessingState;
  555. for (size_t i = 0; i < KeyWidth; ++i) {
  556. //jumping into unsafe world, refusing ownership
  557. static_cast<NUdf::TUnboxedValue&>(processingState.Tongue[i]) = std::move(keyAndState[i]);
  558. }
  559. processingState.TasteIt();
  560. for (size_t i = KeyWidth; i < KeyAndStateType->GetElementsCount(); ++i) {
  561. //jumping into unsafe world, refusing ownership
  562. static_cast<NUdf::TUnboxedValue&>(processingState.Throat[i - KeyWidth]) = std::move(keyAndState[i]);
  563. }
  564. if (InMemoryBucketsCount && !HasMemoryForProcessing() && IsSpillingWhileStateSplitAllowed()) {
  565. ui32 bucketNumToSpill = GetLargestInMemoryBucketNumber();
  566. SplitStateSpillingBucket = bucketNumToSpill;
  567. auto& bucket = SpilledBuckets[bucketNumToSpill];
  568. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  569. SpillingBucketsCount++;
  570. InMemoryBucketsCount--;
  571. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  572. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  573. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  574. //releasing values stored in unsafe TUnboxedValue buffer
  575. keyAndState[i].UnRef();
  576. }
  577. if (bucket.AsyncWriteOperation) return true;
  578. }
  579. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  580. if (bucket.AsyncWriteOperation) return true;
  581. }
  582. }
  583. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  584. auto& bucket = SpilledBuckets[i];
  585. if (bucket.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  586. if (bucket.AsyncWriteOperation.has_value()) {
  587. if (!bucket.AsyncWriteOperation->HasValue()) return true;
  588. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  589. bucket.AsyncWriteOperation = std::nullopt;
  590. }
  591. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  592. if (bucket.AsyncWriteOperation) return true;
  593. bucket.InMemoryProcessingState->ReadMore<false>();
  594. bucket.BucketState = TSpilledBucket::EBucketState::SpillingData;
  595. SpillingBucketsCount--;
  596. }
  597. }
  598. InMemoryProcessingState.ReadMore<false>();
  599. IsInMemoryProcessingStateSplitted = true;
  600. SwitchMode(EOperatingMode::Spilling);
  601. return false;
  602. }
  603. bool CheckMemoryAndSwitchToSpilling() {
  604. if (!(AllowSpilling && Ctx.SpillerFactory)) {
  605. return false;
  606. }
  607. if (StateWantsToSpill || IsSwitchToSpillingModeCondition()) {
  608. StateWantsToSpill = false;
  609. LogMemoryUsage();
  610. SwitchMode(EOperatingMode::SplittingState);
  611. return true;
  612. }
  613. return false;
  614. }
  615. void LogMemoryUsage() const {
  616. const auto used = TlsAllocState->GetUsed();
  617. const auto limit = TlsAllocState->GetLimit();
  618. TStringBuilder logmsg;
  619. logmsg << "Memory usage: ";
  620. if (limit) {
  621. logmsg << (used*100/limit) << "%=";
  622. }
  623. logmsg << (used/1_MB) << "MB/" << (limit/1_MB) << "MB";
  624. YQL_LOG(INFO) << logmsg;
  625. }
  626. void SpillMoreStateFromBucket(TSpilledBucket& bucket) {
  627. MKQL_ENSURE(!bucket.AsyncWriteOperation.has_value(), "Internal logic error");
  628. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) {
  629. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  630. SpillingBucketsCount++;
  631. InMemoryBucketsCount--;
  632. }
  633. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  634. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  635. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  636. //releasing values stored in unsafe TUnboxedValue buffer
  637. keyAndState[i].UnRef();
  638. }
  639. if (bucket.AsyncWriteOperation) return;
  640. }
  641. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  642. if (bucket.AsyncWriteOperation) return;
  643. bucket.InMemoryProcessingState->ReadMore<false>();
  644. bucket.BucketState = TSpilledBucket::EBucketState::SpillingData;
  645. SpillingBucketsCount--;
  646. }
  647. void UpdateSpillingBuckets() {
  648. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  649. auto& bucket = SpilledBuckets[i];
  650. if (bucket.AsyncWriteOperation.has_value() && bucket.AsyncWriteOperation->HasValue()) {
  651. if (bucket.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  652. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  653. bucket.AsyncWriteOperation = std::nullopt;
  654. SpillMoreStateFromBucket(bucket);
  655. } else {
  656. bucket.SpilledData->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  657. bucket.AsyncWriteOperation = std::nullopt;
  658. }
  659. }
  660. }
  661. }
  662. bool TryToReduceMemoryAndWait() {
  663. if (SpillingBucketsCount > 0) {
  664. return true;
  665. }
  666. while (InMemoryBucketsCount > 0) {
  667. ui32 maxLineBucketInd = GetLargestInMemoryBucketNumber();
  668. MKQL_ENSURE(maxLineBucketInd != (ui32)-1, "Internal logic error");
  669. auto& bucketToSpill = SpilledBuckets[maxLineBucketInd];
  670. SpillMoreStateFromBucket(bucketToSpill);
  671. if (bucketToSpill.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  672. return true;
  673. }
  674. }
  675. return false;
  676. }
  677. EUpdateResult ProcessSpilledData() {
  678. if (AsyncReadOperation) {
  679. if (!AsyncReadOperation->HasValue()) return EUpdateResult::Yield;
  680. if (RecoverState) {
  681. SpilledBuckets[0].SpilledState->AsyncReadCompleted(AsyncReadOperation->ExtractValue().value(), Ctx.HolderFactory);
  682. } else {
  683. SpilledBuckets[0].SpilledData->AsyncReadCompleted(AsyncReadOperation->ExtractValue().value(), Ctx.HolderFactory);
  684. }
  685. AsyncReadOperation = std::nullopt;
  686. }
  687. auto& bucket = SpilledBuckets.front();
  688. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) return EUpdateResult::Extract;
  689. //recover spilled state
  690. while(!bucket.SpilledState->Empty()) {
  691. RecoverState = true;
  692. TTemporaryUnboxedValueVector bufferForKeyAndState(KeyAndStateType->GetElementsCount());
  693. AsyncReadOperation = bucket.SpilledState->ExtractWideItem(bufferForKeyAndState);
  694. if (AsyncReadOperation) {
  695. return EUpdateResult::Yield;
  696. }
  697. for (size_t i = 0; i< KeyWidth; ++i) {
  698. //jumping into unsafe world, refusing ownership
  699. static_cast<NUdf::TUnboxedValue&>(bucket.InMemoryProcessingState->Tongue[i]) = std::move(bufferForKeyAndState[i]);
  700. }
  701. auto isNew = bucket.InMemoryProcessingState->TasteIt();
  702. MKQL_ENSURE(isNew, "Internal logic error");
  703. for (size_t i = KeyWidth; i < KeyAndStateType->GetElementsCount(); ++i) {
  704. //jumping into unsafe world, refusing ownership
  705. static_cast<NUdf::TUnboxedValue&>(bucket.InMemoryProcessingState->Throat[i - KeyWidth]) = std::move(bufferForKeyAndState[i]);
  706. }
  707. }
  708. //process spilled data
  709. if (!bucket.SpilledData->Empty()) {
  710. RecoverState = false;
  711. BufferForUsedInputItems.resize(UsedInputItemType->GetElementsCount());
  712. AsyncReadOperation = bucket.SpilledData->ExtractWideItem(BufferForUsedInputItems);
  713. if (AsyncReadOperation) {
  714. return EUpdateResult::Yield;
  715. }
  716. Throat = BufferForUsedInputItems.data();
  717. Tongue = bucket.InMemoryProcessingState->Tongue;
  718. return EUpdateResult::ExtractRawData;
  719. }
  720. bucket.BucketState = TSpilledBucket::EBucketState::InMemory;
  721. return EUpdateResult::Extract;
  722. }
  723. EOperatingMode GetMode() const {
  724. return Mode;
  725. }
  726. void SwitchMode(EOperatingMode mode) {
  727. switch(mode) {
  728. case EOperatingMode::InMemory: {
  729. YQL_LOG(INFO) << "switching Memory mode to InMemory";
  730. MKQL_ENSURE(false, "Internal logic error");
  731. break;
  732. }
  733. case EOperatingMode::SplittingState: {
  734. YQL_LOG(INFO) << "switching Memory mode to SplittingState";
  735. MKQL_ENSURE(EOperatingMode::InMemory == Mode, "Internal logic error");
  736. SpilledBuckets.resize(SpilledBucketCount);
  737. auto spiller = Ctx.SpillerFactory->CreateSpiller();
  738. for (auto &b: SpilledBuckets) {
  739. b.SpilledState = std::make_unique<TWideUnboxedValuesSpillerAdapter>(spiller, KeyAndStateType, 5_MB);
  740. b.SpilledData = std::make_unique<TWideUnboxedValuesSpillerAdapter>(spiller, UsedInputItemType, 5_MB);
  741. b.InMemoryProcessingState = std::make_unique<TState>(MemInfo, KeyWidth, KeyAndStateType->GetElementsCount() - KeyWidth, Hasher, Equal, false);
  742. }
  743. break;
  744. }
  745. case EOperatingMode::Spilling: {
  746. YQL_LOG(INFO) << "switching Memory mode to Spilling";
  747. MKQL_ENSURE(EOperatingMode::SplittingState == Mode || EOperatingMode::InMemory == Mode, "Internal logic error");
  748. Tongue = ViewForKeyAndState.data();
  749. break;
  750. }
  751. case EOperatingMode::ProcessSpilled: {
  752. YQL_LOG(INFO) << "switching Memory mode to ProcessSpilled";
  753. MKQL_ENSURE(EOperatingMode::Spilling == Mode, "Internal logic error");
  754. MKQL_ENSURE(SpilledBuckets.size() == SpilledBucketCount, "Internal logic error");
  755. std::sort(SpilledBuckets.begin(), SpilledBuckets.end(), [](const TSpilledBucket& lhs, const TSpilledBucket& rhs) {
  756. bool lhs_in_memory = lhs.BucketState == TSpilledBucket::EBucketState::InMemory;
  757. bool rhs_in_memory = rhs.BucketState == TSpilledBucket::EBucketState::InMemory;
  758. return lhs_in_memory > rhs_in_memory;
  759. });
  760. break;
  761. }
  762. }
  763. Mode = mode;
  764. }
  765. bool IsSwitchToSpillingModeCondition() const {
  766. return !HasMemoryForProcessing() || TlsAllocState->GetMaximumLimitValueReached();
  767. }
  768. public:
  769. EFetchResult InputStatus = EFetchResult::One;
  770. NUdf::TUnboxedValuePod* Tongue = nullptr;
  771. NUdf::TUnboxedValuePod* Throat = nullptr;
  772. private:
  773. bool StateWantsToSpill = false;
  774. bool IsEverythingExtracted = false;
  775. TState InMemoryProcessingState;
  776. bool IsInMemoryProcessingStateSplitted = false;
  777. const TMultiType* const UsedInputItemType;
  778. const TMultiType* const KeyAndStateType;
  779. const size_t KeyWidth;
  780. const size_t ItemNodesSize;
  781. THashFunc const Hasher;
  782. EOperatingMode Mode;
  783. bool RecoverState; //sub mode for ProcessSpilledData
  784. TAsyncReadOperation AsyncReadOperation = std::nullopt;
  785. static constexpr size_t SpilledBucketCount = 128;
  786. std::deque<TSpilledBucket> SpilledBuckets;
  787. ui32 SpillingBucketsCount = 0;
  788. ui32 InMemoryBucketsCount = SpilledBucketCount;
  789. ui64 BufferForUsedInputItemsBucketId;
  790. TUnboxedValueVector BufferForUsedInputItems;
  791. std::vector<NUdf::TUnboxedValuePod, TMKQLAllocator<NUdf::TUnboxedValuePod>> ViewForKeyAndState;
  792. i64 SplitStateSpillingBucket = -1;
  793. TMemoryUsageInfo* MemInfo = nullptr;
  794. TEqualsFunc const Equal;
  795. const bool AllowSpilling;
  796. TComputationContext& Ctx;
  797. NYql::NUdf::TCounter CounterOutputRows_;
  798. };
  799. #ifndef MKQL_DISABLE_CODEGEN
  800. class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState>> {
  801. private:
  802. using TBase = TLLVMFieldsStructure<TComputationValue<TState>>;
  803. llvm::IntegerType* ValueType;
  804. llvm::PointerType* PtrValueType;
  805. llvm::IntegerType* StatusType;
  806. llvm::IntegerType* StoredType;
  807. llvm::IntegerType* BoolType;
  808. protected:
  809. using TBase::Context;
  810. public:
  811. std::vector<llvm::Type*> GetFieldsArray() {
  812. std::vector<llvm::Type*> result = TBase::GetFields();
  813. result.emplace_back(StatusType); //status
  814. result.emplace_back(PtrValueType); //tongue
  815. result.emplace_back(PtrValueType); //throat
  816. result.emplace_back(StoredType); //StoredDataSize
  817. result.emplace_back(BoolType); //IsOutOfMemory
  818. result.emplace_back(Type::getInt32Ty(Context)); //size
  819. result.emplace_back(Type::getInt32Ty(Context)); //size
  820. return result;
  821. }
  822. llvm::Constant* GetStatus() {
  823. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
  824. }
  825. llvm::Constant* GetTongue() {
  826. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
  827. }
  828. llvm::Constant* GetThroat() {
  829. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 2);
  830. }
  831. llvm::Constant* GetStored() {
  832. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3);
  833. }
  834. llvm::Constant* GetIsOutOfMemory() {
  835. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 4);
  836. }
  837. TLLVMFieldsStructureState(llvm::LLVMContext& context)
  838. : TBase(context)
  839. , ValueType(Type::getInt128Ty(Context))
  840. , PtrValueType(PointerType::getUnqual(ValueType))
  841. , StatusType(Type::getInt32Ty(Context))
  842. , StoredType(Type::getInt64Ty(Context))
  843. , BoolType(Type::getInt1Ty(Context)) {
  844. }
  845. };
  846. #endif
  847. template <bool TrackRss, bool SkipYields>
  848. class TWideCombinerWrapper: public TStatefulWideFlowCodegeneratorNode<TWideCombinerWrapper<TrackRss, SkipYields>>
  849. #ifndef MKQL_DISABLE_CODEGEN
  850. , public ICodegeneratorRootNode
  851. #endif
  852. {
  853. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideCombinerWrapper<TrackRss, SkipYields>>;
  854. public:
  855. TWideCombinerWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TCombinerNodes&& nodes, TKeyTypes&& keyTypes, ui64 memLimit)
  856. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  857. , Flow(flow)
  858. , Nodes(std::move(nodes))
  859. , KeyTypes(std::move(keyTypes))
  860. , MemLimit(memLimit)
  861. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Nodes.ItemNodes.size()))
  862. {}
  863. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  864. if (state.IsInvalid()) {
  865. MakeState(ctx, state);
  866. }
  867. while (const auto ptr = static_cast<TState*>(state.AsBoxed().Get())) {
  868. if (ptr->ReadMore<SkipYields>()) {
  869. switch (ptr->InputStatus) {
  870. case EFetchResult::One:
  871. break;
  872. case EFetchResult::Yield:
  873. ptr->InputStatus = EFetchResult::One;
  874. if constexpr (SkipYields)
  875. break;
  876. else
  877. return EFetchResult::Yield;
  878. case EFetchResult::Finish:
  879. return EFetchResult::Finish;
  880. }
  881. const auto initUsage = MemLimit ? ctx.HolderFactory.GetMemoryUsed() : 0ULL;
  882. auto **fields = ctx.WideFields.data() + WideFieldsIndex;
  883. do {
  884. for (auto i = 0U; i < Nodes.ItemNodes.size(); ++i)
  885. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U || Nodes.PasstroughtItems[i])
  886. fields[i] = &Nodes.ItemNodes[i]->RefValue(ctx);
  887. ptr->InputStatus = Flow->FetchValues(ctx, fields);
  888. if constexpr (SkipYields) {
  889. if (EFetchResult::Yield == ptr->InputStatus) {
  890. if (MemLimit) {
  891. const auto currentUsage = ctx.HolderFactory.GetMemoryUsed();
  892. ptr->StoredDataSize += currentUsage > initUsage ? currentUsage - initUsage : 0;
  893. }
  894. return EFetchResult::Yield;
  895. } else if (EFetchResult::Finish == ptr->InputStatus) {
  896. break;
  897. }
  898. } else {
  899. if (EFetchResult::One != ptr->InputStatus) {
  900. break;
  901. }
  902. }
  903. Nodes.ExtractKey(ctx, fields, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  904. Nodes.ProcessItem(ctx, ptr->TasteIt() ? nullptr : static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  905. } while (!ctx.template CheckAdjustedMemLimit<TrackRss>(MemLimit, initUsage - ptr->StoredDataSize) && !ptr->IsOutOfMemory);
  906. ptr->PushStat(ctx.Stats);
  907. }
  908. if (const auto values = static_cast<NUdf::TUnboxedValue*>(ptr->Extract())) {
  909. Nodes.FinishItem(ctx, values, output);
  910. return EFetchResult::One;
  911. }
  912. }
  913. Y_UNREACHABLE();
  914. }
  915. #ifndef MKQL_DISABLE_CODEGEN
  916. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  917. auto& context = ctx.Codegen.GetContext();
  918. const auto valueType = Type::getInt128Ty(context);
  919. const auto ptrValueType = PointerType::getUnqual(valueType);
  920. const auto statusType = Type::getInt32Ty(context);
  921. const auto storedType = Type::getInt64Ty(context);
  922. TLLVMFieldsStructureState stateFields(context);
  923. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  924. const auto statePtrType = PointerType::getUnqual(stateType);
  925. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  926. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  927. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  928. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  929. block = make;
  930. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  931. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  932. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideCombinerWrapper::MakeState));
  933. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  934. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  935. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  936. BranchInst::Create(main, block);
  937. block = main;
  938. const auto state = new LoadInst(valueType, statePtr, "state", block);
  939. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  940. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  941. BranchInst::Create(more, block);
  942. block = more;
  943. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  944. const auto result = PHINode::Create(statusType, 3U, "result", over);
  945. const auto readMoreFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ReadMore<SkipYields>));
  946. const auto readMoreFuncType = FunctionType::get(Type::getInt1Ty(context), { statePtrType }, false);
  947. const auto readMoreFuncPtr = CastInst::Create(Instruction::IntToPtr, readMoreFunc, PointerType::getUnqual(readMoreFuncType), "read_more_func", block);
  948. const auto readMore = CallInst::Create(readMoreFuncType, readMoreFuncPtr, { stateArg }, "read_more", block);
  949. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  950. const auto full = BasicBlock::Create(context, "full", ctx.Func);
  951. BranchInst::Create(next, full, readMore, block);
  952. {
  953. block = next;
  954. const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
  955. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  956. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  957. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  958. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  959. const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
  960. const auto last = new LoadInst(statusType, statusPtr, "last", block);
  961. result->addIncoming(last, block);
  962. const auto choise = SwitchInst::Create(last, pull, 2U, block);
  963. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), rest);
  964. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), over);
  965. block = rest;
  966. new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::One)), statusPtr, block);
  967. if constexpr (SkipYields) {
  968. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), statusPtr, block);
  969. BranchInst::Create(pull, block);
  970. } else {
  971. result->addIncoming(last, block);
  972. BranchInst::Create(over, block);
  973. }
  974. block = pull;
  975. const auto used = GetMemoryUsed(MemLimit, ctx, block);
  976. BranchInst::Create(loop, block);
  977. block = loop;
  978. const auto getres = GetNodeValues(Flow, ctx, block);
  979. if constexpr (SkipYields) {
  980. const auto save = BasicBlock::Create(context, "save", ctx.Func);
  981. const auto way = SwitchInst::Create(getres.first, good, 2U, block);
  982. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), save);
  983. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), done);
  984. block = save;
  985. if (MemLimit) {
  986. const auto storedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStored() }, "stored_ptr", block);
  987. const auto lastStored = new LoadInst(storedType, storedPtr, "last_stored", block);
  988. const auto currentUsage = GetMemoryUsed(MemLimit, ctx, block);
  989. const auto skipSavingUsed = BasicBlock::Create(context, "skip_saving_used", ctx.Func);
  990. const auto saveUsed = BasicBlock::Create(context, "save_used", ctx.Func);
  991. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, currentUsage, used, "check", block);
  992. BranchInst::Create(saveUsed, skipSavingUsed, check, block);
  993. block = saveUsed;
  994. const auto usedMemory = BinaryOperator::CreateSub(GetMemoryUsed(MemLimit, ctx, block), used, "used_memory", block);
  995. const auto inc = BinaryOperator::CreateAdd(lastStored, usedMemory, "inc", block);
  996. new StoreInst(inc, storedPtr, block);
  997. BranchInst::Create(skipSavingUsed, block);
  998. block = skipSavingUsed;
  999. }
  1000. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), statusPtr, block);
  1001. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  1002. BranchInst::Create(over, block);
  1003. } else {
  1004. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  1005. BranchInst::Create(done, good, special, block);
  1006. }
  1007. block = good;
  1008. std::vector<Value*> items(Nodes.ItemNodes.size(), nullptr);
  1009. for (ui32 i = 0U; i < items.size(); ++i) {
  1010. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1011. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i] = getres.second[i](ctx, block));
  1012. else if (Nodes.PasstroughtItems[i])
  1013. items[i] = getres.second[i](ctx, block);
  1014. }
  1015. const auto tonguePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetTongue() }, "tongue_ptr", block);
  1016. const auto tongue = new LoadInst(ptrValueType, tonguePtr, "tongue", block);
  1017. std::vector<Value*> keyPointers(Nodes.KeyResultNodes.size(), nullptr), keys(Nodes.KeyResultNodes.size(), nullptr);
  1018. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1019. auto& key = keys[i];
  1020. const auto keyPtr = keyPointers[i] = GetElementPtrInst::CreateInBounds(valueType, tongue, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("key_") += ToString(i)).c_str(), block);
  1021. if (const auto map = Nodes.KeysOnItems[i]) {
  1022. auto& it = items[*map];
  1023. if (!it)
  1024. it = getres.second[*map](ctx, block);
  1025. key = it;
  1026. } else {
  1027. key = GetNodeValue(Nodes.KeyResultNodes[i], ctx, block);
  1028. }
  1029. if (Nodes.KeyNodes[i]->GetDependencesCount() > 0U)
  1030. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.KeyNodes[i])->CreateSetValue(ctx, block, key);
  1031. new StoreInst(key, keyPtr, block);
  1032. }
  1033. const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::TasteIt));
  1034. const auto atType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
  1035. const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
  1036. const auto newKey = CallInst::Create(atType, atPtr, {stateArg}, "new_key", block);
  1037. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  1038. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  1039. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  1040. const auto throatPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "throat_ptr", block);
  1041. const auto throat = new LoadInst(ptrValueType, throatPtr, "throat", block);
  1042. std::vector<Value*> pointers;
  1043. pointers.reserve(Nodes.StateNodes.size());
  1044. for (ui32 i = 0U; i < Nodes.StateNodes.size(); ++i) {
  1045. pointers.emplace_back(GetElementPtrInst::CreateInBounds(valueType, throat, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("state_") += ToString(i)).c_str(), block));
  1046. }
  1047. BranchInst::Create(init, next, newKey, block);
  1048. block = init;
  1049. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1050. ValueAddRef(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1051. }
  1052. for (ui32 i = 0U; i < Nodes.InitResultNodes.size(); ++i) {
  1053. if (const auto map = Nodes.InitOnItems[i]) {
  1054. auto& it = items[*map];
  1055. if (!it)
  1056. it = getres.second[*map](ctx, block);
  1057. new StoreInst(it, pointers[i], block);
  1058. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), it, ctx, block);
  1059. } else if (const auto map = Nodes.InitOnKeys[i]) {
  1060. const auto key = keys[*map];
  1061. new StoreInst(key, pointers[i], block);
  1062. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), key, ctx, block);
  1063. } else {
  1064. GetNodeValue(pointers[i], Nodes.InitResultNodes[i], ctx, block);
  1065. }
  1066. }
  1067. BranchInst::Create(test, block);
  1068. block = next;
  1069. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1070. if (Nodes.KeysOnItems[i] || Nodes.KeyResultNodes[i]->IsTemporaryValue())
  1071. ValueCleanup(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1072. }
  1073. std::vector<Value*> stored(Nodes.StateNodes.size(), nullptr);
  1074. for (ui32 i = 0U; i < stored.size(); ++i) {
  1075. const bool hasDependency = Nodes.StateNodes[i]->GetDependencesCount() > 0U;
  1076. if (const auto map = Nodes.StateOnUpdate[i]) {
  1077. if (hasDependency || i != *map) {
  1078. stored[i] = new LoadInst(valueType, pointers[i], (TString("state_") += ToString(i)).c_str(), block);
  1079. if (hasDependency)
  1080. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, stored[i]);
  1081. }
  1082. } else if (hasDependency) {
  1083. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, pointers[i]);
  1084. } else {
  1085. ValueUnRef(Nodes.StateNodes[i]->GetRepresentation(), pointers[i], ctx, block);
  1086. }
  1087. }
  1088. for (ui32 i = 0U; i < Nodes.UpdateResultNodes.size(); ++i) {
  1089. if (const auto map = Nodes.UpdateOnState[i]) {
  1090. if (const auto j = *map; i != j) {
  1091. auto& it = stored[j];
  1092. if (!it)
  1093. it = new LoadInst(valueType, pointers[j], (TString("state_") += ToString(j)).c_str(), block);
  1094. new StoreInst(it, pointers[i], block);
  1095. if (i != *Nodes.StateOnUpdate[j])
  1096. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1097. }
  1098. } else if (const auto map = Nodes.UpdateOnItems[i]) {
  1099. auto& it = items[*map];
  1100. if (!it)
  1101. it = getres.second[*map](ctx, block);
  1102. new StoreInst(it, pointers[i], block);
  1103. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1104. } else if (const auto map = Nodes.UpdateOnKeys[i]) {
  1105. const auto key = keys[*map];
  1106. new StoreInst(key, pointers[i], block);
  1107. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), key, ctx, block);
  1108. } else {
  1109. GetNodeValue(pointers[i], Nodes.UpdateResultNodes[i], ctx, block);
  1110. }
  1111. }
  1112. BranchInst::Create(test, block);
  1113. block = test;
  1114. auto totalUsed = used;
  1115. if (MemLimit) {
  1116. const auto storedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStored() }, "stored_ptr", block);
  1117. const auto lastStored = new LoadInst(storedType, storedPtr, "last_stored", block);
  1118. totalUsed = BinaryOperator::CreateSub(used, lastStored, "decr", block);
  1119. }
  1120. const auto check = CheckAdjustedMemLimit<TrackRss>(MemLimit, totalUsed, ctx, block);
  1121. const auto isOutOfMemoryPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsOutOfMemory() }, "is_out_of_memory_ptr", block);
  1122. const auto isOutOfMemory = new LoadInst(Type::getInt1Ty(context), isOutOfMemoryPtr, "is_out_of_memory", block);
  1123. const auto checkIsOutOfMemory = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, isOutOfMemory, ConstantInt::getTrue(context), "check_is_out_of_memory", block);
  1124. const auto any = BinaryOperator::CreateOr(check, checkIsOutOfMemory, "any", block);
  1125. BranchInst::Create(done, loop, any, block);
  1126. block = done;
  1127. new StoreInst(getres.first, statusPtr, block);
  1128. const auto stat = ctx.GetStat();
  1129. const auto statFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::PushStat));
  1130. const auto statType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType(), stat->getType()}, false);
  1131. const auto statPtr = CastInst::Create(Instruction::IntToPtr, statFunc, PointerType::getUnqual(statType), "stat", block);
  1132. CallInst::Create(statType, statPtr, {stateArg, stat}, "", block);
  1133. BranchInst::Create(full, block);
  1134. }
  1135. {
  1136. block = full;
  1137. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  1138. const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract));
  1139. const auto extractType = FunctionType::get(ptrValueType, {stateArg->getType()}, false);
  1140. const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
  1141. const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
  1142. const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block);
  1143. BranchInst::Create(good, more, has, block);
  1144. block = good;
  1145. for (ui32 i = 0U; i < Nodes.FinishNodes.size(); ++i) {
  1146. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, out, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("out_key_") += ToString(i)).c_str(), block);
  1147. if (Nodes.FinishNodes[i]->GetDependencesCount() > 0 || Nodes.ItemsOnResult[i])
  1148. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.FinishNodes[i])->CreateSetValue(ctx, block, ptr);
  1149. else
  1150. ValueUnRef(Nodes.FinishNodes[i]->GetRepresentation(), ptr, ctx, block);
  1151. }
  1152. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  1153. BranchInst::Create(over, block);
  1154. }
  1155. block = over;
  1156. ICodegeneratorInlineWideNode::TGettersList getters;
  1157. getters.reserve(Nodes.FinishResultNodes.size());
  1158. std::transform(Nodes.FinishResultNodes.cbegin(), Nodes.FinishResultNodes.cend(), std::back_inserter(getters), [&](IComputationNode* node) {
  1159. return [node](const TCodegenContext& ctx, BasicBlock*& block){ return GetNodeValue(node, ctx, block); };
  1160. });
  1161. return {result, std::move(getters)};
  1162. }
  1163. #endif
  1164. private:
  1165. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  1166. #ifdef MKQL_DISABLE_CODEGEN
  1167. state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(), TMyValueHasher(KeyTypes), TMyValueEqual(KeyTypes));
  1168. #else
  1169. state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(),
  1170. ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
  1171. ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes))
  1172. );
  1173. #endif
  1174. if (ctx.CountersProvider) {
  1175. const auto ptr = static_cast<TState*>(state.AsBoxed().Get());
  1176. // id will be assigned externally in future versions
  1177. TString id = TString(Operator_Aggregation) + "0";
  1178. ptr->CounterOutputRows_ = ctx.CountersProvider->GetCounter(id, Counter_OutputRows, false);
  1179. }
  1180. }
  1181. void RegisterDependencies() const final {
  1182. if (const auto flow = this->FlowDependsOn(Flow)) {
  1183. Nodes.RegisterDependencies(
  1184. [this, flow](IComputationNode* node){ this->DependsOn(flow, node); },
  1185. [this, flow](IComputationExternalNode* node){ this->Own(flow, node); }
  1186. );
  1187. }
  1188. }
  1189. IComputationWideFlowNode *const Flow;
  1190. const TCombinerNodes Nodes;
  1191. const TKeyTypes KeyTypes;
  1192. const ui64 MemLimit;
  1193. const ui32 WideFieldsIndex;
  1194. #ifndef MKQL_DISABLE_CODEGEN
  1195. TEqualsPtr Equals = nullptr;
  1196. THashPtr Hash = nullptr;
  1197. Function* EqualsFunc = nullptr;
  1198. Function* HashFunc = nullptr;
  1199. template <bool EqualsOrHash>
  1200. TString MakeName() const {
  1201. TStringStream out;
  1202. out << this->DebugString() << "::" << (EqualsOrHash ? "Equals" : "Hash") << "_(" << static_cast<const void*>(this) << ").";
  1203. return out.Str();
  1204. }
  1205. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1206. if (EqualsFunc) {
  1207. Equals = reinterpret_cast<TEqualsPtr>(codegen.GetPointerToFunction(EqualsFunc));
  1208. }
  1209. if (HashFunc) {
  1210. Hash = reinterpret_cast<THashPtr>(codegen.GetPointerToFunction(HashFunc));
  1211. }
  1212. }
  1213. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1214. codegen.ExportSymbol(HashFunc = GenerateHashFunction(codegen, MakeName<false>(), KeyTypes));
  1215. codegen.ExportSymbol(EqualsFunc = GenerateEqualsFunction(codegen, MakeName<true>(), KeyTypes));
  1216. }
  1217. #endif
  1218. };
  1219. class TWideLastCombinerWrapper: public TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWrapper>
  1220. #ifndef MKQL_DISABLE_CODEGEN
  1221. , public ICodegeneratorRootNode
  1222. #endif
  1223. {
  1224. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWrapper>;
  1225. public:
  1226. TWideLastCombinerWrapper(
  1227. TComputationMutables& mutables,
  1228. IComputationWideFlowNode* flow,
  1229. TCombinerNodes&& nodes,
  1230. const TMultiType* usedInputItemType,
  1231. TKeyTypes&& keyTypes,
  1232. const TMultiType* keyAndStateType,
  1233. bool allowSpilling)
  1234. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  1235. , Flow(flow)
  1236. , Nodes(std::move(nodes))
  1237. , KeyTypes(std::move(keyTypes))
  1238. , UsedInputItemType(usedInputItemType)
  1239. , KeyAndStateType(keyAndStateType)
  1240. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Nodes.ItemNodes.size()))
  1241. , AllowSpilling(allowSpilling)
  1242. {}
  1243. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  1244. if (state.IsInvalid()) {
  1245. MakeState(ctx, state);
  1246. }
  1247. if (const auto ptr = static_cast<TSpillingSupportState*>(state.AsBoxed().Get())) {
  1248. auto **fields = ctx.WideFields.data() + WideFieldsIndex;
  1249. while (true) {
  1250. switch(ptr->Update()) {
  1251. case TSpillingSupportState::EUpdateResult::ReadInput: {
  1252. for (auto i = 0U; i < Nodes.ItemNodes.size(); ++i)
  1253. fields[i] = Nodes.GetUsedInputItemNodePtrOrNull(ctx, i);
  1254. switch (ptr->InputStatus = Flow->FetchValues(ctx, fields)) {
  1255. case EFetchResult::One:
  1256. break;
  1257. case EFetchResult::Finish:
  1258. continue;
  1259. case EFetchResult::Yield:
  1260. return EFetchResult::Yield;
  1261. }
  1262. Nodes.ExtractKey(ctx, fields, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  1263. break;
  1264. }
  1265. case TSpillingSupportState::EUpdateResult::Yield:
  1266. return EFetchResult::Yield;
  1267. case TSpillingSupportState::EUpdateResult::ExtractRawData:
  1268. Nodes.ExtractRawData(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Throat), static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  1269. break;
  1270. case TSpillingSupportState::EUpdateResult::Extract:
  1271. if (const auto values = static_cast<NUdf::TUnboxedValue*>(ptr->Extract())) {
  1272. Nodes.FinishItem(ctx, values, output);
  1273. return EFetchResult::One;
  1274. }
  1275. continue;
  1276. case TSpillingSupportState::EUpdateResult::Finish:
  1277. return EFetchResult::Finish;
  1278. }
  1279. switch(ptr->TasteIt()) {
  1280. case TSpillingSupportState::ETasteResult::Init:
  1281. Nodes.ProcessItem(ctx, nullptr, static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1282. break;
  1283. case TSpillingSupportState::ETasteResult::Update:
  1284. Nodes.ProcessItem(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1285. break;
  1286. case TSpillingSupportState::ETasteResult::ConsumeRawData:
  1287. Nodes.ConsumeRawData(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), fields, static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1288. break;
  1289. }
  1290. }
  1291. }
  1292. Y_UNREACHABLE();
  1293. }
  1294. #ifndef MKQL_DISABLE_CODEGEN
  1295. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  1296. auto& context = ctx.Codegen.GetContext();
  1297. const auto valueType = Type::getInt128Ty(context);
  1298. const auto ptrValueType = PointerType::getUnqual(valueType);
  1299. const auto statusType = Type::getInt32Ty(context);
  1300. const auto wayType = Type::getInt8Ty(context);
  1301. TLLVMFieldsStructureState stateFields(context);
  1302. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  1303. const auto statePtrType = PointerType::getUnqual(stateType);
  1304. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  1305. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  1306. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  1307. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  1308. block = make;
  1309. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  1310. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  1311. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideLastCombinerWrapper::MakeState));
  1312. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  1313. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  1314. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  1315. BranchInst::Create(main, block);
  1316. block = main;
  1317. const auto state = new LoadInst(valueType, statePtr, "state", block);
  1318. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  1319. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  1320. BranchInst::Create(more, block);
  1321. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  1322. const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
  1323. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  1324. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  1325. const auto load = BasicBlock::Create(context, "load", ctx.Func);
  1326. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  1327. const auto data = BasicBlock::Create(context, "data", ctx.Func);
  1328. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  1329. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  1330. const auto stub = BasicBlock::Create(context, "stub", ctx.Func);
  1331. new UnreachableInst(context, stub);
  1332. const auto result = PHINode::Create(statusType, 4U, "result", over);
  1333. std::vector<PHINode*> phis(Nodes.ItemNodes.size(), nullptr);
  1334. auto j = 0U;
  1335. std::generate(phis.begin(), phis.end(), [&]() {
  1336. return Nodes.IsInputItemNodeUsed(j++) ?
  1337. PHINode::Create(valueType, 2U, (TString("item_") += ToString(j)).c_str(), test) : nullptr;
  1338. });
  1339. block = more;
  1340. const auto updateFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::Update));
  1341. const auto updateType = FunctionType::get(wayType, {stateArg->getType()}, false);
  1342. const auto updateFuncPtr = CastInst::Create(Instruction::IntToPtr, updateFunc, PointerType::getUnqual(updateType), "update_func", block);
  1343. const auto update = CallInst::Create(updateType, updateFuncPtr, { stateArg }, "update", block);
  1344. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  1345. const auto updateWay = SwitchInst::Create(update, stub, 5U, block);
  1346. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Yield)), over);
  1347. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Extract)), fill);
  1348. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Finish)), done);
  1349. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::ReadInput)), pull);
  1350. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::ExtractRawData)), load);
  1351. block = load;
  1352. const auto extractorPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "extractor_ptr", block);
  1353. const auto extractor = new LoadInst(ptrValueType, extractorPtr, "extractor", block);
  1354. std::vector<Value*> items(phis.size(), nullptr);
  1355. for (ui32 i = 0U; i < items.size(); ++i) {
  1356. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, extractor, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("load_ptr_") += ToString(i)).c_str(), block);
  1357. if (phis[i])
  1358. items[i] = new LoadInst(valueType, ptr, (TString("load_") += ToString(i)).c_str(), block);
  1359. if (i < Nodes.ItemNodes.size() && Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1360. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i]);
  1361. }
  1362. for (ui32 i = 0U; i < phis.size(); ++i) {
  1363. if (const auto phi = phis[i]) {
  1364. phi->addIncoming(items[i], block);
  1365. }
  1366. }
  1367. BranchInst::Create(test, block);
  1368. block = pull;
  1369. const auto getres = GetNodeValues(Flow, ctx, block);
  1370. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  1371. const auto choise = SwitchInst::Create(getres.first, good, 2U, block);
  1372. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), over);
  1373. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), rest);
  1374. block = rest;
  1375. const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
  1376. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), statusPtr, block);
  1377. BranchInst::Create(more, block);
  1378. block = good;
  1379. for (ui32 i = 0U; i < items.size(); ++i) {
  1380. if (phis[i])
  1381. items[i] = getres.second[i](ctx, block);
  1382. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1383. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i]);
  1384. }
  1385. for (ui32 i = 0U; i < phis.size(); ++i) {
  1386. if (const auto phi = phis[i]) {
  1387. phi->addIncoming(items[i], block);
  1388. }
  1389. }
  1390. BranchInst::Create(test, block);
  1391. block = test;
  1392. const auto tonguePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetTongue() }, "tongue_ptr", block);
  1393. const auto tongue = new LoadInst(ptrValueType, tonguePtr, "tongue", block);
  1394. std::vector<Value*> keyPointers(Nodes.KeyResultNodes.size(), nullptr), keys(Nodes.KeyResultNodes.size(), nullptr);
  1395. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1396. auto& key = keys[i];
  1397. const auto keyPtr = keyPointers[i] = GetElementPtrInst::CreateInBounds(valueType, tongue, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("key_") += ToString(i)).c_str(), block);
  1398. if (const auto map = Nodes.KeysOnItems[i]) {
  1399. key = phis[*map];
  1400. } else {
  1401. key = GetNodeValue(Nodes.KeyResultNodes[i], ctx, block);
  1402. }
  1403. if (Nodes.KeyNodes[i]->GetDependencesCount() > 0U)
  1404. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.KeyNodes[i])->CreateSetValue(ctx, block, key);
  1405. new StoreInst(key, keyPtr, block);
  1406. }
  1407. const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::TasteIt));
  1408. const auto atType = FunctionType::get(wayType, {stateArg->getType()}, false);
  1409. const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
  1410. const auto taste= CallInst::Create(atType, atPtr, {stateArg}, "taste", block);
  1411. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  1412. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  1413. const auto save = BasicBlock::Create(context, "save", ctx.Func);
  1414. const auto throatPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "throat_ptr", block);
  1415. const auto throat = new LoadInst(ptrValueType, throatPtr, "throat", block);
  1416. std::vector<Value*> pointers;
  1417. const auto width = std::max(Nodes.StateNodes.size(), phis.size());
  1418. pointers.reserve(width);
  1419. for (ui32 i = 0U; i < width; ++i) {
  1420. pointers.emplace_back(GetElementPtrInst::CreateInBounds(valueType, throat, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("state_") += ToString(i)).c_str(), block));
  1421. }
  1422. const auto way = SwitchInst::Create(taste, stub, 3U, block);
  1423. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Init)), init);
  1424. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Update)), next);
  1425. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::ConsumeRawData)), save);
  1426. block = init;
  1427. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1428. ValueAddRef(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1429. }
  1430. for (ui32 i = 0U; i < Nodes.InitResultNodes.size(); ++i) {
  1431. if (const auto map = Nodes.InitOnItems[i]) {
  1432. const auto item = phis[*map];
  1433. new StoreInst(item, pointers[i], block);
  1434. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), item, ctx, block);
  1435. } else if (const auto map = Nodes.InitOnKeys[i]) {
  1436. const auto key = keys[*map];
  1437. new StoreInst(key, pointers[i], block);
  1438. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), key, ctx, block);
  1439. } else {
  1440. GetNodeValue(pointers[i], Nodes.InitResultNodes[i], ctx, block);
  1441. }
  1442. }
  1443. BranchInst::Create(more, block);
  1444. block = next;
  1445. std::vector<Value*> stored(Nodes.StateNodes.size(), nullptr);
  1446. for (ui32 i = 0U; i < stored.size(); ++i) {
  1447. const bool hasDependency = Nodes.StateNodes[i]->GetDependencesCount() > 0U;
  1448. if (const auto map = Nodes.StateOnUpdate[i]) {
  1449. if (hasDependency || i != *map) {
  1450. stored[i] = new LoadInst(valueType, pointers[i], (TString("state_") += ToString(i)).c_str(), block);
  1451. if (hasDependency)
  1452. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, stored[i]);
  1453. }
  1454. } else if (hasDependency) {
  1455. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, pointers[i]);
  1456. } else {
  1457. ValueUnRef(Nodes.StateNodes[i]->GetRepresentation(), pointers[i], ctx, block);
  1458. }
  1459. }
  1460. for (ui32 i = 0U; i < Nodes.UpdateResultNodes.size(); ++i) {
  1461. if (const auto map = Nodes.UpdateOnState[i]) {
  1462. if (const auto j = *map; i != j) {
  1463. const auto it = stored[j];
  1464. new StoreInst(it, pointers[i], block);
  1465. if (i != *Nodes.StateOnUpdate[j])
  1466. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1467. }
  1468. } else if (const auto map = Nodes.UpdateOnItems[i]) {
  1469. const auto item = phis[*map];
  1470. new StoreInst(item, pointers[i], block);
  1471. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), item, ctx, block);
  1472. } else if (const auto map = Nodes.UpdateOnKeys[i]) {
  1473. const auto key = keys[*map];
  1474. new StoreInst(key, pointers[i], block);
  1475. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), key, ctx, block);
  1476. } else {
  1477. GetNodeValue(pointers[i], Nodes.UpdateResultNodes[i], ctx, block);
  1478. }
  1479. }
  1480. BranchInst::Create(more, block);
  1481. block = save;
  1482. for (ui32 i = 0U; i < phis.size(); ++i) {
  1483. if (const auto item = phis[i]) {
  1484. new StoreInst(item, pointers[i], block);
  1485. ValueAddRef(Nodes.ItemNodes[i]->GetRepresentation(), item, ctx, block);
  1486. }
  1487. }
  1488. BranchInst::Create(more, block);
  1489. block = fill;
  1490. const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::Extract));
  1491. const auto extractType = FunctionType::get(ptrValueType, {stateArg->getType()}, false);
  1492. const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
  1493. const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
  1494. const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block);
  1495. BranchInst::Create(data, more, has, block);
  1496. block = data;
  1497. for (ui32 i = 0U; i < Nodes.FinishNodes.size(); ++i) {
  1498. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, out, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("out_key_") += ToString(i)).c_str(), block);
  1499. if (Nodes.FinishNodes[i]->GetDependencesCount() > 0 || Nodes.ItemsOnResult[i])
  1500. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.FinishNodes[i])->CreateSetValue(ctx, block, ptr);
  1501. else
  1502. ValueUnRef(Nodes.FinishNodes[i]->GetRepresentation(), ptr, ctx, block);
  1503. }
  1504. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  1505. BranchInst::Create(over, block);
  1506. block = done;
  1507. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  1508. BranchInst::Create(over, block);
  1509. block = over;
  1510. ICodegeneratorInlineWideNode::TGettersList getters;
  1511. getters.reserve(Nodes.FinishResultNodes.size());
  1512. std::transform(Nodes.FinishResultNodes.cbegin(), Nodes.FinishResultNodes.cend(), std::back_inserter(getters), [&](IComputationNode* node) {
  1513. return [node](const TCodegenContext& ctx, BasicBlock*& block){ return GetNodeValue(node, ctx, block); };
  1514. });
  1515. return {result, std::move(getters)};
  1516. }
  1517. #endif
  1518. private:
  1519. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  1520. state = ctx.HolderFactory.Create<TSpillingSupportState>(UsedInputItemType, KeyAndStateType,
  1521. Nodes.KeyNodes.size(),
  1522. Nodes.ItemNodes.size(),
  1523. #ifdef MKQL_DISABLE_CODEGEN
  1524. TMyValueHasher(KeyTypes),
  1525. TMyValueEqual(KeyTypes),
  1526. #else
  1527. ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
  1528. ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes)),
  1529. #endif
  1530. AllowSpilling,
  1531. ctx
  1532. );
  1533. }
  1534. void RegisterDependencies() const final {
  1535. if (const auto flow = this->FlowDependsOn(Flow)) {
  1536. Nodes.RegisterDependencies(
  1537. [this, flow](IComputationNode* node){ this->DependsOn(flow, node); },
  1538. [this, flow](IComputationExternalNode* node){ this->Own(flow, node); }
  1539. );
  1540. }
  1541. }
  1542. IComputationWideFlowNode *const Flow;
  1543. const TCombinerNodes Nodes;
  1544. const TKeyTypes KeyTypes;
  1545. const TMultiType* const UsedInputItemType;
  1546. const TMultiType* const KeyAndStateType;
  1547. const ui32 WideFieldsIndex;
  1548. const bool AllowSpilling;
  1549. #ifndef MKQL_DISABLE_CODEGEN
  1550. TEqualsPtr Equals = nullptr;
  1551. THashPtr Hash = nullptr;
  1552. Function* EqualsFunc = nullptr;
  1553. Function* HashFunc = nullptr;
  1554. template <bool EqualsOrHash>
  1555. TString MakeName() const {
  1556. TStringStream out;
  1557. out << this->DebugString() << "::" << (EqualsOrHash ? "Equals" : "Hash") << "_(" << static_cast<const void*>(this) << ").";
  1558. return out.Str();
  1559. }
  1560. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1561. if (EqualsFunc) {
  1562. Equals = reinterpret_cast<TEqualsPtr>(codegen.GetPointerToFunction(EqualsFunc));
  1563. }
  1564. if (HashFunc) {
  1565. Hash = reinterpret_cast<THashPtr>(codegen.GetPointerToFunction(HashFunc));
  1566. }
  1567. }
  1568. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1569. codegen.ExportSymbol(HashFunc = GenerateHashFunction(codegen, MakeName<false>(), KeyTypes));
  1570. codegen.ExportSymbol(EqualsFunc = GenerateEqualsFunction(codegen, MakeName<true>(), KeyTypes));
  1571. }
  1572. #endif
  1573. };
  1574. }
  1575. template<bool Last>
  1576. IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeFactoryContext& ctx, bool allowSpilling) {
  1577. MKQL_ENSURE(callable.GetInputsCount() >= (Last ? 3U : 4U), "Expected more arguments.");
  1578. const auto inputType = AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType());
  1579. const auto inputWidth = GetWideComponentsCount(inputType);
  1580. const auto outputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType()));
  1581. const auto flow = LocateNode(ctx.NodeLocator, callable, 0U);
  1582. auto index = Last ? 0U : 1U;
  1583. const auto keysSize = AS_VALUE(TDataLiteral, callable.GetInput(++index))->AsValue().Get<ui32>();
  1584. const auto stateSize = AS_VALUE(TDataLiteral, callable.GetInput(++index))->AsValue().Get<ui32>();
  1585. ++index += inputWidth;
  1586. std::vector<TType*> keyAndStateItemTypes;
  1587. keyAndStateItemTypes.reserve(keysSize + stateSize);
  1588. TKeyTypes keyTypes;
  1589. keyTypes.reserve(keysSize);
  1590. for (ui32 i = index; i < index + keysSize; ++i) {
  1591. TType *type = callable.GetInput(i).GetStaticType();
  1592. keyAndStateItemTypes.push_back(type);
  1593. bool optional;
  1594. keyTypes.emplace_back(*UnpackOptionalData(callable.GetInput(i).GetStaticType(), optional)->GetDataSlot(), optional);
  1595. }
  1596. TCombinerNodes nodes;
  1597. nodes.KeyResultNodes.reserve(keysSize);
  1598. std::generate_n(std::back_inserter(nodes.KeyResultNodes), keysSize, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1599. index += keysSize;
  1600. nodes.InitResultNodes.reserve(stateSize);
  1601. for (size_t i = 0; i != stateSize; ++i) {
  1602. TType *type = callable.GetInput(index).GetStaticType();
  1603. keyAndStateItemTypes.push_back(type);
  1604. nodes.InitResultNodes.push_back(LocateNode(ctx.NodeLocator, callable, index++));
  1605. }
  1606. index += stateSize;
  1607. nodes.UpdateResultNodes.reserve(stateSize);
  1608. std::generate_n(std::back_inserter(nodes.UpdateResultNodes), stateSize, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1609. index += keysSize + stateSize;
  1610. nodes.FinishResultNodes.reserve(outputWidth);
  1611. std::generate_n(std::back_inserter(nodes.FinishResultNodes), outputWidth, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1612. index = Last ? 3U : 4U;
  1613. nodes.ItemNodes.reserve(inputWidth);
  1614. std::generate_n(std::back_inserter(nodes.ItemNodes), inputWidth, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1615. index += keysSize;
  1616. nodes.KeyNodes.reserve(keysSize);
  1617. std::generate_n(std::back_inserter(nodes.KeyNodes), keysSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1618. index += stateSize;
  1619. nodes.StateNodes.reserve(stateSize);
  1620. std::generate_n(std::back_inserter(nodes.StateNodes), stateSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1621. index += stateSize;
  1622. nodes.FinishNodes.reserve(keysSize + stateSize);
  1623. std::generate_n(std::back_inserter(nodes.FinishNodes), keysSize + stateSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1624. nodes.BuildMaps();
  1625. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
  1626. if constexpr (Last) {
  1627. const auto inputItemTypes = GetWideComponents(inputType);
  1628. return new TWideLastCombinerWrapper(ctx.Mutables, wide, std::move(nodes),
  1629. TMultiType::Create(inputItemTypes.size(), inputItemTypes.data(), ctx.Env),
  1630. std::move(keyTypes),
  1631. TMultiType::Create(keyAndStateItemTypes.size(),keyAndStateItemTypes.data(), ctx.Env),
  1632. allowSpilling
  1633. );
  1634. } else {
  1635. if constexpr (RuntimeVersion < 46U) {
  1636. const auto memLimit = AS_VALUE(TDataLiteral, callable.GetInput(1U))->AsValue().Get<ui64>();
  1637. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1638. return new TWideCombinerWrapper<true, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), memLimit);
  1639. else
  1640. return new TWideCombinerWrapper<false, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), memLimit);
  1641. } else {
  1642. if (const auto memLimit = AS_VALUE(TDataLiteral, callable.GetInput(1U))->AsValue().Get<i64>(); memLimit >= 0)
  1643. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1644. return new TWideCombinerWrapper<true, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(memLimit));
  1645. else
  1646. return new TWideCombinerWrapper<false, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(memLimit));
  1647. else
  1648. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1649. return new TWideCombinerWrapper<true, true>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(-memLimit));
  1650. else
  1651. return new TWideCombinerWrapper<false, true>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(-memLimit));
  1652. }
  1653. }
  1654. }
  1655. THROW yexception() << "Expected wide flow.";
  1656. }
  1657. IComputationNode* WrapWideCombiner(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1658. return WrapWideCombinerT<false>(callable, ctx, false);
  1659. }
  1660. IComputationNode* WrapWideLastCombiner(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1661. YQL_LOG(INFO) << "Found non-serializable type, spilling is disabled";
  1662. return WrapWideCombinerT<true>(callable, ctx, false);
  1663. }
  1664. IComputationNode* WrapWideLastCombinerWithSpilling(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1665. return WrapWideCombinerT<true>(callable, ctx, true);
  1666. }
  1667. }
  1668. }