mkql_heap_ut.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. #include "mkql_computation_node_ut.h"
  2. #include <yql/essentials/minikql/mkql_node_cast.h>
  3. #include <yql/essentials/minikql/mkql_string_util.h>
  4. #include <yql/essentials/utils/sort.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. Y_UNIT_TEST_SUITE(TMiniKQLHeapTest) {
  8. Y_UNIT_TEST_LLVM(TestMakeHeap) {
  9. const std::array<float, 10U> xxx = {{0.f, 13.f, -3.14f, 1212.f, -7898.8f, 21E4f, HUGE_VALF, -HUGE_VALF, 3673.f, -32764.f}};
  10. TSetup<LLVM> setup;
  11. TProgramBuilder& pb = *setup.PgmBuilder;
  12. std::array<TRuntimeNode, 10U> data;
  13. std::transform(xxx.cbegin(), xxx.cend(), data.begin(), [&pb](float f) { return pb.NewDataLiteral(f); } );
  14. const auto type = pb.NewDataType(NUdf::TDataType<float>::Id);
  15. const auto list = pb.NewList(type, data);
  16. const auto pgmReturn = pb.MakeHeap(list,
  17. [&](TRuntimeNode l, TRuntimeNode r) {
  18. return pb.AggrLess(pb.Abs(l), pb.Abs(r));
  19. });
  20. const auto graph = setup.BuildGraph(pgmReturn);
  21. const auto& result = graph->GetValue();
  22. UNIT_ASSERT_VALUES_EQUAL(result.GetListLength(), xxx.size());
  23. auto copy = xxx;
  24. std::make_heap(copy.begin(), copy.end(), [](float l, float r){ return std::abs(l) < std::abs(r); });
  25. for (auto i = 0U; i < copy.size(); ++i) {
  26. UNIT_ASSERT_VALUES_EQUAL(copy[i], result.GetElement(i).template Get<float>());
  27. }
  28. }
  29. Y_UNIT_TEST_LLVM(TestPopHeap) {
  30. const std::array<double, 10U> xxx = {{0.0, 13.0, -3.140, 1212.0, -7898.8, 210000.0, 17E13, -HUGE_VAL, 3673.0, -32764.0}};
  31. TSetup<LLVM> setup;
  32. TProgramBuilder& pb = *setup.PgmBuilder;
  33. std::array<TRuntimeNode, 10U> data;
  34. std::transform(xxx.cbegin(), xxx.cend(), data.begin(), [&pb](double f) { return pb.NewDataLiteral(f); } );
  35. const auto type = pb.NewDataType(NUdf::TDataType<double>::Id);
  36. const auto list = pb.NewList(type, data);
  37. const auto comparer = [&](TRuntimeNode l, TRuntimeNode r) {
  38. return pb.AggrGreater(pb.Abs(l), pb.Abs(r));
  39. };
  40. const auto pgmReturn = pb.PopHeap(pb.MakeHeap(list,comparer), comparer);
  41. const auto graph = setup.BuildGraph(pgmReturn);
  42. const auto& result = graph->GetValue();
  43. UNIT_ASSERT_VALUES_EQUAL(result.GetListLength(), xxx.size());
  44. auto copy = xxx;
  45. const auto c = [](double l, double r){ return std::abs(l) > std::abs(r); };
  46. std::make_heap(copy.begin(), copy.end(), c);
  47. std::pop_heap(copy.begin(), copy.end(), c);
  48. for (auto i = 0U; i < copy.size(); ++i) {
  49. UNIT_ASSERT_VALUES_EQUAL(copy[i], result.GetElement(i).template Get<double>());
  50. }
  51. }
  52. Y_UNIT_TEST_LLVM(TestSortHeap) {
  53. const std::array<float, 10U> xxx = {{9E9f, -HUGE_VALF, 0.003f, 137.4f, -3.1415f, 1212.f, -7898.8f, 21E4f, 3673.f, -32764.f}};
  54. TSetup<LLVM> setup;
  55. TProgramBuilder& pb = *setup.PgmBuilder;
  56. std::array<TRuntimeNode, 10U> data;
  57. std::transform(xxx.cbegin(), xxx.cend(), data.begin(), [&pb](float f) { return pb.NewDataLiteral(f); } );
  58. const auto type = pb.NewDataType(NUdf::TDataType<float>::Id);
  59. const auto list = pb.NewList(type, data);
  60. const auto pgmReturn = pb.SortHeap(
  61. pb.MakeHeap(list,
  62. [&](TRuntimeNode l, TRuntimeNode r) {
  63. return pb.AggrGreater(l, r);
  64. }),
  65. [&](TRuntimeNode l, TRuntimeNode r) {
  66. return pb.AggrGreater(l, r);
  67. });
  68. const auto graph = setup.BuildGraph(pgmReturn);
  69. const auto& result = graph->GetValue();
  70. UNIT_ASSERT_VALUES_EQUAL(result.GetListLength(), xxx.size());
  71. auto copy = xxx;
  72. std::make_heap(copy.begin(), copy.end(), std::greater<float>());
  73. std::sort_heap(copy.begin(), copy.end(), std::greater<float>());
  74. for (auto i = 0U; i < copy.size(); ++i) {
  75. UNIT_ASSERT_VALUES_EQUAL(copy[i], result.GetElement(i).template Get<float>());
  76. }
  77. }
  78. Y_UNIT_TEST_LLVM(TestStableSort) {
  79. const std::array<double, 10U> xxx = {{9E9f, -HUGE_VALF, 0.003f, HUGE_VALF, +3.1415f, -0.003f, -7898.8f, -3.1415f, 3673.f, 0.003f}};
  80. TSetup<LLVM> setup;
  81. TProgramBuilder& pb = *setup.PgmBuilder;
  82. std::array<TRuntimeNode, 10U> data;
  83. std::transform(xxx.cbegin(), xxx.cend(), data.begin(), [&pb](double f) { return pb.NewDataLiteral(f); } );
  84. const auto type = pb.NewDataType(NUdf::TDataType<double>::Id);
  85. const auto list = pb.NewList(type, data);
  86. const auto pgmReturn = pb.StableSort(list,
  87. [&](TRuntimeNode l, TRuntimeNode r) {
  88. return pb.AggrGreater(pb.Abs(l), pb.Abs(r));
  89. });
  90. const auto graph = setup.BuildGraph(pgmReturn);
  91. const auto& result = graph->GetValue();
  92. UNIT_ASSERT_VALUES_EQUAL(result.GetListLength(), xxx.size());
  93. auto copy = xxx;
  94. std::stable_sort(copy.begin(), copy.end(), [](double l, double r){ return std::abs(l) > std::abs(r); });
  95. for (auto i = 0U; i < copy.size(); ++i) {
  96. UNIT_ASSERT_VALUES_EQUAL(copy[i], result.GetElement(i).template Get<double>());
  97. }
  98. }
  99. Y_UNIT_TEST_LLVM(TestNthElement) {
  100. const std::array<float, 10U> xxx = {{0.f, 13.f, -3.14f, 1212.f, -7898.8f, 21E4f, HUGE_VALF, -HUGE_VALF, 3673.f, -32764.f}};
  101. TSetup<LLVM> setup;
  102. TProgramBuilder& pb = *setup.PgmBuilder;
  103. std::array<TRuntimeNode, 10U> data;
  104. std::transform(xxx.cbegin(), xxx.cend(), data.begin(), [&pb](float f) { return pb.NewDataLiteral(f); } );
  105. const auto type = pb.NewDataType(NUdf::TDataType<float>::Id);
  106. const auto list = pb.NewList(type, data);
  107. const auto n = pb.NewDataLiteral<ui64>(4U);
  108. const auto pgmReturn = pb.NthElement(list, n,
  109. [&](TRuntimeNode l, TRuntimeNode r) {
  110. return pb.AggrGreater(pb.Abs(l), pb.Abs(r));
  111. });
  112. const auto graph = setup.BuildGraph(pgmReturn);
  113. const auto& result = graph->GetValue();
  114. UNIT_ASSERT_VALUES_EQUAL(result.GetListLength(), xxx.size());
  115. auto copy = xxx;
  116. NYql::FastNthElement(copy.begin(), copy.begin() + 4U, copy.end(), [](float l, float r){ return std::abs(l) > std::abs(r); });
  117. for (auto i = 0U; i < copy.size(); ++i) {
  118. UNIT_ASSERT_VALUES_EQUAL(copy[i], result.GetElement(i).template Get<float>());
  119. }
  120. }
  121. Y_UNIT_TEST_LLVM(TestPartialSort) {
  122. const std::array<double, 10U> xxx = {{0.0, 13.0, -3.14, 1212.0, -7898.8, 21.0E4, HUGE_VAL, -HUGE_VAL, 3673.0, -32764.0}};
  123. TSetup<LLVM> setup;
  124. TProgramBuilder& pb = *setup.PgmBuilder;
  125. std::array<TRuntimeNode, 10U> data;
  126. std::transform(xxx.cbegin(), xxx.cend(), data.begin(), [&pb](double f) { return pb.NewDataLiteral(f); } );
  127. const auto type = pb.NewDataType(NUdf::TDataType<double>::Id);
  128. const auto list = pb.NewList(type, data);
  129. const auto n = pb.NewDataLiteral<ui64>(6U);
  130. const auto pgmReturn = pb.PartialSort(list, n,
  131. [&](TRuntimeNode l, TRuntimeNode r) {
  132. return pb.AggrLess(pb.Abs(l), pb.Abs(r));
  133. });
  134. const auto graph = setup.BuildGraph(pgmReturn);
  135. const auto& result = graph->GetValue();
  136. UNIT_ASSERT_VALUES_EQUAL(result.GetListLength(), xxx.size());
  137. auto copy = xxx;
  138. NYql::FastPartialSort(copy.begin(), copy.begin() + 6U, copy.end(), [](double l, double r){ return std::abs(l) < std::abs(r); });
  139. for (auto i = 0U; i < copy.size(); ++i) {
  140. UNIT_ASSERT_VALUES_EQUAL(copy[i], result.GetElement(i).template Get<double>());
  141. }
  142. }
  143. Y_UNIT_TEST_LLVM(TestTopN) {
  144. const std::array<double, 10U> xxx = {{0.0, 13.0, -3.140, -7898.8, 210000.0, 17E13, 1212.0, -HUGE_VAL, 3673.0, -32764.0}};
  145. TSetup<LLVM> setup;
  146. TProgramBuilder& pb = *setup.PgmBuilder;
  147. std::array<TRuntimeNode, 10U> data;
  148. std::transform(xxx.cbegin(), xxx.cend(), data.begin(), [&pb](double f) { return pb.NewDataLiteral(f); } );
  149. const auto type = pb.NewDataType(NUdf::TDataType<double>::Id);
  150. const auto list = pb.NewList(type, data);
  151. const auto comparator = [&](TRuntimeNode l, TRuntimeNode r) { return pb.AggrGreater(pb.Abs(l), pb.Abs(r)); };
  152. const auto n = 5ULL;
  153. const auto limit = pb.NewDataLiteral<ui64>(n);
  154. const auto last = pb.Decrement(limit);
  155. const auto pgmReturn = pb.Take(pb.NthElement(pb.Fold(list, pb.NewEmptyList(type),
  156. [&](TRuntimeNode item, TRuntimeNode state) {
  157. const auto size = pb.Length(state);
  158. return pb.If(pb.AggrLess(size, limit),
  159. pb.If(pb.AggrLess(size, last),
  160. pb.Append(state, item), pb.MakeHeap(pb.Append(state, item), comparator)),
  161. pb.If(comparator(item, pb.Unwrap(pb.ToOptional(state), pb.NewDataLiteral<NUdf::EDataSlot::String>(""), "", 0, 0)),
  162. pb.PushHeap(pb.Append(pb.Take(pb.PopHeap(state, comparator), pb.Decrement(size)), item), comparator),
  163. state
  164. )
  165. );
  166. }
  167. ), last, comparator), limit);
  168. const auto graph = setup.BuildGraph(pgmReturn);
  169. const auto& result = graph->GetValue();
  170. UNIT_ASSERT_VALUES_EQUAL(result.GetListLength(), n);
  171. auto copy = xxx;
  172. const auto comp = [](double l, double r){ return std::abs(l) > std::abs(r); };
  173. NYql::FastNthElement(copy.begin(), copy.begin() + n - 1U, copy.end(), comp);
  174. const auto mm = std::minmax_element(copy.begin(), copy.begin() + n, comp);
  175. double min = result.GetElement(0).template Get<double>(), max = min;
  176. for (auto i = 1U; i < n; ++i) {
  177. const auto v = result.GetElement(i).template Get<double>();
  178. min = std::min(min, v, comp);
  179. max = std::max(max, v, comp);
  180. }
  181. UNIT_ASSERT_VALUES_EQUAL(*mm.first, min);
  182. UNIT_ASSERT_VALUES_EQUAL(*mm.second, max);
  183. }
  184. Y_UNIT_TEST_LLVM(TestTopByNthElement) {
  185. const std::array<double, 10U> xxx = {{0.0, 13.0, -3.140, -7898.8, 210000.0, 17E13, 1212.0, -HUGE_VAL, 3673.0, -32764.0}};
  186. TSetup<LLVM> setup;
  187. TProgramBuilder& pb = *setup.PgmBuilder;
  188. std::array<TRuntimeNode, 10U> data;
  189. std::transform(xxx.cbegin(), xxx.cend(), data.begin(), [&pb](double f) { return pb.NewDataLiteral(f); } );
  190. const auto type = pb.NewDataType(NUdf::TDataType<double>::Id);
  191. const auto list = pb.NewList(type, data);
  192. const auto comparator = [&](TRuntimeNode l, TRuntimeNode r) { return pb.AggrLess(pb.Abs(l), pb.Abs(r)); };
  193. const auto n = 5ULL;
  194. const auto limit = pb.NewDataLiteral<ui64>(n);
  195. const auto reserve = pb.ShiftLeft(limit, pb.NewDataLiteral<ui8>(1U));
  196. const auto last = pb.Decrement(limit);
  197. const auto pgmReturn = pb.Take(pb.NthElement(pb.Fold(list, pb.NewEmptyList(type),
  198. [&](TRuntimeNode item, TRuntimeNode state) {
  199. const auto size = pb.Length(state);
  200. return pb.If(pb.AggrLess(size, limit),
  201. pb.If(pb.AggrLess(size, last),
  202. pb.Append(state, item), pb.MakeHeap(pb.Append(state, item), comparator)),
  203. pb.If(comparator(item, pb.Unwrap(pb.ToOptional(state), pb.NewDataLiteral<NUdf::EDataSlot::String>(""), "", 0, 0)),
  204. pb.If(pb.AggrLess(size, reserve),
  205. pb.Append(state, item),
  206. pb.Take(pb.NthElement(pb.Prepend(item, pb.Skip(state, pb.NewDataLiteral<ui64>(1U))), last, comparator), limit)
  207. ),
  208. state
  209. )
  210. );
  211. }
  212. ), last, comparator), limit);
  213. const auto graph = setup.BuildGraph(pgmReturn);
  214. const auto& result = graph->GetValue();
  215. UNIT_ASSERT_VALUES_EQUAL(result.GetListLength(), n);
  216. auto copy = xxx;
  217. const auto comp = [](double l, double r){ return std::abs(l) < std::abs(r); };
  218. NYql::FastNthElement(copy.begin(), copy.begin() + n - 1U, copy.end(), comp);
  219. const auto mm = std::minmax_element(copy.begin(), copy.begin() + n, comp);
  220. double min = result.GetElement(0).template Get<double>(), max = min;
  221. for (auto i = 1U; i < n; ++i) {
  222. const auto v = result.GetElement(i).template Get<double>();
  223. min = std::min(min, v, comp);
  224. max = std::max(max, v, comp);
  225. }
  226. UNIT_ASSERT_VALUES_EQUAL(*mm.first, min);
  227. UNIT_ASSERT_VALUES_EQUAL(*mm.second, max);
  228. }
  229. }
  230. }
  231. }