mkql_computation_node_ut.h 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #pragma once
  2. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  3. #include <yql/essentials/minikql/mkql_program_builder.h>
  4. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
  5. #include <yql/essentials/minikql/mkql_function_registry.h>
  6. #include <yql/essentials/minikql/mkql_terminator.h>
  7. #include "../mkql_factories.h"
  8. #include <library/cpp/testing/unittest/registar.h>
  9. #define UNBOXED_VALUE_STR_EQUAL(unboxed, expected) \
  10. do { \
  11. const auto v = (unboxed); \
  12. if (!(v.AsStringRef() == (expected))) { \
  13. UNIT_FAIL_IMPL( \
  14. "equal assertion failed", \
  15. Sprintf("%s %s == %s", #unboxed, TString(v.AsStringRef()).c_str(), #expected)); \
  16. } \
  17. } while (0)
  18. #if defined(_msan_enabled_) || defined(_ubsan_enabled_) || defined(WITH_VALGRIND)
  19. #define Y_UNIT_TEST_TWIN_IMPL_REGISTER(N, OPT) \
  20. template<bool OPT> void N(NUnitTest::TTestContext&); \
  21. struct TTestRegistration##N { \
  22. TTestRegistration##N() { \
  23. TCurrentTest::AddTest(#N "-" #OPT, static_cast<void (*)(NUnitTest::TTestContext&)>(&N<false>), false); \
  24. } \
  25. }; \
  26. static TTestRegistration##N testRegistration##N;
  27. #else
  28. #define Y_UNIT_TEST_TWIN_IMPL_REGISTER(N, OPT) \
  29. template<bool OPT> void N(NUnitTest::TTestContext&); \
  30. struct TTestRegistration##N { \
  31. TTestRegistration##N() { \
  32. TCurrentTest::AddTest(#N "-" #OPT, static_cast<void (*)(NUnitTest::TTestContext&)>(&N<false>), false); \
  33. TCurrentTest::AddTest(#N "+" #OPT, static_cast<void (*)(NUnitTest::TTestContext&)>(&N<true>), false); \
  34. } \
  35. }; \
  36. static TTestRegistration##N testRegistration##N;
  37. #endif
  38. #define Y_UNIT_TEST_TWIN(N, OPT) \
  39. Y_UNIT_TEST_TWIN_IMPL_REGISTER(N, OPT) \
  40. template<bool OPT> \
  41. void N(NUnitTest::TTestContext&)
  42. #define Y_UNIT_TEST_LLVM(N) Y_UNIT_TEST_TWIN(N, LLVM)
  43. #define Y_UNIT_TEST_LLVM_SPILLING(N) Y_UNIT_TEST_QUAD(N, LLVM, SPILLING)
  44. #define Y_UNIT_TEST_QUAD(N, OPT1, OPT2) \
  45. template<bool OPT1, bool OPT2> void N(NUnitTest::TTestContext&); \
  46. struct TTestRegistration##N { \
  47. TTestRegistration##N() { \
  48. TCurrentTest::AddTest(#N "-" #OPT1 "-" #OPT2, static_cast<void (*)(NUnitTest::TTestContext&)>(&N<false, false>), false); \
  49. TCurrentTest::AddTest(#N "-" #OPT1 "+" #OPT2, static_cast<void (*)(NUnitTest::TTestContext&)>(&N<false, true>), false); \
  50. TCurrentTest::AddTest(#N "+" #OPT1 "-" #OPT2, static_cast<void (*)(NUnitTest::TTestContext&)>(&N<true, false>), false); \
  51. TCurrentTest::AddTest(#N "+" #OPT1 "+" #OPT2, static_cast<void (*)(NUnitTest::TTestContext&)>(&N<true, true>), false); \
  52. } \
  53. }; \
  54. static TTestRegistration##N testRegistration##N; \
  55. template<bool OPT1, bool OPT2> \
  56. void N(NUnitTest::TTestContext&)
  57. namespace NKikimr {
  58. namespace NMiniKQL {
  59. TComputationNodeFactory GetTestFactory(TComputationNodeFactory customFactory = {});
  60. template<typename T>
  61. NUdf::TUnboxedValuePod ToValue(T value) {
  62. return NUdf::TUnboxedValuePod(value);
  63. }
  64. struct TUdfModuleInfo {
  65. TString LibraryPath;
  66. TString ModuleName;
  67. NUdf::TUniquePtr<NUdf::IUdfModule> Module;
  68. };
  69. template<bool UseLLVM, bool EnableSpilling = false>
  70. struct TSetup {
  71. explicit TSetup(TComputationNodeFactory nodeFactory = GetTestFactory(), TVector<TUdfModuleInfo>&& modules = {})
  72. : Alloc(__LOCATION__)
  73. , StatsRegistry(CreateDefaultStatsRegistry())
  74. {
  75. NodeFactory = nodeFactory;
  76. FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry());
  77. if (!modules.empty()) {
  78. auto mutableRegistry = FunctionRegistry->Clone();
  79. for (auto& m : modules) {
  80. mutableRegistry->AddModule(m.LibraryPath, m.ModuleName, std::move(m.Module));
  81. }
  82. FunctionRegistry = mutableRegistry;
  83. }
  84. Alloc.Ref().ForcefullySetMemoryYellowZone(EnableSpilling);
  85. RandomProvider = CreateDeterministicRandomProvider(1);
  86. TimeProvider = CreateDeterministicTimeProvider(10000000);
  87. Env.Reset(new TTypeEnvironment(Alloc));
  88. PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry));
  89. }
  90. THolder<IComputationGraph> BuildGraph(TRuntimeNode pgm, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) {
  91. return BuildGraph(pgm, EGraphPerProcess::Multi, entryPoints);
  92. }
  93. THolder<IComputationGraph> BuildGraph(TRuntimeNode pgm, EGraphPerProcess graphPerProcess) {
  94. return BuildGraph(pgm, graphPerProcess, {});
  95. }
  96. TAutoPtr<IComputationGraph> BuildGraph(TRuntimeNode pgm, EGraphPerProcess graphPerProcess, const std::vector<TNode*>& entryPoints) {
  97. Reset();
  98. Explorer.Walk(pgm.GetNode(), *Env);
  99. TComputationPatternOpts opts(Alloc.Ref(), *Env, NodeFactory,
  100. FunctionRegistry.Get(), NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception,
  101. UseLLVM ? "" : "OFF", graphPerProcess, StatsRegistry.Get(), nullptr, nullptr);
  102. Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
  103. auto graph = Pattern->Clone(opts.ToComputationOptions(*RandomProvider, *TimeProvider));
  104. Terminator.Reset(new TBindTerminator(graph->GetTerminator()));
  105. return graph;
  106. }
  107. void RenameCallable(TRuntimeNode pgm, TString originalName, TString newName) {
  108. const auto renameProvider = [originalName = std::move(originalName), newName = std::move(newName)](TInternName name) -> TCallableVisitFunc {
  109. if (name == originalName) {
  110. return [name, newName = std::move(newName)](TCallable& callable, const TTypeEnvironment& env) {
  111. TCallableBuilder callableBuilder(env, newName,
  112. callable.GetType()->GetReturnType(), false);
  113. for (ui32 i = 0; i < callable.GetInputsCount(); ++i) {
  114. callableBuilder.Add(callable.GetInput(i));
  115. }
  116. return TRuntimeNode(callableBuilder.Build(), false);
  117. };
  118. } else {
  119. return TCallableVisitFunc();
  120. }
  121. };
  122. TExploringNodeVisitor explorer;
  123. explorer.Walk(pgm.GetNode(), *Env);
  124. bool wereChanges = false;
  125. SinglePassVisitCallables(pgm, explorer, renameProvider, *Env, true, wereChanges);
  126. }
  127. void Reset() {
  128. Terminator.Destroy();
  129. Pattern.Reset();
  130. }
  131. TScopedAlloc Alloc;
  132. TComputationNodeFactory NodeFactory;
  133. TIntrusivePtr<IFunctionRegistry> FunctionRegistry;
  134. TIntrusivePtr<IRandomProvider> RandomProvider;
  135. TIntrusivePtr<ITimeProvider> TimeProvider;
  136. IStatsRegistryPtr StatsRegistry;
  137. THolder<TTypeEnvironment> Env;
  138. THolder<TProgramBuilder> PgmBuilder;
  139. TExploringNodeVisitor Explorer;
  140. IComputationPattern::TPtr Pattern;
  141. THolder<TBindTerminator> Terminator;
  142. };
  143. extern const std::vector<std::pair<i8, double>> I8Samples;
  144. extern const std::vector<std::pair<ui16, double>> Ui16Samples;
  145. extern const std::vector<std::tuple<ui64, std::string, std::string, double, double, double, double>> TpchSamples;
  146. }
  147. }