mkql_aggrcount.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #include "mkql_aggrcount.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. class TAggrCountInitWrapper : public TDecoratorCodegeneratorNode<TAggrCountInitWrapper> {
  10. typedef TDecoratorCodegeneratorNode<TAggrCountInitWrapper> TBaseComputation;
  11. public:
  12. TAggrCountInitWrapper(IComputationNode* value)
  13. : TBaseComputation(value)
  14. {}
  15. NUdf::TUnboxedValuePod DoCalculate(TComputationContext&, const NUdf::TUnboxedValuePod& value) const {
  16. return NUdf::TUnboxedValuePod(ui64(value ? 1ULL : 0ULL));
  17. }
  18. #ifndef MKQL_DISABLE_CODEGEN
  19. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* value, BasicBlock*& block) const {
  20. auto& context = ctx.Codegen.GetContext();
  21. const auto check = IsExists(value, block, context);
  22. if (Node->IsTemporaryValue())
  23. ValueCleanup(Node->GetRepresentation(), value, ctx, block);
  24. return MakeBoolean(check, context, block);
  25. }
  26. #endif
  27. };
  28. class TAggrCountUpdateWrapper : public TDecoratorCodegeneratorNode<TAggrCountUpdateWrapper> {
  29. typedef TDecoratorCodegeneratorNode<TAggrCountUpdateWrapper> TBaseComputation;
  30. public:
  31. TAggrCountUpdateWrapper(IComputationNode* state)
  32. : TBaseComputation(state)
  33. {}
  34. NUdf::TUnboxedValuePod DoCalculate(TComputationContext&, const NUdf::TUnboxedValuePod& value) const {
  35. return NUdf::TUnboxedValuePod(value.Get<ui64>() + 1U);
  36. }
  37. #ifndef MKQL_DISABLE_CODEGEN
  38. Value* DoGenerateGetValue(const TCodegenContext&, Value* value, BasicBlock*& block) const {
  39. return BinaryOperator::CreateAdd(value, ConstantInt::get(value->getType(), 1), "incr", block);
  40. }
  41. #endif
  42. };
  43. class TAggrCountIfUpdateWrapper : public TMutableCodegeneratorNode<TAggrCountIfUpdateWrapper> {
  44. typedef TMutableCodegeneratorNode<TAggrCountIfUpdateWrapper> TBaseComputation;
  45. public:
  46. TAggrCountIfUpdateWrapper(TComputationMutables& mutables, IComputationNode* value, IComputationNode* state)
  47. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  48. , Arg(value)
  49. , State(state)
  50. {
  51. }
  52. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  53. auto state = State->GetValue(compCtx);
  54. return Arg->GetValue(compCtx) ? NUdf::TUnboxedValuePod(state.Get<ui64>() + 1U) : state.Release();
  55. }
  56. #ifndef MKQL_DISABLE_CODEGEN
  57. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  58. auto& context = ctx.Codegen.GetContext();
  59. const auto state = GetNodeValue(State, ctx, block);
  60. const auto value = GetNodeValue(Arg, ctx, block);
  61. const auto check = IsExists(value, block, context);
  62. if (Arg->IsTemporaryValue())
  63. ValueCleanup(Arg->GetRepresentation(), value, ctx, block);
  64. const auto zext = new ZExtInst(check, state->getType(), "zext", block);
  65. const auto incr = BinaryOperator::CreateAdd(state, zext, "incr", block);
  66. return incr;
  67. }
  68. #endif
  69. private:
  70. void RegisterDependencies() const final {
  71. DependsOn(Arg);
  72. DependsOn(State);
  73. }
  74. IComputationNode* const Arg;
  75. IComputationNode* const State;
  76. };
  77. }
  78. IComputationNode* WrapAggrCountInit(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  79. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
  80. if (callable.GetInput(0).GetStaticType()->IsOptional()) {
  81. return new TAggrCountInitWrapper(LocateNode(ctx.NodeLocator, callable, 0));
  82. } else {
  83. return ctx.NodeFactory.CreateImmutableNode(NUdf::TUnboxedValuePod(ui64(1ULL)));
  84. }
  85. }
  86. IComputationNode* WrapAggrCountUpdate(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  87. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  88. MKQL_ENSURE(AS_TYPE(TDataType, callable.GetInput(1))->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64 type");
  89. if (callable.GetInput(0).GetStaticType()->IsOptional()) {
  90. return new TAggrCountIfUpdateWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), LocateNode(ctx.NodeLocator, callable, 1));
  91. } else {
  92. return new TAggrCountUpdateWrapper(LocateNode(ctx.NodeLocator, callable, 1));
  93. }
  94. }
  95. }
  96. }