vector_udf.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #include <yql/essentials/public/udf/udf_type_ops.h>
  2. #include <yql/essentials/public/udf/udf_helpers.h>
  3. #include <vector>
  4. using namespace NKikimr;
  5. using namespace NUdf;
  6. namespace {
  7. class TVector {
  8. private:
  9. std::vector<TUnboxedValue, TUnboxedValue::TAllocator> Vector;
  10. public:
  11. TVector()
  12. : Vector()
  13. {}
  14. TUnboxedValue GetResult(const IValueBuilder* builder) {
  15. TUnboxedValue* values = nullptr;
  16. auto list = builder->NewArray(Vector.size(), values);
  17. std::copy(Vector.begin(), Vector.end(), values);
  18. return list;
  19. }
  20. void Emplace(const ui64 index, const TUnboxedValuePod& value) {
  21. if (index < Vector.size()) {
  22. Vector[index] = value;
  23. } else {
  24. Vector.push_back(value);
  25. }
  26. }
  27. void Swap(const ui64 a, const ui64 b) {
  28. if (a < Vector.size() && b < Vector.size()) {
  29. std::swap(Vector[a], Vector[b]);
  30. }
  31. }
  32. void Reserve(ui64 expectedSize) {
  33. Vector.reserve(expectedSize);
  34. }
  35. };
  36. extern const char VectorResourceName[] = "Vector.VectorResource";
  37. class TVectorResource:
  38. public TBoxedResource<TVector, VectorResourceName>
  39. {
  40. public:
  41. template <typename... Args>
  42. inline TVectorResource(Args&&... args)
  43. : TBoxedResource(std::forward<Args>(args)...)
  44. {}
  45. };
  46. TVectorResource* GetVectorResource(const TUnboxedValuePod& arg) {
  47. TVectorResource::Validate(arg);
  48. return static_cast<TVectorResource*>(arg.AsBoxed().Get());
  49. }
  50. class TVectorCreate: public TBoxedValue {
  51. private:
  52. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  53. auto resource = new TVectorResource;
  54. resource->Get()->Reserve(args[0].Get<ui64>());
  55. return TUnboxedValuePod(resource);
  56. }
  57. };
  58. class TVectorEmplace: public TBoxedValue {
  59. private:
  60. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  61. auto resource = GetVectorResource(args[0]);
  62. resource->Get()->Emplace(args[1].Get<ui64>(), args[2]);
  63. return TUnboxedValuePod(resource);
  64. }
  65. };
  66. class TVectorSwap: public TBoxedValue {
  67. private:
  68. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  69. auto resource = GetVectorResource(args[0]);
  70. resource->Get()->Swap(args[1].Get<ui64>(), args[2].Get<ui64>());
  71. return TUnboxedValuePod(resource);
  72. }
  73. };
  74. class TVectorGetResult: public TBoxedValue {
  75. private:
  76. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const override {
  77. return GetVectorResource(args[0])->Get()->GetResult(valueBuilder);
  78. }
  79. };
  80. static const auto CreateName = TStringRef::Of("Create");
  81. static const auto EmplaceName = TStringRef::Of("Emplace");
  82. static const auto SwapName = TStringRef::Of("Swap");
  83. static const auto GetResultName = TStringRef::Of("GetResult");
  84. class TVectorModule: public IUdfModule {
  85. public:
  86. TStringRef Name() const {
  87. return TStringRef::Of("Vector");
  88. }
  89. void CleanupOnTerminate() const final {
  90. }
  91. void GetAllFunctions(IFunctionsSink& sink) const final {
  92. sink.Add(CreateName)->SetTypeAwareness();
  93. sink.Add(EmplaceName)->SetTypeAwareness();
  94. sink.Add(SwapName)->SetTypeAwareness();
  95. sink.Add(GetResultName)->SetTypeAwareness();
  96. }
  97. void BuildFunctionTypeInfo(
  98. const TStringRef& name,
  99. TType* userType,
  100. const TStringRef& typeConfig,
  101. ui32 flags,
  102. IFunctionTypeInfoBuilder& builder) const final
  103. {
  104. Y_UNUSED(typeConfig);
  105. try {
  106. const bool typesOnly = (flags & TFlags::TypesOnly);
  107. builder.UserType(userType);
  108. auto typeHelper = builder.TypeInfoHelper();
  109. auto userTypeInspector = TTupleTypeInspector(*typeHelper, userType);
  110. if (!userTypeInspector || userTypeInspector.GetElementsCount() != 3) {
  111. builder.SetError("User type is not a 3-tuple");
  112. return;
  113. }
  114. auto valueType = userTypeInspector.GetElementType(2);
  115. TType* vectorType = builder.Resource(VectorResourceName);
  116. if (name == CreateName) {
  117. builder.IsStrict();
  118. builder.Args()->Add<ui64>().Done().Returns(vectorType);
  119. if (!typesOnly) {
  120. builder.Implementation(new TVectorCreate);
  121. }
  122. }
  123. if (name == EmplaceName) {
  124. builder.IsStrict();
  125. builder.Args()->Add(vectorType).Add<ui64>().Add(valueType).Done().Returns(vectorType);
  126. if (!typesOnly) {
  127. builder.Implementation(new TVectorEmplace);
  128. }
  129. }
  130. if (name == SwapName) {
  131. builder.IsStrict();
  132. builder.Args()->Add(vectorType).Add<ui64>().Add<ui64>().Done().Returns(vectorType);
  133. if (!typesOnly) {
  134. builder.Implementation(new TVectorSwap);
  135. }
  136. }
  137. if (name == GetResultName) {
  138. auto resultType = builder.List()->Item(valueType).Build();
  139. builder.IsStrict();
  140. builder.Args()->Add(vectorType).Done().Returns(resultType);
  141. if (!typesOnly) {
  142. builder.Implementation(new TVectorGetResult);
  143. }
  144. }
  145. } catch (const std::exception& e) {
  146. builder.SetError(CurrentExceptionMessage());
  147. }
  148. }
  149. };
  150. } // namespace
  151. REGISTER_MODULES(TVectorModule)