set_udf.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. #include <yql/essentials/public/udf/udf_type_ops.h>
  2. #include <yql/essentials/public/udf/udf_helpers.h>
  3. #include <unordered_set>
  4. using namespace NKikimr;
  5. using namespace NUdf;
  6. namespace {
  7. template <typename THash, typename TEquals>
  8. class TSetBase {
  9. private:
  10. std::unordered_set<TUnboxedValue, THash, TEquals, TUnboxedValue::TAllocator> Set;
  11. ui32 MaxSize = 0;
  12. bool WasChanged = false;
  13. protected:
  14. TSetBase(THash hash, TEquals equals)
  15. : Set(1, hash, equals)
  16. {}
  17. void Init(const TUnboxedValuePod& value, ui32 maxSize) {
  18. MaxSize = maxSize ? maxSize : std::numeric_limits<ui32>::max();
  19. AddValue(value);
  20. }
  21. void Merge(const TSetBase& left, const TSetBase& right) {
  22. MaxSize = std::max(left.MaxSize, right.MaxSize);
  23. for (const auto& item : left.Set) {
  24. AddValue(item);
  25. }
  26. for (const auto& item : right.Set) {
  27. AddValue(item);
  28. }
  29. }
  30. void Deserialize(const TUnboxedValuePod& serialized) {
  31. MaxSize = serialized.GetElement(0).Get<ui32>();
  32. auto list = serialized.GetElement(1);
  33. const auto listIter = list.GetListIterator();
  34. for (TUnboxedValue current; listIter.Next(current);) {
  35. AddValue(current);
  36. }
  37. }
  38. public:
  39. void ResetChanged() {
  40. WasChanged = false;
  41. }
  42. bool Changed() const {
  43. return WasChanged;
  44. }
  45. TUnboxedValue Serialize(const IValueBuilder* builder) {
  46. TUnboxedValue* values = nullptr;
  47. auto list = builder->NewArray(Set.size(), values);
  48. for (const auto& item : Set) {
  49. *values++ = item;
  50. }
  51. TUnboxedValue* items = nullptr;
  52. auto result = builder->NewArray(2U, items);
  53. items[0] = TUnboxedValuePod(MaxSize);
  54. items[1] = list;
  55. return result;
  56. }
  57. TUnboxedValue GetResult(const IValueBuilder* builder) {
  58. TUnboxedValue* values = nullptr;
  59. auto result = builder->NewArray(Set.size(), values);
  60. for (const auto& item : Set) {
  61. *values++ = item;
  62. }
  63. return result;
  64. }
  65. void AddValue(const TUnboxedValuePod& value) {
  66. if (Set.size() < MaxSize) {
  67. WasChanged = Set.insert(TUnboxedValuePod(value)).second;
  68. }
  69. }
  70. };
  71. template <EDataSlot Slot>
  72. class TSetData
  73. : public TSetBase<TUnboxedValueHash<Slot>, TUnboxedValueEquals<Slot>>
  74. {
  75. public:
  76. using TBase = TSetBase<TUnboxedValueHash<Slot>, TUnboxedValueEquals<Slot>>;
  77. TSetData(const TUnboxedValuePod& value, ui32 maxSize)
  78. : TBase(TUnboxedValueHash<Slot>(), TUnboxedValueEquals<Slot>())
  79. {
  80. TBase::Init(value, maxSize);
  81. }
  82. TSetData(const TSetData& left, const TSetData& right)
  83. : TBase(TUnboxedValueHash<Slot>(), TUnboxedValueEquals<Slot>())
  84. {
  85. TBase::Merge(left, right);
  86. }
  87. explicit TSetData(const TUnboxedValuePod& serialized)
  88. : TBase(TUnboxedValueHash<Slot>(), TUnboxedValueEquals<Slot>())
  89. {
  90. TBase::Deserialize(serialized);
  91. }
  92. };
  93. struct TGenericHash {
  94. IHash::TPtr Hash;
  95. std::size_t operator()(const TUnboxedValuePod& value) const {
  96. return Hash->Hash(value);
  97. }
  98. };
  99. struct TGenericEquals {
  100. IEquate::TPtr Equate;
  101. bool operator()(const TUnboxedValuePod& left, const TUnboxedValuePod& right) const {
  102. return Equate->Equals(left, right);
  103. }
  104. };
  105. class TSetGeneric
  106. : public TSetBase<TGenericHash, TGenericEquals>
  107. {
  108. public:
  109. using TBase = TSetBase<TGenericHash, TGenericEquals>;
  110. TSetGeneric(const TUnboxedValuePod& value, ui32 maxSize,
  111. IHash::TPtr hash, IEquate::TPtr equate)
  112. : TBase(TGenericHash{hash}, TGenericEquals{equate})
  113. {
  114. TBase::Init(value, maxSize);
  115. }
  116. TSetGeneric(const TSetGeneric& left, const TSetGeneric& right,
  117. IHash::TPtr hash, IEquate::TPtr equate)
  118. : TBase(TGenericHash{hash}, TGenericEquals{equate})
  119. {
  120. TBase::Merge(left, right);
  121. }
  122. TSetGeneric(const TUnboxedValuePod& serialized,
  123. IHash::TPtr hash, IEquate::TPtr equate)
  124. : TBase(TGenericHash{hash}, TGenericEquals{equate})
  125. {
  126. TBase::Deserialize(serialized);
  127. }
  128. };
  129. extern const char SetResourceNameGeneric[] = "Set.SetResource.Generic";
  130. class TSetResource:
  131. public TBoxedResource<TSetGeneric, SetResourceNameGeneric>
  132. {
  133. public:
  134. template <typename... Args>
  135. inline TSetResource(Args&&... args)
  136. : TBoxedResource(std::forward<Args>(args)...)
  137. {}
  138. };
  139. template <EDataSlot Slot>
  140. class TSetResourceData;
  141. template <EDataSlot Slot>
  142. TSetResourceData<Slot>* GetSetResourceData(const TUnboxedValuePod& arg) {
  143. TSetResourceData<Slot>::Validate(arg);
  144. return static_cast<TSetResourceData<Slot>*>(arg.AsBoxed().Get());
  145. }
  146. TSetResource* GetSetResource(const TUnboxedValuePod& arg) {
  147. TSetResource::Validate(arg);
  148. return static_cast<TSetResource*>(arg.AsBoxed().Get());
  149. }
  150. template <EDataSlot Slot>
  151. class TSetCreateData: public TBoxedValue {
  152. private:
  153. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  154. return TUnboxedValuePod(new TSetResourceData<Slot>(args[0], args[1].Get<ui32>()));
  155. }
  156. };
  157. class TSetCreate: public TBoxedValue {
  158. private:
  159. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  160. return TUnboxedValuePod(new TSetResource(args[0], args[1].Get<ui32>(), Hash_, Equate_));
  161. }
  162. public:
  163. TSetCreate(IHash::TPtr hash, IEquate::TPtr equate)
  164. : Hash_(hash)
  165. , Equate_(equate)
  166. {}
  167. private:
  168. IHash::TPtr Hash_;
  169. IEquate::TPtr Equate_;
  170. };
  171. template <EDataSlot Slot>
  172. class TSetAddValueData: public TBoxedValue {
  173. private:
  174. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  175. auto resource = GetSetResourceData<Slot>(args[0]);
  176. resource->Get()->ResetChanged();
  177. resource->Get()->AddValue(args[1]);
  178. return TUnboxedValuePod(resource);
  179. }
  180. };
  181. class TSetAddValue: public TBoxedValue {
  182. private:
  183. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  184. auto resource = GetSetResource(args[0]);
  185. resource->Get()->ResetChanged();
  186. resource->Get()->AddValue(args[1]);
  187. return TUnboxedValuePod(resource);
  188. }
  189. };
  190. template <EDataSlot Slot>
  191. class TSetWasChangedData: public TBoxedValue {
  192. private:
  193. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  194. auto resource = GetSetResourceData<Slot>(args[0]);
  195. return TUnboxedValuePod(resource->Get()->Changed());
  196. }
  197. };
  198. class TSetWasChanged: public TBoxedValue {
  199. private:
  200. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  201. auto resource = GetSetResource(args[0]);
  202. return TUnboxedValuePod(resource->Get()->Changed());
  203. }
  204. };
  205. template <EDataSlot Slot>
  206. class TSetSerializeData: public TBoxedValue {
  207. private:
  208. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const override {
  209. return GetSetResourceData<Slot>(args[0])->Get()->Serialize(valueBuilder);
  210. }
  211. };
  212. class TSetSerialize: public TBoxedValue {
  213. private:
  214. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const override {
  215. return GetSetResource(args[0])->Get()->Serialize(valueBuilder);
  216. }
  217. };
  218. template <EDataSlot Slot>
  219. class TSetDeserializeData: public TBoxedValue {
  220. private:
  221. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  222. return TUnboxedValuePod(new TSetResourceData<Slot>(args[0]));
  223. }
  224. };
  225. class TSetDeserialize: public TBoxedValue {
  226. private:
  227. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  228. return TUnboxedValuePod(new TSetResource(args[0], Hash_, Equate_));
  229. }
  230. public:
  231. TSetDeserialize(IHash::TPtr hash, IEquate::TPtr equate)
  232. : Hash_(hash)
  233. , Equate_(equate)
  234. {}
  235. private:
  236. IHash::TPtr Hash_;
  237. IEquate::TPtr Equate_;
  238. };
  239. template <EDataSlot Slot>
  240. class TSetMergeData: public TBoxedValue {
  241. private:
  242. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  243. auto left = GetSetResourceData<Slot>(args[0]);
  244. auto right = GetSetResourceData<Slot>(args[1]);
  245. return TUnboxedValuePod(new TSetResourceData<Slot>(*left->Get(), *right->Get()));
  246. }
  247. };
  248. class TSetMerge: public TBoxedValue {
  249. private:
  250. TUnboxedValue Run(const IValueBuilder*, const TUnboxedValuePod* args) const override {
  251. auto left = GetSetResource(args[0]);
  252. auto right = GetSetResource(args[1]);
  253. return TUnboxedValuePod(new TSetResource(*left->Get(), *right->Get(), Hash_, Equate_));
  254. }
  255. public:
  256. TSetMerge(IHash::TPtr hash, IEquate::TPtr equate)
  257. : Hash_(hash)
  258. , Equate_(equate)
  259. {}
  260. private:
  261. IHash::TPtr Hash_;
  262. IEquate::TPtr Equate_;
  263. };
  264. template <EDataSlot Slot>
  265. class TSetGetResultData: public TBoxedValue {
  266. private:
  267. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const override {
  268. return GetSetResourceData<Slot>(args[0])->Get()->GetResult(valueBuilder);
  269. }
  270. };
  271. class TSetGetResult: public TBoxedValue {
  272. private:
  273. TUnboxedValue Run(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const override {
  274. return GetSetResource(args[0])->Get()->GetResult(valueBuilder);
  275. }
  276. };
  277. #define MAKE_RESOURCE(slot, ...) \
  278. extern const char SetResourceName##slot[] = "Set.SetResource."#slot; \
  279. template <> \
  280. class TSetResourceData<EDataSlot::slot>: \
  281. public TBoxedResource<TSetData<EDataSlot::slot>, SetResourceName##slot> \
  282. { \
  283. public: \
  284. template <typename... Args> \
  285. inline TSetResourceData(Args&&... args) \
  286. : TBoxedResource(std::forward<Args>(args)...) \
  287. {} \
  288. };
  289. UDF_TYPE_ID_MAP(MAKE_RESOURCE)
  290. #define MAKE_IMPL(operation, slot) \
  291. case EDataSlot::slot: \
  292. builder.Implementation(new operation<EDataSlot::slot>); \
  293. break;
  294. #define MAKE_CREATE(slot, ...) MAKE_IMPL(TSetCreateData, slot)
  295. #define MAKE_ADD_VALUE(slot, ...) MAKE_IMPL(TSetAddValueData, slot)
  296. #define MAKE_WAS_CHANGED(slot, ...) MAKE_IMPL(TSetWasChangedData, slot)
  297. #define MAKE_SERIALIZE(slot, ...) MAKE_IMPL(TSetSerializeData, slot)
  298. #define MAKE_DESERIALIZE(slot, ...) MAKE_IMPL(TSetDeserializeData, slot)
  299. #define MAKE_MERGE(slot, ...) MAKE_IMPL(TSetMergeData, slot)
  300. #define MAKE_GET_RESULT(slot, ...) MAKE_IMPL(TSetGetResultData, slot)
  301. #define MAKE_TYPE(slot, ...) \
  302. case EDataSlot::slot: \
  303. setType = builder.Resource(SetResourceName##slot); \
  304. break;
  305. static const auto CreateName = TStringRef::Of("Create");
  306. static const auto AddValueName = TStringRef::Of("AddValue");
  307. static const auto WasChangedName = TStringRef::Of("WasChanged"); // must be used right after AddValue
  308. static const auto SerializeName = TStringRef::Of("Serialize");
  309. static const auto DeserializeName = TStringRef::Of("Deserialize");
  310. static const auto MergeName = TStringRef::Of("Merge");
  311. static const auto GetResultName = TStringRef::Of("GetResult");
  312. class TSetModule: public IUdfModule {
  313. public:
  314. TStringRef Name() const {
  315. return TStringRef::Of("Set");
  316. }
  317. void CleanupOnTerminate() const final {
  318. }
  319. void GetAllFunctions(IFunctionsSink& sink) const final {
  320. sink.Add(CreateName)->SetTypeAwareness();
  321. sink.Add(AddValueName)->SetTypeAwareness();
  322. sink.Add(WasChangedName)->SetTypeAwareness();
  323. sink.Add(SerializeName)->SetTypeAwareness();
  324. sink.Add(DeserializeName)->SetTypeAwareness();
  325. sink.Add(MergeName)->SetTypeAwareness();
  326. sink.Add(GetResultName)->SetTypeAwareness();
  327. }
  328. void BuildFunctionTypeInfo(
  329. const TStringRef& name,
  330. TType* userType,
  331. const TStringRef& typeConfig,
  332. ui32 flags,
  333. IFunctionTypeInfoBuilder& builder) const final
  334. {
  335. Y_UNUSED(typeConfig);
  336. try {
  337. const bool typesOnly = (flags & TFlags::TypesOnly);
  338. builder.UserType(userType);
  339. auto typeHelper = builder.TypeInfoHelper();
  340. auto userTypeInspector = TTupleTypeInspector(*typeHelper, userType);
  341. if (!userTypeInspector || userTypeInspector.GetElementsCount() != 3) {
  342. builder.SetError("User type is not a 3-tuple");
  343. return;
  344. }
  345. bool isGeneric = false;
  346. IHash::TPtr hash;
  347. IEquate::TPtr equate;
  348. TMaybe<EDataSlot> slot;
  349. auto valueType = userTypeInspector.GetElementType(2);
  350. auto valueTypeInspector = TDataTypeInspector(*typeHelper, valueType);
  351. if (!valueTypeInspector) {
  352. isGeneric = true;
  353. hash = builder.MakeHash(valueType);
  354. equate = builder.MakeEquate(valueType);
  355. if (!hash || !equate) {
  356. return;
  357. }
  358. } else {
  359. slot = FindDataSlot(valueTypeInspector.GetTypeId());
  360. if (!slot) {
  361. builder.SetError("Unknown data type");
  362. return;
  363. }
  364. const auto& info = NUdf::GetDataTypeInfo(*slot);
  365. const auto& features = info.Features;
  366. if (!(features & NUdf::CanHash) || !(features & NUdf::CanEquate)) {
  367. builder.SetError(TStringBuilder() << "Type " << info.Name << " is not hashable or equatable");
  368. return;
  369. }
  370. }
  371. auto serializedListType = builder.List()->Item(valueType).Build();
  372. auto serializedType = builder.Tuple()->Add<ui32>().Add(serializedListType).Build();
  373. TType* setType = nullptr;
  374. if (isGeneric) {
  375. setType = builder.Resource(SetResourceNameGeneric);
  376. } else {
  377. switch (*slot) {
  378. UDF_TYPE_ID_MAP(MAKE_TYPE)
  379. }
  380. }
  381. if (name == CreateName) {
  382. builder.IsStrict();
  383. builder.Args()->Add(valueType).Add<ui32>().Done().Returns(setType);
  384. if (!typesOnly) {
  385. if (isGeneric) {
  386. builder.Implementation(new TSetCreate(hash, equate));
  387. } else {
  388. switch (*slot) {
  389. UDF_TYPE_ID_MAP(MAKE_CREATE)
  390. }
  391. }
  392. }
  393. }
  394. if (name == AddValueName) {
  395. builder.IsStrict();
  396. builder.Args()->Add(setType).Add(valueType).Done().Returns(setType);
  397. if (!typesOnly) {
  398. if (isGeneric) {
  399. builder.Implementation(new TSetAddValue);
  400. } else {
  401. switch (*slot) {
  402. UDF_TYPE_ID_MAP(MAKE_ADD_VALUE)
  403. }
  404. }
  405. }
  406. }
  407. if (name == WasChangedName) {
  408. builder.IsStrict();
  409. builder.Args()->Add(setType).Done().Returns<bool>();
  410. if (!typesOnly) {
  411. if (isGeneric) {
  412. builder.Implementation(new TSetWasChanged);
  413. } else {
  414. switch (*slot) {
  415. UDF_TYPE_ID_MAP(MAKE_WAS_CHANGED)
  416. }
  417. }
  418. }
  419. }
  420. if (name == MergeName) {
  421. builder.IsStrict();
  422. builder.Args()->Add(setType).Add(setType).Done().Returns(setType);
  423. if (!typesOnly) {
  424. if (isGeneric) {
  425. builder.Implementation(new TSetMerge(hash, equate));
  426. } else {
  427. switch (*slot) {
  428. UDF_TYPE_ID_MAP(MAKE_MERGE)
  429. }
  430. }
  431. }
  432. }
  433. if (name == SerializeName) {
  434. builder.IsStrict();
  435. builder.Args()->Add(setType).Done().Returns(serializedType);
  436. if (!typesOnly) {
  437. if (isGeneric) {
  438. builder.Implementation(new TSetSerialize);
  439. } else {
  440. switch (*slot) {
  441. UDF_TYPE_ID_MAP(MAKE_SERIALIZE)
  442. }
  443. }
  444. }
  445. }
  446. if (name == DeserializeName) {
  447. builder.Args()->Add(serializedType).Done().Returns(setType);
  448. if (!typesOnly) {
  449. if (isGeneric) {
  450. builder.Implementation(new TSetDeserialize(hash, equate));
  451. } else {
  452. switch (*slot) {
  453. UDF_TYPE_ID_MAP(MAKE_DESERIALIZE)
  454. }
  455. }
  456. }
  457. }
  458. if (name == GetResultName) {
  459. auto resultType = builder.List()->Item(valueType).Build();
  460. builder.IsStrict();
  461. builder.Args()->Add(setType).Done().Returns(resultType);
  462. if (!typesOnly) {
  463. if (isGeneric) {
  464. builder.Implementation(new TSetGetResult);
  465. } else {
  466. switch (*slot) {
  467. UDF_TYPE_ID_MAP(MAKE_GET_RESULT)
  468. }
  469. }
  470. }
  471. }
  472. } catch (const std::exception& e) {
  473. builder.SetError(CurrentExceptionMessage());
  474. }
  475. }
  476. };
  477. } // namespace
  478. REGISTER_MODULES(TSetModule)