mkql_length.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #include "mkql_length.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_node_builder.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. template <bool IsDict, bool IsOptional>
  10. class TLengthWrapper : public TMutableCodegeneratorNode<TLengthWrapper<IsDict, IsOptional>> {
  11. typedef TMutableCodegeneratorNode<TLengthWrapper<IsDict, IsOptional>> TBaseComputation;
  12. public:
  13. TLengthWrapper(TComputationMutables& mutables, IComputationNode* collection)
  14. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  15. , Collection(collection)
  16. {}
  17. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  18. const auto& collection = Collection->GetValue(compCtx);
  19. if (IsOptional && !collection) {
  20. return NUdf::TUnboxedValuePod();
  21. }
  22. const auto length = IsDict ? collection.GetDictLength() : collection.GetListLength();
  23. return NUdf::TUnboxedValuePod(length);
  24. }
  25. #ifndef MKQL_DISABLE_CODEGEN
  26. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  27. auto& context = ctx.Codegen.GetContext();
  28. const auto collection = GetNodeValue(Collection, ctx, block);
  29. if constexpr (IsOptional) {
  30. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  31. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  32. const auto result = PHINode::Create(collection->getType(), 2U, "result", done);
  33. result->addIncoming(collection, block);
  34. BranchInst::Create(done, good, IsEmpty(collection, block, context), block);
  35. block = good;
  36. const auto length = CallBoxedValueVirtualMethod<IsDict ? NUdf::TBoxedValueAccessor::EMethod::GetDictLength : NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), collection, ctx.Codegen, block);
  37. if (Collection->IsTemporaryValue())
  38. CleanupBoxed(collection, ctx, block);
  39. result->addIncoming(SetterFor<ui64>(length, context, block), block);
  40. BranchInst::Create(done, block);
  41. block = done;
  42. return result;
  43. } else {
  44. const auto length = CallBoxedValueVirtualMethod<IsDict ? NUdf::TBoxedValueAccessor::EMethod::GetDictLength : NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), collection, ctx.Codegen, block);
  45. if (Collection->IsTemporaryValue())
  46. CleanupBoxed(collection, ctx, block);
  47. return SetterFor<ui64>(length, context, block);
  48. }
  49. }
  50. #endif
  51. private:
  52. void RegisterDependencies() const final {
  53. this->DependsOn(Collection);
  54. }
  55. IComputationNode* const Collection;
  56. };
  57. }
  58. IComputationNode* WrapLength(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  59. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
  60. bool isOptional;
  61. const auto type = UnpackOptional(callable.GetInput(0).GetStaticType(), isOptional);
  62. if (type->IsDict() || type->IsEmptyDict()) {
  63. if (isOptional)
  64. return new TLengthWrapper<true, true>(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0));
  65. else
  66. return new TLengthWrapper<true, false>(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0));
  67. } else if (type->IsList() || type->IsEmptyList()) {
  68. if (isOptional)
  69. return new TLengthWrapper<false, true>(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0));
  70. else
  71. return new TLengthWrapper<false, false>(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0));
  72. }
  73. THROW yexception() << "Expected list or dict.";
  74. }
  75. }
  76. }