mkql_sort.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784
  1. #include "mkql_sort.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/presort.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_node_builder.h>
  6. #include <yql/essentials/minikql/mkql_string_util.h>
  7. #include <yql/essentials/minikql/mkql_type_builder.h>
  8. #include <yql/essentials/utils/sort.h>
  9. #include <algorithm>
  10. #include <iterator>
  11. namespace NKikimr {
  12. namespace NMiniKQL {
  13. namespace {
  14. std::vector<NUdf::EDataSlot> PrepareKeyTypesByScheme(const std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>>& keySchemeTypes) {
  15. MKQL_ENSURE(!keySchemeTypes.empty(), "No key types provided");
  16. std::vector<NUdf::EDataSlot> keyTypes;
  17. keyTypes.reserve(keySchemeTypes.size());
  18. for (const auto& schemeType: keySchemeTypes) {
  19. keyTypes.emplace_back(std::get<0>(schemeType));
  20. const auto& info = NUdf::GetDataTypeInfo(keyTypes.back());
  21. MKQL_ENSURE(info.Features & NUdf::CanCompare, "Cannot compare key type: " << info.Name);
  22. }
  23. return keyTypes;
  24. }
  25. class TEncoders : public TComputationValue<TEncoders> {
  26. typedef TComputationValue<TEncoders> TBase;
  27. public:
  28. TEncoders(TMemoryUsageInfo* memInfo, const std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>>& keySchemeTypes,
  29. bool allowEncoding)
  30. : TBase(memInfo)
  31. {
  32. Columns.reserve(keySchemeTypes.size());
  33. for (const auto& x : keySchemeTypes) {
  34. Columns.push_back(Nothing());
  35. auto type = std::get<2>(x);
  36. if (allowEncoding && type) {
  37. NeedEncode = true;
  38. Columns.back().ConstructInPlace(type);
  39. }
  40. }
  41. }
  42. std::vector<TMaybe<TGenericPresortEncoder>> Columns;
  43. bool NeedEncode = false;
  44. };
  45. class TGatherIteratorRef {
  46. public:
  47. TGatherIteratorRef(NUdf::TUnboxedValue& first, NUdf::TUnboxedValue& second)
  48. : First(first)
  49. , Second(second)
  50. {}
  51. operator TKeyPayloadPair () const {
  52. return TKeyPayloadPair(First, Second);
  53. }
  54. TGatherIteratorRef& operator=(const TKeyPayloadPair& rhs) {
  55. First = rhs.first;
  56. Second = rhs.second;
  57. return *this;
  58. }
  59. TGatherIteratorRef& operator=(const TGatherIteratorRef& rhs) {
  60. First = rhs.First;
  61. Second = rhs.Second;
  62. return *this;
  63. }
  64. friend void swap(TGatherIteratorRef x, TGatherIteratorRef y) {
  65. std::swap(x.First, y.First);
  66. std::swap(x.Second, y.Second);
  67. }
  68. private:
  69. NUdf::TUnboxedValue& First;
  70. NUdf::TUnboxedValue& Second;
  71. };
  72. class TGatherIterator : public std::iterator<std::random_access_iterator_tag, TKeyPayloadPair,
  73. ptrdiff_t, TKeyPayloadPair*, TGatherIteratorRef>
  74. {
  75. public:
  76. TGatherIterator()
  77. : First(nullptr)
  78. , Second(nullptr)
  79. {
  80. }
  81. TGatherIterator(NUdf::TUnboxedValue* first, NUdf::TUnboxedValue* second)
  82. : First(first)
  83. , Second(second)
  84. {}
  85. TGatherIterator(const TGatherIterator&) = default;
  86. TGatherIterator& operator=(const TGatherIterator&) = default;
  87. TGatherIteratorRef operator*() const& {
  88. return TGatherIteratorRef(*First, *Second);
  89. }
  90. TGatherIterator& operator ++ () {
  91. First++;
  92. Second++;
  93. return *this;
  94. }
  95. TGatherIterator& operator -- () {
  96. First--;
  97. Second--;
  98. return *this;
  99. }
  100. TGatherIterator operator ++ (int) {
  101. TGatherIterator tmp(*this);
  102. First++;
  103. Second++;
  104. return tmp;
  105. }
  106. TGatherIterator operator -- (int) {
  107. TGatherIterator tmp(*this);
  108. First--;
  109. Second--;
  110. return tmp;
  111. }
  112. TGatherIterator& operator += (ptrdiff_t rhs) {
  113. First += rhs;
  114. Second += rhs;
  115. return *this;
  116. }
  117. TGatherIterator& operator -= (ptrdiff_t rhs) {
  118. First -= rhs;
  119. Second -= rhs;
  120. return *this;
  121. }
  122. ptrdiff_t operator - (TGatherIterator& rhs) const& {
  123. return First - rhs.First;
  124. }
  125. TGatherIterator operator + (ptrdiff_t n) const& {
  126. TGatherIterator tmp(*this);
  127. tmp.First += n;
  128. tmp.Second += n;
  129. return tmp;
  130. }
  131. TGatherIterator operator - (ptrdiff_t n) const& {
  132. TGatherIterator tmp(*this);
  133. tmp.First -= n;
  134. tmp.Second -= n;
  135. return tmp;
  136. }
  137. bool operator==(const TGatherIterator& rhs) const& {
  138. return First == rhs.First;
  139. }
  140. bool operator!=(const TGatherIterator& rhs) const& {
  141. return First != rhs.First;
  142. }
  143. bool operator<(TGatherIterator& rhs) const& {
  144. return First < rhs.First;
  145. }
  146. bool operator<=(TGatherIterator& rhs) const& {
  147. return First <= rhs.First;
  148. }
  149. bool operator>(TGatherIterator& rhs) const& {
  150. return First > rhs.First;
  151. }
  152. bool operator>=(TGatherIterator& rhs) const& {
  153. return First >= rhs.First;
  154. }
  155. private:
  156. NUdf::TUnboxedValue* First;
  157. NUdf::TUnboxedValue* Second;
  158. };
  159. using TComparator = std::function<bool(const TKeyPayloadPairVector::value_type&, const TKeyPayloadPairVector::value_type&)>;
  160. using TAlgorithm = void(*)(TKeyPayloadPairVector::iterator, TKeyPayloadPairVector::iterator, TComparator);
  161. using TAlgorithmInplace = void(*)(TGatherIterator, TGatherIterator, TComparator);
  162. using TNthAlgorithm = void(*)(TKeyPayloadPairVector::iterator, TKeyPayloadPairVector::iterator, TKeyPayloadPairVector::iterator, TComparator);
  163. struct TCompareDescr {
  164. TCompareDescr(TComputationMutables& mutables, std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>>&& keySchemeTypes,
  165. const TVector<NUdf::ICompare::TPtr>& comparators)
  166. : KeySchemeTypes(std::move(keySchemeTypes))
  167. , KeyTypes(PrepareKeyTypesByScheme(KeySchemeTypes))
  168. , Comparators(comparators)
  169. , Encoders(mutables)
  170. {}
  171. static TKeyPayloadPairVector::value_type::first_type& Set(TKeyPayloadPairVector::value_type& item) { return item.first; }
  172. static TUnboxedValueVector::value_type& Set(TUnboxedValueVector::value_type& item) { return item; }
  173. static const TKeyPayloadPairVector::value_type::first_type& Get(const TKeyPayloadPairVector::value_type& item) { return item.first; }
  174. static const TUnboxedValueVector::value_type& Get(const TUnboxedValueVector::value_type& item) { return item; }
  175. template<class Container>
  176. std::function<bool(const typename Container::value_type&, const typename Container::value_type&)>
  177. MakeComparator(const NUdf::TUnboxedValue& ascending) const {
  178. if (KeyTypes.size() > 1U) {
  179. // sort tuples
  180. if (!Comparators.empty()) {
  181. return [this, &ascending](const typename Container::value_type& x, const typename Container::value_type& y) {
  182. const auto& left = Get(x);
  183. const auto& right = Get(y);
  184. for (ui32 i = 0; i < KeyTypes.size(); ++i) {
  185. const auto& leftElem = left.GetElement(i);
  186. const auto& rightElem = right.GetElement(i);
  187. const bool asc = ascending.GetElement(i).Get<bool>();
  188. if (const auto cmp = Comparators[i]->Compare(leftElem, rightElem)) {
  189. return asc ? cmp < 0 : cmp > 0;
  190. }
  191. }
  192. return false;
  193. };
  194. }
  195. return [this, &ascending](const typename Container::value_type& x, const typename Container::value_type& y) {
  196. const auto& left = Get(x);
  197. const auto& right = Get(y);
  198. for (ui32 i = 0; i < KeyTypes.size(); ++i) {
  199. const auto& keyType = KeyTypes[i];
  200. const auto& leftElem = left.GetElement(i);
  201. const auto& rightElem = right.GetElement(i);
  202. const bool asc = ascending.GetElement(i).Get<bool>();
  203. if (const auto cmp = CompareValues(keyType, asc, std::get<1>(KeySchemeTypes[i]), leftElem, rightElem)) {
  204. return cmp < 0;
  205. }
  206. }
  207. return false;
  208. };
  209. } else {
  210. // sort one column
  211. const bool isOptional = std::get<1>(KeySchemeTypes.front());
  212. const bool asc = ascending.Get<bool>();
  213. if (!Comparators.empty()) {
  214. return [this, asc](const typename Container::value_type& x, const typename Container::value_type& y) {
  215. auto cmp = Comparators.front()->Compare(Get(x), Get(y));
  216. return asc ? cmp < 0 : cmp > 0;
  217. };
  218. }
  219. return [this, asc, isOptional](const typename Container::value_type& x, const typename Container::value_type& y) {
  220. return CompareValues(KeyTypes.front(), asc, isOptional, Get(x), Get(y)) < 0;
  221. };
  222. }
  223. }
  224. template<class Container>
  225. void Prepare(TComputationContext& ctx, Container& items) const {
  226. if (!KeyTypes.empty()) {
  227. auto& encoders = Encoders.RefMutableObject(ctx, KeySchemeTypes, Comparators.empty());
  228. for (auto& x : items) {
  229. PrepareImpl(ctx, x, encoders);
  230. }
  231. }
  232. }
  233. void PrepareValue(TComputationContext& ctx, NUdf::TUnboxedValue& item) const {
  234. if (!KeyTypes.empty()) {
  235. auto& encoders = Encoders.RefMutableObject(ctx, KeySchemeTypes, Comparators.empty());
  236. PrepareImpl(ctx, item, encoders);
  237. }
  238. }
  239. template <class T>
  240. void PrepareImpl(TComputationContext& ctx, T& item, TEncoders& encoders) const {
  241. if (KeyTypes.size() > 1U) {
  242. // sort tuples
  243. if (encoders.NeedEncode) {
  244. NUdf::TUnboxedValue* arrayItems = nullptr;
  245. NUdf::TUnboxedValue array = ctx.HolderFactory.CreateDirectArrayHolder(KeyTypes.size(), arrayItems);
  246. for (ui32 i = 0; i < KeyTypes.size(); ++i) {
  247. if (auto& e = encoders.Columns[i]) {
  248. arrayItems[i] = MakeString(e->Encode(Get(item).GetElement(i), false));
  249. } else {
  250. arrayItems[i] = Get(item).GetElement(i);
  251. }
  252. }
  253. Set(item) = std::move(array);
  254. }
  255. } else if (auto& encoder = encoders.Columns.front()) {
  256. Set(item) = MakeString(encoder->Encode(Get(item), false));
  257. }
  258. }
  259. const std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>> KeySchemeTypes;
  260. const std::vector<NUdf::EDataSlot> KeyTypes;
  261. const TVector<NUdf::ICompare::TPtr> Comparators;
  262. TMutableObjectOverBoxedValue<TEncoders> Encoders;
  263. };
  264. template<class TWrapperImpl, bool MaybeInplace>
  265. class TAlgoBaseWrapper : public TMutableComputationNode<TAlgoBaseWrapper<TWrapperImpl, MaybeInplace>> {
  266. using TBaseComputation = TMutableComputationNode<TAlgoBaseWrapper<TWrapperImpl, MaybeInplace>>;
  267. protected:
  268. TAlgoBaseWrapper(
  269. TComputationMutables& mutables,
  270. std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>>&& keySchemeTypes,
  271. const TVector<NUdf::ICompare::TPtr>& comparators,
  272. IComputationNode* list,
  273. IComputationExternalNode* item,
  274. IComputationNode* key,
  275. IComputationNode* ascending,
  276. bool stealed)
  277. : TBaseComputation(mutables)
  278. , Description(mutables, std::move(keySchemeTypes), comparators)
  279. , List(list)
  280. , Item(item)
  281. , Key(key)
  282. , Ascending(ascending)
  283. , Stealed(stealed)
  284. {}
  285. public:
  286. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  287. const auto& list = List->GetValue(ctx);
  288. auto ptr = list.GetElements();
  289. if (MaybeInplace && ptr) {
  290. TUnboxedValueVector keys;
  291. NUdf::TUnboxedValue *inplace = nullptr;
  292. NUdf::TUnboxedValue res;
  293. auto size = list.GetListLength();
  294. if (!size) {
  295. return ctx.HolderFactory.GetEmptyContainerLazy();
  296. }
  297. if (Stealed) {
  298. res = list;
  299. inplace = const_cast<NUdf::TUnboxedValue*>(ptr);
  300. } else {
  301. res = ctx.HolderFactory.CreateDirectArrayHolder(size, inplace);
  302. }
  303. keys.reserve(size);
  304. for (size_t i = 0; i < size; ++i) {
  305. if (!Stealed) {
  306. inplace[i] = ptr[i];
  307. }
  308. Item->SetValue(ctx, NUdf::TUnboxedValuePod(ptr[i]));
  309. keys.emplace_back(Key->GetValue(ctx));
  310. }
  311. Description.Prepare(ctx, keys);
  312. static_cast<const TWrapperImpl*>(this)->PerformInplace(ctx, size, keys.data(), inplace,
  313. Description.MakeComparator<TKeyPayloadPairVector>(Ascending->GetValue(ctx)));
  314. return res.Release();
  315. } else {
  316. TKeyPayloadPairVector items;
  317. if (ptr) {
  318. auto size = list.GetListLength();
  319. items.reserve(size);
  320. for (ui32 i = 0; i < size; ++i) {
  321. Item->SetValue(ctx, NUdf::TUnboxedValuePod(ptr[i]));
  322. items.emplace_back(Key->GetValue(ctx), Item->GetValue(ctx));
  323. }
  324. } else {
  325. const auto& iter = list.GetListIterator();
  326. if (list.HasFastListLength()) {
  327. items.reserve(list.GetListLength());
  328. }
  329. for (NUdf::TUnboxedValue item; iter.Next(item);) {
  330. Item->SetValue(ctx, std::move(item));
  331. items.emplace_back(Key->GetValue(ctx), Item->GetValue(ctx));
  332. }
  333. }
  334. if (items.empty()) {
  335. return ctx.HolderFactory.GetEmptyContainerLazy();
  336. }
  337. Description.Prepare(ctx, items);
  338. return static_cast<const TWrapperImpl*>(this)->Perform(ctx, items,
  339. Description.MakeComparator<TKeyPayloadPairVector>(Ascending->GetValue(ctx)));
  340. }
  341. }
  342. protected:
  343. void RegisterDependencies() const override {
  344. this->DependsOn(List);
  345. this->Own(Item);
  346. this->DependsOn(Key);
  347. this->DependsOn(Ascending);
  348. }
  349. private:
  350. TCompareDescr Description;
  351. IComputationNode* const List;
  352. IComputationExternalNode* const Item;
  353. IComputationNode* const Key;
  354. IComputationNode* const Ascending;
  355. const bool Stealed;
  356. };
  357. class TAlgoWrapper : public TAlgoBaseWrapper<TAlgoWrapper, true> {
  358. using TBaseComputation = TAlgoBaseWrapper<TAlgoWrapper, true>;
  359. public:
  360. TAlgoWrapper(
  361. TAlgorithm algorithm,
  362. TAlgorithmInplace algorithmInplace,
  363. TComputationMutables& mutables,
  364. std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>>&& keySchemeTypes,
  365. const TVector<NUdf::ICompare::TPtr>& comparators,
  366. IComputationNode* list,
  367. IComputationExternalNode* item,
  368. IComputationNode* key,
  369. IComputationNode* ascending,
  370. bool stealed)
  371. : TBaseComputation(mutables, std::move(keySchemeTypes), comparators, list, item, key, ascending, stealed)
  372. , Algorithm(algorithm)
  373. , AlgorithmInplace(algorithmInplace)
  374. {}
  375. NUdf::TUnboxedValuePod Perform(TComputationContext& ctx, TKeyPayloadPairVector& items, const TComparator& comparator) const {
  376. Algorithm(items.begin(), items.end(), comparator);
  377. NUdf::TUnboxedValue *inplace = nullptr;
  378. const auto result = ctx.HolderFactory.CreateDirectArrayHolder(items.size(), inplace);
  379. for (auto& item : items) {
  380. *inplace++ = std::move(item.second);
  381. }
  382. return result;
  383. }
  384. void PerformInplace(TComputationContext&, ui32 size, NUdf::TUnboxedValue* keys, NUdf::TUnboxedValue* items, const TComparator& comparator) const {
  385. AlgorithmInplace(TGatherIterator(keys, items), TGatherIterator(keys, items) + size, comparator);
  386. }
  387. private:
  388. const TAlgorithm Algorithm;
  389. const TAlgorithmInplace AlgorithmInplace;
  390. };
  391. class TNthAlgoWrapper : public TAlgoBaseWrapper<TNthAlgoWrapper, false> {
  392. using TBaseComputation = TAlgoBaseWrapper<TNthAlgoWrapper, false>;
  393. public:
  394. TNthAlgoWrapper(
  395. TNthAlgorithm algorithm,
  396. TComputationMutables& mutables,
  397. std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>>&& keySchemeTypes,
  398. const TVector<NUdf::ICompare::TPtr>& comparators,
  399. IComputationNode* list,
  400. IComputationNode* nth,
  401. IComputationExternalNode* item,
  402. IComputationNode* key,
  403. IComputationNode* ascending)
  404. : TBaseComputation(mutables, std::move(keySchemeTypes), comparators, list, item, key, ascending, false)
  405. , Algorithm(algorithm), Nth(nth)
  406. {}
  407. NUdf::TUnboxedValuePod Perform(TComputationContext& ctx, TKeyPayloadPairVector& items, const TComparator& comparator) const {
  408. const auto n = std::min<ui64>(Nth->GetValue(ctx).Get<ui64>(), items.size());
  409. if (!n) {
  410. return ctx.HolderFactory.GetEmptyContainerLazy();
  411. }
  412. Algorithm(items.begin(), items.begin() + n, items.end(), comparator);
  413. items.resize(n);
  414. NUdf::TUnboxedValue *inplace = nullptr;
  415. const auto result = ctx.HolderFactory.CreateDirectArrayHolder(n, inplace);
  416. for (auto& item : items) {
  417. *inplace++ = std::move(item.second);
  418. }
  419. return result;
  420. }
  421. void PerformInplace(TComputationContext& ctx, ui32 size, NUdf::TUnboxedValue* keys, NUdf::TUnboxedValue* items, const TComparator& comparator) const {
  422. Y_UNUSED(ctx);
  423. Y_UNUSED(size);
  424. Y_UNUSED(keys);
  425. Y_UNUSED(items);
  426. Y_UNUSED(comparator);
  427. Y_ABORT("Not supported");
  428. }
  429. private:
  430. void RegisterDependencies() const final {
  431. TBaseComputation::RegisterDependencies();
  432. this->DependsOn(Nth);
  433. }
  434. const TNthAlgorithm Algorithm;
  435. IComputationNode* const Nth;
  436. };
  437. class TKeepTopWrapper : public TMutableComputationNode<TKeepTopWrapper> {
  438. using TBaseComputation = TMutableComputationNode<TKeepTopWrapper>;
  439. public:
  440. TKeepTopWrapper(
  441. TComputationMutables& mutables,
  442. std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>>&& keySchemeTypes,
  443. const TVector<NUdf::ICompare::TPtr>& comparators,
  444. IComputationNode* count,
  445. IComputationNode* list,
  446. IComputationNode* item,
  447. IComputationExternalNode* arg,
  448. IComputationNode* key,
  449. IComputationNode* ascending,
  450. IComputationExternalNode* hotkey)
  451. : TBaseComputation(mutables)
  452. , Description(mutables, std::move(keySchemeTypes), comparators)
  453. , Count(count)
  454. , List(list)
  455. , Item(item)
  456. , Arg(arg)
  457. , Key(key)
  458. , Ascending(ascending)
  459. , HotKey(hotkey)
  460. {}
  461. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  462. const auto count = Count->GetValue(ctx).Get<ui64>();
  463. if (!count) {
  464. return ctx.HolderFactory.GetEmptyContainerLazy();
  465. }
  466. auto list = List->GetValue(ctx);
  467. auto item = Item->GetValue(ctx);
  468. const auto size = list.GetListLength();
  469. if (size < count) {
  470. return ctx.HolderFactory.Append(list.Release(), item.Release());
  471. }
  472. auto hotkey = HotKey->GetValue(ctx);
  473. auto hotkey_prepared = hotkey;
  474. if (!hotkey_prepared.IsInvalid()) {
  475. Description.PrepareValue(ctx, hotkey_prepared);
  476. }
  477. if (size == count) {
  478. if (hotkey.IsInvalid()) {
  479. TUnboxedValueVector keys;
  480. keys.reserve(size);
  481. const auto ptr = list.GetElements();
  482. std::transform(ptr, ptr + size, std::back_inserter(keys), [&](const NUdf::TUnboxedValuePod item) {
  483. Arg->SetValue(ctx, item);
  484. return Key->GetValue(ctx);
  485. });
  486. auto keys_copy = keys;
  487. Description.Prepare(ctx, keys);
  488. const auto& ascending = Ascending->GetValue(ctx);
  489. const auto max = std::max_element(keys.begin(), keys.end(), Description.MakeComparator<TUnboxedValueVector>(ascending));
  490. hotkey_prepared = *max;
  491. HotKey->SetValue(ctx, std::move(keys_copy[max - keys.begin()]));
  492. }
  493. }
  494. const auto copy = item;
  495. Arg->SetValue(ctx, item.Release());
  496. auto key_prepared = Key->GetValue(ctx);
  497. Description.PrepareValue(ctx, key_prepared);
  498. const auto& ascending = Ascending->GetValue(ctx);
  499. if (Description.MakeComparator<TUnboxedValueVector>(ascending)(key_prepared, hotkey_prepared)) {
  500. const auto reserve = std::max<ui64>(count << 1ULL, 1ULL << 8ULL);
  501. if (size < reserve) {
  502. return ctx.HolderFactory.Append(list.Release(), Arg->GetValue(ctx).Release());
  503. }
  504. TKeyPayloadPairVector items(1U, TKeyPayloadPair(Key->GetValue(ctx), Arg->GetValue(ctx)));
  505. items.reserve(items.size() + size);
  506. const auto ptr = list.GetElements();
  507. std::transform(ptr, ptr + size, std::back_inserter(items), [&](const NUdf::TUnboxedValuePod item) {
  508. Arg->SetValue(ctx, item);
  509. return TKeyPayloadPair(Key->GetValue(ctx), Arg->GetValue(ctx));
  510. });
  511. Description.Prepare(ctx, items);
  512. NYql::FastNthElement(items.begin(), items.begin() + count - 1U, items.end(), Description.MakeComparator<TKeyPayloadPairVector>(ascending));
  513. items.resize(count);
  514. NUdf::TUnboxedValue *inplace = nullptr;
  515. const auto result = ctx.HolderFactory.CreateDirectArrayHolder(count, inplace); /// TODO: Use list holder.
  516. for (auto& item : items) {
  517. *inplace++ = std::move(item.second);
  518. }
  519. return result;
  520. }
  521. return list.Release();
  522. }
  523. private:
  524. void RegisterDependencies() const final {
  525. DependsOn(Count);
  526. DependsOn(List);
  527. DependsOn(Item);
  528. Own(Arg);
  529. DependsOn(Key);
  530. DependsOn(Ascending);
  531. Own(HotKey);
  532. }
  533. TCompareDescr Description;
  534. IComputationNode* const Count;
  535. IComputationNode* const List;
  536. IComputationNode* const Item;
  537. IComputationExternalNode* const Arg;
  538. IComputationNode* const Key;
  539. IComputationNode* const Ascending;
  540. IComputationExternalNode* const HotKey;
  541. };
  542. std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>> GetKeySchemeTypes(TType* keyType, TType* ascType) {
  543. std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>> keySchemeTypes;
  544. if (ascType->IsTuple()) {
  545. MKQL_ENSURE(keyType->IsTuple(), "Key must be tuple");
  546. const auto keyDetailedType = static_cast<TTupleType*>(keyType);
  547. const auto keyElementsCount = keyDetailedType->GetElementsCount();
  548. keySchemeTypes.reserve(keyElementsCount);
  549. for (ui32 i = 0; i < keyElementsCount; ++i) {
  550. const auto elementType = keyDetailedType->GetElementType(i);
  551. bool isOptional;
  552. const auto unpacked = UnpackOptional(elementType, isOptional);
  553. if (!unpacked->IsData()) {
  554. keySchemeTypes.emplace_back(NUdf::EDataSlot::String, false, elementType);
  555. } else {
  556. keySchemeTypes.emplace_back(*static_cast<TDataType*>(unpacked)->GetDataSlot(), isOptional, nullptr);
  557. }
  558. }
  559. } else {
  560. keySchemeTypes.reserve(1);
  561. bool isOptional;
  562. const auto unpacked = UnpackOptional(keyType, isOptional);
  563. if (!unpacked->IsData()) {
  564. keySchemeTypes.emplace_back(NUdf::EDataSlot::String, false, keyType);
  565. } else {
  566. keySchemeTypes.emplace_back(*static_cast<TDataType*>(unpacked)->GetDataSlot(), isOptional, nullptr);
  567. }
  568. }
  569. return keySchemeTypes;
  570. }
  571. TVector<NUdf::ICompare::TPtr> MakeComparators(TType* keyType, bool isTuple) {
  572. if (keyType->IsPresortSupported()) {
  573. return {};
  574. }
  575. if (!isTuple) {
  576. return { MakeCompareImpl(keyType) };
  577. } else {
  578. MKQL_ENSURE(keyType->IsTuple(), "Key must be tuple");
  579. const auto keyDetailedType = static_cast<TTupleType*>(keyType);
  580. const auto keyElementsCount = keyDetailedType->GetElementsCount();
  581. TVector<NUdf::ICompare::TPtr> ret;
  582. for (ui32 i = 0; i < keyElementsCount; ++i) {
  583. ret.emplace_back(MakeCompareImpl(keyDetailedType->GetElementType(i)));
  584. }
  585. return ret;
  586. }
  587. }
  588. IComputationNode* WrapAlgo(TAlgorithm algorithm, TAlgorithmInplace algorithmInplace, TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  589. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
  590. const auto keyNode = callable.GetInput(2);
  591. const auto sortNode = callable.GetInput(3);
  592. const auto keyType = keyNode.GetStaticType();
  593. const auto ascType = sortNode.GetStaticType();
  594. auto listNode = callable.GetInput(0);
  595. IComputationNode* list = nullptr;
  596. bool stealed = false;
  597. if (listNode.GetNode()->GetType()->IsCallable()) {
  598. auto name = AS_TYPE(TCallableType, listNode.GetNode()->GetType())->GetName();
  599. if (name == "Steal") {
  600. list = LocateNode(ctx.NodeLocator, static_cast<TCallable&>(*listNode.GetNode()), 0);
  601. stealed = true;
  602. }
  603. }
  604. if (!list) {
  605. list = LocateNode(ctx.NodeLocator, callable, 0);
  606. }
  607. const auto key = LocateNode(ctx.NodeLocator, callable, 2);
  608. const auto ascending = LocateNode(ctx.NodeLocator, callable, 3);
  609. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1);
  610. auto comparators = MakeComparators(keyType, ascType->IsTuple());
  611. return new TAlgoWrapper(algorithm, algorithmInplace, ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, list,
  612. itemArg, key, ascending, stealed);
  613. }
  614. IComputationNode* WrapNthAlgo(TNthAlgorithm algorithm, TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  615. MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args");
  616. const auto keyNode = callable.GetInput(3);
  617. const auto sortNode = callable.GetInput(4);
  618. const auto keyType = keyNode.GetStaticType();
  619. const auto ascType = sortNode.GetStaticType();
  620. const auto list = LocateNode(ctx.NodeLocator, callable, 0);
  621. const auto nth = LocateNode(ctx.NodeLocator, callable, 1);
  622. const auto key = LocateNode(ctx.NodeLocator, callable, 3);
  623. const auto ascending = LocateNode(ctx.NodeLocator, callable, 4);
  624. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 2);
  625. auto comparators = MakeComparators(keyType, ascType->IsTuple());
  626. return new TNthAlgoWrapper(algorithm, ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, list, nth, itemArg, key, ascending);
  627. }
  628. }
  629. IComputationNode* WrapUnstableSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  630. return WrapAlgo(&std::sort<TKeyPayloadPairVector::iterator, TComparator>,
  631. &std::sort<TGatherIterator, TComparator>, callable, ctx);
  632. }
  633. IComputationNode* WrapSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  634. return WrapAlgo(&std::stable_sort<TKeyPayloadPairVector::iterator, TComparator>,
  635. &std::stable_sort<TGatherIterator, TComparator>, callable, ctx);
  636. }
  637. IComputationNode* WrapTop(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  638. return WrapNthAlgo(&NYql::FastNthElement<TKeyPayloadPairVector::iterator, TComparator>, callable, ctx);
  639. }
  640. IComputationNode* WrapTopSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  641. return WrapNthAlgo(&NYql::FastPartialSort<TKeyPayloadPairVector::iterator, TComparator>, callable, ctx);
  642. }
  643. IComputationNode* WrapKeepTop(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  644. MKQL_ENSURE(callable.GetInputsCount() == 7, "Expected 7 args");
  645. const auto keyNode = callable.GetInput(4);
  646. const auto sortNode = callable.GetInput(5);
  647. const auto keyType = keyNode.GetStaticType();
  648. const auto ascType = sortNode.GetStaticType();
  649. const auto count = LocateNode(ctx.NodeLocator, callable, 0);
  650. const auto list = LocateNode(ctx.NodeLocator, callable, 1);
  651. const auto item = LocateNode(ctx.NodeLocator, callable, 2);
  652. const auto key = LocateNode(ctx.NodeLocator, callable, 4);
  653. const auto ascending = LocateNode(ctx.NodeLocator, callable, 5);
  654. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 3);
  655. const auto hotkey = LocateExternalNode(ctx.NodeLocator, callable, 6);
  656. auto comparators = MakeComparators(keyType, ascType->IsTuple());
  657. return new TKeepTopWrapper(ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, count, list, item, itemArg, key, ascending, hotkey);
  658. }
  659. }
  660. }