mkql_enumerate.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #include "mkql_enumerate.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  5. #include <yql/essentials/minikql/computation/mkql_custom_list.h>
  6. #include <yql/essentials/minikql/mkql_node_cast.h>
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. namespace {
  10. class TEnumerateWrapper : public TMutableCodegeneratorNode<TEnumerateWrapper> {
  11. typedef TMutableCodegeneratorNode<TEnumerateWrapper> TBaseComputation;
  12. public:
  13. using TSelf = TEnumerateWrapper;
  14. class TValue : public TCustomListValue {
  15. public:
  16. class TIterator : public TComputationValue<TIterator> {
  17. public:
  18. TIterator(
  19. TMemoryUsageInfo* memInfo,
  20. NUdf::TUnboxedValue&& inner,
  21. ui64 start, ui64 step,
  22. TComputationContext& ctx, const TSelf* self)
  23. : TComputationValue(memInfo)
  24. , Inner(std::move(inner))
  25. , Step(step)
  26. , Counter(start - step)
  27. , Ctx(ctx)
  28. , Self(self)
  29. {
  30. }
  31. private:
  32. bool Next(NUdf::TUnboxedValue& value) override {
  33. NUdf::TUnboxedValue item;
  34. if (Inner.Next(item)) {
  35. Counter += Step;
  36. NUdf::TUnboxedValue* items = nullptr;
  37. value = Self->ResPair.NewArray(Ctx, 2, items);
  38. items[0] = NUdf::TUnboxedValuePod(Counter);
  39. items[1] = std::move(item);
  40. return true;
  41. }
  42. return false;
  43. }
  44. bool Skip() override {
  45. if (Inner.Skip()) {
  46. Counter += Step;
  47. return true;
  48. }
  49. return false;
  50. }
  51. const NUdf::TUnboxedValue Inner;
  52. const ui64 Step;
  53. ui64 Counter;
  54. TComputationContext& Ctx;
  55. const TSelf* const Self;
  56. };
  57. TValue(
  58. TMemoryUsageInfo* memInfo,
  59. const NUdf::TUnboxedValue& list,
  60. ui64 start, ui64 step,
  61. TComputationContext& ctx,
  62. const TSelf* self
  63. )
  64. : TCustomListValue(memInfo)
  65. , List(list)
  66. , Start(start)
  67. , Step(step)
  68. , Ctx(ctx)
  69. , Self(self)
  70. {}
  71. private:
  72. ui64 GetListLength() const override {
  73. if (!Length) {
  74. Length = List.GetListLength();
  75. }
  76. return *Length;
  77. }
  78. bool HasListItems() const override {
  79. if (!HasItems) {
  80. HasItems = List.HasListItems();
  81. }
  82. return *HasItems;
  83. }
  84. NUdf::TUnboxedValue GetListIterator() const override {
  85. return Ctx.HolderFactory.Create<TIterator>(List.GetListIterator(), Start, Step, Ctx, Self);
  86. }
  87. const NUdf::TUnboxedValue List;
  88. const ui64 Start;
  89. const ui64 Step;
  90. TComputationContext& Ctx;
  91. const TSelf* const Self;
  92. };
  93. TEnumerateWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationNode* start, IComputationNode* step)
  94. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  95. , List(list)
  96. , Start(start)
  97. , Step(step)
  98. , ResPair(mutables)
  99. {}
  100. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  101. return WrapList(ctx, List->GetValue(ctx).Release(), Start->GetValue(ctx).Get<ui64>(), Step->GetValue(ctx).Get<ui64>());
  102. }
  103. #ifndef MKQL_DISABLE_CODEGEN
  104. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  105. auto& context = ctx.Codegen.GetContext();
  106. const auto list = GetNodeValue(List, ctx, block);
  107. const auto startv = GetNodeValue(Start, ctx, block);
  108. const auto stepv = GetNodeValue(Step, ctx, block);
  109. const auto start = GetterFor<ui64>(startv, context, block);
  110. const auto step = GetterFor<ui64>(stepv, context, block);
  111. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TEnumerateWrapper::WrapList));
  112. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  113. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  114. const auto signature = FunctionType::get(list->getType(), {self->getType(), ctx.Ctx->getType(), list->getType(), start->getType(), step->getType()}, false);
  115. const auto creator = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(signature), "creator", block);
  116. const auto output = CallInst::Create(signature, creator, {self, ctx.Ctx, list, start, step}, "output", block);
  117. return output;
  118. }
  119. #endif
  120. private:
  121. NUdf::TUnboxedValuePod WrapList(TComputationContext& ctx, NUdf::TUnboxedValuePod list, ui64 start, ui64 step) const {
  122. return ctx.HolderFactory.Create<TValue>(list, start, step, ctx, this);
  123. }
  124. void RegisterDependencies() const final {
  125. DependsOn(List);
  126. DependsOn(Start);
  127. DependsOn(Step);
  128. }
  129. IComputationNode* const List;
  130. IComputationNode* const Start;
  131. IComputationNode* const Step;
  132. const TContainerCacheOnContext ResPair;
  133. };
  134. }
  135. IComputationNode* WrapEnumerate(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  136. MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
  137. AS_TYPE(TListType, callable.GetInput(0));
  138. MKQL_ENSURE(AS_TYPE(TDataType, callable.GetInput(1))->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
  139. MKQL_ENSURE(AS_TYPE(TDataType, callable.GetInput(2))->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
  140. return new TEnumerateWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0),
  141. LocateNode(ctx.NodeLocator, callable, 1), LocateNode(ctx.NodeLocator, callable, 2));
  142. }
  143. }
  144. }