hyperloglog_udf.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. #include <yql/essentials/public/udf/udf_helpers.h>
  2. #include <library/cpp/hyperloglog/hyperloglog.h>
  3. #include <util/generic/hash_set.h>
  4. #include <variant>
  5. using namespace NKikimr;
  6. using namespace NUdf;
  7. namespace {
  8. class THybridHyperLogLog {
  9. private:
  10. using THybridSet = THashSet<ui64, std::hash<ui64>, std::equal_to<ui64>, TStdAllocatorForUdf<ui64>>;
  11. using THybridHll = THyperLogLogWithAlloc<TStdAllocatorForUdf<ui8>>;
  12. explicit THybridHyperLogLog(unsigned precision)
  13. : Var(THybridSet()), SizeLimit((1u << precision) / 8), Precision(precision)
  14. { }
  15. THybridHll ConvertToHyperLogLog() const {
  16. auto res = THybridHll::Create(Precision);
  17. for (auto& el : GetSetRef()) {
  18. res.Update(el);
  19. }
  20. return res;
  21. }
  22. bool IsSet() const {
  23. return Var.index() == 1;
  24. }
  25. const THybridSet& GetSetRef() const {
  26. return std::get<1>(Var);
  27. }
  28. THybridSet& GetMutableSetRef() {
  29. return std::get<1>(Var);
  30. }
  31. const THybridHll& GetHllRef() const {
  32. return std::get<0>(Var);
  33. }
  34. THybridHll& GetMutableHllRef() {
  35. return std::get<0>(Var);
  36. }
  37. public:
  38. THybridHyperLogLog (THybridHyperLogLog&&) = default;
  39. THybridHyperLogLog& operator=(THybridHyperLogLog&&) = default;
  40. void Update(ui64 hash) {
  41. if (IsSet()) {
  42. GetMutableSetRef().insert(hash);
  43. if (GetSetRef().size() >= SizeLimit) {
  44. Var = ConvertToHyperLogLog();
  45. }
  46. } else {
  47. GetMutableHllRef().Update(hash);
  48. }
  49. }
  50. void Merge(const THybridHyperLogLog& rh) {
  51. if (IsSet() && rh.IsSet()) {
  52. GetMutableSetRef().insert(rh.GetSetRef().begin(), rh.GetSetRef().end());
  53. if (GetSetRef().size() >= SizeLimit) {
  54. Var = ConvertToHyperLogLog();
  55. }
  56. } else {
  57. if (IsSet()) {
  58. Var = ConvertToHyperLogLog();
  59. }
  60. if (rh.IsSet()) {
  61. GetMutableHllRef().Merge(rh.ConvertToHyperLogLog());
  62. } else {
  63. GetMutableHllRef().Merge(rh.GetHllRef());
  64. }
  65. }
  66. }
  67. void Save(IOutputStream& out) const {
  68. out.Write(static_cast<char>(Var.index()));
  69. out.Write(static_cast<char>(Precision));
  70. if (IsSet()) {
  71. ::Save(&out, GetSetRef());
  72. } else {
  73. GetHllRef().Save(out);
  74. }
  75. }
  76. ui64 Estimate() const {
  77. if (IsSet()) {
  78. return GetSetRef().size();
  79. }
  80. return GetHllRef().Estimate();
  81. }
  82. static THybridHyperLogLog Create(unsigned precision) {
  83. Y_ENSURE(precision >= THyperLogLog::PRECISION_MIN && precision <= THyperLogLog::PRECISION_MAX);
  84. return THybridHyperLogLog(precision);
  85. }
  86. static THybridHyperLogLog Load(IInputStream& in) {
  87. char type;
  88. Y_ENSURE(in.ReadChar(type));
  89. char precision;
  90. Y_ENSURE(in.ReadChar(precision));
  91. auto res = Create(precision);
  92. if (type) {
  93. ::Load(&in, res.GetMutableSetRef());
  94. } else {
  95. res.Var = THybridHll::Load(in);
  96. }
  97. return res;
  98. }
  99. private:
  100. std::variant<THybridHll, THybridSet> Var;
  101. size_t SizeLimit;
  102. unsigned Precision;
  103. };
  104. extern const char HyperLogLogResourceName[] = "HyperLogLog.State";
  105. using THyperLogLogResource = TBoxedResource<THybridHyperLogLog, HyperLogLogResourceName>;
  106. class THyperLogLog_Create: public TBoxedValue {
  107. public:
  108. THyperLogLog_Create(TSourcePosition pos)
  109. : Pos_(pos)
  110. {}
  111. static const TStringRef& Name() {
  112. static auto nameRef = TStringRef::Of("Create");
  113. return nameRef;
  114. }
  115. private:
  116. TUnboxedValue Run(
  117. const IValueBuilder*,
  118. const TUnboxedValuePod* args) const override {
  119. try {
  120. THolder<THyperLogLogResource> hll(new THyperLogLogResource(THybridHyperLogLog::Create(args[1].Get<ui32>())));
  121. hll->Get()->Update(args[0].Get<ui64>());
  122. return TUnboxedValuePod(hll.Release());
  123. } catch (const std::exception& e) {
  124. UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
  125. }
  126. }
  127. public:
  128. static bool DeclareSignature(
  129. const TStringRef& name,
  130. TType* userType,
  131. IFunctionTypeInfoBuilder& builder,
  132. bool typesOnly) {
  133. Y_UNUSED(userType);
  134. if (Name() == name) {
  135. builder.SimpleSignature<TResource<HyperLogLogResourceName>(ui64, ui32)>();
  136. if (!typesOnly) {
  137. builder.Implementation(new THyperLogLog_Create(builder.GetSourcePosition()));
  138. }
  139. return true;
  140. } else {
  141. return false;
  142. }
  143. }
  144. private:
  145. TSourcePosition Pos_;
  146. };
  147. class THyperLogLog_AddValue: public TBoxedValue {
  148. public:
  149. THyperLogLog_AddValue(TSourcePosition pos)
  150. : Pos_(pos)
  151. {}
  152. static const TStringRef& Name() {
  153. static auto nameRef = TStringRef::Of("AddValue");
  154. return nameRef;
  155. }
  156. private:
  157. TUnboxedValue Run(
  158. const IValueBuilder* valueBuilder,
  159. const TUnboxedValuePod* args) const override {
  160. try {
  161. Y_UNUSED(valueBuilder);
  162. THyperLogLogResource* resource = static_cast<THyperLogLogResource*>(args[0].AsBoxed().Get());
  163. resource->Get()->Update(args[1].Get<ui64>());
  164. return TUnboxedValuePod(args[0]);
  165. } catch (const std::exception& e) {
  166. UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
  167. }
  168. }
  169. public:
  170. static bool DeclareSignature(
  171. const TStringRef& name,
  172. TType* userType,
  173. IFunctionTypeInfoBuilder& builder,
  174. bool typesOnly) {
  175. Y_UNUSED(userType);
  176. if (Name() == name) {
  177. builder.SimpleSignature<TResource<HyperLogLogResourceName>(TResource<HyperLogLogResourceName>, ui64)>();
  178. if (!typesOnly) {
  179. builder.Implementation(new THyperLogLog_AddValue(builder.GetSourcePosition()));
  180. }
  181. builder.IsStrict();
  182. return true;
  183. } else {
  184. return false;
  185. }
  186. }
  187. private:
  188. TSourcePosition Pos_;
  189. };
  190. class THyperLogLog_Serialize: public TBoxedValue {
  191. public:
  192. THyperLogLog_Serialize(TSourcePosition pos)
  193. : Pos_(pos)
  194. {}
  195. public:
  196. static const TStringRef& Name() {
  197. static auto nameRef = TStringRef::Of("Serialize");
  198. return nameRef;
  199. }
  200. private:
  201. TUnboxedValue Run(
  202. const IValueBuilder* valueBuilder,
  203. const TUnboxedValuePod* args) const override {
  204. try {
  205. TStringStream result;
  206. static_cast<THyperLogLogResource*>(args[0].AsBoxed().Get())->Get()->Save(result);
  207. return valueBuilder->NewString(result.Str());
  208. } catch (const std::exception& e) {
  209. UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
  210. }
  211. }
  212. public:
  213. static bool DeclareSignature(
  214. const TStringRef& name,
  215. TType* userType,
  216. IFunctionTypeInfoBuilder& builder,
  217. bool typesOnly) {
  218. Y_UNUSED(userType);
  219. if (Name() == name) {
  220. builder.SimpleSignature<char*(TResource<HyperLogLogResourceName>)>();
  221. if (!typesOnly) {
  222. builder.Implementation(new THyperLogLog_Serialize(builder.GetSourcePosition()));
  223. }
  224. return true;
  225. } else {
  226. return false;
  227. }
  228. }
  229. private:
  230. TSourcePosition Pos_;
  231. };
  232. class THyperLogLog_Deserialize: public TBoxedValue {
  233. public:
  234. THyperLogLog_Deserialize(TSourcePosition pos)
  235. : Pos_(pos)
  236. {}
  237. static const TStringRef& Name() {
  238. static auto nameRef = TStringRef::Of("Deserialize");
  239. return nameRef;
  240. }
  241. private:
  242. TUnboxedValue Run(
  243. const IValueBuilder* valueBuilder,
  244. const TUnboxedValuePod* args) const override {
  245. try {
  246. Y_UNUSED(valueBuilder);
  247. const TString arg(args[0].AsStringRef());
  248. TStringInput input(arg);
  249. THolder<THyperLogLogResource> hll(new THyperLogLogResource(THybridHyperLogLog::Load(input)));
  250. return TUnboxedValuePod(hll.Release());
  251. } catch (const std::exception& e) {
  252. UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
  253. }
  254. }
  255. public:
  256. static bool DeclareSignature(
  257. const TStringRef& name,
  258. TType* userType,
  259. IFunctionTypeInfoBuilder& builder,
  260. bool typesOnly) {
  261. Y_UNUSED(userType);
  262. if (Name() == name) {
  263. builder.SimpleSignature<TResource<HyperLogLogResourceName>(char*)>();
  264. if (!typesOnly) {
  265. builder.Implementation(new THyperLogLog_Deserialize(builder.GetSourcePosition()));
  266. }
  267. return true;
  268. } else {
  269. return false;
  270. }
  271. }
  272. private:
  273. TSourcePosition Pos_;
  274. };
  275. class THyperLogLog_Merge: public TBoxedValue {
  276. public:
  277. THyperLogLog_Merge(TSourcePosition pos)
  278. : Pos_(pos)
  279. {}
  280. static const TStringRef& Name() {
  281. static auto nameRef = TStringRef::Of("Merge");
  282. return nameRef;
  283. }
  284. private:
  285. TUnboxedValue Run(
  286. const IValueBuilder* valueBuilder,
  287. const TUnboxedValuePod* args) const override {
  288. try {
  289. Y_UNUSED(valueBuilder);
  290. auto left = static_cast<THyperLogLogResource*>(args[0].AsBoxed().Get())->Get();
  291. static_cast<THyperLogLogResource*>(args[1].AsBoxed().Get())->Get()->Merge(*left);
  292. return TUnboxedValuePod(args[1]);
  293. } catch (const std::exception& e) {
  294. UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
  295. }
  296. }
  297. public:
  298. static bool DeclareSignature(
  299. const TStringRef& name,
  300. TType* userType,
  301. IFunctionTypeInfoBuilder& builder,
  302. bool typesOnly) {
  303. Y_UNUSED(userType);
  304. if (Name() == name) {
  305. builder.SimpleSignature<TResource<HyperLogLogResourceName>(TResource<HyperLogLogResourceName>, TResource<HyperLogLogResourceName>)>();
  306. if (!typesOnly) {
  307. builder.Implementation(new THyperLogLog_Merge(builder.GetSourcePosition()));
  308. }
  309. builder.IsStrict();
  310. return true;
  311. } else {
  312. return false;
  313. }
  314. }
  315. private:
  316. TSourcePosition Pos_;
  317. };
  318. class THyperLogLog_GetResult: public TBoxedValue {
  319. public:
  320. THyperLogLog_GetResult(TSourcePosition pos)
  321. : Pos_(pos)
  322. {}
  323. static const TStringRef& Name() {
  324. static auto nameRef = TStringRef::Of("GetResult");
  325. return nameRef;
  326. }
  327. private:
  328. TUnboxedValue Run(
  329. const IValueBuilder* valueBuilder,
  330. const TUnboxedValuePod* args) const override {
  331. Y_UNUSED(valueBuilder);
  332. auto hll = static_cast<THyperLogLogResource*>(args[0].AsBoxed().Get())->Get();
  333. return TUnboxedValuePod(hll->Estimate());
  334. }
  335. public:
  336. static bool DeclareSignature(
  337. const TStringRef& name,
  338. TType* userType,
  339. IFunctionTypeInfoBuilder& builder,
  340. bool typesOnly) {
  341. Y_UNUSED(userType);
  342. if (Name() == name) {
  343. auto resource = builder.Resource(HyperLogLogResourceName);
  344. builder.Args()->Add(resource).Done().Returns<ui64>();
  345. if (!typesOnly) {
  346. builder.Implementation(new THyperLogLog_GetResult(builder.GetSourcePosition()));
  347. }
  348. builder.IsStrict();
  349. return true;
  350. } else {
  351. return false;
  352. }
  353. }
  354. private:
  355. TSourcePosition Pos_;
  356. };
  357. SIMPLE_MODULE(THyperLogLogModule,
  358. THyperLogLog_Create,
  359. THyperLogLog_AddValue,
  360. THyperLogLog_Serialize,
  361. THyperLogLog_Deserialize,
  362. THyperLogLog_Merge,
  363. THyperLogLog_GetResult)
  364. }
  365. REGISTER_MODULES(THyperLogLogModule)