mkql_test_factory.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #include "mkql_computation_node_ut.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_string_util.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. ui64 g_Yield = std::numeric_limits<ui64>::max();
  10. ui64 g_TestStreamData[] = {0, 0, 1, 0, 0, 0, 1, 2, 3};
  11. ui64 g_TestYieldStreamData[] = {0, 1, 2, g_Yield, 0, g_Yield, 1, 2, 0, 1, 2, 0, g_Yield, 1, 2};
  12. class TTestStreamWrapper: public TMutableComputationNode<TTestStreamWrapper> {
  13. typedef TMutableComputationNode<TTestStreamWrapper> TBaseComputation;
  14. public:
  15. class TStreamValue : public TComputationValue<TStreamValue> {
  16. public:
  17. using TBase = TComputationValue<TStreamValue>;
  18. TStreamValue(TMemoryUsageInfo* memInfo, ui64 count)
  19. : TBase(memInfo)
  20. , Count(count)
  21. {
  22. }
  23. private:
  24. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
  25. if (Index == Count) {
  26. return NUdf::EFetchStatus::Finish;
  27. }
  28. result = NUdf::TUnboxedValuePod(g_TestStreamData[Index++]);
  29. return NUdf::EFetchStatus::Ok;
  30. }
  31. private:
  32. ui64 Index = 0;
  33. const ui64 Count;
  34. };
  35. TTestStreamWrapper(TComputationMutables& mutables, ui64 count)
  36. : TBaseComputation(mutables)
  37. , Count(Min<ui64>(count, Y_ARRAY_SIZE(g_TestStreamData)))
  38. {
  39. }
  40. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  41. return ctx.HolderFactory.Create<TStreamValue>(Count);
  42. }
  43. private:
  44. void RegisterDependencies() const final {}
  45. private:
  46. const ui64 Count;
  47. };
  48. class TTestYieldStreamWrapper: public TMutableComputationNode<TTestYieldStreamWrapper> {
  49. typedef TMutableComputationNode<TTestYieldStreamWrapper> TBaseComputation;
  50. public:
  51. class TStreamValue : public TComputationValue<TStreamValue> {
  52. public:
  53. using TBase = TComputationValue<TStreamValue>;
  54. TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx)
  55. : TBase(memInfo)
  56. , CompCtx(compCtx) {}
  57. private:
  58. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
  59. if (Index == Y_ARRAY_SIZE(g_TestYieldStreamData)) {
  60. return NUdf::EFetchStatus::Finish;
  61. }
  62. const auto value = g_TestYieldStreamData[Index];
  63. if (value == g_Yield) {
  64. ++Index;
  65. return NUdf::EFetchStatus::Yield;
  66. }
  67. NUdf::TUnboxedValue* items = nullptr;
  68. result = CompCtx.HolderFactory.CreateDirectArrayHolder(2, items);
  69. items[0] = NUdf::TUnboxedValuePod(value);
  70. items[1] = MakeString(ToString(Index));
  71. ++Index;
  72. return NUdf::EFetchStatus::Ok;
  73. }
  74. private:
  75. TComputationContext& CompCtx;
  76. ui64 Index = 0;
  77. };
  78. TTestYieldStreamWrapper(TComputationMutables& mutables)
  79. : TBaseComputation(mutables) {}
  80. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  81. return ctx.HolderFactory.Create<TStreamValue>(ctx);
  82. }
  83. private:
  84. void RegisterDependencies() const final {}
  85. };
  86. IComputationNode* WrapTestStream(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  87. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args");
  88. const ui64 count = AS_VALUE(TDataLiteral, callable.GetInput(0))->AsValue().Get<ui64>();
  89. return new TTestStreamWrapper(ctx.Mutables, count);
  90. }
  91. IComputationNode* WrapTestYieldStream(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  92. MKQL_ENSURE(!callable.GetInputsCount(), "Expected no args");
  93. return new TTestYieldStreamWrapper(ctx.Mutables);
  94. }
  95. }
  96. TComputationNodeFactory GetTestFactory(TComputationNodeFactory customFactory) {
  97. return [customFactory](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* {
  98. if (callable.GetType()->GetName() == "TestList") {
  99. return new TExternalComputationNode(ctx.Mutables);
  100. }
  101. if (callable.GetType()->GetName() == "TestStream") {
  102. return WrapTestStream(callable, ctx);
  103. }
  104. if (callable.GetType()->GetName() == "TestYieldStream") {
  105. return WrapTestYieldStream(callable, ctx);
  106. }
  107. if (customFactory) {
  108. auto ret = customFactory(callable, ctx);
  109. if (ret) {
  110. return ret;
  111. }
  112. }
  113. return GetBuiltinFactory()(callable, ctx);
  114. };
  115. }
  116. }
  117. }