mkql_wide_combine.cpp 89 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016
  1. #include "mkql_counters.h"
  2. #include "mkql_rh_hash.h"
  3. #include "mkql_wide_combine.h"
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  5. #include <yql/essentials/minikql/computation/mkql_llvm_base.h> // Y_IGNORE
  6. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  7. #include <yql/essentials/minikql/computation/mkql_spiller_adapter.h>
  8. #include <yql/essentials/minikql/computation/mkql_spiller.h>
  9. #include <yql/essentials/minikql/mkql_node_builder.h>
  10. #include <yql/essentials/minikql/mkql_node_cast.h>
  11. #include <yql/essentials/minikql/mkql_runtime_version.h>
  12. #include <yql/essentials/minikql/mkql_stats_registry.h>
  13. #include <yql/essentials/minikql/defs.h>
  14. #include <yql/essentials/utils/cast.h>
  15. #include <yql/essentials/utils/log/log.h>
  16. #include <util/string/cast.h>
  17. #include <contrib/libs/xxhash/xxhash.h>
  18. namespace NKikimr {
  19. namespace NMiniKQL {
  20. using NYql::EnsureDynamicCast;
  21. using NYql::TChunkedBuffer;
  22. extern TStatKey Combine_FlushesCount;
  23. extern TStatKey Combine_MaxRowsCount;
  24. namespace {
  25. bool HasMemoryForProcessing() {
  26. return !TlsAllocState->IsMemoryYellowZoneEnabled();
  27. }
  28. struct TMyValueEqual {
  29. TMyValueEqual(const TKeyTypes& types)
  30. : Types(types)
  31. {}
  32. bool operator()(const NUdf::TUnboxedValuePod* left, const NUdf::TUnboxedValuePod* right) const {
  33. for (ui32 i = 0U; i < Types.size(); ++i)
  34. if (CompareValues(Types[i].first, true, Types[i].second, left[i], right[i]))
  35. return false;
  36. return true;
  37. }
  38. const TKeyTypes& Types;
  39. };
  40. struct TMyValueHasher {
  41. TMyValueHasher(const TKeyTypes& types)
  42. : Types(types)
  43. {}
  44. NUdf::THashType operator()(const NUdf::TUnboxedValuePod* values) const {
  45. if (Types.size() == 1U)
  46. if (const auto v = *values)
  47. return NUdf::GetValueHash(Types.front().first, v);
  48. else
  49. return HashOfNull;
  50. NUdf::THashType hash = 0ULL;
  51. for (const auto& type : Types) {
  52. if (const auto v = *values++)
  53. hash = CombineHashes(hash, NUdf::GetValueHash(type.first, v));
  54. else
  55. hash = CombineHashes(hash, HashOfNull);
  56. }
  57. return hash;
  58. }
  59. const TKeyTypes& Types;
  60. };
  61. using TEqualsPtr = bool(*)(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*);
  62. using THashPtr = NUdf::THashType(*)(const NUdf::TUnboxedValuePod*);
  63. using TEqualsFunc = std::function<bool(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)>;
  64. using THashFunc = std::function<NUdf::THashType(const NUdf::TUnboxedValuePod*)>;
  65. using TDependsOn = std::function<void(IComputationNode*)>;
  66. using TOwn = std::function<void(IComputationExternalNode*)>;
  67. struct TCombinerNodes {
  68. TComputationExternalNodePtrVector ItemNodes, KeyNodes, StateNodes, FinishNodes;
  69. TComputationNodePtrVector KeyResultNodes, InitResultNodes, UpdateResultNodes, FinishResultNodes;
  70. TPasstroughtMap
  71. KeysOnItems,
  72. InitOnKeys,
  73. InitOnItems,
  74. UpdateOnKeys,
  75. UpdateOnItems,
  76. UpdateOnState,
  77. StateOnUpdate,
  78. ItemsOnResult,
  79. ResultOnItems;
  80. std::vector<bool> PasstroughtItems;
  81. void BuildMaps() {
  82. KeysOnItems = GetPasstroughtMap(KeyResultNodes, ItemNodes);
  83. InitOnKeys = GetPasstroughtMap(InitResultNodes, KeyNodes);
  84. InitOnItems = GetPasstroughtMap(InitResultNodes, ItemNodes);
  85. UpdateOnKeys = GetPasstroughtMap(UpdateResultNodes, KeyNodes);
  86. UpdateOnItems = GetPasstroughtMap(UpdateResultNodes, ItemNodes);
  87. UpdateOnState = GetPasstroughtMap(UpdateResultNodes, StateNodes);
  88. StateOnUpdate = GetPasstroughtMap(StateNodes, UpdateResultNodes);
  89. ItemsOnResult = GetPasstroughtMap(FinishNodes, FinishResultNodes);
  90. ResultOnItems = GetPasstroughtMap(FinishResultNodes, FinishNodes);
  91. PasstroughtItems.resize(ItemNodes.size());
  92. auto anyResults = KeyResultNodes;
  93. anyResults.insert(anyResults.cend(), InitResultNodes.cbegin(), InitResultNodes.cend());
  94. anyResults.insert(anyResults.cend(), UpdateResultNodes.cbegin(), UpdateResultNodes.cend());
  95. const auto itemsOnResults = GetPasstroughtMap(ItemNodes, anyResults);
  96. std::transform(itemsOnResults.cbegin(), itemsOnResults.cend(), PasstroughtItems.begin(), [](const TPasstroughtMap::value_type& v) { return v.has_value(); });
  97. }
  98. bool IsInputItemNodeUsed(size_t i) const {
  99. return (ItemNodes[i]->GetDependencesCount() > 0U || PasstroughtItems[i]);
  100. }
  101. NUdf::TUnboxedValue* GetUsedInputItemNodePtrOrNull(TComputationContext& ctx, size_t i) const {
  102. return IsInputItemNodeUsed(i) ?
  103. &ItemNodes[i]->RefValue(ctx) :
  104. nullptr;
  105. }
  106. void ExtractKey(TComputationContext& ctx, NUdf::TUnboxedValue** values, NUdf::TUnboxedValue* keys) const {
  107. std::for_each(ItemNodes.cbegin(), ItemNodes.cend(), [&](IComputationExternalNode* item) {
  108. if (const auto pointer = *values++)
  109. item->SetValue(ctx, std::move(*pointer));
  110. });
  111. for (ui32 i = 0U; i < KeyNodes.size(); ++i) {
  112. auto& key = KeyNodes[i]->RefValue(ctx);
  113. *keys++ = key = KeyResultNodes[i]->GetValue(ctx);
  114. }
  115. }
  116. void ConsumeRawData(TComputationContext& /*ctx*/, NUdf::TUnboxedValue* keys, NUdf::TUnboxedValue** from, NUdf::TUnboxedValue* to) const {
  117. std::fill_n(keys, KeyResultNodes.size(), NUdf::TUnboxedValuePod());
  118. for (ui32 i = 0U; i < ItemNodes.size(); ++i) {
  119. if (from[i] && IsInputItemNodeUsed(i)) {
  120. to[i] = std::move(*(from[i]));
  121. }
  122. }
  123. }
  124. void ExtractRawData(TComputationContext& ctx, NUdf::TUnboxedValue* from, NUdf::TUnboxedValue* keys) const {
  125. for (ui32 i = 0U; i != ItemNodes.size(); ++i) {
  126. if (IsInputItemNodeUsed(i)) {
  127. ItemNodes[i]->SetValue(ctx, std::move(from[i]));
  128. }
  129. }
  130. for (ui32 i = 0U; i < KeyNodes.size(); ++i) {
  131. auto& key = KeyNodes[i]->RefValue(ctx);
  132. *keys++ = key = KeyResultNodes[i]->GetValue(ctx);
  133. }
  134. }
  135. void ProcessItem(TComputationContext& ctx, NUdf::TUnboxedValue* keys, NUdf::TUnboxedValue* state) const {
  136. if (keys) {
  137. std::fill_n(keys, KeyResultNodes.size(), NUdf::TUnboxedValuePod());
  138. auto source = state;
  139. std::for_each(StateNodes.cbegin(), StateNodes.cend(), [&](IComputationExternalNode* item){ item->SetValue(ctx, std::move(*source++)); });
  140. std::transform(UpdateResultNodes.cbegin(), UpdateResultNodes.cend(), state, [&](IComputationNode* node) { return node->GetValue(ctx); });
  141. } else {
  142. std::transform(InitResultNodes.cbegin(), InitResultNodes.cend(), state, [&](IComputationNode* node) { return node->GetValue(ctx); });
  143. }
  144. }
  145. void FinishItem(TComputationContext& ctx, NUdf::TUnboxedValue* state, NUdf::TUnboxedValue*const* output) const {
  146. std::for_each(FinishNodes.cbegin(), FinishNodes.cend(), [&](IComputationExternalNode* item) { item->SetValue(ctx, std::move(*state++)); });
  147. for (const auto node : FinishResultNodes)
  148. if (const auto out = *output++)
  149. *out = node->GetValue(ctx);
  150. }
  151. void RegisterDependencies(const TDependsOn& dependsOn, const TOwn& own) const {
  152. std::for_each(ItemNodes.cbegin(), ItemNodes.cend(), own);
  153. std::for_each(KeyNodes.cbegin(), KeyNodes.cend(), own);
  154. std::for_each(StateNodes.cbegin(), StateNodes.cend(), own);
  155. std::for_each(FinishNodes.cbegin(), FinishNodes.cend(), own);
  156. std::for_each(KeyResultNodes.cbegin(), KeyResultNodes.cend(), dependsOn);
  157. std::for_each(InitResultNodes.cbegin(), InitResultNodes.cend(), dependsOn);
  158. std::for_each(UpdateResultNodes.cbegin(), UpdateResultNodes.cend(), dependsOn);
  159. std::for_each(FinishResultNodes.cbegin(), FinishResultNodes.cend(), dependsOn);
  160. }
  161. };
  162. class TState : public TComputationValue<TState> {
  163. typedef TComputationValue<TState> TBase;
  164. private:
  165. using TStates = TRobinHoodHashSet<NUdf::TUnboxedValuePod*, TEqualsFunc, THashFunc, TMKQLAllocator<char, EMemorySubPool::Temporary>>;
  166. using TRow = std::vector<NUdf::TUnboxedValuePod, TMKQLAllocator<NUdf::TUnboxedValuePod>>;
  167. using TStorage = std::deque<TRow, TMKQLAllocator<TRow>>;
  168. class TStorageIterator {
  169. private:
  170. TStorage& Storage;
  171. const ui32 RowSize = 0;
  172. const ui64 Count = 0;
  173. ui64 Ready = 0;
  174. TStorage::iterator ItStorage;
  175. TRow::iterator ItRow;
  176. public:
  177. TStorageIterator(TStorage& storage, const ui32 rowSize, const ui64 count)
  178. : Storage(storage)
  179. , RowSize(rowSize)
  180. , Count(count)
  181. {
  182. ItStorage = Storage.begin();
  183. if (ItStorage != Storage.end()) {
  184. ItRow = ItStorage->begin();
  185. }
  186. }
  187. bool IsValid() {
  188. return Ready < Count;
  189. }
  190. bool Next() {
  191. if (++Ready >= Count) {
  192. return false;
  193. }
  194. ItRow += RowSize;
  195. if (ItRow == ItStorage->end()) {
  196. ++ItStorage;
  197. ItRow = ItStorage->begin();
  198. }
  199. return true;
  200. }
  201. NUdf::TUnboxedValuePod* GetValuePtr() const {
  202. return &*ItRow;
  203. }
  204. };
  205. static constexpr ui32 CountRowsOnPage = 128;
  206. ui32 RowSize() const {
  207. return KeyWidth + StateWidth;
  208. }
  209. public:
  210. TState(TMemoryUsageInfo* memInfo, ui32 keyWidth, ui32 stateWidth, const THashFunc& hash, const TEqualsFunc& equal, bool allowOutOfMemory = true)
  211. : TBase(memInfo), KeyWidth(keyWidth), StateWidth(stateWidth), AllowOutOfMemory(allowOutOfMemory), Hash(hash), Equal(equal) {
  212. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  213. CurrentPosition = 0;
  214. Tongue = CurrentPage->data();
  215. States = std::make_unique<TStates>(Hash, Equal, CountRowsOnPage);
  216. }
  217. ~TState() {
  218. //Workaround for YQL-16663, consider to rework this class in a safe manner
  219. while (auto row = Extract()) {
  220. for (size_t i = 0; i != RowSize(); ++i) {
  221. row[i].UnRef();
  222. }
  223. }
  224. ExtractIt.reset();
  225. Storage.clear();
  226. States->Clear();
  227. CleanupCurrentContext();
  228. }
  229. bool TasteIt() {
  230. Y_ABORT_UNLESS(!ExtractIt);
  231. bool isNew = false;
  232. auto itInsert = States->Insert(Tongue, isNew);
  233. if (isNew) {
  234. CurrentPosition += RowSize();
  235. if (CurrentPosition == CurrentPage->size()) {
  236. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  237. CurrentPosition = 0;
  238. }
  239. Tongue = CurrentPage->data() + CurrentPosition;
  240. }
  241. Throat = States->GetKey(itInsert) + KeyWidth;
  242. if (isNew) {
  243. GrowStates();
  244. }
  245. IsOutOfMemory = IsOutOfMemory || (!HasMemoryForProcessing() && States->GetSize() > 1000);
  246. return isNew;
  247. }
  248. void GrowStates() {
  249. try {
  250. States->CheckGrow();
  251. } catch (TMemoryLimitExceededException) {
  252. YQL_LOG(INFO) << "State failed to grow";
  253. if (IsOutOfMemory || !AllowOutOfMemory) {
  254. throw;
  255. } else {
  256. IsOutOfMemory = true;
  257. }
  258. }
  259. }
  260. template<bool SkipYields>
  261. bool ReadMore() {
  262. if constexpr (SkipYields) {
  263. if (EFetchResult::Yield == InputStatus)
  264. return true;
  265. }
  266. if (!States->Empty())
  267. return false;
  268. {
  269. TStorage localStorage;
  270. std::swap(localStorage, Storage);
  271. }
  272. if (IsOutOfMemory) {
  273. States = std::make_unique<TStates>(Hash, Equal, CountRowsOnPage);
  274. }
  275. CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod());
  276. CurrentPosition = 0;
  277. Tongue = CurrentPage->data();
  278. StoredDataSize = 0;
  279. IsOutOfMemory = false;
  280. CleanupCurrentContext();
  281. return true;
  282. }
  283. void PushStat(IStatsRegistry* stats) const {
  284. if (!States->Empty()) {
  285. MKQL_SET_MAX_STAT(stats, Combine_MaxRowsCount, static_cast<i64>(States->GetSize()));
  286. MKQL_INC_STAT(stats, Combine_FlushesCount);
  287. }
  288. }
  289. NUdf::TUnboxedValuePod* Extract() {
  290. if (!ExtractIt) {
  291. ExtractIt.emplace(Storage, RowSize(), States->GetSize());
  292. } else {
  293. ExtractIt->Next();
  294. }
  295. if (!ExtractIt->IsValid()) {
  296. ExtractIt.reset();
  297. States->Clear();
  298. return nullptr;
  299. }
  300. NUdf::TUnboxedValuePod* result = ExtractIt->GetValuePtr();
  301. CounterOutputRows_.Inc();
  302. return result;
  303. }
  304. EFetchResult InputStatus = EFetchResult::One;
  305. NUdf::TUnboxedValuePod* Tongue = nullptr;
  306. NUdf::TUnboxedValuePod* Throat = nullptr;
  307. i64 StoredDataSize = 0;
  308. bool IsOutOfMemory = false;
  309. NYql::NUdf::TCounter CounterOutputRows_;
  310. private:
  311. std::optional<TStorageIterator> ExtractIt;
  312. const ui32 KeyWidth, StateWidth;
  313. const bool AllowOutOfMemory;
  314. ui64 CurrentPosition = 0;
  315. TRow* CurrentPage = nullptr;
  316. TStorage Storage;
  317. std::unique_ptr<TStates> States;
  318. const THashFunc Hash;
  319. const TEqualsFunc Equal;
  320. };
  321. class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
  322. typedef TComputationValue<TSpillingSupportState> TBase;
  323. typedef std::optional<NThreading::TFuture<ISpiller::TKey>> TAsyncWriteOperation;
  324. typedef std::optional<NThreading::TFuture<std::optional<TChunkedBuffer>>> TAsyncReadOperation;
  325. struct TSpilledBucket {
  326. std::unique_ptr<TWideUnboxedValuesSpillerAdapter> SpilledState; //state collected before switching to spilling mode
  327. std::unique_ptr<TWideUnboxedValuesSpillerAdapter> SpilledData; //data collected in spilling mode
  328. std::unique_ptr<TState> InMemoryProcessingState;
  329. TAsyncWriteOperation AsyncWriteOperation;
  330. enum class EBucketState {
  331. InMemory,
  332. SpillingState,
  333. SpillingData
  334. };
  335. EBucketState BucketState = EBucketState::InMemory;
  336. ui64 LineCount = 0;
  337. };
  338. enum class EOperatingMode {
  339. InMemory,
  340. SplittingState,
  341. Spilling,
  342. ProcessSpilled
  343. };
  344. public:
  345. enum class ETasteResult: i8 {
  346. Init = -1,
  347. Update,
  348. ConsumeRawData
  349. };
  350. enum class EUpdateResult: i8 {
  351. Yield = -1,
  352. ExtractRawData,
  353. ReadInput,
  354. Extract,
  355. Finish
  356. };
  357. TSpillingSupportState(
  358. TMemoryUsageInfo* memInfo,
  359. const TMultiType* usedInputItemType, const TMultiType* keyAndStateType, ui32 keyWidth, size_t itemNodesSize,
  360. const THashFunc& hash, const TEqualsFunc& equal, bool allowSpilling, TComputationContext& ctx
  361. )
  362. : TBase(memInfo)
  363. , InMemoryProcessingState(memInfo, keyWidth, keyAndStateType->GetElementsCount() - keyWidth, hash, equal, allowSpilling && ctx.SpillerFactory)
  364. , UsedInputItemType(usedInputItemType)
  365. , KeyAndStateType(keyAndStateType)
  366. , KeyWidth(keyWidth)
  367. , ItemNodesSize(itemNodesSize)
  368. , Hasher(hash)
  369. , Mode(EOperatingMode::InMemory)
  370. , ViewForKeyAndState(keyAndStateType->GetElementsCount())
  371. , MemInfo(memInfo)
  372. , Equal(equal)
  373. , AllowSpilling(allowSpilling)
  374. , Ctx(ctx)
  375. {
  376. BufferForUsedInputItems.reserve(usedInputItemType->GetElementsCount());
  377. Tongue = InMemoryProcessingState.Tongue;
  378. Throat = InMemoryProcessingState.Throat;
  379. if (ctx.CountersProvider) {
  380. // id will be assigned externally in future versions
  381. TString id = TString(Operator_Aggregation) + "0";
  382. CounterOutputRows_ = ctx.CountersProvider->GetCounter(id, Counter_OutputRows, false);
  383. }
  384. }
  385. EUpdateResult Update() {
  386. if (IsEverythingExtracted) {
  387. return EUpdateResult::Finish;
  388. }
  389. switch (GetMode()) {
  390. case EOperatingMode::InMemory: {
  391. Tongue = InMemoryProcessingState.Tongue;
  392. if (CheckMemoryAndSwitchToSpilling()) {
  393. return Update();
  394. }
  395. if (InputStatus == EFetchResult::Finish) return EUpdateResult::Extract;
  396. return EUpdateResult::ReadInput;
  397. }
  398. case EOperatingMode::SplittingState: {
  399. if (SplitStateIntoBucketsAndWait()) return EUpdateResult::Yield;
  400. return Update();
  401. }
  402. case EOperatingMode::Spilling: {
  403. UpdateSpillingBuckets();
  404. if (!HasMemoryForProcessing() && InputStatus != EFetchResult::Finish && TryToReduceMemoryAndWait()) {
  405. return EUpdateResult::Yield;
  406. }
  407. if (BufferForUsedInputItems.size()) {
  408. auto& bucket = SpilledBuckets[BufferForUsedInputItemsBucketId];
  409. if (bucket.AsyncWriteOperation.has_value()) return EUpdateResult::Yield;
  410. bucket.AsyncWriteOperation = bucket.SpilledData->WriteWideItem(BufferForUsedInputItems);
  411. BufferForUsedInputItems.resize(0); //for freeing allocated key value asap
  412. }
  413. if (InputStatus == EFetchResult::Finish) return FlushSpillingBuffersAndWait();
  414. return EUpdateResult::ReadInput;
  415. }
  416. case EOperatingMode::ProcessSpilled:
  417. return ProcessSpilledData();
  418. }
  419. }
  420. ETasteResult TasteIt() {
  421. if (GetMode() == EOperatingMode::InMemory) {
  422. bool isNew = InMemoryProcessingState.TasteIt();
  423. if (InMemoryProcessingState.IsOutOfMemory) {
  424. StateWantsToSpill = true;
  425. }
  426. Throat = InMemoryProcessingState.Throat;
  427. return isNew ? ETasteResult::Init : ETasteResult::Update;
  428. }
  429. if (GetMode() == EOperatingMode::ProcessSpilled) {
  430. // while restoration we process buckets one by one starting from the first in a queue
  431. bool isNew = SpilledBuckets.front().InMemoryProcessingState->TasteIt();
  432. Throat = SpilledBuckets.front().InMemoryProcessingState->Throat;
  433. return isNew ? ETasteResult::Init : ETasteResult::Update;
  434. }
  435. auto bucketId = ChooseBucket(ViewForKeyAndState.data());
  436. auto& bucket = SpilledBuckets[bucketId];
  437. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) {
  438. std::copy_n(ViewForKeyAndState.data(), KeyWidth, static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Tongue));
  439. bool isNew = bucket.InMemoryProcessingState->TasteIt();
  440. Throat = bucket.InMemoryProcessingState->Throat;
  441. bucket.LineCount += isNew;
  442. return isNew ? ETasteResult::Init : ETasteResult::Update;
  443. }
  444. bucket.LineCount++;
  445. // Prepare space for raw data
  446. MKQL_ENSURE(BufferForUsedInputItems.size() == 0, "Internal logic error");
  447. BufferForUsedInputItems.resize(ItemNodesSize);
  448. BufferForUsedInputItemsBucketId = bucketId;
  449. Throat = BufferForUsedInputItems.data();
  450. return ETasteResult::ConsumeRawData;
  451. }
  452. NUdf::TUnboxedValuePod* Extract() {
  453. NUdf::TUnboxedValue* value = nullptr;
  454. if (GetMode() == EOperatingMode::InMemory) {
  455. value = static_cast<NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract());
  456. if (value) {
  457. CounterOutputRows_.Inc();
  458. } else {
  459. IsEverythingExtracted = true;
  460. }
  461. return value;
  462. }
  463. MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error");
  464. MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error");
  465. value = static_cast<NUdf::TUnboxedValue*>(SpilledBuckets.front().InMemoryProcessingState->Extract());
  466. if (value) {
  467. CounterOutputRows_.Inc();
  468. } else {
  469. SpilledBuckets.front().InMemoryProcessingState->ReadMore<false>();
  470. SpilledBuckets.pop_front();
  471. if (SpilledBuckets.empty()) IsEverythingExtracted = true;
  472. }
  473. return value;
  474. }
  475. private:
  476. ui64 ChooseBucket(const NUdf::TUnboxedValuePod *const key) {
  477. auto provided_hash = Hasher(key);
  478. XXH64_hash_t bucket = XXH64(&provided_hash, sizeof(provided_hash), 0) % SpilledBucketCount;
  479. return bucket;
  480. }
  481. EUpdateResult FlushSpillingBuffersAndWait() {
  482. UpdateSpillingBuckets();
  483. ui64 finishedCount = 0;
  484. for (auto& bucket : SpilledBuckets) {
  485. MKQL_ENSURE(bucket.BucketState != TSpilledBucket::EBucketState::SpillingState, "Internal logic error");
  486. if (!bucket.AsyncWriteOperation.has_value()) {
  487. auto writeOperation = bucket.SpilledData->FinishWriting();
  488. if (!writeOperation) {
  489. ++finishedCount;
  490. } else {
  491. bucket.AsyncWriteOperation = writeOperation;
  492. }
  493. }
  494. }
  495. if (finishedCount != SpilledBuckets.size()) return EUpdateResult::Yield;
  496. SwitchMode(EOperatingMode::ProcessSpilled);
  497. return ProcessSpilledData();
  498. }
  499. ui32 GetLargestInMemoryBucketNumber() const {
  500. ui64 maxSize = 0;
  501. ui32 largestInMemoryBucketNum = (ui32)-1;
  502. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  503. if (SpilledBuckets[i].BucketState == TSpilledBucket::EBucketState::InMemory) {
  504. if (SpilledBuckets[i].LineCount >= maxSize) {
  505. largestInMemoryBucketNum = i;
  506. maxSize = SpilledBuckets[i].LineCount;
  507. }
  508. }
  509. }
  510. return largestInMemoryBucketNum;
  511. }
  512. bool IsSpillingWhileStateSplitAllowed() const {
  513. // TODO: Write better condition here. For example: InMemorybuckets > 64
  514. return true;
  515. }
  516. bool SplitStateIntoBucketsAndWait() {
  517. if (SplitStateSpillingBucket != -1) {
  518. auto& bucket = SpilledBuckets[SplitStateSpillingBucket];
  519. MKQL_ENSURE(bucket.AsyncWriteOperation.has_value(), "Internal logic error");
  520. if (!bucket.AsyncWriteOperation->HasValue()) return true;
  521. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  522. bucket.AsyncWriteOperation = std::nullopt;
  523. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  524. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  525. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  526. //releasing values stored in unsafe TUnboxedValue buffer
  527. keyAndState[i].UnRef();
  528. }
  529. if (bucket.AsyncWriteOperation) return true;
  530. }
  531. SplitStateSpillingBucket = -1;
  532. }
  533. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue *>(InMemoryProcessingState.Extract())) {
  534. auto bucketId = ChooseBucket(keyAndState); // This uses only key for hashing
  535. auto& bucket = SpilledBuckets[bucketId];
  536. bucket.LineCount++;
  537. if (bucket.BucketState != TSpilledBucket::EBucketState::InMemory) {
  538. if (bucket.BucketState != TSpilledBucket::EBucketState::SpillingState) {
  539. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  540. SpillingBucketsCount++;
  541. }
  542. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  543. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  544. //releasing values stored in unsafe TUnboxedValue buffer
  545. keyAndState[i].UnRef();
  546. }
  547. if (bucket.AsyncWriteOperation) {
  548. SplitStateSpillingBucket = bucketId;
  549. return true;
  550. }
  551. continue;
  552. }
  553. auto& processingState = *bucket.InMemoryProcessingState;
  554. for (size_t i = 0; i < KeyWidth; ++i) {
  555. //jumping into unsafe world, refusing ownership
  556. static_cast<NUdf::TUnboxedValue&>(processingState.Tongue[i]) = std::move(keyAndState[i]);
  557. }
  558. processingState.TasteIt();
  559. for (size_t i = KeyWidth; i < KeyAndStateType->GetElementsCount(); ++i) {
  560. //jumping into unsafe world, refusing ownership
  561. static_cast<NUdf::TUnboxedValue&>(processingState.Throat[i - KeyWidth]) = std::move(keyAndState[i]);
  562. }
  563. if (InMemoryBucketsCount && !HasMemoryForProcessing() && IsSpillingWhileStateSplitAllowed()) {
  564. ui32 bucketNumToSpill = GetLargestInMemoryBucketNumber();
  565. SplitStateSpillingBucket = bucketNumToSpill;
  566. auto& bucket = SpilledBuckets[bucketNumToSpill];
  567. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  568. SpillingBucketsCount++;
  569. InMemoryBucketsCount--;
  570. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  571. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  572. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  573. //releasing values stored in unsafe TUnboxedValue buffer
  574. keyAndState[i].UnRef();
  575. }
  576. if (bucket.AsyncWriteOperation) return true;
  577. }
  578. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  579. if (bucket.AsyncWriteOperation) return true;
  580. }
  581. }
  582. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  583. auto& bucket = SpilledBuckets[i];
  584. if (bucket.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  585. if (bucket.AsyncWriteOperation.has_value()) {
  586. if (!bucket.AsyncWriteOperation->HasValue()) return true;
  587. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  588. bucket.AsyncWriteOperation = std::nullopt;
  589. }
  590. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  591. if (bucket.AsyncWriteOperation) return true;
  592. bucket.InMemoryProcessingState->ReadMore<false>();
  593. bucket.BucketState = TSpilledBucket::EBucketState::SpillingData;
  594. SpillingBucketsCount--;
  595. }
  596. }
  597. InMemoryProcessingState.ReadMore<false>();
  598. IsInMemoryProcessingStateSplitted = true;
  599. SwitchMode(EOperatingMode::Spilling);
  600. return false;
  601. }
  602. bool CheckMemoryAndSwitchToSpilling() {
  603. if (!(AllowSpilling && Ctx.SpillerFactory)) {
  604. return false;
  605. }
  606. if (StateWantsToSpill || IsSwitchToSpillingModeCondition()) {
  607. StateWantsToSpill = false;
  608. LogMemoryUsage();
  609. SwitchMode(EOperatingMode::SplittingState);
  610. return true;
  611. }
  612. return false;
  613. }
  614. void LogMemoryUsage() const {
  615. const auto used = TlsAllocState->GetUsed();
  616. const auto limit = TlsAllocState->GetLimit();
  617. TStringBuilder logmsg;
  618. logmsg << "Memory usage: ";
  619. if (limit) {
  620. logmsg << (used*100/limit) << "%=";
  621. }
  622. logmsg << (used/1_MB) << "MB/" << (limit/1_MB) << "MB";
  623. YQL_LOG(INFO) << logmsg;
  624. }
  625. void SpillMoreStateFromBucket(TSpilledBucket& bucket) {
  626. MKQL_ENSURE(!bucket.AsyncWriteOperation.has_value(), "Internal logic error");
  627. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) {
  628. bucket.BucketState = TSpilledBucket::EBucketState::SpillingState;
  629. SpillingBucketsCount++;
  630. InMemoryBucketsCount--;
  631. }
  632. while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
  633. bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()});
  634. for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) {
  635. //releasing values stored in unsafe TUnboxedValue buffer
  636. keyAndState[i].UnRef();
  637. }
  638. if (bucket.AsyncWriteOperation) return;
  639. }
  640. bucket.AsyncWriteOperation = bucket.SpilledState->FinishWriting();
  641. if (bucket.AsyncWriteOperation) return;
  642. bucket.InMemoryProcessingState->ReadMore<false>();
  643. bucket.BucketState = TSpilledBucket::EBucketState::SpillingData;
  644. SpillingBucketsCount--;
  645. }
  646. void UpdateSpillingBuckets() {
  647. for (ui64 i = 0; i < SpilledBucketCount; ++i) {
  648. auto& bucket = SpilledBuckets[i];
  649. if (bucket.AsyncWriteOperation.has_value() && bucket.AsyncWriteOperation->HasValue()) {
  650. if (bucket.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  651. bucket.SpilledState->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  652. bucket.AsyncWriteOperation = std::nullopt;
  653. SpillMoreStateFromBucket(bucket);
  654. } else {
  655. bucket.SpilledData->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
  656. bucket.AsyncWriteOperation = std::nullopt;
  657. }
  658. }
  659. }
  660. }
  661. bool TryToReduceMemoryAndWait() {
  662. if (SpillingBucketsCount > 0) {
  663. return true;
  664. }
  665. while (InMemoryBucketsCount > 0) {
  666. ui32 maxLineBucketInd = GetLargestInMemoryBucketNumber();
  667. MKQL_ENSURE(maxLineBucketInd != (ui32)-1, "Internal logic error");
  668. auto& bucketToSpill = SpilledBuckets[maxLineBucketInd];
  669. SpillMoreStateFromBucket(bucketToSpill);
  670. if (bucketToSpill.BucketState == TSpilledBucket::EBucketState::SpillingState) {
  671. return true;
  672. }
  673. }
  674. return false;
  675. }
  676. EUpdateResult ProcessSpilledData() {
  677. if (AsyncReadOperation) {
  678. if (!AsyncReadOperation->HasValue()) return EUpdateResult::Yield;
  679. if (RecoverState) {
  680. SpilledBuckets[0].SpilledState->AsyncReadCompleted(AsyncReadOperation->ExtractValue().value(), Ctx.HolderFactory);
  681. } else {
  682. SpilledBuckets[0].SpilledData->AsyncReadCompleted(AsyncReadOperation->ExtractValue().value(), Ctx.HolderFactory);
  683. }
  684. AsyncReadOperation = std::nullopt;
  685. }
  686. auto& bucket = SpilledBuckets.front();
  687. if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) return EUpdateResult::Extract;
  688. //recover spilled state
  689. while(!bucket.SpilledState->Empty()) {
  690. RecoverState = true;
  691. TTemporaryUnboxedValueVector bufferForKeyAndState(KeyAndStateType->GetElementsCount());
  692. AsyncReadOperation = bucket.SpilledState->ExtractWideItem(bufferForKeyAndState);
  693. if (AsyncReadOperation) {
  694. return EUpdateResult::Yield;
  695. }
  696. for (size_t i = 0; i< KeyWidth; ++i) {
  697. //jumping into unsafe world, refusing ownership
  698. static_cast<NUdf::TUnboxedValue&>(bucket.InMemoryProcessingState->Tongue[i]) = std::move(bufferForKeyAndState[i]);
  699. }
  700. auto isNew = bucket.InMemoryProcessingState->TasteIt();
  701. MKQL_ENSURE(isNew, "Internal logic error");
  702. for (size_t i = KeyWidth; i < KeyAndStateType->GetElementsCount(); ++i) {
  703. //jumping into unsafe world, refusing ownership
  704. static_cast<NUdf::TUnboxedValue&>(bucket.InMemoryProcessingState->Throat[i - KeyWidth]) = std::move(bufferForKeyAndState[i]);
  705. }
  706. }
  707. //process spilled data
  708. if (!bucket.SpilledData->Empty()) {
  709. RecoverState = false;
  710. std::fill(BufferForUsedInputItems.begin(), BufferForUsedInputItems.end(), NUdf::TUnboxedValuePod());
  711. AsyncReadOperation = bucket.SpilledData->ExtractWideItem(BufferForUsedInputItems);
  712. if (AsyncReadOperation) {
  713. return EUpdateResult::Yield;
  714. }
  715. Throat = BufferForUsedInputItems.data();
  716. Tongue = bucket.InMemoryProcessingState->Tongue;
  717. return EUpdateResult::ExtractRawData;
  718. }
  719. bucket.BucketState = TSpilledBucket::EBucketState::InMemory;
  720. return EUpdateResult::Extract;
  721. }
  722. EOperatingMode GetMode() const {
  723. return Mode;
  724. }
  725. void SwitchMode(EOperatingMode mode) {
  726. switch(mode) {
  727. case EOperatingMode::InMemory: {
  728. YQL_LOG(INFO) << "switching Memory mode to InMemory";
  729. MKQL_ENSURE(false, "Internal logic error");
  730. break;
  731. }
  732. case EOperatingMode::SplittingState: {
  733. YQL_LOG(INFO) << "switching Memory mode to SplittingState";
  734. MKQL_ENSURE(EOperatingMode::InMemory == Mode, "Internal logic error");
  735. SpilledBuckets.resize(SpilledBucketCount);
  736. auto spiller = Ctx.SpillerFactory->CreateSpiller();
  737. for (auto &b: SpilledBuckets) {
  738. b.SpilledState = std::make_unique<TWideUnboxedValuesSpillerAdapter>(spiller, KeyAndStateType, 5_MB);
  739. b.SpilledData = std::make_unique<TWideUnboxedValuesSpillerAdapter>(spiller, UsedInputItemType, 5_MB);
  740. b.InMemoryProcessingState = std::make_unique<TState>(MemInfo, KeyWidth, KeyAndStateType->GetElementsCount() - KeyWidth, Hasher, Equal, false);
  741. }
  742. break;
  743. }
  744. case EOperatingMode::Spilling: {
  745. YQL_LOG(INFO) << "switching Memory mode to Spilling";
  746. MKQL_ENSURE(EOperatingMode::SplittingState == Mode || EOperatingMode::InMemory == Mode, "Internal logic error");
  747. Tongue = ViewForKeyAndState.data();
  748. break;
  749. }
  750. case EOperatingMode::ProcessSpilled: {
  751. YQL_LOG(INFO) << "switching Memory mode to ProcessSpilled";
  752. MKQL_ENSURE(EOperatingMode::Spilling == Mode, "Internal logic error");
  753. MKQL_ENSURE(SpilledBuckets.size() == SpilledBucketCount, "Internal logic error");
  754. MKQL_ENSURE(BufferForUsedInputItems.empty(), "Internal logic error");
  755. BufferForUsedInputItems.resize(UsedInputItemType->GetElementsCount());
  756. std::sort(SpilledBuckets.begin(), SpilledBuckets.end(), [](const TSpilledBucket& lhs, const TSpilledBucket& rhs) {
  757. bool lhs_in_memory = lhs.BucketState == TSpilledBucket::EBucketState::InMemory;
  758. bool rhs_in_memory = rhs.BucketState == TSpilledBucket::EBucketState::InMemory;
  759. return lhs_in_memory > rhs_in_memory;
  760. });
  761. break;
  762. }
  763. }
  764. Mode = mode;
  765. }
  766. bool IsSwitchToSpillingModeCondition() const {
  767. return !HasMemoryForProcessing() || TlsAllocState->GetMaximumLimitValueReached();
  768. }
  769. public:
  770. EFetchResult InputStatus = EFetchResult::One;
  771. NUdf::TUnboxedValuePod* Tongue = nullptr;
  772. NUdf::TUnboxedValuePod* Throat = nullptr;
  773. private:
  774. bool StateWantsToSpill = false;
  775. bool IsEverythingExtracted = false;
  776. TState InMemoryProcessingState;
  777. bool IsInMemoryProcessingStateSplitted = false;
  778. const TMultiType* const UsedInputItemType;
  779. const TMultiType* const KeyAndStateType;
  780. const size_t KeyWidth;
  781. const size_t ItemNodesSize;
  782. THashFunc const Hasher;
  783. EOperatingMode Mode;
  784. bool RecoverState; //sub mode for ProcessSpilledData
  785. TAsyncReadOperation AsyncReadOperation = std::nullopt;
  786. static constexpr size_t SpilledBucketCount = 128;
  787. std::deque<TSpilledBucket> SpilledBuckets;
  788. ui32 SpillingBucketsCount = 0;
  789. ui32 InMemoryBucketsCount = SpilledBucketCount;
  790. ui64 BufferForUsedInputItemsBucketId;
  791. TUnboxedValueVector BufferForUsedInputItems;
  792. std::vector<NUdf::TUnboxedValuePod, TMKQLAllocator<NUdf::TUnboxedValuePod>> ViewForKeyAndState;
  793. i64 SplitStateSpillingBucket = -1;
  794. TMemoryUsageInfo* MemInfo = nullptr;
  795. TEqualsFunc const Equal;
  796. const bool AllowSpilling;
  797. TComputationContext& Ctx;
  798. NYql::NUdf::TCounter CounterOutputRows_;
  799. };
  800. #ifndef MKQL_DISABLE_CODEGEN
  801. class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState>> {
  802. private:
  803. using TBase = TLLVMFieldsStructure<TComputationValue<TState>>;
  804. llvm::IntegerType* ValueType;
  805. llvm::PointerType* PtrValueType;
  806. llvm::IntegerType* StatusType;
  807. llvm::IntegerType* StoredType;
  808. llvm::IntegerType* BoolType;
  809. protected:
  810. using TBase::Context;
  811. public:
  812. std::vector<llvm::Type*> GetFieldsArray() {
  813. std::vector<llvm::Type*> result = TBase::GetFields();
  814. result.emplace_back(StatusType); //status
  815. result.emplace_back(PtrValueType); //tongue
  816. result.emplace_back(PtrValueType); //throat
  817. result.emplace_back(StoredType); //StoredDataSize
  818. result.emplace_back(BoolType); //IsOutOfMemory
  819. result.emplace_back(Type::getInt32Ty(Context)); //size
  820. result.emplace_back(Type::getInt32Ty(Context)); //size
  821. return result;
  822. }
  823. llvm::Constant* GetStatus() {
  824. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
  825. }
  826. llvm::Constant* GetTongue() {
  827. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
  828. }
  829. llvm::Constant* GetThroat() {
  830. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 2);
  831. }
  832. llvm::Constant* GetStored() {
  833. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3);
  834. }
  835. llvm::Constant* GetIsOutOfMemory() {
  836. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 4);
  837. }
  838. TLLVMFieldsStructureState(llvm::LLVMContext& context)
  839. : TBase(context)
  840. , ValueType(Type::getInt128Ty(Context))
  841. , PtrValueType(PointerType::getUnqual(ValueType))
  842. , StatusType(Type::getInt32Ty(Context))
  843. , StoredType(Type::getInt64Ty(Context))
  844. , BoolType(Type::getInt1Ty(Context)) {
  845. }
  846. };
  847. #endif
  848. template <bool TrackRss, bool SkipYields>
  849. class TWideCombinerWrapper: public TStatefulWideFlowCodegeneratorNode<TWideCombinerWrapper<TrackRss, SkipYields>>
  850. #ifndef MKQL_DISABLE_CODEGEN
  851. , public ICodegeneratorRootNode
  852. #endif
  853. {
  854. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideCombinerWrapper<TrackRss, SkipYields>>;
  855. public:
  856. TWideCombinerWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TCombinerNodes&& nodes, TKeyTypes&& keyTypes, ui64 memLimit)
  857. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  858. , Flow(flow)
  859. , Nodes(std::move(nodes))
  860. , KeyTypes(std::move(keyTypes))
  861. , MemLimit(memLimit)
  862. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Nodes.ItemNodes.size()))
  863. {}
  864. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  865. if (state.IsInvalid()) {
  866. MakeState(ctx, state);
  867. }
  868. while (const auto ptr = static_cast<TState*>(state.AsBoxed().Get())) {
  869. if (ptr->ReadMore<SkipYields>()) {
  870. switch (ptr->InputStatus) {
  871. case EFetchResult::One:
  872. break;
  873. case EFetchResult::Yield:
  874. ptr->InputStatus = EFetchResult::One;
  875. if constexpr (SkipYields)
  876. break;
  877. else
  878. return EFetchResult::Yield;
  879. case EFetchResult::Finish:
  880. return EFetchResult::Finish;
  881. }
  882. const auto initUsage = MemLimit ? ctx.HolderFactory.GetMemoryUsed() : 0ULL;
  883. auto **fields = ctx.WideFields.data() + WideFieldsIndex;
  884. do {
  885. for (auto i = 0U; i < Nodes.ItemNodes.size(); ++i)
  886. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U || Nodes.PasstroughtItems[i])
  887. fields[i] = &Nodes.ItemNodes[i]->RefValue(ctx);
  888. ptr->InputStatus = Flow->FetchValues(ctx, fields);
  889. if constexpr (SkipYields) {
  890. if (EFetchResult::Yield == ptr->InputStatus) {
  891. if (MemLimit) {
  892. const auto currentUsage = ctx.HolderFactory.GetMemoryUsed();
  893. ptr->StoredDataSize += currentUsage > initUsage ? currentUsage - initUsage : 0;
  894. }
  895. return EFetchResult::Yield;
  896. } else if (EFetchResult::Finish == ptr->InputStatus) {
  897. break;
  898. }
  899. } else {
  900. if (EFetchResult::One != ptr->InputStatus) {
  901. break;
  902. }
  903. }
  904. Nodes.ExtractKey(ctx, fields, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  905. Nodes.ProcessItem(ctx, ptr->TasteIt() ? nullptr : static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  906. } while (!ctx.template CheckAdjustedMemLimit<TrackRss>(MemLimit, initUsage - ptr->StoredDataSize) && !ptr->IsOutOfMemory);
  907. ptr->PushStat(ctx.Stats);
  908. }
  909. if (const auto values = static_cast<NUdf::TUnboxedValue*>(ptr->Extract())) {
  910. Nodes.FinishItem(ctx, values, output);
  911. return EFetchResult::One;
  912. }
  913. }
  914. Y_UNREACHABLE();
  915. }
  916. #ifndef MKQL_DISABLE_CODEGEN
  917. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  918. auto& context = ctx.Codegen.GetContext();
  919. const auto valueType = Type::getInt128Ty(context);
  920. const auto ptrValueType = PointerType::getUnqual(valueType);
  921. const auto statusType = Type::getInt32Ty(context);
  922. const auto storedType = Type::getInt64Ty(context);
  923. TLLVMFieldsStructureState stateFields(context);
  924. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  925. const auto statePtrType = PointerType::getUnqual(stateType);
  926. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  927. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  928. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  929. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  930. block = make;
  931. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  932. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  933. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideCombinerWrapper::MakeState));
  934. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  935. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  936. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  937. BranchInst::Create(main, block);
  938. block = main;
  939. const auto state = new LoadInst(valueType, statePtr, "state", block);
  940. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  941. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  942. BranchInst::Create(more, block);
  943. block = more;
  944. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  945. const auto result = PHINode::Create(statusType, 3U, "result", over);
  946. const auto readMoreFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ReadMore<SkipYields>));
  947. const auto readMoreFuncType = FunctionType::get(Type::getInt1Ty(context), { statePtrType }, false);
  948. const auto readMoreFuncPtr = CastInst::Create(Instruction::IntToPtr, readMoreFunc, PointerType::getUnqual(readMoreFuncType), "read_more_func", block);
  949. const auto readMore = CallInst::Create(readMoreFuncType, readMoreFuncPtr, { stateArg }, "read_more", block);
  950. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  951. const auto full = BasicBlock::Create(context, "full", ctx.Func);
  952. BranchInst::Create(next, full, readMore, block);
  953. {
  954. block = next;
  955. const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
  956. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  957. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  958. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  959. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  960. const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
  961. const auto last = new LoadInst(statusType, statusPtr, "last", block);
  962. result->addIncoming(last, block);
  963. const auto choise = SwitchInst::Create(last, pull, 2U, block);
  964. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), rest);
  965. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), over);
  966. block = rest;
  967. new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::One)), statusPtr, block);
  968. if constexpr (SkipYields) {
  969. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), statusPtr, block);
  970. BranchInst::Create(pull, block);
  971. } else {
  972. result->addIncoming(last, block);
  973. BranchInst::Create(over, block);
  974. }
  975. block = pull;
  976. const auto used = GetMemoryUsed(MemLimit, ctx, block);
  977. BranchInst::Create(loop, block);
  978. block = loop;
  979. const auto getres = GetNodeValues(Flow, ctx, block);
  980. if constexpr (SkipYields) {
  981. const auto save = BasicBlock::Create(context, "save", ctx.Func);
  982. const auto way = SwitchInst::Create(getres.first, good, 2U, block);
  983. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), save);
  984. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), done);
  985. block = save;
  986. if (MemLimit) {
  987. const auto storedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStored() }, "stored_ptr", block);
  988. const auto lastStored = new LoadInst(storedType, storedPtr, "last_stored", block);
  989. const auto currentUsage = GetMemoryUsed(MemLimit, ctx, block);
  990. const auto skipSavingUsed = BasicBlock::Create(context, "skip_saving_used", ctx.Func);
  991. const auto saveUsed = BasicBlock::Create(context, "save_used", ctx.Func);
  992. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, currentUsage, used, "check", block);
  993. BranchInst::Create(saveUsed, skipSavingUsed, check, block);
  994. block = saveUsed;
  995. const auto usedMemory = BinaryOperator::CreateSub(GetMemoryUsed(MemLimit, ctx, block), used, "used_memory", block);
  996. const auto inc = BinaryOperator::CreateAdd(lastStored, usedMemory, "inc", block);
  997. new StoreInst(inc, storedPtr, block);
  998. BranchInst::Create(skipSavingUsed, block);
  999. block = skipSavingUsed;
  1000. }
  1001. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), statusPtr, block);
  1002. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  1003. BranchInst::Create(over, block);
  1004. } else {
  1005. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  1006. BranchInst::Create(done, good, special, block);
  1007. }
  1008. block = good;
  1009. std::vector<Value*> items(Nodes.ItemNodes.size(), nullptr);
  1010. for (ui32 i = 0U; i < items.size(); ++i) {
  1011. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1012. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i] = getres.second[i](ctx, block));
  1013. else if (Nodes.PasstroughtItems[i])
  1014. items[i] = getres.second[i](ctx, block);
  1015. }
  1016. const auto tonguePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetTongue() }, "tongue_ptr", block);
  1017. const auto tongue = new LoadInst(ptrValueType, tonguePtr, "tongue", block);
  1018. std::vector<Value*> keyPointers(Nodes.KeyResultNodes.size(), nullptr), keys(Nodes.KeyResultNodes.size(), nullptr);
  1019. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1020. auto& key = keys[i];
  1021. const auto keyPtr = keyPointers[i] = GetElementPtrInst::CreateInBounds(valueType, tongue, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("key_") += ToString(i)).c_str(), block);
  1022. if (const auto map = Nodes.KeysOnItems[i]) {
  1023. auto& it = items[*map];
  1024. if (!it)
  1025. it = getres.second[*map](ctx, block);
  1026. key = it;
  1027. } else {
  1028. key = GetNodeValue(Nodes.KeyResultNodes[i], ctx, block);
  1029. }
  1030. if (Nodes.KeyNodes[i]->GetDependencesCount() > 0U)
  1031. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.KeyNodes[i])->CreateSetValue(ctx, block, key);
  1032. new StoreInst(key, keyPtr, block);
  1033. }
  1034. const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::TasteIt));
  1035. const auto atType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
  1036. const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
  1037. const auto newKey = CallInst::Create(atType, atPtr, {stateArg}, "new_key", block);
  1038. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  1039. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  1040. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  1041. const auto throatPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "throat_ptr", block);
  1042. const auto throat = new LoadInst(ptrValueType, throatPtr, "throat", block);
  1043. std::vector<Value*> pointers;
  1044. pointers.reserve(Nodes.StateNodes.size());
  1045. for (ui32 i = 0U; i < Nodes.StateNodes.size(); ++i) {
  1046. pointers.emplace_back(GetElementPtrInst::CreateInBounds(valueType, throat, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("state_") += ToString(i)).c_str(), block));
  1047. }
  1048. BranchInst::Create(init, next, newKey, block);
  1049. block = init;
  1050. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1051. ValueAddRef(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1052. }
  1053. for (ui32 i = 0U; i < Nodes.InitResultNodes.size(); ++i) {
  1054. if (const auto map = Nodes.InitOnItems[i]) {
  1055. auto& it = items[*map];
  1056. if (!it)
  1057. it = getres.second[*map](ctx, block);
  1058. new StoreInst(it, pointers[i], block);
  1059. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), it, ctx, block);
  1060. } else if (const auto map = Nodes.InitOnKeys[i]) {
  1061. const auto key = keys[*map];
  1062. new StoreInst(key, pointers[i], block);
  1063. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), key, ctx, block);
  1064. } else {
  1065. GetNodeValue(pointers[i], Nodes.InitResultNodes[i], ctx, block);
  1066. }
  1067. }
  1068. BranchInst::Create(test, block);
  1069. block = next;
  1070. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1071. if (Nodes.KeysOnItems[i] || Nodes.KeyResultNodes[i]->IsTemporaryValue())
  1072. ValueCleanup(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1073. }
  1074. std::vector<Value*> stored(Nodes.StateNodes.size(), nullptr);
  1075. for (ui32 i = 0U; i < stored.size(); ++i) {
  1076. const bool hasDependency = Nodes.StateNodes[i]->GetDependencesCount() > 0U;
  1077. if (const auto map = Nodes.StateOnUpdate[i]) {
  1078. if (hasDependency || i != *map) {
  1079. stored[i] = new LoadInst(valueType, pointers[i], (TString("state_") += ToString(i)).c_str(), block);
  1080. if (hasDependency)
  1081. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, stored[i]);
  1082. }
  1083. } else if (hasDependency) {
  1084. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, pointers[i]);
  1085. } else {
  1086. ValueUnRef(Nodes.StateNodes[i]->GetRepresentation(), pointers[i], ctx, block);
  1087. }
  1088. }
  1089. for (ui32 i = 0U; i < Nodes.UpdateResultNodes.size(); ++i) {
  1090. if (const auto map = Nodes.UpdateOnState[i]) {
  1091. if (const auto j = *map; i != j) {
  1092. auto& it = stored[j];
  1093. if (!it)
  1094. it = new LoadInst(valueType, pointers[j], (TString("state_") += ToString(j)).c_str(), block);
  1095. new StoreInst(it, pointers[i], block);
  1096. if (i != *Nodes.StateOnUpdate[j])
  1097. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1098. }
  1099. } else if (const auto map = Nodes.UpdateOnItems[i]) {
  1100. auto& it = items[*map];
  1101. if (!it)
  1102. it = getres.second[*map](ctx, block);
  1103. new StoreInst(it, pointers[i], block);
  1104. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1105. } else if (const auto map = Nodes.UpdateOnKeys[i]) {
  1106. const auto key = keys[*map];
  1107. new StoreInst(key, pointers[i], block);
  1108. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), key, ctx, block);
  1109. } else {
  1110. GetNodeValue(pointers[i], Nodes.UpdateResultNodes[i], ctx, block);
  1111. }
  1112. }
  1113. BranchInst::Create(test, block);
  1114. block = test;
  1115. auto totalUsed = used;
  1116. if (MemLimit) {
  1117. const auto storedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStored() }, "stored_ptr", block);
  1118. const auto lastStored = new LoadInst(storedType, storedPtr, "last_stored", block);
  1119. totalUsed = BinaryOperator::CreateSub(used, lastStored, "decr", block);
  1120. }
  1121. const auto check = CheckAdjustedMemLimit<TrackRss>(MemLimit, totalUsed, ctx, block);
  1122. const auto isOutOfMemoryPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsOutOfMemory() }, "is_out_of_memory_ptr", block);
  1123. const auto isOutOfMemory = new LoadInst(Type::getInt1Ty(context), isOutOfMemoryPtr, "is_out_of_memory", block);
  1124. const auto checkIsOutOfMemory = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, isOutOfMemory, ConstantInt::getTrue(context), "check_is_out_of_memory", block);
  1125. const auto any = BinaryOperator::CreateOr(check, checkIsOutOfMemory, "any", block);
  1126. BranchInst::Create(done, loop, any, block);
  1127. block = done;
  1128. new StoreInst(getres.first, statusPtr, block);
  1129. const auto stat = ctx.GetStat();
  1130. const auto statFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::PushStat));
  1131. const auto statType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType(), stat->getType()}, false);
  1132. const auto statPtr = CastInst::Create(Instruction::IntToPtr, statFunc, PointerType::getUnqual(statType), "stat", block);
  1133. CallInst::Create(statType, statPtr, {stateArg, stat}, "", block);
  1134. BranchInst::Create(full, block);
  1135. }
  1136. {
  1137. block = full;
  1138. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  1139. const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract));
  1140. const auto extractType = FunctionType::get(ptrValueType, {stateArg->getType()}, false);
  1141. const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
  1142. const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
  1143. const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block);
  1144. BranchInst::Create(good, more, has, block);
  1145. block = good;
  1146. for (ui32 i = 0U; i < Nodes.FinishNodes.size(); ++i) {
  1147. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, out, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("out_key_") += ToString(i)).c_str(), block);
  1148. if (Nodes.FinishNodes[i]->GetDependencesCount() > 0 || Nodes.ItemsOnResult[i])
  1149. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.FinishNodes[i])->CreateSetValue(ctx, block, ptr);
  1150. else
  1151. ValueUnRef(Nodes.FinishNodes[i]->GetRepresentation(), ptr, ctx, block);
  1152. }
  1153. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  1154. BranchInst::Create(over, block);
  1155. }
  1156. block = over;
  1157. ICodegeneratorInlineWideNode::TGettersList getters;
  1158. getters.reserve(Nodes.FinishResultNodes.size());
  1159. std::transform(Nodes.FinishResultNodes.cbegin(), Nodes.FinishResultNodes.cend(), std::back_inserter(getters), [&](IComputationNode* node) {
  1160. return [node](const TCodegenContext& ctx, BasicBlock*& block){ return GetNodeValue(node, ctx, block); };
  1161. });
  1162. return {result, std::move(getters)};
  1163. }
  1164. #endif
  1165. private:
  1166. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  1167. #ifdef MKQL_DISABLE_CODEGEN
  1168. state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(), TMyValueHasher(KeyTypes), TMyValueEqual(KeyTypes));
  1169. #else
  1170. state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(),
  1171. ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
  1172. ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes))
  1173. );
  1174. #endif
  1175. if (ctx.CountersProvider) {
  1176. const auto ptr = static_cast<TState*>(state.AsBoxed().Get());
  1177. // id will be assigned externally in future versions
  1178. TString id = TString(Operator_Aggregation) + "0";
  1179. ptr->CounterOutputRows_ = ctx.CountersProvider->GetCounter(id, Counter_OutputRows, false);
  1180. }
  1181. }
  1182. void RegisterDependencies() const final {
  1183. if (const auto flow = this->FlowDependsOn(Flow)) {
  1184. Nodes.RegisterDependencies(
  1185. [this, flow](IComputationNode* node){ this->DependsOn(flow, node); },
  1186. [this, flow](IComputationExternalNode* node){ this->Own(flow, node); }
  1187. );
  1188. }
  1189. }
  1190. IComputationWideFlowNode *const Flow;
  1191. const TCombinerNodes Nodes;
  1192. const TKeyTypes KeyTypes;
  1193. const ui64 MemLimit;
  1194. const ui32 WideFieldsIndex;
  1195. #ifndef MKQL_DISABLE_CODEGEN
  1196. TEqualsPtr Equals = nullptr;
  1197. THashPtr Hash = nullptr;
  1198. Function* EqualsFunc = nullptr;
  1199. Function* HashFunc = nullptr;
  1200. template <bool EqualsOrHash>
  1201. TString MakeName() const {
  1202. TStringStream out;
  1203. out << this->DebugString() << "::" << (EqualsOrHash ? "Equals" : "Hash") << "_(" << static_cast<const void*>(this) << ").";
  1204. return out.Str();
  1205. }
  1206. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1207. if (EqualsFunc) {
  1208. Equals = reinterpret_cast<TEqualsPtr>(codegen.GetPointerToFunction(EqualsFunc));
  1209. }
  1210. if (HashFunc) {
  1211. Hash = reinterpret_cast<THashPtr>(codegen.GetPointerToFunction(HashFunc));
  1212. }
  1213. }
  1214. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1215. codegen.ExportSymbol(HashFunc = GenerateHashFunction(codegen, MakeName<false>(), KeyTypes));
  1216. codegen.ExportSymbol(EqualsFunc = GenerateEqualsFunction(codegen, MakeName<true>(), KeyTypes));
  1217. }
  1218. #endif
  1219. };
  1220. class TWideLastCombinerWrapper: public TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWrapper>
  1221. #ifndef MKQL_DISABLE_CODEGEN
  1222. , public ICodegeneratorRootNode
  1223. #endif
  1224. {
  1225. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWrapper>;
  1226. public:
  1227. TWideLastCombinerWrapper(
  1228. TComputationMutables& mutables,
  1229. IComputationWideFlowNode* flow,
  1230. TCombinerNodes&& nodes,
  1231. const TMultiType* usedInputItemType,
  1232. TKeyTypes&& keyTypes,
  1233. const TMultiType* keyAndStateType,
  1234. bool allowSpilling)
  1235. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  1236. , Flow(flow)
  1237. , Nodes(std::move(nodes))
  1238. , KeyTypes(std::move(keyTypes))
  1239. , UsedInputItemType(usedInputItemType)
  1240. , KeyAndStateType(keyAndStateType)
  1241. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Nodes.ItemNodes.size()))
  1242. , AllowSpilling(allowSpilling)
  1243. {}
  1244. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  1245. if (state.IsInvalid()) {
  1246. MakeState(ctx, state);
  1247. }
  1248. if (const auto ptr = static_cast<TSpillingSupportState*>(state.AsBoxed().Get())) {
  1249. auto **fields = ctx.WideFields.data() + WideFieldsIndex;
  1250. while (true) {
  1251. switch(ptr->Update()) {
  1252. case TSpillingSupportState::EUpdateResult::ReadInput: {
  1253. for (auto i = 0U; i < Nodes.ItemNodes.size(); ++i)
  1254. fields[i] = Nodes.GetUsedInputItemNodePtrOrNull(ctx, i);
  1255. switch (ptr->InputStatus = Flow->FetchValues(ctx, fields)) {
  1256. case EFetchResult::One:
  1257. break;
  1258. case EFetchResult::Finish:
  1259. continue;
  1260. case EFetchResult::Yield:
  1261. return EFetchResult::Yield;
  1262. }
  1263. Nodes.ExtractKey(ctx, fields, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  1264. break;
  1265. }
  1266. case TSpillingSupportState::EUpdateResult::Yield:
  1267. return EFetchResult::Yield;
  1268. case TSpillingSupportState::EUpdateResult::ExtractRawData:
  1269. Nodes.ExtractRawData(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Throat), static_cast<NUdf::TUnboxedValue*>(ptr->Tongue));
  1270. break;
  1271. case TSpillingSupportState::EUpdateResult::Extract:
  1272. if (const auto values = static_cast<NUdf::TUnboxedValue*>(ptr->Extract())) {
  1273. Nodes.FinishItem(ctx, values, output);
  1274. return EFetchResult::One;
  1275. }
  1276. continue;
  1277. case TSpillingSupportState::EUpdateResult::Finish:
  1278. return EFetchResult::Finish;
  1279. }
  1280. switch(ptr->TasteIt()) {
  1281. case TSpillingSupportState::ETasteResult::Init:
  1282. Nodes.ProcessItem(ctx, nullptr, static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1283. break;
  1284. case TSpillingSupportState::ETasteResult::Update:
  1285. Nodes.ProcessItem(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1286. break;
  1287. case TSpillingSupportState::ETasteResult::ConsumeRawData:
  1288. Nodes.ConsumeRawData(ctx, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), fields, static_cast<NUdf::TUnboxedValue*>(ptr->Throat));
  1289. break;
  1290. }
  1291. }
  1292. }
  1293. Y_UNREACHABLE();
  1294. }
  1295. #ifndef MKQL_DISABLE_CODEGEN
  1296. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  1297. auto& context = ctx.Codegen.GetContext();
  1298. const auto valueType = Type::getInt128Ty(context);
  1299. const auto ptrValueType = PointerType::getUnqual(valueType);
  1300. const auto statusType = Type::getInt32Ty(context);
  1301. const auto wayType = Type::getInt8Ty(context);
  1302. TLLVMFieldsStructureState stateFields(context);
  1303. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  1304. const auto statePtrType = PointerType::getUnqual(stateType);
  1305. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  1306. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  1307. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  1308. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  1309. block = make;
  1310. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  1311. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  1312. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideLastCombinerWrapper::MakeState));
  1313. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  1314. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  1315. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  1316. BranchInst::Create(main, block);
  1317. block = main;
  1318. const auto state = new LoadInst(valueType, statePtr, "state", block);
  1319. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  1320. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  1321. BranchInst::Create(more, block);
  1322. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  1323. const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
  1324. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  1325. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  1326. const auto load = BasicBlock::Create(context, "load", ctx.Func);
  1327. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  1328. const auto data = BasicBlock::Create(context, "data", ctx.Func);
  1329. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  1330. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  1331. const auto stub = BasicBlock::Create(context, "stub", ctx.Func);
  1332. new UnreachableInst(context, stub);
  1333. const auto result = PHINode::Create(statusType, 4U, "result", over);
  1334. std::vector<PHINode*> phis(Nodes.ItemNodes.size(), nullptr);
  1335. auto j = 0U;
  1336. std::generate(phis.begin(), phis.end(), [&]() {
  1337. return Nodes.IsInputItemNodeUsed(j++) ?
  1338. PHINode::Create(valueType, 2U, (TString("item_") += ToString(j)).c_str(), test) : nullptr;
  1339. });
  1340. block = more;
  1341. const auto updateFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::Update));
  1342. const auto updateType = FunctionType::get(wayType, {stateArg->getType()}, false);
  1343. const auto updateFuncPtr = CastInst::Create(Instruction::IntToPtr, updateFunc, PointerType::getUnqual(updateType), "update_func", block);
  1344. const auto update = CallInst::Create(updateType, updateFuncPtr, { stateArg }, "update", block);
  1345. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  1346. const auto updateWay = SwitchInst::Create(update, stub, 5U, block);
  1347. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Yield)), over);
  1348. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Extract)), fill);
  1349. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::Finish)), done);
  1350. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::ReadInput)), pull);
  1351. updateWay->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::EUpdateResult::ExtractRawData)), load);
  1352. block = load;
  1353. const auto extractorPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "extractor_ptr", block);
  1354. const auto extractor = new LoadInst(ptrValueType, extractorPtr, "extractor", block);
  1355. std::vector<Value*> items(phis.size(), nullptr);
  1356. for (ui32 i = 0U; i < items.size(); ++i) {
  1357. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, extractor, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("load_ptr_") += ToString(i)).c_str(), block);
  1358. if (phis[i])
  1359. items[i] = new LoadInst(valueType, ptr, (TString("load_") += ToString(i)).c_str(), block);
  1360. if (i < Nodes.ItemNodes.size() && Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1361. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i]);
  1362. }
  1363. for (ui32 i = 0U; i < phis.size(); ++i) {
  1364. if (const auto phi = phis[i]) {
  1365. phi->addIncoming(items[i], block);
  1366. }
  1367. }
  1368. BranchInst::Create(test, block);
  1369. block = pull;
  1370. const auto getres = GetNodeValues(Flow, ctx, block);
  1371. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  1372. const auto choise = SwitchInst::Create(getres.first, good, 2U, block);
  1373. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), over);
  1374. choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), rest);
  1375. block = rest;
  1376. const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
  1377. new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), statusPtr, block);
  1378. BranchInst::Create(more, block);
  1379. block = good;
  1380. for (ui32 i = 0U; i < items.size(); ++i) {
  1381. if (phis[i])
  1382. items[i] = getres.second[i](ctx, block);
  1383. if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
  1384. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.ItemNodes[i])->CreateSetValue(ctx, block, items[i]);
  1385. }
  1386. for (ui32 i = 0U; i < phis.size(); ++i) {
  1387. if (const auto phi = phis[i]) {
  1388. phi->addIncoming(items[i], block);
  1389. }
  1390. }
  1391. BranchInst::Create(test, block);
  1392. block = test;
  1393. const auto tonguePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetTongue() }, "tongue_ptr", block);
  1394. const auto tongue = new LoadInst(ptrValueType, tonguePtr, "tongue", block);
  1395. std::vector<Value*> keyPointers(Nodes.KeyResultNodes.size(), nullptr), keys(Nodes.KeyResultNodes.size(), nullptr);
  1396. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1397. auto& key = keys[i];
  1398. const auto keyPtr = keyPointers[i] = GetElementPtrInst::CreateInBounds(valueType, tongue, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("key_") += ToString(i)).c_str(), block);
  1399. if (const auto map = Nodes.KeysOnItems[i]) {
  1400. key = phis[*map];
  1401. } else {
  1402. key = GetNodeValue(Nodes.KeyResultNodes[i], ctx, block);
  1403. }
  1404. if (Nodes.KeyNodes[i]->GetDependencesCount() > 0U)
  1405. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.KeyNodes[i])->CreateSetValue(ctx, block, key);
  1406. new StoreInst(key, keyPtr, block);
  1407. }
  1408. const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::TasteIt));
  1409. const auto atType = FunctionType::get(wayType, {stateArg->getType()}, false);
  1410. const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
  1411. const auto taste= CallInst::Create(atType, atPtr, {stateArg}, "taste", block);
  1412. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  1413. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  1414. const auto save = BasicBlock::Create(context, "save", ctx.Func);
  1415. const auto throatPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetThroat() }, "throat_ptr", block);
  1416. const auto throat = new LoadInst(ptrValueType, throatPtr, "throat", block);
  1417. std::vector<Value*> pointers;
  1418. const auto width = std::max(Nodes.StateNodes.size(), phis.size());
  1419. pointers.reserve(width);
  1420. for (ui32 i = 0U; i < width; ++i) {
  1421. pointers.emplace_back(GetElementPtrInst::CreateInBounds(valueType, throat, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("state_") += ToString(i)).c_str(), block));
  1422. }
  1423. const auto way = SwitchInst::Create(taste, stub, 3U, block);
  1424. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Init)), init);
  1425. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Update)), next);
  1426. way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::ConsumeRawData)), save);
  1427. block = init;
  1428. for (ui32 i = 0U; i < Nodes.KeyResultNodes.size(); ++i) {
  1429. ValueAddRef(Nodes.KeyResultNodes[i]->GetRepresentation(), keyPointers[i], ctx, block);
  1430. }
  1431. for (ui32 i = 0U; i < Nodes.InitResultNodes.size(); ++i) {
  1432. if (const auto map = Nodes.InitOnItems[i]) {
  1433. const auto item = phis[*map];
  1434. new StoreInst(item, pointers[i], block);
  1435. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), item, ctx, block);
  1436. } else if (const auto map = Nodes.InitOnKeys[i]) {
  1437. const auto key = keys[*map];
  1438. new StoreInst(key, pointers[i], block);
  1439. ValueAddRef(Nodes.InitResultNodes[i]->GetRepresentation(), key, ctx, block);
  1440. } else {
  1441. GetNodeValue(pointers[i], Nodes.InitResultNodes[i], ctx, block);
  1442. }
  1443. }
  1444. BranchInst::Create(more, block);
  1445. block = next;
  1446. std::vector<Value*> stored(Nodes.StateNodes.size(), nullptr);
  1447. for (ui32 i = 0U; i < stored.size(); ++i) {
  1448. const bool hasDependency = Nodes.StateNodes[i]->GetDependencesCount() > 0U;
  1449. if (const auto map = Nodes.StateOnUpdate[i]) {
  1450. if (hasDependency || i != *map) {
  1451. stored[i] = new LoadInst(valueType, pointers[i], (TString("state_") += ToString(i)).c_str(), block);
  1452. if (hasDependency)
  1453. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, stored[i]);
  1454. }
  1455. } else if (hasDependency) {
  1456. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.StateNodes[i])->CreateSetValue(ctx, block, pointers[i]);
  1457. } else {
  1458. ValueUnRef(Nodes.StateNodes[i]->GetRepresentation(), pointers[i], ctx, block);
  1459. }
  1460. }
  1461. for (ui32 i = 0U; i < Nodes.UpdateResultNodes.size(); ++i) {
  1462. if (const auto map = Nodes.UpdateOnState[i]) {
  1463. if (const auto j = *map; i != j) {
  1464. const auto it = stored[j];
  1465. new StoreInst(it, pointers[i], block);
  1466. if (i != *Nodes.StateOnUpdate[j])
  1467. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), it, ctx, block);
  1468. }
  1469. } else if (const auto map = Nodes.UpdateOnItems[i]) {
  1470. const auto item = phis[*map];
  1471. new StoreInst(item, pointers[i], block);
  1472. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), item, ctx, block);
  1473. } else if (const auto map = Nodes.UpdateOnKeys[i]) {
  1474. const auto key = keys[*map];
  1475. new StoreInst(key, pointers[i], block);
  1476. ValueAddRef(Nodes.UpdateResultNodes[i]->GetRepresentation(), key, ctx, block);
  1477. } else {
  1478. GetNodeValue(pointers[i], Nodes.UpdateResultNodes[i], ctx, block);
  1479. }
  1480. }
  1481. BranchInst::Create(more, block);
  1482. block = save;
  1483. for (ui32 i = 0U; i < phis.size(); ++i) {
  1484. if (const auto item = phis[i]) {
  1485. new StoreInst(item, pointers[i], block);
  1486. ValueAddRef(Nodes.ItemNodes[i]->GetRepresentation(), item, ctx, block);
  1487. }
  1488. }
  1489. BranchInst::Create(more, block);
  1490. block = fill;
  1491. const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::Extract));
  1492. const auto extractType = FunctionType::get(ptrValueType, {stateArg->getType()}, false);
  1493. const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
  1494. const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
  1495. const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block);
  1496. BranchInst::Create(data, more, has, block);
  1497. block = data;
  1498. for (ui32 i = 0U; i < Nodes.FinishNodes.size(); ++i) {
  1499. const auto ptr = GetElementPtrInst::CreateInBounds(valueType, out, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("out_key_") += ToString(i)).c_str(), block);
  1500. if (Nodes.FinishNodes[i]->GetDependencesCount() > 0 || Nodes.ItemsOnResult[i])
  1501. EnsureDynamicCast<ICodegeneratorExternalNode*>(Nodes.FinishNodes[i])->CreateSetValue(ctx, block, ptr);
  1502. else
  1503. ValueUnRef(Nodes.FinishNodes[i]->GetRepresentation(), ptr, ctx, block);
  1504. }
  1505. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  1506. BranchInst::Create(over, block);
  1507. block = done;
  1508. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  1509. BranchInst::Create(over, block);
  1510. block = over;
  1511. ICodegeneratorInlineWideNode::TGettersList getters;
  1512. getters.reserve(Nodes.FinishResultNodes.size());
  1513. std::transform(Nodes.FinishResultNodes.cbegin(), Nodes.FinishResultNodes.cend(), std::back_inserter(getters), [&](IComputationNode* node) {
  1514. return [node](const TCodegenContext& ctx, BasicBlock*& block){ return GetNodeValue(node, ctx, block); };
  1515. });
  1516. return {result, std::move(getters)};
  1517. }
  1518. #endif
  1519. private:
  1520. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  1521. state = ctx.HolderFactory.Create<TSpillingSupportState>(UsedInputItemType, KeyAndStateType,
  1522. Nodes.KeyNodes.size(),
  1523. Nodes.ItemNodes.size(),
  1524. #ifdef MKQL_DISABLE_CODEGEN
  1525. TMyValueHasher(KeyTypes),
  1526. TMyValueEqual(KeyTypes),
  1527. #else
  1528. ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
  1529. ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes)),
  1530. #endif
  1531. AllowSpilling,
  1532. ctx
  1533. );
  1534. }
  1535. void RegisterDependencies() const final {
  1536. if (const auto flow = this->FlowDependsOn(Flow)) {
  1537. Nodes.RegisterDependencies(
  1538. [this, flow](IComputationNode* node){ this->DependsOn(flow, node); },
  1539. [this, flow](IComputationExternalNode* node){ this->Own(flow, node); }
  1540. );
  1541. }
  1542. }
  1543. IComputationWideFlowNode *const Flow;
  1544. const TCombinerNodes Nodes;
  1545. const TKeyTypes KeyTypes;
  1546. const TMultiType* const UsedInputItemType;
  1547. const TMultiType* const KeyAndStateType;
  1548. const ui32 WideFieldsIndex;
  1549. const bool AllowSpilling;
  1550. #ifndef MKQL_DISABLE_CODEGEN
  1551. TEqualsPtr Equals = nullptr;
  1552. THashPtr Hash = nullptr;
  1553. Function* EqualsFunc = nullptr;
  1554. Function* HashFunc = nullptr;
  1555. template <bool EqualsOrHash>
  1556. TString MakeName() const {
  1557. TStringStream out;
  1558. out << this->DebugString() << "::" << (EqualsOrHash ? "Equals" : "Hash") << "_(" << static_cast<const void*>(this) << ").";
  1559. return out.Str();
  1560. }
  1561. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1562. if (EqualsFunc) {
  1563. Equals = reinterpret_cast<TEqualsPtr>(codegen.GetPointerToFunction(EqualsFunc));
  1564. }
  1565. if (HashFunc) {
  1566. Hash = reinterpret_cast<THashPtr>(codegen.GetPointerToFunction(HashFunc));
  1567. }
  1568. }
  1569. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1570. codegen.ExportSymbol(HashFunc = GenerateHashFunction(codegen, MakeName<false>(), KeyTypes));
  1571. codegen.ExportSymbol(EqualsFunc = GenerateEqualsFunction(codegen, MakeName<true>(), KeyTypes));
  1572. }
  1573. #endif
  1574. };
  1575. }
  1576. template<bool Last>
  1577. IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeFactoryContext& ctx, bool allowSpilling) {
  1578. MKQL_ENSURE(callable.GetInputsCount() >= (Last ? 3U : 4U), "Expected more arguments.");
  1579. const auto inputType = AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType());
  1580. const auto inputWidth = GetWideComponentsCount(inputType);
  1581. const auto outputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType()));
  1582. const auto flow = LocateNode(ctx.NodeLocator, callable, 0U);
  1583. auto index = Last ? 0U : 1U;
  1584. const auto keysSize = AS_VALUE(TDataLiteral, callable.GetInput(++index))->AsValue().Get<ui32>();
  1585. const auto stateSize = AS_VALUE(TDataLiteral, callable.GetInput(++index))->AsValue().Get<ui32>();
  1586. ++index += inputWidth;
  1587. std::vector<TType*> keyAndStateItemTypes;
  1588. keyAndStateItemTypes.reserve(keysSize + stateSize);
  1589. TKeyTypes keyTypes;
  1590. keyTypes.reserve(keysSize);
  1591. for (ui32 i = index; i < index + keysSize; ++i) {
  1592. TType *type = callable.GetInput(i).GetStaticType();
  1593. keyAndStateItemTypes.push_back(type);
  1594. bool optional;
  1595. keyTypes.emplace_back(*UnpackOptionalData(callable.GetInput(i).GetStaticType(), optional)->GetDataSlot(), optional);
  1596. }
  1597. TCombinerNodes nodes;
  1598. nodes.KeyResultNodes.reserve(keysSize);
  1599. std::generate_n(std::back_inserter(nodes.KeyResultNodes), keysSize, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1600. index += keysSize;
  1601. nodes.InitResultNodes.reserve(stateSize);
  1602. for (size_t i = 0; i != stateSize; ++i) {
  1603. TType *type = callable.GetInput(index).GetStaticType();
  1604. keyAndStateItemTypes.push_back(type);
  1605. nodes.InitResultNodes.push_back(LocateNode(ctx.NodeLocator, callable, index++));
  1606. }
  1607. index += stateSize;
  1608. nodes.UpdateResultNodes.reserve(stateSize);
  1609. std::generate_n(std::back_inserter(nodes.UpdateResultNodes), stateSize, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1610. index += keysSize + stateSize;
  1611. nodes.FinishResultNodes.reserve(outputWidth);
  1612. std::generate_n(std::back_inserter(nodes.FinishResultNodes), outputWidth, [&](){ return LocateNode(ctx.NodeLocator, callable, index++); } );
  1613. index = Last ? 3U : 4U;
  1614. nodes.ItemNodes.reserve(inputWidth);
  1615. std::generate_n(std::back_inserter(nodes.ItemNodes), inputWidth, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1616. index += keysSize;
  1617. nodes.KeyNodes.reserve(keysSize);
  1618. std::generate_n(std::back_inserter(nodes.KeyNodes), keysSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1619. index += stateSize;
  1620. nodes.StateNodes.reserve(stateSize);
  1621. std::generate_n(std::back_inserter(nodes.StateNodes), stateSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1622. index += stateSize;
  1623. nodes.FinishNodes.reserve(keysSize + stateSize);
  1624. std::generate_n(std::back_inserter(nodes.FinishNodes), keysSize + stateSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, index++); } );
  1625. nodes.BuildMaps();
  1626. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
  1627. if constexpr (Last) {
  1628. const auto inputItemTypes = GetWideComponents(inputType);
  1629. return new TWideLastCombinerWrapper(ctx.Mutables, wide, std::move(nodes),
  1630. TMultiType::Create(inputItemTypes.size(), inputItemTypes.data(), ctx.Env),
  1631. std::move(keyTypes),
  1632. TMultiType::Create(keyAndStateItemTypes.size(),keyAndStateItemTypes.data(), ctx.Env),
  1633. allowSpilling
  1634. );
  1635. } else {
  1636. if constexpr (RuntimeVersion < 46U) {
  1637. const auto memLimit = AS_VALUE(TDataLiteral, callable.GetInput(1U))->AsValue().Get<ui64>();
  1638. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1639. return new TWideCombinerWrapper<true, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), memLimit);
  1640. else
  1641. return new TWideCombinerWrapper<false, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), memLimit);
  1642. } else {
  1643. if (const auto memLimit = AS_VALUE(TDataLiteral, callable.GetInput(1U))->AsValue().Get<i64>(); memLimit >= 0)
  1644. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1645. return new TWideCombinerWrapper<true, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(memLimit));
  1646. else
  1647. return new TWideCombinerWrapper<false, false>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(memLimit));
  1648. else
  1649. if (EGraphPerProcess::Single == ctx.GraphPerProcess)
  1650. return new TWideCombinerWrapper<true, true>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(-memLimit));
  1651. else
  1652. return new TWideCombinerWrapper<false, true>(ctx.Mutables, wide, std::move(nodes), std::move(keyTypes), ui64(-memLimit));
  1653. }
  1654. }
  1655. }
  1656. THROW yexception() << "Expected wide flow.";
  1657. }
  1658. IComputationNode* WrapWideCombiner(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1659. return WrapWideCombinerT<false>(callable, ctx, false);
  1660. }
  1661. IComputationNode* WrapWideLastCombiner(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1662. YQL_LOG(INFO) << "Found non-serializable type, spilling is disabled";
  1663. return WrapWideCombinerT<true>(callable, ctx, false);
  1664. }
  1665. IComputationNode* WrapWideLastCombinerWithSpilling(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1666. return WrapWideCombinerT<true>(callable, ctx, true);
  1667. }
  1668. }
  1669. }