mkql_block_agg.cpp 102 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500
  1. #include "mkql_block_agg.h"
  2. #include "mkql_block_agg_factory.h"
  3. #include "mkql_rh_hash.h"
  4. #include <yql/essentials/minikql/computation/mkql_block_reader.h>
  5. #include <yql/essentials/minikql/computation/mkql_block_builder.h>
  6. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  7. #include <yql/essentials/minikql/computation/mkql_block_impl_codegen.h> // Y_IGNORE
  8. #include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
  9. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  10. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  11. #include <yql/essentials/minikql/mkql_node_cast.h>
  12. #include <yql/essentials/minikql/mkql_node_builder.h>
  13. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  14. #include <yql/essentials/minikql/arrow/arrow_util.h>
  15. #include <yql/essentials/minikql/arrow/mkql_bit_utils.h>
  16. #include <yql/essentials/utils/prefetch.h>
  17. #include <arrow/scalar.h>
  18. #include <arrow/array/array_primitive.h>
  19. #include <arrow/array/builder_primitive.h>
  20. #include <arrow/chunked_array.h>
  21. //#define USE_STD_UNORDERED
  22. namespace NKikimr {
  23. namespace NMiniKQL {
  24. namespace {
  25. constexpr bool InlineAggState = false;
  26. #ifdef USE_STD_UNORDERED
  27. template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = void>
  28. class TDynamicHashMapImpl {
  29. public:
  30. using TMapType = std::unordered_map<TKey, std::vector<char>, THash, TEqual>;
  31. using const_iterator = typename TMapType::const_iterator;
  32. using iterator = typename TMapType::iterator;
  33. TDynamicHashMapImpl(size_t stateSize, const THash& hasher, const TEqual& equal)
  34. : StateSize_(stateSize)
  35. , Map_(0, hasher, equal)
  36. {}
  37. ui64 GetSize() const {
  38. return Map_.size();
  39. }
  40. const_iterator Begin() const {
  41. return Map_.begin();
  42. }
  43. const_iterator End() const {
  44. return Map_.end();
  45. }
  46. bool IsValid(const_iterator iter) const {
  47. return true;
  48. }
  49. void Advance(const_iterator& iter) const {
  50. ++iter;
  51. }
  52. iterator Insert(const TKey& key, bool& isNew) {
  53. auto res = Map_.emplace(key, std::vector<char>());
  54. isNew = res.second;
  55. if (isNew) {
  56. res.first->second.resize(StateSize_);
  57. }
  58. return res.first;
  59. }
  60. template <typename TSink>
  61. void BatchInsert(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) {
  62. for (size_t index = 0; index < batchRequest.size(); ++index) {
  63. bool isNew;
  64. auto iter = Insert(batchRequest[index].GetKey(), isNew);
  65. sink(index, iter, isNew);
  66. }
  67. }
  68. const TKey& GetKey(const_iterator it) const {
  69. return it->first;
  70. }
  71. char* GetMutablePayload(iterator it) const {
  72. return it->second.data();
  73. }
  74. const char* GetPayload(const_iterator it) const {
  75. return it->second.data();
  76. }
  77. void CheckGrow() {
  78. }
  79. private:
  80. const size_t StateSize_;
  81. TMapType Map_;
  82. };
  83. template <typename TKey, typename TPayload, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = void>
  84. class TFixedHashMapImpl {
  85. public:
  86. using TMapType = std::unordered_map<TKey, TPayload, THash, TEqual>;
  87. using const_iterator = typename TMapType::const_iterator;
  88. using iterator = typename TMapType::iterator;
  89. TFixedHashMapImpl(const THash& hasher, const TEqual& equal)
  90. : Map_(0, hasher, equal)
  91. {}
  92. ui64 GetSize() const {
  93. return Map_.size();
  94. }
  95. const_iterator Begin() const {
  96. return Map_.begin();
  97. }
  98. const_iterator End() const {
  99. return Map_.end();
  100. }
  101. bool IsValid(const_iterator iter) const {
  102. return true;
  103. }
  104. void Advance(const_iterator& iter) const {
  105. ++iter;
  106. }
  107. iterator Insert(const TKey& key, bool& isNew) {
  108. auto res = Map_.emplace(key, TPayload());
  109. isNew = res.second;
  110. return res.first;
  111. }
  112. template <typename TSink>
  113. void BatchInsert(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) {
  114. for (size_t index = 0; index < batchRequest.size(); ++index) {
  115. bool isNew;
  116. auto iter = Insert(batchRequest[index].GetKey(), isNew);
  117. sink(index, iter, isNew);
  118. }
  119. }
  120. const TKey& GetKey(const_iterator it) const {
  121. return it->first;
  122. }
  123. char* GetMutablePayload(iterator it) const {
  124. return (char*)&it->second;
  125. }
  126. const char* GetPayload(const_iterator it) const {
  127. return (const char*)&it->second;
  128. }
  129. void CheckGrow() {
  130. }
  131. private:
  132. TMapType Map_;
  133. };
  134. template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = void>
  135. class THashSetImpl {
  136. public:
  137. using TSetType = std::unordered_set<TKey, THash, TEqual>;
  138. using const_iterator = typename TSetType::const_iterator;
  139. using iterator = typename TSetType::iterator;
  140. THashSetImpl(const THash& hasher, const TEqual& equal)
  141. : Set_(0, hasher, equal)
  142. {}
  143. ui64 GetSize() const {
  144. return Set_.size();
  145. }
  146. const_iterator Begin() const {
  147. return Set_.begin();
  148. }
  149. const_iterator End() const {
  150. return Set_.end();
  151. }
  152. bool IsValid(const_iterator iter) const {
  153. return true;
  154. }
  155. void Advance(const_iterator& iter) const {
  156. ++iter;
  157. }
  158. iterator Insert(const TKey& key, bool& isNew) {
  159. auto res = Set_.emplace(key);
  160. isNew = res.second;
  161. return res.first;
  162. }
  163. template <typename TSink>
  164. void BatchInsert(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) {
  165. for (size_t index = 0; index < batchRequest.size(); ++index) {
  166. bool isNew;
  167. auto iter = Insert(batchRequest[index].GetKey(), isNew);
  168. sink(index, iter, isNew);
  169. }
  170. }
  171. void CheckGrow() {
  172. }
  173. const TKey& GetKey(const_iterator it) const {
  174. return *it;
  175. }
  176. char* GetMutablePayload(iterator it) const {
  177. Y_UNUSED(it);
  178. return nullptr;
  179. }
  180. const char* GetPayload(const_iterator it) const {
  181. Y_UNUSED(it);
  182. return nullptr;
  183. }
  184. private:
  185. TSetType Set_;
  186. };
  187. #else
  188. #define TDynamicHashMapImpl TRobinHoodHashMap
  189. #define TFixedHashMapImpl TRobinHoodHashFixedMap
  190. #define THashSetImpl TRobinHoodHashSet
  191. #endif
  192. using TState8 = ui64;
  193. static_assert(sizeof(TState8) == 8);
  194. using TState16 = std::pair<ui64, ui64>;
  195. static_assert(sizeof(TState16) == 16);
  196. using TStateArena = void*;
  197. static_assert(sizeof(TStateArena) == sizeof(void*));
  198. struct TExternalFixedSizeKey {
  199. mutable const char* Data;
  200. };
  201. struct TKey16 {
  202. ui64 Lo;
  203. ui64 Hi;
  204. };
  205. class TSSOKey {
  206. public:
  207. static constexpr size_t SSO_Length = 15;
  208. static_assert(SSO_Length < 128); // should fit into 7 bits
  209. private:
  210. struct TExternal {
  211. ui64 Length_;
  212. const char* Ptr_;
  213. };
  214. struct TInplace {
  215. ui8 SmallLength_;
  216. char Buffer_[SSO_Length];
  217. };
  218. public:
  219. TSSOKey(const TSSOKey& other) {
  220. memcpy(U.A, other.U.A, SSO_Length + 1);
  221. }
  222. TSSOKey& operator=(const TSSOKey& other) {
  223. memcpy(U.A, other.U.A, SSO_Length + 1);
  224. return *this;
  225. }
  226. static bool CanBeInplace(TStringBuf data) {
  227. return data.Size() + 1 <= sizeof(TSSOKey);
  228. }
  229. static TSSOKey Inplace(TStringBuf data) {
  230. Y_ASSERT(CanBeInplace(data));
  231. TSSOKey ret(1 | (data.Size() << 1), 0);
  232. memcpy(ret.U.I.Buffer_, data.Data(), data.Size());
  233. return ret;
  234. }
  235. static TSSOKey External(TStringBuf data) {
  236. return TSSOKey(data.Size() << 1, data.Data());
  237. }
  238. bool IsInplace() const {
  239. return U.I.SmallLength_ & 1;
  240. }
  241. TStringBuf AsView() const {
  242. if (IsInplace()) {
  243. // inplace
  244. return TStringBuf(U.I.Buffer_, U.I.SmallLength_ >> 1);
  245. } else {
  246. // external
  247. return TStringBuf(U.E.Ptr_, U.E.Length_ >> 1);
  248. }
  249. }
  250. void UpdateExternalPointer(const char *ptr) const {
  251. Y_ASSERT(!IsInplace());
  252. const_cast<TExternal&>(U.E).Ptr_ = ptr;
  253. }
  254. private:
  255. TSSOKey(ui64 length, const char* ptr) {
  256. U.E.Length_ = length;
  257. U.E.Ptr_ = ptr;
  258. }
  259. private:
  260. union {
  261. TExternal E;
  262. TInplace I;
  263. char A[SSO_Length + 1];
  264. } U;
  265. };
  266. static_assert(sizeof(TSSOKey) == TSSOKey::SSO_Length + 1);
  267. }
  268. }
  269. }
  270. namespace std {
  271. template <>
  272. struct hash<NKikimr::NMiniKQL::TKey16> {
  273. using argument_type = NKikimr::NMiniKQL::TKey16;
  274. using result_type = size_t;
  275. inline result_type operator()(argument_type const& s) const noexcept {
  276. auto hasher = std::hash<ui64>();
  277. return hasher(s.Hi) * 31 + hasher(s.Lo);
  278. }
  279. };
  280. template <>
  281. struct equal_to<NKikimr::NMiniKQL::TKey16> {
  282. using argument_type = NKikimr::NMiniKQL::TKey16;
  283. bool operator()(argument_type x, argument_type y) const {
  284. return x.Hi == y.Hi && x.Lo == y.Lo;
  285. }
  286. };
  287. template <>
  288. struct hash<NKikimr::NMiniKQL::TSSOKey> {
  289. using argument_type = NKikimr::NMiniKQL::TSSOKey;
  290. using result_type = size_t;
  291. inline result_type operator()(argument_type const& s) const noexcept {
  292. return std::hash<std::string_view>()(s.AsView());
  293. }
  294. };
  295. template <>
  296. struct equal_to<NKikimr::NMiniKQL::TSSOKey> {
  297. using argument_type = NKikimr::NMiniKQL::TSSOKey;
  298. bool operator()(argument_type x, argument_type y) const {
  299. return x.AsView() == y.AsView();
  300. }
  301. bool operator()(argument_type x, TStringBuf y) const {
  302. return x.AsView() == y;
  303. }
  304. using is_transparent = void;
  305. };
  306. template <>
  307. struct hash<NKikimr::NMiniKQL::TExternalFixedSizeKey> {
  308. using argument_type = NKikimr::NMiniKQL::TExternalFixedSizeKey;
  309. using result_type = size_t;
  310. hash(ui32 length)
  311. : Length(length)
  312. {}
  313. inline result_type operator()(argument_type const& s) const noexcept {
  314. return std::hash<std::string_view>()(std::string_view(s.Data, Length));
  315. }
  316. const ui32 Length;
  317. };
  318. template <>
  319. struct equal_to<NKikimr::NMiniKQL::TExternalFixedSizeKey> {
  320. using argument_type = NKikimr::NMiniKQL::TExternalFixedSizeKey;
  321. equal_to(ui32 length)
  322. : Length(length)
  323. {}
  324. bool operator()(argument_type x, argument_type y) const {
  325. return memcmp(x.Data, y.Data, Length) == 0;
  326. }
  327. bool operator()(argument_type x, TStringBuf y) const {
  328. Y_ASSERT(y.Size() <= Length);
  329. return memcmp(x.Data, y.Data(), Length) == 0;
  330. }
  331. using is_transparent = void;
  332. const ui32 Length;
  333. };
  334. }
  335. namespace NKikimr {
  336. namespace NMiniKQL {
  337. namespace {
  338. template <typename T>
  339. struct TAggParams {
  340. std::unique_ptr<IPreparedBlockAggregator<T>> Prepared_;
  341. ui32 Column_ = 0;
  342. TType* StateType_ = nullptr;
  343. TType* ReturnType_ = nullptr;
  344. ui32 Hint_ = 0;
  345. };
  346. struct TKeyParams {
  347. ui32 Index;
  348. TType* Type;
  349. };
  350. size_t GetBitmapPopCount(const std::shared_ptr<arrow::ArrayData>& arr) {
  351. size_t len = (size_t)arr->length;
  352. MKQL_ENSURE(arr->GetNullCount() == 0, "Bitmap block should not have nulls");
  353. const ui8* src = arr->GetValues<ui8>(1);
  354. return GetSparseBitmapPopCount(src, len);
  355. }
  356. size_t CalcMaxBlockLenForOutput(TType* out) {
  357. const auto wideComponents = GetWideComponents(out);
  358. MKQL_ENSURE(wideComponents.size() > 0, "Expecting at least one output column");
  359. size_t maxBlockItemSize = 0;
  360. for (ui32 i = 0; i < wideComponents.size() - 1; ++i) {
  361. auto type = AS_TYPE(TBlockType, wideComponents[i]);
  362. MKQL_ENSURE(type->GetShape() != TBlockType::EShape::Scalar, "Expecting block type");
  363. maxBlockItemSize = std::max(maxBlockItemSize, CalcMaxBlockItemSize(type->GetItemType()));
  364. }
  365. return CalcBlockLen(maxBlockItemSize);
  366. }
  367. class TBlockCombineAllWrapperCodegenBase {
  368. protected:
  369. #ifndef MKQL_DISABLE_CODEGEN
  370. class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> {
  371. private:
  372. using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>;
  373. llvm::PointerType*const PointerType;
  374. llvm::IntegerType*const IsFinishedType;
  375. public:
  376. std::vector<llvm::Type*> GetFieldsArray() {
  377. std::vector<llvm::Type*> result = TBase::GetFields();
  378. result.emplace_back(PointerType);
  379. result.emplace_back(IsFinishedType);
  380. return result;
  381. }
  382. llvm::Constant* GetPointer() {
  383. return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
  384. }
  385. llvm::Constant* GetIsFinished() {
  386. return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
  387. }
  388. TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
  389. : TBase(context)
  390. , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width)))
  391. , IsFinishedType(llvm::Type::getInt1Ty(Context))
  392. {}
  393. };
  394. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValuesImpl(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block,
  395. IComputationWideFlowNode* flow, size_t width, size_t aggCount,
  396. uintptr_t getStateMethodPtr, uint64_t makeStateMethodPtr,
  397. uintptr_t processInputMethodPtr, uintptr_t makeOutputMethodPtr) const {
  398. auto& context = ctx.Codegen.GetContext();
  399. const auto valueType = Type::getInt128Ty(context);
  400. const auto statusType = Type::getInt32Ty(context);
  401. const auto indexType = Type::getInt64Ty(context);
  402. const auto flagType = Type::getInt1Ty(context);
  403. const auto arrayType = ArrayType::get(valueType, width);
  404. const auto ptrValuesType = PointerType::getUnqual(arrayType);
  405. TLLVMFieldsStructureState stateFields(context, width);
  406. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  407. const auto statePtrType = PointerType::getUnqual(stateType);
  408. const auto atTop = &ctx.Func->getEntryBlock().back();
  409. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), getStateMethodPtr);
  410. const auto getType = FunctionType::get(valueType, {statePtrType, indexType}, false);
  411. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
  412. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
  413. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
  414. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  415. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  416. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  417. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  418. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  419. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  420. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  421. block = make;
  422. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  423. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  424. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), makeStateMethodPtr);
  425. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false);
  426. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  427. CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block);
  428. BranchInst::Create(main, block);
  429. block = main;
  430. const auto state = new LoadInst(valueType, statePtr, "state", block);
  431. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  432. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  433. const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
  434. const auto finished = new LoadInst(flagType, finishedPtr, "finished", block);
  435. const auto result = PHINode::Create(statusType, 3U, "result", over);
  436. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  437. BranchInst::Create(over, read, finished, block);
  438. block = read;
  439. const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
  440. const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
  441. SafeUnRefUnboxedArray(values, arrayType, ctx, block);
  442. const auto getres = GetNodeValues(flow, ctx, block);
  443. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  444. const auto way = SwitchInst::Create(getres.first, good, 2U, block);
  445. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), work);
  446. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
  447. block = good;
  448. Value* array = UndefValue::get(arrayType);
  449. for (auto idx = 0U; idx < getres.second.size(); ++idx) {
  450. const auto value = getres.second[idx](ctx, block);
  451. AddRefBoxed(value, ctx, block);
  452. array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
  453. }
  454. new StoreInst(array, values, block);
  455. const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), processInputMethodPtr);
  456. const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false);
  457. const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block);
  458. CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
  459. BranchInst::Create(read, block);
  460. block = work;
  461. const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), makeOutputMethodPtr);
  462. const auto makeOutputType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false);
  463. const auto makeOutputPtr = CastInst::Create(Instruction::IntToPtr, makeOutputFunc, PointerType::getUnqual(makeOutputType), "make_output_func", block);
  464. const auto hasData = CallInst::Create(makeOutputType, makeOutputPtr, {stateArg, ctx.GetFactory()}, "make_output", block);
  465. const auto output = SelectInst::Create(hasData, ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), "output", block);
  466. new StoreInst(stateArg, stateOnStack, block);
  467. result->addIncoming(output, block);
  468. BranchInst::Create(over, block);
  469. block = over;
  470. ICodegeneratorInlineWideNode::TGettersList getters(aggCount);
  471. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  472. getters[idx] = [idx, getType, getPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
  473. Y_UNUSED(ctx);
  474. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  475. return CallInst::Create(getType, getPtr, {stateArg, ConstantInt::get(indexType, idx)}, "get", block);
  476. };
  477. }
  478. return {result, std::move(getters)};
  479. }
  480. #endif
  481. };
  482. struct TBlockCombineAllState : public TComputationValue<TBlockCombineAllState> {
  483. NUdf::TUnboxedValue* Pointer_ = nullptr;
  484. bool IsFinished_ = false;
  485. bool HasValues_ = false;
  486. TUnboxedValueVector Values_;
  487. std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_;
  488. std::vector<char> AggStates_;
  489. const std::optional<ui32> FilterColumn_;
  490. const size_t Width_;
  491. TBlockCombineAllState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx)
  492. : TComputationValue(memInfo)
  493. , Values_(std::max(width, params.size()))
  494. , FilterColumn_(filterColumn)
  495. , Width_(width)
  496. {
  497. Pointer_ = Values_.data();
  498. ui32 totalStateSize = 0;
  499. for (const auto& p : params) {
  500. Aggs_.emplace_back(p.Prepared_->Make(ctx));
  501. MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch");
  502. totalStateSize += Aggs_.back()->StateSize;
  503. }
  504. AggStates_.resize(totalStateSize);
  505. char* ptr = AggStates_.data();
  506. for (const auto& agg : Aggs_) {
  507. agg->InitState(ptr);
  508. ptr += agg->StateSize;
  509. }
  510. }
  511. void ProcessInput() {
  512. const ui64 batchLength = TArrowBlock::From(Values_[Width_ - 1U]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
  513. if (!batchLength) {
  514. return;
  515. }
  516. std::optional<ui64> filtered;
  517. if (FilterColumn_) {
  518. const auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum();
  519. if (filterDatum.is_scalar()) {
  520. if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) {
  521. return;
  522. }
  523. } else {
  524. const ui64 popCount = GetBitmapPopCount(filterDatum.array());
  525. if (popCount == 0) {
  526. return;
  527. }
  528. if (popCount < batchLength) {
  529. filtered = popCount;
  530. }
  531. }
  532. }
  533. HasValues_ = true;
  534. char* ptr = AggStates_.data();
  535. for (size_t i = 0; i < Aggs_.size(); ++i) {
  536. Aggs_[i]->AddMany(ptr, Values_.data(), batchLength, filtered);
  537. ptr += Aggs_[i]->StateSize;
  538. }
  539. }
  540. bool MakeOutput() {
  541. IsFinished_ = true;
  542. if (!HasValues_)
  543. return false;
  544. char* ptr = AggStates_.data();
  545. for (size_t i = 0; i < Aggs_.size(); ++i) {
  546. Values_[i] = Aggs_[i]->FinishOne(ptr);
  547. Aggs_[i]->DestroyState(ptr);
  548. ptr += Aggs_[i]->StateSize;
  549. }
  550. return true;
  551. }
  552. NUdf::TUnboxedValuePod Get(size_t index) const {
  553. return Values_[index];
  554. }
  555. };
  556. class TBlockCombineAllWrapperFromFlow : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapperFromFlow>,
  557. protected TBlockCombineAllWrapperCodegenBase {
  558. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapperFromFlow>;
  559. using TState = TBlockCombineAllState;
  560. public:
  561. TBlockCombineAllWrapperFromFlow(TComputationMutables& mutables,
  562. IComputationWideFlowNode* flow,
  563. std::optional<ui32> filterColumn,
  564. size_t width,
  565. std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams)
  566. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  567. , Flow_(flow)
  568. , FilterColumn_(filterColumn)
  569. , Width_(width)
  570. , AggsParams_(std::move(aggsParams))
  571. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
  572. {
  573. MKQL_ENSURE(Width_ > 0, "Missing block length column");
  574. }
  575. EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
  576. TComputationContext& ctx,
  577. NUdf::TUnboxedValue*const* output) const
  578. {
  579. auto& s = GetState(state, ctx);
  580. if (s.IsFinished_)
  581. return EFetchResult::Finish;
  582. for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) {
  583. switch (Flow_->FetchValues(ctx, fields)) {
  584. case EFetchResult::Yield:
  585. return EFetchResult::Yield;
  586. case EFetchResult::One:
  587. s.ProcessInput();
  588. continue;
  589. case EFetchResult::Finish:
  590. break;
  591. }
  592. if (s.MakeOutput()) {
  593. for (size_t i = 0; i < AggsParams_.size(); ++i) {
  594. if (const auto out = output[i]) {
  595. *out = s.Get(i);
  596. }
  597. }
  598. return EFetchResult::One;
  599. }
  600. return EFetchResult::Finish;
  601. }
  602. }
  603. #ifndef MKQL_DISABLE_CODEGEN
  604. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  605. return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, AggsParams_.size(),
  606. GetMethodPtr(&TState::Get), GetMethodPtr(&TBlockCombineAllWrapperFromFlow::MakeState),
  607. GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::MakeOutput));
  608. }
  609. #endif
  610. private:
  611. void RegisterDependencies() const final {
  612. FlowDependsOn(Flow_);
  613. }
  614. void MakeState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  615. state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
  616. }
  617. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  618. if (state.IsInvalid()) {
  619. MakeState(state, ctx);
  620. auto& s = *static_cast<TState*>(state.AsBoxed().Get());
  621. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  622. for (size_t i = 0; i < Width_; ++i) {
  623. fields[i] = &s.Values_[i];
  624. }
  625. return s;
  626. }
  627. return *static_cast<TState*>(state.AsBoxed().Get());
  628. }
  629. private:
  630. IComputationWideFlowNode *const Flow_;
  631. const std::optional<ui32> FilterColumn_;
  632. const size_t Width_;
  633. const std::vector<TAggParams<IBlockAggregatorCombineAll>> AggsParams_;
  634. const size_t WideFieldsIndex_;
  635. };
  636. class TBlockCombineAllWrapperFromStream : public TMutableComputationNode<TBlockCombineAllWrapperFromStream> {
  637. using TBaseComputation = TMutableComputationNode<TBlockCombineAllWrapperFromStream>;
  638. using TState = TBlockCombineAllState;
  639. public:
  640. TBlockCombineAllWrapperFromStream(TComputationMutables& mutables,
  641. IComputationNode* stream,
  642. std::optional<ui32> filterColumn,
  643. size_t width,
  644. std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams)
  645. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  646. , Stream_(stream)
  647. , FilterColumn_(filterColumn)
  648. , Width_(width)
  649. , AggsParams_(std::move(aggsParams))
  650. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
  651. {
  652. MKQL_ENSURE(Width_ > 0, "Missing block length column");
  653. }
  654. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const
  655. {
  656. const auto state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
  657. return ctx.HolderFactory.Create<TStreamValue>(std::move(state), std::move(Stream_->GetValue(ctx)));
  658. }
  659. private:
  660. class TStreamValue : public TComputationValue<TStreamValue> {
  661. using TBase = TComputationValue<TStreamValue>;
  662. public:
  663. TStreamValue(TMemoryUsageInfo* memInfo, NUdf::TUnboxedValue&& state, NUdf::TUnboxedValue&& stream)
  664. : TBase(memInfo)
  665. , State_(state)
  666. , Stream_(stream)
  667. {
  668. }
  669. private:
  670. NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
  671. TState& state = *static_cast<TState*>(State_.AsBoxed().Get());
  672. auto* inputFields = state.Values_.data();
  673. const size_t inputWidth = state.Width_;
  674. if (state.IsFinished_)
  675. return NUdf::EFetchStatus::Finish;
  676. while (true) {
  677. switch (Stream_.WideFetch(inputFields, inputWidth)) {
  678. case NUdf::EFetchStatus::Yield:
  679. return NUdf::EFetchStatus::Yield;
  680. case NUdf::EFetchStatus::Ok:
  681. state.ProcessInput();
  682. continue;
  683. case NUdf::EFetchStatus::Finish:
  684. break;
  685. }
  686. if (state.MakeOutput()) {
  687. for (size_t i = 0; i < width; ++i) {
  688. output[i] = state.Get(i);
  689. }
  690. return NUdf::EFetchStatus::Ok;
  691. }
  692. return NUdf::EFetchStatus::Finish;
  693. }
  694. }
  695. private:
  696. NUdf::TUnboxedValue State_;
  697. NUdf::TUnboxedValue Stream_;
  698. };
  699. private:
  700. void RegisterDependencies() const final {
  701. DependsOn(Stream_);
  702. }
  703. private:
  704. IComputationNode *const Stream_;
  705. const std::optional<ui32> FilterColumn_;
  706. const size_t Width_;
  707. const std::vector<TAggParams<IBlockAggregatorCombineAll>> AggsParams_;
  708. const size_t WideFieldsIndex_;
  709. };
  710. template <typename T>
  711. T MakeKey(TStringBuf s, ui32 keyLength) {
  712. Y_UNUSED(keyLength);
  713. Y_ASSERT(s.Size() <= sizeof(T));
  714. return *(const T*)s.Data();
  715. }
  716. template <>
  717. TSSOKey MakeKey(TStringBuf s, ui32 keyLength) {
  718. Y_UNUSED(keyLength);
  719. if (TSSOKey::CanBeInplace(s)) {
  720. return TSSOKey::Inplace(s);
  721. } else {
  722. return TSSOKey::External(s);
  723. }
  724. }
  725. template <>
  726. TExternalFixedSizeKey MakeKey(TStringBuf s, ui32 keyLength) {
  727. Y_ASSERT(s.Size() == keyLength);
  728. return { s.Data() };
  729. }
  730. void MoveKeyToArena(const TSSOKey& key, TPagedArena& arena, ui32 keyLength) {
  731. Y_UNUSED(keyLength);
  732. if (key.IsInplace()) {
  733. return;
  734. }
  735. auto view = key.AsView();
  736. auto ptr = (char*)arena.Alloc(view.Size());
  737. memcpy(ptr, view.Data(), view.Size());
  738. key.UpdateExternalPointer(ptr);
  739. }
  740. void MoveKeyToArena(const TExternalFixedSizeKey& key, TPagedArena& arena, ui32 keyLength) {
  741. auto ptr = (char*)arena.Alloc(keyLength);
  742. memcpy(ptr, key.Data, keyLength);
  743. key.Data = ptr;
  744. }
  745. template <typename T>
  746. TStringBuf GetKeyView(const T& key, ui32 keyLength) {
  747. Y_UNUSED(keyLength);
  748. return TStringBuf((const char*)&key, sizeof(T));
  749. }
  750. template <>
  751. TStringBuf GetKeyView(const TSSOKey& key, ui32 keyLength) {
  752. Y_UNUSED(keyLength);
  753. return key.AsView();
  754. }
  755. template <>
  756. TStringBuf GetKeyView(const TExternalFixedSizeKey& key, ui32 keyLength) {
  757. return TStringBuf(key.Data, keyLength);
  758. }
  759. template <typename T>
  760. std::equal_to<T> MakeEqual(ui32 keyLength) {
  761. Y_UNUSED(keyLength);
  762. return std::equal_to<T>();
  763. }
  764. template <>
  765. std::equal_to<TExternalFixedSizeKey> MakeEqual(ui32 keyLength) {
  766. return std::equal_to<TExternalFixedSizeKey>(keyLength);
  767. }
  768. template <typename T>
  769. std::hash<T> MakeHash(ui32 keyLength) {
  770. Y_UNUSED(keyLength);
  771. return std::hash<T>();
  772. }
  773. template <>
  774. std::hash<TExternalFixedSizeKey> MakeHash(ui32 keyLength) {
  775. return std::hash<TExternalFixedSizeKey>(keyLength);
  776. }
  777. class THashedWrapperCodegenBase {
  778. protected:
  779. #ifndef MKQL_DISABLE_CODEGEN
  780. class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState {
  781. private:
  782. using TBase = TLLVMFieldsStructureBlockState;
  783. llvm::IntegerType*const WritingOutputType;
  784. llvm::IntegerType*const IsFinishedType;
  785. protected:
  786. using TBase::Context;
  787. public:
  788. std::vector<llvm::Type*> GetFieldsArray() {
  789. std::vector<llvm::Type*> result = TBase::GetFieldsArray();
  790. result.emplace_back(WritingOutputType);
  791. result.emplace_back(IsFinishedType);
  792. return result;
  793. }
  794. llvm::Constant* GetWritingOutput() {
  795. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields);
  796. }
  797. llvm::Constant* GetIsFinished() {
  798. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1);
  799. }
  800. TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
  801. : TBase(context, width)
  802. , WritingOutputType(Type::getInt1Ty(Context))
  803. , IsFinishedType(Type::getInt1Ty(Context))
  804. {}
  805. };
  806. Y_NO_INLINE ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValuesImpl(
  807. const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block,
  808. IComputationWideFlowNode* flow, size_t width, size_t outputWidth,
  809. uintptr_t getStateMethodPtr, uintptr_t makeStateMethodPtr,
  810. uintptr_t processInputMethodPtr, uintptr_t finishMethodPtr,
  811. uintptr_t fillOutputMethodPtr, uintptr_t sliceMethodPtr) const {
  812. auto& context = ctx.Codegen.GetContext();
  813. const auto valueType = Type::getInt128Ty(context);
  814. const auto statusType = Type::getInt32Ty(context);
  815. const auto indexType = Type::getInt64Ty(context);
  816. const auto flagType = Type::getInt1Ty(context);
  817. const auto arrayType = ArrayType::get(valueType, width);
  818. const auto ptrValuesType = PointerType::getUnqual(arrayType);
  819. TLLVMFieldsStructureState stateFields(context, width);
  820. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  821. const auto statePtrType = PointerType::getUnqual(stateType);
  822. const auto atTop = &ctx.Func->getEntryBlock().back();
  823. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), getStateMethodPtr);
  824. const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false);
  825. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
  826. const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop);
  827. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
  828. new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop);
  829. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
  830. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  831. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  832. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  833. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  834. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  835. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  836. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  837. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  838. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  839. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  840. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  841. block = make;
  842. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  843. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  844. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), makeStateMethodPtr);
  845. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false);
  846. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  847. CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block);
  848. BranchInst::Create(main, block);
  849. block = main;
  850. const auto state = new LoadInst(valueType, statePtr, "state", block);
  851. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  852. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  853. const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
  854. const auto count = new LoadInst(indexType, countPtr, "count", block);
  855. const auto none = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "none", block);
  856. BranchInst::Create(more, fill, none, block);
  857. block = more;
  858. const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
  859. const auto finished = new LoadInst(flagType, finishedPtr, "finished", block);
  860. const auto result = PHINode::Create(statusType, 5U, "result", over);
  861. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  862. BranchInst::Create(over, test, finished, block);
  863. block = test;
  864. const auto writingOutputPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetWritingOutput() }, "writing_output_ptr", block);
  865. const auto writingOutput = new LoadInst(flagType, writingOutputPtr, "writing_output", block);
  866. BranchInst::Create(work, read, writingOutput, block);
  867. block = read;
  868. const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
  869. const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
  870. SafeUnRefUnboxedArray(values, arrayType, ctx, block);
  871. const auto getres = GetNodeValues(flow, ctx, block);
  872. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  873. const auto way = SwitchInst::Create(getres.first, good, 2U, block);
  874. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), stop);
  875. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
  876. block = good;
  877. Value* array = UndefValue::get(arrayType);
  878. for (auto idx = 0U; idx < getres.second.size(); ++idx) {
  879. const auto value = getres.second[idx](ctx, block);
  880. AddRefBoxed(value, ctx, block);
  881. array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
  882. }
  883. new StoreInst(array, values, block);
  884. const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), processInputMethodPtr);
  885. const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false);
  886. const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block);
  887. CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
  888. BranchInst::Create(read, block);
  889. block = stop;
  890. const auto finishFunc = ConstantInt::get(Type::getInt64Ty(context), finishMethodPtr);
  891. const auto finishType = FunctionType::get(flagType, {statePtrType}, false);
  892. const auto finishPtr = CastInst::Create(Instruction::IntToPtr, finishFunc, PointerType::getUnqual(finishType), "finish_func", block);
  893. const auto hasOutput = CallInst::Create(finishType, finishPtr, {stateArg}, "has_output", block);
  894. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  895. BranchInst::Create(work, over, hasOutput, block);
  896. block = work;
  897. const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), fillOutputMethodPtr);
  898. const auto fillBlockType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false);
  899. const auto fillBlockPtr = CastInst::Create(Instruction::IntToPtr, fillBlockFunc, PointerType::getUnqual(fillBlockType), "fill_output_func", block);
  900. const auto hasData = CallInst::Create(fillBlockType, fillBlockPtr, {stateArg, ctx.GetFactory()}, "fill_output", block);
  901. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  902. BranchInst::Create(fill, over, hasData, block);
  903. block = fill;
  904. const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), sliceMethodPtr);
  905. const auto sliceType = FunctionType::get(indexType, {statePtrType}, false);
  906. const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block);
  907. const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block);
  908. new StoreInst(slice, heightPtr, block);
  909. new StoreInst(stateArg, stateOnStack, block);
  910. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  911. BranchInst::Create(over, block);
  912. block = over;
  913. ICodegeneratorInlineWideNode::TGettersList getters(outputWidth);
  914. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  915. getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
  916. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  917. const auto heightArg = new LoadInst(indexType, heightPtr, "height", block);
  918. return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block);
  919. };
  920. }
  921. return {result, std::move(getters)};
  922. }
  923. #endif
  924. };
  925. template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived>
  926. struct THashedWrapperBaseState : public TBlockState {
  927. private:
  928. static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value;
  929. public:
  930. bool WritingOutput_ = false;
  931. bool IsFinished_ = false;
  932. const std::optional<ui32> FilterColumn_;
  933. const std::vector<TKeyParams> Keys_;
  934. const std::vector<TAggParams<TAggregator>>& AggsParams_;
  935. const ui32 KeyLength_;
  936. const ui32 StreamIndex_;
  937. const std::vector<std::vector<ui32>> Streams_;
  938. const size_t MaxBlockLen_;
  939. const size_t Width_;
  940. const size_t OutputWidth_;
  941. template<typename TKeyType>
  942. struct THashSettings {
  943. static constexpr bool CacheHash = std::is_same_v<TKeyType, TSSOKey>;
  944. };
  945. using TDynMapImpl = TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>;
  946. using TSetImpl = THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>;
  947. using TFixedMapImpl = TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>;
  948. ui64 BatchNum_ = 0;
  949. TUnboxedValueVector Values_;
  950. std::vector<std::unique_ptr<TAggregator>> Aggs_;
  951. std::vector<ui32> AggStateOffsets_;
  952. TUnboxedValueVector UnwrappedValues_;
  953. std::vector<std::unique_ptr<IBlockReader>> Readers_;
  954. std::vector<std::unique_ptr<IArrayBuilder>> Builders_;
  955. std::vector<std::unique_ptr<IAggColumnBuilder>> AggBuilders_;
  956. bool HasValues_ = false;
  957. ui32 TotalStateSize_ = 0;
  958. size_t OutputBlockSize_ = 0;
  959. std::unique_ptr<TDynMapImpl> HashMap_;
  960. typename TDynMapImpl::const_iterator HashMapIt_;
  961. std::unique_ptr<TSetImpl> HashSet_;
  962. typename TSetImpl::const_iterator HashSetIt_;
  963. std::unique_ptr<TFixedMapImpl> HashFixedMap_;
  964. typename TFixedMapImpl::const_iterator HashFixedMapIt_;
  965. TPagedArena Arena_;
  966. THashedWrapperBaseState(TMemoryUsageInfo* memInfo, ui32 keyLength, ui32 streamIndex, size_t width, size_t outputWidth, std::optional<ui32> filterColumn, const std::vector<TAggParams<TAggregator>>& params,
  967. const std::vector<std::vector<ui32>>& streams, const std::vector<TKeyParams>& keys, size_t maxBlockLen, TComputationContext& ctx)
  968. : TBlockState(memInfo, outputWidth)
  969. , FilterColumn_(filterColumn)
  970. , Keys_(keys)
  971. , AggsParams_(params)
  972. , KeyLength_(keyLength)
  973. , StreamIndex_(streamIndex)
  974. , Streams_(streams)
  975. , MaxBlockLen_(maxBlockLen)
  976. , Width_(width)
  977. , OutputWidth_(outputWidth)
  978. , Values_(width)
  979. , UnwrappedValues_(width)
  980. , Readers_(keys.size())
  981. , Builders_(keys.size())
  982. , Arena_(TlsAllocState)
  983. {
  984. Pointer_ = Values_.data();
  985. for (size_t i = 0; i < Keys_.size(); ++i) {
  986. auto itemType = AS_TYPE(TBlockType, Keys_[i].Type)->GetItemType();
  987. Readers_[i] = NYql::NUdf::MakeBlockReader(TTypeInfoHelper(), itemType);
  988. Builders_[i] = NYql::NUdf::MakeArrayBuilder(TTypeInfoHelper(), itemType, ctx.ArrowMemoryPool, MaxBlockLen_, &ctx.Builder->GetPgBuilder());
  989. }
  990. if constexpr (Many) {
  991. TotalStateSize_ += Streams_.size();
  992. }
  993. for (const auto& p : AggsParams_) {
  994. Aggs_.emplace_back(p.Prepared_->Make(ctx));
  995. MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch");
  996. AggStateOffsets_.emplace_back(TotalStateSize_);
  997. TotalStateSize_ += Aggs_.back()->StateSize;
  998. }
  999. auto equal = MakeEqual<TKey>(KeyLength_);
  1000. auto hasher = MakeHash<TKey>(KeyLength_);
  1001. if constexpr (UseSet) {
  1002. MKQL_ENSURE(params.empty(), "Only keys are supported");
  1003. HashSet_ = std::make_unique<THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(hasher, equal);
  1004. } else {
  1005. if (!InlineAggState) {
  1006. HashFixedMap_ = std::make_unique<TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(hasher, equal);
  1007. } else {
  1008. HashMap_ = std::make_unique<TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(TotalStateSize_, hasher, equal);
  1009. }
  1010. }
  1011. }
  1012. void ProcessInput(const THolderFactory& holderFactory) {
  1013. ++BatchNum_;
  1014. const auto batchLength = TArrowBlock::From(Values_.back()).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
  1015. if (!batchLength) {
  1016. return;
  1017. }
  1018. const ui8* filterBitmap = nullptr;
  1019. if constexpr (UseFilter) {
  1020. auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum();
  1021. if (filterDatum.is_scalar()) {
  1022. if (!filterDatum.template scalar_as<arrow::UInt8Scalar>().value) {
  1023. return;
  1024. }
  1025. } else {
  1026. const auto& arr = filterDatum.array();
  1027. filterBitmap = arr->template GetValues<ui8>(1);
  1028. ui64 popCount = GetBitmapPopCount(arr);
  1029. if (popCount == 0) {
  1030. return;
  1031. }
  1032. }
  1033. }
  1034. const ui32* streamIndexData = nullptr;
  1035. TMaybe<ui32> streamIndexScalar;
  1036. if constexpr (Many) {
  1037. auto streamIndexDatum = TArrowBlock::From(Values_[StreamIndex_]).GetDatum();
  1038. if (streamIndexDatum.is_scalar()) {
  1039. streamIndexScalar = streamIndexDatum.template scalar_as<arrow::UInt32Scalar>().value;
  1040. } else {
  1041. MKQL_ENSURE(streamIndexDatum.is_array(), "Expected array");
  1042. streamIndexData = streamIndexDatum.array()->template GetValues<ui32>(1);
  1043. }
  1044. UnwrappedValues_ = Values_;
  1045. for (const auto& p : AggsParams_) {
  1046. const auto& columnDatum = TArrowBlock::From(UnwrappedValues_[p.Column_]).GetDatum();
  1047. MKQL_ENSURE(columnDatum.is_array(), "Expected array");
  1048. UnwrappedValues_[p.Column_] = holderFactory.CreateArrowBlock(Unwrap(*columnDatum.array(), p.StateType_));
  1049. }
  1050. }
  1051. HasValues_ = true;
  1052. std::vector<arrow::Datum> keysDatum;
  1053. keysDatum.reserve(Keys_.size());
  1054. for (ui32 i = 0; i < Keys_.size(); ++i) {
  1055. keysDatum.emplace_back(TArrowBlock::From(Values_[Keys_[i].Index]).GetDatum());
  1056. }
  1057. std::array<TOutputBuffer, PrefetchBatchSize> out;
  1058. for (ui32 i = 0; i < PrefetchBatchSize; ++i) {
  1059. out[i].Resize(sizeof(TKey));
  1060. }
  1061. std::array<TRobinHoodBatchRequestItem<TKey>, PrefetchBatchSize> insertBatch;
  1062. std::array<ui64, PrefetchBatchSize> insertBatchRows;
  1063. std::array<char*, PrefetchBatchSize> insertBatchPayloads;
  1064. std::array<bool, PrefetchBatchSize> insertBatchIsNew;
  1065. ui32 insertBatchLen = 0;
  1066. const auto processInsertBatch = [&]() {
  1067. for (ui32 i = 0; i < insertBatchLen; ++i) {
  1068. auto& r = insertBatch[i];
  1069. TStringBuf str = out[i].Finish();
  1070. TKey key = MakeKey<TKey>(str, KeyLength_);
  1071. r.ConstructKey(key);
  1072. }
  1073. if constexpr (UseSet) {
  1074. HashSet_->BatchInsert({insertBatch.data(), insertBatchLen},[&](size_t index, typename THashedWrapperBaseState::TSetImpl::iterator iter, bool isNew) {
  1075. Y_UNUSED(index);
  1076. if (isNew) {
  1077. if constexpr (std::is_same<TKey, TSSOKey>::value || std::is_same<TKey, TExternalFixedSizeKey>::value) {
  1078. MoveKeyToArena(HashSet_->GetKey(iter), Arena_, KeyLength_);
  1079. }
  1080. }
  1081. });
  1082. } else {
  1083. using THashTable = std::conditional_t<InlineAggState, typename THashedWrapperBaseState::TDynMapImpl, typename THashedWrapperBaseState::TFixedMapImpl>;
  1084. THashTable* hash;
  1085. if constexpr (!InlineAggState) {
  1086. hash = HashFixedMap_.get();
  1087. } else {
  1088. hash = HashMap_.get();
  1089. }
  1090. hash->BatchInsert({insertBatch.data(), insertBatchLen}, [&](size_t index, typename THashTable::iterator iter, bool isNew) {
  1091. if (isNew) {
  1092. if constexpr (std::is_same<TKey, TSSOKey>::value || std::is_same<TKey, TExternalFixedSizeKey>::value) {
  1093. MoveKeyToArena(hash->GetKey(iter), Arena_, KeyLength_);
  1094. }
  1095. }
  1096. if constexpr (UseArena) {
  1097. // prefetch payloads only
  1098. auto payload = hash->GetPayload(iter);
  1099. char* ptr;
  1100. if (isNew) {
  1101. ptr = (char*)Arena_.Alloc(TotalStateSize_);
  1102. *(char**)payload = ptr;
  1103. } else {
  1104. ptr = *(char**)payload;
  1105. }
  1106. insertBatchIsNew[index] = isNew;
  1107. insertBatchPayloads[index] = ptr;
  1108. NYql::PrefetchForWrite(ptr);
  1109. } else {
  1110. // process insert
  1111. auto payload = (char*)hash->GetPayload(iter);
  1112. auto row = insertBatchRows[index];
  1113. ui32 streamIndex = 0;
  1114. if constexpr (Many) {
  1115. streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row];
  1116. }
  1117. Insert(row, payload, isNew, streamIndex);
  1118. }
  1119. });
  1120. if constexpr (UseArena) {
  1121. for (ui32 i = 0; i < insertBatchLen; ++i) {
  1122. auto row = insertBatchRows[i];
  1123. ui32 streamIndex = 0;
  1124. if constexpr (Many) {
  1125. streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row];
  1126. }
  1127. bool isNew = insertBatchIsNew[i];
  1128. char* payload = insertBatchPayloads[i];
  1129. Insert(row, payload, isNew, streamIndex);
  1130. }
  1131. }
  1132. }
  1133. };
  1134. for (ui64 row = 0; row < batchLength; ++row) {
  1135. if constexpr (UseFilter) {
  1136. if (filterBitmap && !filterBitmap[row]) {
  1137. continue;
  1138. }
  1139. }
  1140. // encode key
  1141. out[insertBatchLen].Rewind();
  1142. for (ui32 i = 0; i < keysDatum.size(); ++i) {
  1143. if (keysDatum[i].is_scalar()) {
  1144. // TODO: more efficient code when grouping by scalar
  1145. Readers_[i]->SaveScalarItem(*keysDatum[i].scalar(), out[insertBatchLen]);
  1146. } else {
  1147. Readers_[i]->SaveItem(*keysDatum[i].array(), row, out[insertBatchLen]);
  1148. }
  1149. }
  1150. insertBatchRows[insertBatchLen] = row;
  1151. ++insertBatchLen;
  1152. if (insertBatchLen == PrefetchBatchSize) {
  1153. processInsertBatch();
  1154. insertBatchLen = 0;
  1155. }
  1156. }
  1157. processInsertBatch();
  1158. }
  1159. bool Finish() {
  1160. if (!HasValues_) {
  1161. IsFinished_ = true;
  1162. return false;
  1163. }
  1164. WritingOutput_ = true;
  1165. OutputBlockSize_ = 0;
  1166. PrepareAggBuilders();
  1167. if constexpr (UseSet) {
  1168. HashSetIt_ = HashSet_->Begin();
  1169. } else {
  1170. if constexpr (!InlineAggState) {
  1171. HashFixedMapIt_ = HashFixedMap_->Begin();
  1172. } else {
  1173. HashMapIt_ = HashMap_->Begin();
  1174. }
  1175. }
  1176. return true;
  1177. }
  1178. bool FillOutput(const THolderFactory& holderFactory) {
  1179. bool exit = false;
  1180. while (WritingOutput_) {
  1181. if constexpr (UseSet) {
  1182. for (;!exit && HashSetIt_ != HashSet_->End(); HashSet_->Advance(HashSetIt_)) {
  1183. if (!HashSet_->IsValid(HashSetIt_)) {
  1184. continue;
  1185. }
  1186. if (OutputBlockSize_ == MaxBlockLen_) {
  1187. Flush(false, holderFactory);
  1188. //return EFetchResult::One;
  1189. exit = true;
  1190. break;
  1191. }
  1192. const TKey& key = HashSet_->GetKey(HashSetIt_);
  1193. TInputBuffer in(GetKeyView<TKey>(key, KeyLength_));
  1194. for (auto& kb : Builders_) {
  1195. kb->Add(in);
  1196. }
  1197. ++OutputBlockSize_;
  1198. }
  1199. break;
  1200. } else {
  1201. const bool done = InlineAggState ?
  1202. Iterate(*HashMap_, HashMapIt_) :
  1203. Iterate(*HashFixedMap_, HashFixedMapIt_);
  1204. if (done) {
  1205. break;
  1206. }
  1207. Flush(false, holderFactory);
  1208. exit = true;
  1209. break;
  1210. }
  1211. }
  1212. if (!exit) {
  1213. IsFinished_ = true;
  1214. WritingOutput_ = false;
  1215. if (!OutputBlockSize_)
  1216. return false;
  1217. Flush(true, holderFactory);
  1218. }
  1219. FillArrays();
  1220. return true;
  1221. }
  1222. private:
  1223. void PrepareAggBuilders() {
  1224. if constexpr (!UseSet) {
  1225. AggBuilders_.clear();
  1226. AggBuilders_.reserve(Aggs_.size());
  1227. for (const auto& a : Aggs_) {
  1228. if constexpr (Finalize) {
  1229. AggBuilders_.emplace_back(a->MakeResultBuilder(MaxBlockLen_));
  1230. } else {
  1231. AggBuilders_.emplace_back(a->MakeStateBuilder(MaxBlockLen_));
  1232. }
  1233. }
  1234. }
  1235. }
  1236. void Flush(bool final, const THolderFactory& holderFactory) {
  1237. if (!OutputBlockSize_) {
  1238. return;
  1239. }
  1240. for (size_t i = 0; i < Builders_.size(); ++i) {
  1241. Values[i] = holderFactory.CreateArrowBlock(Builders_[i]->Build(final));
  1242. }
  1243. if constexpr (!UseSet) {
  1244. for (size_t i = 0; i < Aggs_.size(); ++i) {
  1245. Values[Builders_.size() + i] = AggBuilders_[i]->Build();
  1246. }
  1247. if (!final) {
  1248. PrepareAggBuilders();
  1249. }
  1250. }
  1251. Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(OutputBlockSize_)));
  1252. OutputBlockSize_ = 0;
  1253. }
  1254. void Insert(ui64 row, char* payload, bool isNew, ui32 currentStreamIndex) const {
  1255. char* ptr = payload;
  1256. if (isNew) {
  1257. if constexpr (Many) {
  1258. static_assert(Finalize);
  1259. MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index");
  1260. memset(ptr, 0, Streams_.size());
  1261. ptr[currentStreamIndex] = 1;
  1262. for (auto i : Streams_[currentStreamIndex]) {
  1263. Aggs_[i]->LoadState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row);
  1264. }
  1265. } else {
  1266. for (size_t i = 0; i < Aggs_.size(); ++i) {
  1267. if constexpr (Finalize) {
  1268. Aggs_[i]->LoadState(ptr, BatchNum_, Values_.data(), row);
  1269. } else {
  1270. Aggs_[i]->InitKey(ptr, BatchNum_, Values_.data(), row);
  1271. }
  1272. ptr += Aggs_[i]->StateSize;
  1273. }
  1274. }
  1275. } else {
  1276. if constexpr (Many) {
  1277. static_assert(Finalize);
  1278. MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index");
  1279. bool isNewStream = !ptr[currentStreamIndex];
  1280. ptr[currentStreamIndex] = 1;
  1281. for (auto i : Streams_[currentStreamIndex]) {
  1282. if (isNewStream) {
  1283. Aggs_[i]->LoadState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row);
  1284. } else {
  1285. Aggs_[i]->UpdateState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row);
  1286. }
  1287. }
  1288. } else {
  1289. for (size_t i = 0; i < Aggs_.size(); ++i) {
  1290. if constexpr (Finalize) {
  1291. Aggs_[i]->UpdateState(ptr, BatchNum_, Values_.data(), row);
  1292. } else {
  1293. Aggs_[i]->UpdateKey(ptr, BatchNum_, Values_.data(), row);
  1294. }
  1295. ptr += Aggs_[i]->StateSize;
  1296. }
  1297. }
  1298. }
  1299. }
  1300. template <typename THash>
  1301. bool Iterate(THash& hash, typename THash::const_iterator& iter) {
  1302. MKQL_ENSURE(WritingOutput_, "Supposed to be called at the end");
  1303. std::array<typename THash::const_iterator, PrefetchBatchSize> iters;
  1304. ui32 itersLen = 0;
  1305. auto iterateBatch = [&]() {
  1306. for (ui32 i = 0; i < itersLen; ++i) {
  1307. auto iter = iters[i];
  1308. const TKey& key = hash.GetKey(iter);
  1309. auto payload = (char*)hash.GetPayload(iter);
  1310. char* ptr;
  1311. if constexpr (UseArena) {
  1312. ptr = *(char**)payload;
  1313. } else {
  1314. ptr = payload;
  1315. }
  1316. TInputBuffer in(GetKeyView<TKey>(key, KeyLength_));
  1317. for (auto& kb : Builders_) {
  1318. kb->Add(in);
  1319. }
  1320. if constexpr (Many) {
  1321. for (ui32 i = 0; i < Streams_.size(); ++i) {
  1322. MKQL_ENSURE(ptr[i], "Missing partial aggregation state for stream #" << i);
  1323. }
  1324. ptr += Streams_.size();
  1325. }
  1326. for (size_t i = 0; i < Aggs_.size(); ++i) {
  1327. AggBuilders_[i]->Add(ptr);
  1328. Aggs_[i]->DestroyState(ptr);
  1329. ptr += Aggs_[i]->StateSize;
  1330. }
  1331. }
  1332. };
  1333. for (; iter != hash.End(); hash.Advance(iter)) {
  1334. if (!hash.IsValid(iter)) {
  1335. continue;
  1336. }
  1337. if (OutputBlockSize_ == MaxBlockLen_) {
  1338. iterateBatch();
  1339. return false;
  1340. }
  1341. if (itersLen == iters.size()) {
  1342. iterateBatch();
  1343. itersLen = 0;
  1344. }
  1345. iters[itersLen] = iter;
  1346. ++itersLen;
  1347. ++OutputBlockSize_;
  1348. if constexpr (UseArena) {
  1349. auto payload = (char*)hash.GetPayload(iter);
  1350. auto ptr = *(char**)payload;
  1351. NYql::PrefetchForWrite(ptr);
  1352. }
  1353. if constexpr (std::is_same<TKey, TSSOKey>::value) {
  1354. const auto& key = hash.GetKey(iter);
  1355. if (!key.IsInplace()) {
  1356. NYql::PrefetchForRead(key.AsView().Data());
  1357. }
  1358. } else if constexpr (std::is_same<TKey, TExternalFixedSizeKey>::value) {
  1359. const auto& key = hash.GetKey(iter);
  1360. NYql::PrefetchForRead(key.Data);
  1361. }
  1362. }
  1363. iterateBatch();
  1364. return true;
  1365. }
  1366. };
  1367. template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived>
  1368. class THashedWrapperBaseFromFlow : public TStatefulWideFlowCodegeneratorNode<TDerived>,
  1369. protected THashedWrapperCodegenBase
  1370. {
  1371. using TComputationBase = TStatefulWideFlowCodegeneratorNode<TDerived>;
  1372. using TState = THashedWrapperBaseState<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, Many, TDerived>;
  1373. public:
  1374. THashedWrapperBaseFromFlow(TComputationMutables& mutables,
  1375. IComputationWideFlowNode* flow,
  1376. std::optional<ui32> filterColumn,
  1377. size_t width,
  1378. const std::vector<TKeyParams>& keys,
  1379. size_t maxBlockLen,
  1380. ui32 keyLength,
  1381. std::vector<TAggParams<TAggregator>>&& aggsParams,
  1382. ui32 streamIndex,
  1383. std::vector<std::vector<ui32>>&& streams)
  1384. : TComputationBase(mutables, flow, EValueRepresentation::Boxed)
  1385. , Flow_(flow)
  1386. , FilterColumn_(filterColumn)
  1387. , Width_(width)
  1388. , OutputWidth_(keys.size() + aggsParams.size() + 1)
  1389. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
  1390. , Keys_(keys)
  1391. , MaxBlockLen_(maxBlockLen)
  1392. , AggsParams_(std::move(aggsParams))
  1393. , KeyLength_(keyLength)
  1394. , StreamIndex_(streamIndex)
  1395. , Streams_(std::move(streams))
  1396. {
  1397. MKQL_ENSURE(Width_ > 0, "Missing block length column");
  1398. if constexpr (UseFilter) {
  1399. MKQL_ENSURE(filterColumn, "Missing filter column");
  1400. MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize");
  1401. } else {
  1402. MKQL_ENSURE(!filterColumn, "Unexpected filter column");
  1403. }
  1404. }
  1405. EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
  1406. TComputationContext& ctx,
  1407. NUdf::TUnboxedValue*const* output) const
  1408. {
  1409. auto& s = GetState(state, ctx);
  1410. if (!s.Count) {
  1411. if (s.IsFinished_)
  1412. return EFetchResult::Finish;
  1413. while (!s.WritingOutput_) {
  1414. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  1415. s.Values_.assign(s.Values_.size(), NUdf::TUnboxedValuePod());
  1416. switch (Flow_->FetchValues(ctx, fields)) {
  1417. case EFetchResult::Yield:
  1418. return EFetchResult::Yield;
  1419. case EFetchResult::One:
  1420. s.ProcessInput(ctx.HolderFactory);
  1421. continue;
  1422. case EFetchResult::Finish:
  1423. break;
  1424. }
  1425. if (s.Finish())
  1426. break;
  1427. else
  1428. return EFetchResult::Finish;
  1429. }
  1430. if (!s.FillOutput(ctx.HolderFactory))
  1431. return EFetchResult::Finish;
  1432. }
  1433. const auto sliceSize = s.Slice();
  1434. for (size_t i = 0; i < OutputWidth_; ++i) {
  1435. if (const auto out = output[i]) {
  1436. *out = s.Get(sliceSize, ctx.HolderFactory, i);
  1437. }
  1438. }
  1439. return EFetchResult::One;
  1440. }
  1441. #ifndef MKQL_DISABLE_CODEGEN
  1442. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  1443. return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, OutputWidth_,
  1444. GetMethodPtr(&TState::Get), GetMethodPtr(&THashedWrapperBaseFromFlow::MakeState),
  1445. GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::Finish),
  1446. GetMethodPtr(&TState::FillOutput), GetMethodPtr(&TState::Slice));
  1447. }
  1448. #endif
  1449. private:
  1450. void RegisterDependencies() const final {
  1451. this->FlowDependsOn(Flow_);
  1452. }
  1453. void MakeState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  1454. state = ctx.HolderFactory.Create<TState>(KeyLength_, StreamIndex_, Width_, OutputWidth_, FilterColumn_, AggsParams_, Streams_, Keys_, MaxBlockLen_, ctx);
  1455. }
  1456. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  1457. if (state.IsInvalid()) {
  1458. MakeState(state, ctx);
  1459. auto& s = *static_cast<TState*>(state.AsBoxed().Get());
  1460. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  1461. for (size_t i = 0; i < s.Values_.size(); ++i) {
  1462. fields[i] = &s.Values_[i];
  1463. }
  1464. return s;
  1465. }
  1466. return *static_cast<TState*>(state.AsBoxed().Get());
  1467. }
  1468. IComputationWideFlowNode *const Flow_;
  1469. const std::optional<ui32> FilterColumn_;
  1470. const size_t Width_;
  1471. const size_t OutputWidth_;
  1472. const size_t WideFieldsIndex_;
  1473. const std::vector<TKeyParams> Keys_;
  1474. const size_t MaxBlockLen_;
  1475. const std::vector<TAggParams<TAggregator>> AggsParams_;
  1476. const ui32 KeyLength_;
  1477. const ui32 StreamIndex_;
  1478. const std::vector<std::vector<ui32>> Streams_;
  1479. };
  1480. template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived>
  1481. class THashedWrapperBaseFromStream : public TMutableComputationNode<TDerived>,
  1482. protected THashedWrapperCodegenBase
  1483. {
  1484. using TComputationBase = TMutableComputationNode<TDerived>;
  1485. using TState = THashedWrapperBaseState<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, Many, TDerived>;
  1486. public:
  1487. THashedWrapperBaseFromStream(TComputationMutables& mutables,
  1488. IComputationNode* stream,
  1489. std::optional<ui32> filterColumn,
  1490. size_t width,
  1491. const std::vector<TKeyParams>& keys,
  1492. size_t maxBlockLen,
  1493. ui32 keyLength,
  1494. std::vector<TAggParams<TAggregator>>&& aggsParams,
  1495. ui32 streamIndex,
  1496. std::vector<std::vector<ui32>>&& streams)
  1497. : TComputationBase(mutables, EValueRepresentation::Boxed)
  1498. , Stream_(stream)
  1499. , FilterColumn_(filterColumn)
  1500. , Width_(width)
  1501. , OutputWidth_(keys.size() + aggsParams.size() + 1)
  1502. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
  1503. , Keys_(keys)
  1504. , MaxBlockLen_(maxBlockLen)
  1505. , AggsParams_(std::move(aggsParams))
  1506. , KeyLength_(keyLength)
  1507. , StreamIndex_(streamIndex)
  1508. , Streams_(std::move(streams))
  1509. {
  1510. MKQL_ENSURE(Width_ > 0, "Missing block length column");
  1511. if constexpr (UseFilter) {
  1512. MKQL_ENSURE(filterColumn, "Missing filter column");
  1513. MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize");
  1514. } else {
  1515. MKQL_ENSURE(!filterColumn, "Unexpected filter column");
  1516. }
  1517. }
  1518. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const
  1519. {
  1520. const auto state = ctx.HolderFactory.Create<TState>(KeyLength_, StreamIndex_, Width_, OutputWidth_, FilterColumn_, AggsParams_, Streams_, Keys_, MaxBlockLen_, ctx);
  1521. return ctx.HolderFactory.Create<TStreamValue>(ctx.HolderFactory, std::move(state), std::move(Stream_->GetValue(ctx)));
  1522. }
  1523. private:
  1524. class TStreamValue : public TComputationValue<TStreamValue> {
  1525. using TBase = TComputationValue<TStreamValue>;
  1526. public:
  1527. TStreamValue(TMemoryUsageInfo* memInfo, const THolderFactory& holderFactory,
  1528. NUdf::TUnboxedValue&& state, NUdf::TUnboxedValue&& stream)
  1529. : TBase(memInfo)
  1530. , State_(state)
  1531. , Stream_(stream)
  1532. , HolderFactory_(holderFactory)
  1533. {
  1534. }
  1535. private:
  1536. NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
  1537. TState& state = *static_cast<TState*>(State_.AsBoxed().Get());
  1538. auto* inputFields = state.Values_.data();
  1539. const size_t inputWidth = state.Width_;
  1540. const size_t outputWidth = state.OutputWidth_;
  1541. MKQL_ENSURE(outputWidth == width, "The given width doesn't equal to the result type size");
  1542. if (!state.Count) {
  1543. if (state.IsFinished_)
  1544. return NUdf::EFetchStatus::Finish;
  1545. while (!state.WritingOutput_) {
  1546. switch (Stream_.WideFetch(inputFields, inputWidth)) {
  1547. case NUdf::EFetchStatus::Yield:
  1548. return NUdf::EFetchStatus::Yield;
  1549. case NUdf::EFetchStatus::Ok:
  1550. state.ProcessInput(HolderFactory_);
  1551. continue;
  1552. case NUdf::EFetchStatus::Finish:
  1553. break;
  1554. }
  1555. if (state.Finish())
  1556. break;
  1557. else
  1558. return NUdf::EFetchStatus::Finish;
  1559. }
  1560. if (!state.FillOutput(HolderFactory_))
  1561. return NUdf::EFetchStatus::Finish;
  1562. }
  1563. const auto sliceSize = state.Slice();
  1564. for (size_t i = 0; i < outputWidth; ++i) {
  1565. output[i] = state.Get(sliceSize, HolderFactory_, i);
  1566. }
  1567. return NUdf::EFetchStatus::Ok;
  1568. }
  1569. private:
  1570. NUdf::TUnboxedValue State_;
  1571. NUdf::TUnboxedValue Stream_;
  1572. const THolderFactory& HolderFactory_;
  1573. };
  1574. private:
  1575. void RegisterDependencies() const final {
  1576. this->DependsOn(Stream_);
  1577. }
  1578. IComputationNode *const Stream_;
  1579. const std::optional<ui32> FilterColumn_;
  1580. const size_t Width_;
  1581. const size_t OutputWidth_;
  1582. const size_t WideFieldsIndex_;
  1583. const std::vector<TKeyParams> Keys_;
  1584. const size_t MaxBlockLen_;
  1585. const std::vector<TAggParams<TAggregator>> AggsParams_;
  1586. const ui32 KeyLength_;
  1587. const ui32 StreamIndex_;
  1588. const std::vector<std::vector<ui32>> Streams_;
  1589. };
  1590. template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter, typename TInputNode>
  1591. class TBlockCombineHashedWrapper {};
  1592. template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter>
  1593. class TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationWideFlowNode>
  1594. : public THashedWrapperBaseFromFlow<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationWideFlowNode>> {
  1595. public:
  1596. using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationWideFlowNode>;
  1597. using TBase = THashedWrapperBaseFromFlow<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TSelf>;
  1598. TBlockCombineHashedWrapper(TComputationMutables& mutables,
  1599. IComputationWideFlowNode* flow,
  1600. std::optional<ui32> filterColumn,
  1601. size_t width,
  1602. const std::vector<TKeyParams>& keys,
  1603. size_t maxBlockLen,
  1604. ui32 keyLength,
  1605. std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams)
  1606. : TBase(mutables, flow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams), 0, {})
  1607. {}
  1608. };
  1609. template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter>
  1610. class TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationNode>
  1611. : public THashedWrapperBaseFromStream<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationNode>> {
  1612. public:
  1613. using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationNode>;
  1614. using TBase = THashedWrapperBaseFromStream<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TSelf>;
  1615. TBlockCombineHashedWrapper(TComputationMutables& mutables,
  1616. IComputationNode* stream,
  1617. std::optional<ui32> filterColumn,
  1618. size_t width,
  1619. const std::vector<TKeyParams>& keys,
  1620. size_t maxBlockLen,
  1621. ui32 keyLength,
  1622. std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams)
  1623. : TBase(mutables, stream, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams), 0, {})
  1624. {}
  1625. };
  1626. template <typename TKey, typename TFixedAggState, bool UseSet, typename TInputNode>
  1627. class TBlockMergeFinalizeHashedWrapper {};
  1628. template <typename TKey, typename TFixedAggState, bool UseSet>
  1629. class TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationWideFlowNode>
  1630. : public THashedWrapperBaseFromFlow<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationWideFlowNode>> {
  1631. public:
  1632. using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationWideFlowNode>;
  1633. using TBase = THashedWrapperBaseFromFlow<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TSelf>;
  1634. TBlockMergeFinalizeHashedWrapper(TComputationMutables& mutables,
  1635. IComputationWideFlowNode* flow,
  1636. size_t width,
  1637. const std::vector<TKeyParams>& keys,
  1638. size_t maxBlockLen,
  1639. ui32 keyLength,
  1640. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams)
  1641. : TBase(mutables, flow, {}, width, keys, maxBlockLen, keyLength, std::move(aggsParams), 0, {})
  1642. {}
  1643. };
  1644. template <typename TKey, typename TFixedAggState, bool UseSet>
  1645. class TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationNode>
  1646. : public THashedWrapperBaseFromStream<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationNode>> {
  1647. public:
  1648. using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationNode>;
  1649. using TBase = THashedWrapperBaseFromStream<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TSelf>;
  1650. TBlockMergeFinalizeHashedWrapper(TComputationMutables& mutables,
  1651. IComputationNode* stream,
  1652. size_t width,
  1653. const std::vector<TKeyParams>& keys,
  1654. size_t maxBlockLen,
  1655. ui32 keyLength,
  1656. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams)
  1657. : TBase(mutables, stream, {}, width, keys, maxBlockLen, keyLength, std::move(aggsParams), 0, {})
  1658. {}
  1659. };
  1660. template <typename TKey, typename TFixedAggState, typename TInputNode>
  1661. class TBlockMergeManyFinalizeHashedWrapper {};
  1662. template <typename TKey, typename TFixedAggState>
  1663. class TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationWideFlowNode>
  1664. : public THashedWrapperBaseFromFlow<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationWideFlowNode>> {
  1665. public:
  1666. using TSelf = TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationWideFlowNode>;
  1667. using TBase = THashedWrapperBaseFromFlow<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TSelf>;
  1668. TBlockMergeManyFinalizeHashedWrapper(TComputationMutables& mutables,
  1669. IComputationWideFlowNode* flow,
  1670. size_t width,
  1671. const std::vector<TKeyParams>& keys,
  1672. size_t maxBlockLen,
  1673. ui32 keyLength,
  1674. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams,
  1675. ui32 streamIndex, std::vector<std::vector<ui32>>&& streams)
  1676. : TBase(mutables, flow, {}, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams))
  1677. {}
  1678. };
  1679. template <typename TKey, typename TFixedAggState>
  1680. class TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationNode>
  1681. : public THashedWrapperBaseFromStream<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationNode>> {
  1682. public:
  1683. using TSelf = TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationNode>;
  1684. using TBase = THashedWrapperBaseFromStream<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TSelf>;
  1685. TBlockMergeManyFinalizeHashedWrapper(TComputationMutables& mutables,
  1686. IComputationNode* stream,
  1687. size_t width,
  1688. const std::vector<TKeyParams>& keys,
  1689. size_t maxBlockLen,
  1690. ui32 keyLength,
  1691. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams,
  1692. ui32 streamIndex, std::vector<std::vector<ui32>>&& streams)
  1693. : TBase(mutables, stream, {}, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams))
  1694. {}
  1695. };
  1696. template <typename TAggregator>
  1697. std::unique_ptr<IPreparedBlockAggregator<TAggregator>> PrepareBlockAggregator(const IBlockAggregatorFactory& factory,
  1698. TTupleType* tupleType,
  1699. std::optional<ui32> filterColumn,
  1700. const std::vector<ui32>& argsColumns,
  1701. const TTypeEnvironment& env,
  1702. TType* returnType,
  1703. ui32 hint);
  1704. template <>
  1705. std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineAll>> PrepareBlockAggregator<IBlockAggregatorCombineAll>(const IBlockAggregatorFactory& factory,
  1706. TTupleType* tupleType,
  1707. std::optional<ui32> filterColumn,
  1708. const std::vector<ui32>& argsColumns,
  1709. const TTypeEnvironment& env,
  1710. TType* returnType,
  1711. ui32 hint) {
  1712. Y_UNUSED(hint);
  1713. MKQL_ENSURE(!returnType, "Unexpected return type");
  1714. return factory.PrepareCombineAll(tupleType, filterColumn, argsColumns, env);
  1715. }
  1716. template <>
  1717. std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineKeys>> PrepareBlockAggregator<IBlockAggregatorCombineKeys>(const IBlockAggregatorFactory& factory,
  1718. TTupleType* tupleType,
  1719. std::optional<ui32> filterColumn,
  1720. const std::vector<ui32>& argsColumns,
  1721. const TTypeEnvironment& env,
  1722. TType* returnType,
  1723. ui32 hint) {
  1724. Y_UNUSED(hint);
  1725. MKQL_ENSURE(!filterColumn, "Unexpected filter column");
  1726. MKQL_ENSURE(!returnType, "Unexpected return type");
  1727. return factory.PrepareCombineKeys(tupleType, argsColumns, env);
  1728. }
  1729. template <>
  1730. std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorFinalizeKeys>> PrepareBlockAggregator<IBlockAggregatorFinalizeKeys>(const IBlockAggregatorFactory& factory,
  1731. TTupleType* tupleType,
  1732. std::optional<ui32> filterColumn,
  1733. const std::vector<ui32>& argsColumns,
  1734. const TTypeEnvironment& env,
  1735. TType* returnType,
  1736. ui32 hint) {
  1737. MKQL_ENSURE(!filterColumn, "Unexpected filter column");
  1738. MKQL_ENSURE(returnType, "Missing return type");
  1739. return factory.PrepareFinalizeKeys(tupleType, argsColumns, env, returnType, hint);
  1740. }
  1741. template <typename TAggregator>
  1742. ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, std::vector<TAggParams<TAggregator>>& aggsParams,
  1743. const TTypeEnvironment& env, bool overState, bool many, TArrayRef<TType* const> returnTypes, ui32 keysCount) {
  1744. TTupleType* unwrappedTupleType = tupleType;
  1745. if (many) {
  1746. std::vector<TType*> unwrappedTypes(tupleType->GetElementsCount());
  1747. for (ui32 i = 0; i < tupleType->GetElementsCount(); ++i) {
  1748. unwrappedTypes[i] = tupleType->GetElementType(i);
  1749. }
  1750. for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) {
  1751. auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i));
  1752. MKQL_ENSURE(aggVal->GetValuesCount() == 2, "Expected only one column");
  1753. auto index = AS_VALUE(TDataLiteral, aggVal->GetValue(1))->AsValue().Get<ui32>();
  1754. MKQL_ENSURE(index < unwrappedTypes.size(), "Bad state column index");
  1755. auto blockType = AS_TYPE(TBlockType, unwrappedTypes[index]);
  1756. MKQL_ENSURE(blockType->GetShape() == TBlockType::EShape::Many, "State must be a block");
  1757. bool isOptional;
  1758. auto unpacked = UnpackOptional(blockType->GetItemType(), isOptional);
  1759. MKQL_ENSURE(isOptional, "State must be optional");
  1760. unwrappedTypes[index] = TBlockType::Create(unpacked, TBlockType::EShape::Many, env);
  1761. }
  1762. unwrappedTupleType = TTupleType::Create(unwrappedTypes.size(), unwrappedTypes.data(), env);
  1763. }
  1764. ui32 totalStateSize = 0;
  1765. for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) {
  1766. auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i));
  1767. auto name = AS_VALUE(TDataLiteral, aggVal->GetValue(0))->AsValue().AsStringRef();
  1768. std::vector<ui32> argColumns;
  1769. for (ui32 j = 1; j < aggVal->GetValuesCount(); ++j) {
  1770. argColumns.push_back(AS_VALUE(TDataLiteral, aggVal->GetValue(j))->AsValue().Get<ui32>());
  1771. }
  1772. TAggParams<TAggregator> p;
  1773. if (overState) {
  1774. MKQL_ENSURE(argColumns.size() == 1, "Expected exactly one column");
  1775. p.Column_ = argColumns[0];
  1776. p.StateType_ = AS_TYPE(TBlockType, tupleType->GetElementType(p.Column_))->GetItemType();
  1777. p.ReturnType_ = returnTypes[i + keysCount];
  1778. TStringBuf left, right;
  1779. if (TStringBuf(name).TrySplit('#', left, right)) {
  1780. p.Hint_ = FromString<ui32>(right);
  1781. }
  1782. }
  1783. p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), unwrappedTupleType, filterColumn, argColumns, env, p.ReturnType_, p.Hint_);
  1784. totalStateSize += p.Prepared_->StateSize;
  1785. aggsParams.emplace_back(std::move(p));
  1786. }
  1787. return totalStateSize;
  1788. }
  1789. template <bool UseSet, bool UseFilter, typename TKey, typename TInputNode>
  1790. IComputationNode* MakeBlockCombineHashedWrapper(
  1791. ui32 keyLength,
  1792. ui32 totalStateSize,
  1793. TComputationMutables& mutables,
  1794. TInputNode* streamOrFlow,
  1795. std::optional<ui32> filterColumn,
  1796. size_t width,
  1797. const std::vector<TKeyParams>& keys,
  1798. size_t maxBlockLen,
  1799. std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) {
  1800. if (totalStateSize <= sizeof(TState8)) {
  1801. return new TBlockCombineHashedWrapper<TKey, TState8, UseSet, UseFilter, TInputNode>(mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams));
  1802. }
  1803. if (totalStateSize <= sizeof(TState16)) {
  1804. return new TBlockCombineHashedWrapper<TKey, TState16, UseSet, UseFilter, TInputNode>(mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams));
  1805. }
  1806. return new TBlockCombineHashedWrapper<TKey, TStateArena, UseSet, UseFilter, TInputNode>(mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams));
  1807. }
  1808. template <bool UseSet, bool UseFilter, typename TInputNode>
  1809. IComputationNode* MakeBlockCombineHashedWrapper(
  1810. TMaybe<ui32> totalKeysSize,
  1811. bool isFixed,
  1812. ui32 totalStateSize,
  1813. TComputationMutables& mutables,
  1814. TInputNode* streamOrFlow,
  1815. std::optional<ui32> filterColumn,
  1816. size_t width,
  1817. const std::vector<TKeyParams>& keys,
  1818. size_t maxBlockLen,
  1819. std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) {
  1820. if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) {
  1821. return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui32>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams));
  1822. }
  1823. if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) {
  1824. return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui64>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams));
  1825. }
  1826. if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) {
  1827. return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TKey16>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams));
  1828. }
  1829. if (totalKeysSize && isFixed) {
  1830. return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TExternalFixedSizeKey>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams));
  1831. }
  1832. return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TSSOKey>(Max<ui32>(), totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams));
  1833. }
  1834. template <typename TKey, bool UseSet, typename TInputNode>
  1835. IComputationNode* MakeBlockMergeFinalizeHashedWrapper(
  1836. ui32 keyLength,
  1837. ui32 totalStateSize,
  1838. TComputationMutables& mutables,
  1839. TInputNode* streamOrFlow,
  1840. size_t width,
  1841. const std::vector<TKeyParams>& keys,
  1842. size_t maxBlockLen,
  1843. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) {
  1844. if (totalStateSize <= sizeof(TState8)) {
  1845. return new TBlockMergeFinalizeHashedWrapper<TKey, TState8, UseSet, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams));
  1846. }
  1847. if (totalStateSize <= sizeof(TState16)) {
  1848. return new TBlockMergeFinalizeHashedWrapper<TKey, TState16, UseSet, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams));
  1849. }
  1850. return new TBlockMergeFinalizeHashedWrapper<TKey, TStateArena, UseSet, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams));
  1851. }
  1852. template <bool UseSet, typename TInputNode>
  1853. IComputationNode* MakeBlockMergeFinalizeHashedWrapper(
  1854. TMaybe<ui32> totalKeysSize,
  1855. bool isFixed,
  1856. ui32 totalStateSize,
  1857. TComputationMutables& mutables,
  1858. TInputNode* streamOrFlow,
  1859. size_t width,
  1860. const std::vector<TKeyParams>& keys,
  1861. size_t maxBlockLen,
  1862. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) {
  1863. if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) {
  1864. return MakeBlockMergeFinalizeHashedWrapper<ui32, UseSet>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams));
  1865. }
  1866. if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) {
  1867. return MakeBlockMergeFinalizeHashedWrapper<ui64, UseSet>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams));
  1868. }
  1869. if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) {
  1870. return MakeBlockMergeFinalizeHashedWrapper<TKey16, UseSet>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams));
  1871. }
  1872. if (totalKeysSize && isFixed) {
  1873. return MakeBlockMergeFinalizeHashedWrapper<TExternalFixedSizeKey, UseSet>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams));
  1874. }
  1875. return MakeBlockMergeFinalizeHashedWrapper<TSSOKey, UseSet>(Max<ui32>(), totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams));
  1876. }
  1877. template <typename TKey, typename TInputNode>
  1878. IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper(
  1879. ui32 keyLength,
  1880. ui32 totalStateSize,
  1881. TComputationMutables& mutables,
  1882. TInputNode* streamOrFlow,
  1883. size_t width,
  1884. const std::vector<TKeyParams>& keys,
  1885. size_t maxBlockLen,
  1886. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams,
  1887. ui32 streamIndex,
  1888. std::vector<std::vector<ui32>>&& streams) {
  1889. if (totalStateSize <= sizeof(TState8)) {
  1890. return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState8, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams));
  1891. }
  1892. if (totalStateSize <= sizeof(TState16)) {
  1893. return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState16, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams));
  1894. }
  1895. return new TBlockMergeManyFinalizeHashedWrapper<TKey, TStateArena, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams));
  1896. }
  1897. template <typename TInputNode>
  1898. IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper(
  1899. TMaybe<ui32> totalKeysSize,
  1900. bool isFixed,
  1901. ui32 totalStateSize,
  1902. TComputationMutables& mutables,
  1903. TInputNode* streamOrFlow,
  1904. size_t width,
  1905. const std::vector<TKeyParams>& keys,
  1906. size_t maxBlockLen,
  1907. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams,
  1908. ui32 streamIndex,
  1909. std::vector<std::vector<ui32>>&& streams) {
  1910. if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) {
  1911. return MakeBlockMergeManyFinalizeHashedWrapper<ui32>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams));
  1912. }
  1913. if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) {
  1914. return MakeBlockMergeManyFinalizeHashedWrapper<ui64>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams));
  1915. }
  1916. if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) {
  1917. return MakeBlockMergeManyFinalizeHashedWrapper<TKey16>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams));
  1918. }
  1919. if (totalKeysSize && isFixed) {
  1920. return MakeBlockMergeManyFinalizeHashedWrapper<TExternalFixedSizeKey>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams));
  1921. }
  1922. return MakeBlockMergeManyFinalizeHashedWrapper<TSSOKey>(Max<ui32>(), totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams));
  1923. }
  1924. void PrepareKeys(const std::vector<TKeyParams>& keys, TMaybe<ui32>& totalKeysSize, bool& isFixed) {
  1925. NYql::NUdf::TBlockItemSerializeProps props;
  1926. for (auto& param : keys) {
  1927. auto type = AS_TYPE(TBlockType, param.Type);
  1928. UpdateBlockItemSerializeProps(TTypeInfoHelper(), type->GetItemType(), props);
  1929. }
  1930. isFixed = props.IsFixed;
  1931. totalKeysSize = props.MaxSize;
  1932. }
  1933. void FillAggStreams(TRuntimeNode streamsNode, std::vector<std::vector<ui32>>& streams) {
  1934. auto streamsVal = AS_VALUE(TTupleLiteral, streamsNode);
  1935. for (ui32 i = 0; i < streamsVal->GetValuesCount(); ++i) {
  1936. streams.emplace_back();
  1937. auto& stream = streams.back();
  1938. auto streamVal = AS_VALUE(TTupleLiteral, streamsVal->GetValue(i));
  1939. for (ui32 j = 0; j < streamVal->GetValuesCount(); ++j) {
  1940. ui32 index = AS_VALUE(TDataLiteral, streamVal->GetValue(j))->AsValue().Get<ui32>();
  1941. stream.emplace_back(index);
  1942. }
  1943. }
  1944. }
  1945. }
  1946. IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1947. MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
  1948. const bool isStream = callable.GetInput(0).GetStaticType()->IsStream();
  1949. MKQL_ENSURE(isStream == callable.GetType()->GetReturnType()->IsStream(), "input and output must be both either flow or stream");
  1950. const auto wideComponents = GetWideComponents(callable.GetInput(0).GetStaticType());
  1951. const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env);
  1952. const auto returnWideComponents = GetWideComponents(callable.GetType()->GetReturnType());
  1953. const auto wideFlowOrStream = LocateNode(ctx.NodeLocator, callable, 0);
  1954. auto filterColumnVal = AS_VALUE(TOptionalLiteral, callable.GetInput(1));
  1955. std::optional<ui32> filterColumn;
  1956. if (filterColumnVal->HasItem()) {
  1957. filterColumn = AS_VALUE(TDataLiteral, filterColumnVal->GetItem())->AsValue().Get<ui32>();
  1958. }
  1959. auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
  1960. std::vector<TAggParams<IBlockAggregatorCombineAll>> aggsParams;
  1961. FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env, false, false, returnWideComponents, 0);
  1962. if (isStream) {
  1963. const auto wideStream = wideFlowOrStream;
  1964. return new TBlockCombineAllWrapperFromStream(ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams));
  1965. } else {
  1966. const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(wideFlowOrStream);
  1967. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  1968. return new TBlockCombineAllWrapperFromFlow(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams));
  1969. }
  1970. }
  1971. IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1972. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
  1973. const bool isStream = callable.GetInput(0).GetStaticType()->IsStream();
  1974. MKQL_ENSURE(isStream == callable.GetType()->GetReturnType()->IsStream(), "input and output must be both either flow or stream");
  1975. const auto wideComponents = GetWideComponents(callable.GetInput(0).GetStaticType());
  1976. const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env);
  1977. const auto returnWideComponents = GetWideComponents(callable.GetType()->GetReturnType());
  1978. const auto wideStreamOrFlow = LocateNode(ctx.NodeLocator, callable, 0);
  1979. auto filterColumnVal = AS_VALUE(TOptionalLiteral, callable.GetInput(1));
  1980. std::optional<ui32> filterColumn;
  1981. if (filterColumnVal->HasItem()) {
  1982. filterColumn = AS_VALUE(TDataLiteral, filterColumnVal->GetItem())->AsValue().Get<ui32>();
  1983. }
  1984. auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
  1985. std::vector<TKeyParams> keys;
  1986. for (ui32 i = 0; i < keysVal->GetValuesCount(); ++i) {
  1987. ui32 index = AS_VALUE(TDataLiteral, keysVal->GetValue(i))->AsValue().Get<ui32>();
  1988. keys.emplace_back(TKeyParams{ index, tupleType->GetElementType(index) });
  1989. }
  1990. auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(3));
  1991. std::vector<TAggParams<IBlockAggregatorCombineKeys>> aggsParams;
  1992. ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, false, false, returnWideComponents, keys.size());
  1993. TMaybe<ui32> totalKeysSize;
  1994. bool isFixed = false;
  1995. PrepareKeys(keys, totalKeysSize, isFixed);
  1996. const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType());
  1997. if (isStream) {
  1998. const auto wideStream = wideStreamOrFlow;
  1999. if (filterColumn) {
  2000. if (aggsParams.empty()) {
  2001. return MakeBlockCombineHashedWrapper<true, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2002. } else {
  2003. return MakeBlockCombineHashedWrapper<false, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2004. }
  2005. } else {
  2006. if (aggsParams.empty()) {
  2007. return MakeBlockCombineHashedWrapper<true, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2008. } else {
  2009. return MakeBlockCombineHashedWrapper<false, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2010. }
  2011. }
  2012. } else {
  2013. const auto wideFlow = dynamic_cast<IComputationWideFlowNode *>(wideStreamOrFlow);
  2014. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  2015. if (filterColumn) {
  2016. if (aggsParams.empty()) {
  2017. return MakeBlockCombineHashedWrapper<true, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2018. } else {
  2019. return MakeBlockCombineHashedWrapper<false, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2020. }
  2021. } else {
  2022. if (aggsParams.empty()) {
  2023. return MakeBlockCombineHashedWrapper<true, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2024. } else {
  2025. return MakeBlockCombineHashedWrapper<false, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2026. }
  2027. }
  2028. }
  2029. }
  2030. IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  2031. MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
  2032. const bool isStream = callable.GetInput(0).GetStaticType()->IsStream();
  2033. MKQL_ENSURE(isStream == callable.GetType()->GetReturnType()->IsStream(), "input and output must be both either flow or stream");
  2034. const auto wideComponents = GetWideComponents(callable.GetInput(0).GetStaticType());
  2035. const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env);
  2036. const auto returnWideComponents = GetWideComponents(callable.GetType()->GetReturnType());
  2037. const auto wideStreamOrFlow = LocateNode(ctx.NodeLocator, callable, 0);
  2038. auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(1));
  2039. std::vector<TKeyParams> keys;
  2040. for (ui32 i = 0; i < keysVal->GetValuesCount(); ++i) {
  2041. ui32 index = AS_VALUE(TDataLiteral, keysVal->GetValue(i))->AsValue().Get<ui32>();
  2042. keys.emplace_back(TKeyParams{ index, tupleType->GetElementType(index) });
  2043. }
  2044. auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
  2045. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams;
  2046. ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, true, false, returnWideComponents, keys.size());
  2047. TMaybe<ui32> totalKeysSize;
  2048. bool isFixed = false;
  2049. PrepareKeys(keys, totalKeysSize, isFixed);
  2050. const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType());
  2051. if (isStream) {
  2052. const auto wideStream = wideStreamOrFlow;
  2053. if (aggsParams.empty()) {
  2054. return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2055. } else {
  2056. return MakeBlockMergeFinalizeHashedWrapper<false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2057. }
  2058. } else {
  2059. const auto wideFlow = dynamic_cast<IComputationWideFlowNode *>(wideStreamOrFlow);
  2060. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  2061. if (aggsParams.empty()) {
  2062. return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2063. } else {
  2064. return MakeBlockMergeFinalizeHashedWrapper<false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams));
  2065. }
  2066. }
  2067. }
  2068. IComputationNode* WrapBlockMergeManyFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  2069. MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args");
  2070. const bool isStream = callable.GetInput(0).GetStaticType()->IsStream();
  2071. MKQL_ENSURE(isStream == callable.GetType()->GetReturnType()->IsStream(), "input and output must be both either flow or stream");
  2072. const auto wideComponents = GetWideComponents(callable.GetInput(0).GetStaticType());
  2073. const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env);
  2074. const auto returnWideComponents = GetWideComponents(callable.GetType()->GetReturnType());
  2075. const auto wideStreamOrFlow = LocateNode(ctx.NodeLocator, callable, 0);
  2076. auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(1));
  2077. std::vector<TKeyParams> keys;
  2078. for (ui32 i = 0; i < keysVal->GetValuesCount(); ++i) {
  2079. ui32 index = AS_VALUE(TDataLiteral, keysVal->GetValue(i))->AsValue().Get<ui32>();
  2080. keys.emplace_back(TKeyParams{ index, tupleType->GetElementType(index) });
  2081. }
  2082. const auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
  2083. std::vector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams;
  2084. ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, true, true, returnWideComponents, keys.size());
  2085. TMaybe<ui32> totalKeysSize;
  2086. bool isFixed = false;
  2087. PrepareKeys(keys, totalKeysSize, isFixed);
  2088. const ui32 streamIndex = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui32>();
  2089. std::vector<std::vector<ui32>> streams;
  2090. FillAggStreams(callable.GetInput(4), streams);
  2091. totalStateSize += streams.size();
  2092. const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType());
  2093. if (isStream){
  2094. const auto wideStream = wideStreamOrFlow;
  2095. if (aggsParams.empty()) {
  2096. return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, tupleType->GetElementsCount(),
  2097. keys, maxBlockLen, std::move(aggsParams));
  2098. } else {
  2099. return MakeBlockMergeManyFinalizeHashedWrapper(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, tupleType->GetElementsCount(),
  2100. keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams));
  2101. }
  2102. } else {
  2103. const auto wideFlow = dynamic_cast<IComputationWideFlowNode *>(wideStreamOrFlow);
  2104. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  2105. if (aggsParams.empty()) {
  2106. return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(),
  2107. keys, maxBlockLen, std::move(aggsParams));
  2108. } else {
  2109. return MakeBlockMergeManyFinalizeHashedWrapper(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(),
  2110. keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams));
  2111. }
  2112. }
  2113. }
  2114. }
  2115. }