mkql_computation_node_dict_ut.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #include <yql/essentials/minikql/mkql_node.h>
  2. #include <yql/essentials/minikql/mkql_node_cast.h>
  3. #include <yql/essentials/minikql/mkql_program_builder.h>
  4. #include <yql/essentials/minikql/mkql_function_registry.h>
  5. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  6. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  7. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
  8. #include <yql/essentials/minikql/comp_nodes/mkql_factories.h>
  9. #include <library/cpp/testing/unittest/registar.h>
  10. #include <vector>
  11. #include <utility>
  12. #include <algorithm>
  13. namespace NKikimr {
  14. namespace NMiniKQL {
  15. namespace {
  16. struct TSetup {
  17. TSetup(TScopedAlloc& alloc)
  18. : Alloc(alloc)
  19. {
  20. FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry());
  21. RandomProvider = CreateDeterministicRandomProvider(1);
  22. TimeProvider = CreateDeterministicTimeProvider(10000000);
  23. Env.Reset(new TTypeEnvironment(Alloc));
  24. PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry));
  25. }
  26. THolder<IComputationGraph> BuildGraph(TRuntimeNode pgm, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) {
  27. Explorer.Walk(pgm.GetNode(), *Env);
  28. TComputationPatternOpts opts(Alloc.Ref(), *Env, GetBuiltinFactory(),
  29. FunctionRegistry.Get(),
  30. NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, "OFF", EGraphPerProcess::Multi);
  31. Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
  32. TComputationOptsFull compOpts = opts.ToComputationOptions(*RandomProvider, *TimeProvider);
  33. return Pattern->Clone(compOpts);
  34. }
  35. TIntrusivePtr<IFunctionRegistry> FunctionRegistry;
  36. TIntrusivePtr<IRandomProvider> RandomProvider;
  37. TIntrusivePtr<ITimeProvider> TimeProvider;
  38. TScopedAlloc& Alloc;
  39. THolder<TTypeEnvironment> Env;
  40. THolder<TProgramBuilder> PgmBuilder;
  41. TExploringNodeVisitor Explorer;
  42. IComputationPattern::TPtr Pattern;
  43. };
  44. }
  45. Y_UNIT_TEST_SUITE(TestCompactMultiDict) {
  46. Y_UNIT_TEST(TestIterate) {
  47. TScopedAlloc alloc(__LOCATION__);
  48. TSetup setup(alloc);
  49. const std::vector<std::pair<ui32, std::vector<ui32>>> items = {{1, {1, 2}}, {2, {1}}, {3, {0}}, {6, {1, 7}}};
  50. TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
  51. TVector<TRuntimeNode> rItems;
  52. for (auto& [k, vv]: items) {
  53. for (auto& v: vv) {
  54. rItems.push_back(pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui32>(k), pgmBuilder.NewDataLiteral<ui32>(v)}));
  55. }
  56. }
  57. auto ui32Type = pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id);
  58. auto list = pgmBuilder.NewList(pgmBuilder.NewTupleType({ui32Type, ui32Type}), rItems);
  59. auto dict = pgmBuilder.ToHashedDict(list, /*all*/true,
  60. [&pgmBuilder](TRuntimeNode item) { return pgmBuilder.Nth(item, 0); },
  61. [&pgmBuilder](TRuntimeNode item) { return pgmBuilder.Nth(item, 1); },
  62. /*isCompact*/true,
  63. items.size());
  64. auto graph = setup.BuildGraph(dict, {});
  65. NUdf::TUnboxedValue res = graph->GetValue();
  66. std::vector<ui32> keyVals;
  67. for (NUdf::TUnboxedValue keys = res.GetKeysIterator(), v; keys.Next(v);) {
  68. keyVals.push_back(v.Get<ui32>());
  69. }
  70. UNIT_ASSERT_VALUES_EQUAL(keyVals.size(), items.size());
  71. std::sort(keyVals.begin(), keyVals.end());
  72. UNIT_ASSERT(
  73. std::equal(keyVals.begin(), keyVals.end(), items.begin(),
  74. [](ui32 l, const std::pair<ui32, std::vector<ui32>>& r) { return l == r.first; }
  75. )
  76. );
  77. std::vector<std::vector<ui32>> origPayloads;
  78. for (auto& [k, vv]: items) {
  79. origPayloads.push_back(vv);
  80. std::sort(origPayloads.back().begin(), origPayloads.back().end());
  81. }
  82. std::sort(origPayloads.begin(), origPayloads.end());
  83. std::vector<std::vector<ui32>> payloadVals;
  84. for (NUdf::TUnboxedValue payloads = res.GetPayloadsIterator(), v; payloads.Next(v);) {
  85. payloadVals.emplace_back();
  86. for (NUdf::TUnboxedValue i = v.GetListIterator(), p; i.Next(p);) {
  87. payloadVals.back().push_back(p.Get<ui32>());
  88. }
  89. std::sort(payloadVals.back().begin(), payloadVals.back().end());
  90. }
  91. std::sort(payloadVals.begin(), payloadVals.end());
  92. UNIT_ASSERT_VALUES_EQUAL(origPayloads, payloadVals);
  93. std::vector<std::pair<ui32, std::vector<ui32>>> vals;
  94. for (NUdf::TUnboxedValue values = res.GetDictIterator(), k, payloads; values.NextPair(k, payloads);) {
  95. vals.emplace_back(k.Get<ui32>(), std::vector<ui32>{});
  96. for (NUdf::TUnboxedValue i = payloads.GetListIterator(), p; i.Next(p);) {
  97. vals.back().second.push_back(p.Get<ui32>());
  98. }
  99. std::sort(vals.back().second.begin(), vals.back().second.end());
  100. }
  101. UNIT_ASSERT_VALUES_EQUAL(items, vals);
  102. }
  103. }
  104. }
  105. }