topfreq_udf.h 15 KB


  1. #pragma once
  2. #include <yql/essentials/public/udf/udf_helpers.h>
  3. #include <yql/essentials/public/udf/udf_value_builder.h>
  4. #include <yql/essentials/public/udf/udf_registrator.h>
  5. #include <yql/essentials/public/udf/udf_type_builder.h>
  6. #include <yql/essentials/public/udf/udf_type_inspection.h>
  7. #include <yql/essentials/public/udf/udf_value.h>
  8. #include <yql/essentials/public/udf/udf_types.h>
  9. #include "topfreq.h"
  10. #include <algorithm>
  11. #include <array>
  12. using namespace NYql;
  13. using namespace NUdf;
  14. namespace {
  15. extern const char TopFreqResourceNameGeneric[] = "TopFreq.TopFreqResource.Generic";
  16. class TTopFreqResource:
  17. public TBoxedResource<TTopFreqGeneric, TopFreqResourceNameGeneric>
  18. {
  19. public:
  20. template <typename... Args>
  21. inline TTopFreqResource(Args&&... args)
  22. : TBoxedResource(std::forward<Args>(args)...)
  23. {}
  24. };
  25. template <EDataSlot Slot>
  26. class TTopFreqResourceData;
  27. template <EDataSlot Slot>
  28. TTopFreqResourceData<Slot>* GetTopFreqResourceData(const TUnboxedValuePod& arg) {
  29. TTopFreqResourceData<Slot>::Validate(arg);
  30. return static_cast<TTopFreqResourceData<Slot>*>(arg.AsBoxed().Get());
  31. }
  32. TTopFreqResource* GetTopFreqResource(const TUnboxedValuePod& arg) {
  33. TTopFreqResource::Validate(arg);
  34. return static_cast<TTopFreqResource*>(arg.AsBoxed().Get());
  35. }
  36. template <EDataSlot Slot>
  37. class TTopFreq_CreateData: public TBoxedValue {
  38. private:
  39. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const {
  40. ui32 minSize = args[1].Get<ui32>();
  41. return TUnboxedValuePod(new TTopFreqResourceData<Slot>(args[0], minSize, minSize * 2));
  42. }
  43. };
  44. class TTopFreq_Create: public TBoxedValue {
  45. private:
  46. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const {
  47. ui32 minSize = args[1].Get<ui32>();
  48. return TUnboxedValuePod(new TTopFreqResource(args[0], minSize, minSize * 2, Hash_, Equate_));
  49. }
  50. public:
  51. TTopFreq_Create(IHash::TPtr hash, IEquate::TPtr equate)
  52. : Hash_(hash)
  53. , Equate_(equate)
  54. {}
  55. private:
  56. IHash::TPtr Hash_;
  57. IEquate::TPtr Equate_;
  58. };
  59. template <EDataSlot Slot>
  60. class TTopFreq_AddValueData: public TBoxedValue {
  61. private:
  62. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const {
  63. const auto topFreq = GetTopFreqResourceData<Slot>(args[0]);
  64. topFreq->Get()->AddValue(args[1]);
  65. return TUnboxedValuePod(topFreq);
  66. }
  67. };
  68. class TTopFreq_AddValue: public TBoxedValue {
  69. private:
  70. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const {
  71. const auto topFreq = GetTopFreqResource(args[0]);
  72. topFreq->Get()->AddValue(args[1]);
  73. return TUnboxedValuePod(topFreq);
  74. }
  75. };
  76. template <EDataSlot Slot>
  77. class TTopFreq_SerializeData: public TBoxedValue {
  78. private:
  79. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const {
  80. return GetTopFreqResourceData<Slot>(args[0])->Get()->Serialize(valueBuilder);
  81. }
  82. };
  83. class TTopFreq_Serialize: public TBoxedValue {
  84. private:
  85. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const {
  86. return GetTopFreqResource(args[0])->Get()->Serialize(valueBuilder);
  87. }
  88. };
  89. template <EDataSlot Slot>
  90. class TTopFreq_DeserializeData: public TBoxedValue {
  91. private:
  92. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const {
  93. return TUnboxedValuePod(new TTopFreqResourceData<Slot>(args[0]));
  94. }
  95. };
  96. class TTopFreq_Deserialize: public TBoxedValue {
  97. private:
  98. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const {
  99. return TUnboxedValuePod(new TTopFreqResource(args[0], Hash_, Equate_));
  100. }
  101. public:
  102. TTopFreq_Deserialize(IHash::TPtr hash, IEquate::TPtr equate)
  103. : Hash_(hash)
  104. , Equate_(equate)
  105. {}
  106. private:
  107. IHash::TPtr Hash_;
  108. IEquate::TPtr Equate_;
  109. };
  110. template <EDataSlot Slot>
  111. class TTopFreq_MergeData: public TBoxedValue {
  112. private:
  113. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const {
  114. const auto topFreq0 = GetTopFreqResourceData<Slot>(args[0]);
  115. const auto topFreq1 = GetTopFreqResourceData<Slot>(args[1]);
  116. return TUnboxedValuePod(new TTopFreqResourceData<Slot>(*topFreq0->Get(), *topFreq1->Get()));
  117. }
  118. };
  119. class TTopFreq_Merge: public TBoxedValue {
  120. private:
  121. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const {
  122. const auto topFreq0 = GetTopFreqResource(args[0]);
  123. const auto topFreq1 = GetTopFreqResource(args[1]);
  124. return TUnboxedValuePod(new TTopFreqResource(*topFreq0->Get(), *topFreq1->Get(), Hash_, Equate_));
  125. }
  126. public:
  127. TTopFreq_Merge(IHash::TPtr hash, IEquate::TPtr equate)
  128. : Hash_(hash)
  129. , Equate_(equate)
  130. {}
  131. private:
  132. IHash::TPtr Hash_;
  133. IEquate::TPtr Equate_;
  134. };
  135. template <EDataSlot Slot>
  136. class TTopFreq_GetData: public TBoxedValue {
  137. private:
  138. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const {
  139. return GetTopFreqResourceData<Slot>(args[0])->Get()->Get(valueBuilder, args[1].Get<ui32>());
  140. }
  141. };
  142. class TTopFreq_Get: public TBoxedValue {
  143. private:
  144. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const {
  145. return GetTopFreqResource(args[0])->Get()->Get(valueBuilder, args[1].Get<ui32>());
  146. }
  147. };
  148. #define MAKE_RESOURCE(slot, ...) \
  149. extern const char TopFreqResourceName##slot[] = "TopFreq.TopFreqResource."#slot; \
  150. template <> \
  151. class TTopFreqResourceData<EDataSlot::slot>: \
  152. public TBoxedResource<TTopFreqData<EDataSlot::slot>, TopFreqResourceName##slot> \
  153. { \
  154. public: \
  155. template <typename... Args> \
  156. inline TTopFreqResourceData(Args&&... args) \
  157. : TBoxedResource(std::forward<Args>(args)...) \
  158. {} \
  159. };
  160. UDF_TYPE_ID_MAP(MAKE_RESOURCE)
  161. #define MAKE_IMPL(operation, slot) \
  162. case EDataSlot::slot: \
  163. builder.Implementation(new operation<EDataSlot::slot>); \
  164. break;
  165. #define MAKE_CREATE(slot, ...) MAKE_IMPL(TTopFreq_CreateData, slot)
  166. #define MAKE_ADD_VALUE(slot, ...) MAKE_IMPL(TTopFreq_AddValueData, slot)
  167. #define MAKE_SERIALIZE(slot, ...) MAKE_IMPL(TTopFreq_SerializeData, slot)
  168. #define MAKE_DESERIALIZE(slot, ...) MAKE_IMPL(TTopFreq_DeserializeData, slot)
  169. #define MAKE_MERGE(slot, ...) MAKE_IMPL(TTopFreq_MergeData, slot)
  170. #define MAKE_GET(slot, ...) MAKE_IMPL(TTopFreq_GetData, slot)
  171. #define MAKE_TYPE(slot, ...) \
  172. case EDataSlot::slot: \
  173. topFreqType = builder.Resource(TopFreqResourceName##slot); \
  174. break;
  175. static const auto CreateName = TStringRef::Of("TopFreq_Create");
  176. static const auto AddValueName = TStringRef::Of("TopFreq_AddValue");
  177. static const auto SerializeName = TStringRef::Of("TopFreq_Serialize");
  178. static const auto DeserializeName = TStringRef::Of("TopFreq_Deserialize");
  179. static const auto MergeName = TStringRef::Of("TopFreq_Merge");
  180. static const auto GetName = TStringRef::Of("TopFreq_Get");
  181. class TTopFreqModule: public IUdfModule {
  182. public:
  183. TStringRef Name() const {
  184. return TStringRef::Of("TopFreq");
  185. }
  186. void CleanupOnTerminate() const final {
  187. }
  188. void GetAllFunctions(IFunctionsSink& sink) const final {
  189. sink.Add(CreateName)->SetTypeAwareness();
  190. sink.Add(AddValueName)->SetTypeAwareness();
  191. sink.Add(SerializeName)->SetTypeAwareness();
  192. sink.Add(DeserializeName)->SetTypeAwareness();
  193. sink.Add(MergeName)->SetTypeAwareness();
  194. sink.Add(GetName)->SetTypeAwareness();
  195. }
  196. void BuildFunctionTypeInfo(
  197. const TStringRef& name,
  198. TType* userType,
  199. const TStringRef& typeConfig,
  200. ui32 flags,
  201. IFunctionTypeInfoBuilder& builder) const final
  202. {
  203. Y_UNUSED(typeConfig);
  204. try {
  205. const bool typesOnly = (flags & TFlags::TypesOnly);
  206. builder.UserType(userType);
  207. auto typeHelper = builder.TypeInfoHelper();
  208. auto userTypeInspector = TTupleTypeInspector(*typeHelper, userType);
  209. if (!userTypeInspector || userTypeInspector.GetElementsCount() != 3) {
  210. builder.SetError("User type is not a 3-tuple");
  211. return;
  212. }
  213. bool isGeneric = false;
  214. IHash::TPtr hash;
  215. IEquate::TPtr equate;
  216. TMaybe<EDataSlot> slot;
  217. auto valueType = userTypeInspector.GetElementType(2);
  218. auto valueTypeInspector = TDataTypeInspector(*typeHelper, valueType);
  219. if (!valueTypeInspector) {
  220. isGeneric = true;
  221. hash = builder.MakeHash(valueType);
  222. equate = builder.MakeEquate(valueType);
  223. if (!hash || !equate) {
  224. return;
  225. }
  226. } else {
  227. slot = FindDataSlot(valueTypeInspector.GetTypeId());
  228. if (!slot) {
  229. builder.SetError("Unknown data type");
  230. return;
  231. }
  232. const auto& features = NUdf::GetDataTypeInfo(*slot).Features;
  233. if (!(features & NUdf::CanHash) || !(features & NUdf::CanEquate)) {
  234. builder.SetError("Data type is not hashable or equatable");
  235. return;
  236. }
  237. }
  238. auto serializedItemType = builder.Tuple()->Add<ui64>().Add(valueType).Build();
  239. auto serializedListType = builder.List()->Item(serializedItemType).Build();
  240. auto serializedType = builder.Tuple()->Add<ui32>().Add<ui32>().Add(serializedListType).Build();
  241. TType* topFreqType = nullptr;
  242. if (isGeneric) {
  243. topFreqType = builder.Resource(TopFreqResourceNameGeneric);
  244. } else {
  245. switch (*slot) {
  246. UDF_TYPE_ID_MAP(MAKE_TYPE)
  247. }
  248. }
  249. if (name == CreateName) {
  250. builder.Args()->Add(valueType).Add<ui32>().Done().Returns(topFreqType);
  251. if (!typesOnly) {
  252. if (isGeneric) {
  253. builder.Implementation(new TTopFreq_Create(hash, equate));
  254. } else {
  255. switch (*slot) {
  256. UDF_TYPE_ID_MAP(MAKE_CREATE)
  257. }
  258. }
  259. }
  260. builder.IsStrict();
  261. }
  262. if (name == AddValueName) {
  263. builder.Args()->Add(topFreqType).Add(valueType).Done().Returns(topFreqType);
  264. if (!typesOnly) {
  265. if (isGeneric) {
  266. builder.Implementation(new TTopFreq_AddValue);
  267. } else {
  268. switch (*slot) {
  269. UDF_TYPE_ID_MAP(MAKE_ADD_VALUE)
  270. }
  271. }
  272. }
  273. builder.IsStrict();
  274. }
  275. if (name == MergeName) {
  276. builder.Args()->Add(topFreqType).Add(topFreqType).Done().Returns(topFreqType);
  277. if (!typesOnly) {
  278. if (isGeneric) {
  279. builder.Implementation(new TTopFreq_Merge(hash, equate));
  280. } else {
  281. switch (*slot) {
  282. UDF_TYPE_ID_MAP(MAKE_MERGE)
  283. }
  284. }
  285. }
  286. builder.IsStrict();
  287. }
  288. if (name == SerializeName) {
  289. builder.Args()->Add(topFreqType).Done().Returns(serializedType);
  290. if (!typesOnly) {
  291. if (isGeneric) {
  292. builder.Implementation(new TTopFreq_Serialize);
  293. } else {
  294. switch (*slot) {
  295. UDF_TYPE_ID_MAP(MAKE_SERIALIZE)
  296. }
  297. }
  298. }
  299. builder.IsStrict();
  300. }
  301. if (name == DeserializeName) {
  302. builder.Args()->Add(serializedType).Done().Returns(topFreqType);
  303. if (!typesOnly) {
  304. if (isGeneric) {
  305. builder.Implementation(new TTopFreq_Deserialize(hash, equate));
  306. } else {
  307. switch (*slot) {
  308. UDF_TYPE_ID_MAP(MAKE_DESERIALIZE)
  309. }
  310. }
  311. }
  312. }
  313. if (name == GetName) {
  314. ui32 indexF, indexV;
  315. auto itemType = builder.Struct()->AddField<ui64>("Frequency", &indexF).AddField("Value", valueType, &indexV).Build();
  316. auto resultType = builder.List()->Item(itemType).Build();
  317. builder.Args()->Add(topFreqType).Add<ui32>().Done().Returns(resultType);
  318. if (!typesOnly) {
  319. if (isGeneric) {
  320. builder.Implementation(new TTopFreq_Get);
  321. } else {
  322. switch (*slot) {
  323. UDF_TYPE_ID_MAP(MAKE_GET)
  324. }
  325. }
  326. }
  327. builder.IsStrict();
  328. }
  329. } catch (const std::exception& e) {
  330. builder.SetError(CurrentExceptionMessage());
  331. }
  332. }
  333. };
  334. } // namespace