block_groupby.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. #include <util/datetime/cputimer.h>
  2. #include <yql/essentials/minikql/comp_nodes/mkql_rh_hash.h>
  3. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  4. #include <arrow/array/builder_primitive.h>
  5. #include <arrow/datum.h>
  6. #include <library/cpp/getopt/last_getopt.h>
  7. #include <util/digest/fnv.h>
  8. #include <util/digest/murmur.h>
  9. #include <util/digest/city.h>
  10. enum class EDistribution {
  11. Const,
  12. Linear,
  13. Random,
  14. Few,
  15. RandomFew
  16. };
  17. enum class EShape {
  18. Default,
  19. Sqrt,
  20. Log
  21. };
  22. arrow::Datum MakeIntColumn(ui32 len, EDistribution dist, EShape shape, ui32 buckets) {
  23. arrow::Int32Builder builder;
  24. ARROW_OK(builder.Reserve(len));
  25. for (ui32 i = 0; i < len; ++i) {
  26. ui32 val;
  27. switch (shape) {
  28. case EShape::Default:
  29. val = i;
  30. break;
  31. case EShape::Sqrt:
  32. val = (ui32)sqrt(i);
  33. break;
  34. case EShape::Log:
  35. val = (ui32)log(1 + i);
  36. break;
  37. }
  38. switch (dist) {
  39. case EDistribution::Const:
  40. builder.UnsafeAppend(0);
  41. break;
  42. case EDistribution::Few:
  43. builder.UnsafeAppend(val % buckets);
  44. break;
  45. case EDistribution::Linear:
  46. builder.UnsafeAppend(val);
  47. break;
  48. case EDistribution::Random:
  49. builder.UnsafeAppend(IntHash(val));
  50. break;
  51. case EDistribution::RandomFew:
  52. builder.UnsafeAppend(IntHash(val) % buckets);
  53. break;
  54. }
  55. }
  56. std::shared_ptr<arrow::ArrayData> result;
  57. ARROW_OK(builder.FinishInternal(&result));
  58. return arrow::Datum(result);
  59. }
  60. class IAggregator {
  61. public:
  62. virtual ~IAggregator() = default;
  63. virtual void Init(i64* state, i32 payload) = 0;
  64. virtual void Update(i64* state, i32 payload) = 0;
  65. };
  66. class TSumAggregator : public IAggregator {
  67. public:
  68. void Init(i64* state, i32 payload) final {
  69. *state = payload;
  70. }
  71. void Update(i64* state, i32 payload) final {
  72. *state += payload;
  73. }
  74. };
  75. template <typename T>
  76. struct TCityHasher {
  77. public:
  78. ui64 operator()(const T& x) const {
  79. return CityHash64(TStringBuf((char*)&x, sizeof(x)));
  80. }
  81. };
  82. // sum(payloads) group by keys
  83. template <bool CalculateHashStats, bool UseRH>
  84. class TAggregate {
  85. private:
  86. struct TOneCell {
  87. i32 Key = 0;
  88. bool IsEmpty = true;
  89. i64 State = 0;
  90. };
  91. struct TCell {
  92. i32 Key = 0;
  93. i32 PSL = -1;
  94. i64 State = 0;
  95. };
  96. public:
  97. TAggregate(const std::vector<IAggregator*>& aggs)
  98. : Aggs(aggs)
  99. , RH(sizeof(i64))
  100. {
  101. Cells.resize(1u << 8);
  102. }
  103. void AddBatch(arrow::Datum keys, arrow::Datum payloads) {
  104. auto arrKeys = keys.array();
  105. auto arrPayloads = payloads.array();
  106. auto len = arrKeys->length;
  107. const i32* ptrKeys = arrKeys->GetValues<i32>(1);
  108. const i32* ptrPayloads = arrPayloads->GetValues<i32>(1);
  109. for (int64_t i = 0; i < len; ++i) {
  110. auto key = ptrKeys[i];
  111. auto payload = ptrPayloads[i];
  112. if (!MoreThanOne) {
  113. if (One.IsEmpty) {
  114. One.IsEmpty = false;
  115. One.Key = key;
  116. for (const auto& a : Aggs) {
  117. a->Init(&One.State, payload);
  118. }
  119. Size = 1;
  120. continue;
  121. } else {
  122. if (key == One.Key) {
  123. for (const auto& a : Aggs) {
  124. a->Update(&One.State, payload);
  125. }
  126. continue;
  127. } else {
  128. MoreThanOne = true;
  129. if constexpr (UseRH) {
  130. bool isNew;
  131. auto iter = RH.Insert(One.Key, isNew);
  132. Y_ASSERT(isNew);
  133. *(i64*)RH.GetPayload(iter) = One.State;
  134. } else {
  135. bool isNew;
  136. ui64 bucket = AddBucketFromKeyImpl(One.Key, Cells, isNew);
  137. auto& c = Cells[bucket];
  138. c.PSL = 0;
  139. c.Key = One.Key;
  140. c.State = One.State;
  141. }
  142. }
  143. }
  144. }
  145. if constexpr (UseRH) {
  146. bool isNew = false;
  147. auto iter = RH.Insert(key, isNew);
  148. if (isNew) {
  149. for (const auto& a : Aggs) {
  150. a->Init((i64*)RH.GetPayload(iter), payload);
  151. }
  152. RH.CheckGrow();
  153. } else {
  154. for (const auto& a : Aggs) {
  155. a->Update((i64*)RH.GetPayload(iter), payload);
  156. }
  157. }
  158. } else {
  159. bool isNew = false;
  160. ui64 bucket = AddBucketFromKey(key, isNew);
  161. auto& c = Cells[bucket];
  162. if (isNew) {
  163. Size += 1;
  164. for (const auto& a : Aggs) {
  165. a->Init(&c.State, payload);
  166. }
  167. if (Size * 2 >= Cells.size()) {
  168. Grow();
  169. }
  170. } else {
  171. for (const auto& a : Aggs) {
  172. a->Update(&c.State, payload);
  173. }
  174. }
  175. }
  176. }
  177. }
  178. static ui64 MakeHash(i32 key) {
  179. //auto hash = FnvHash<ui64>(&key, sizeof(key));
  180. //auto hash = MurmurHash<ui64>(&key, sizeof(key));
  181. auto hash = CityHash64(TStringBuf((char*)&key, sizeof(key)));
  182. //auto hash = key;
  183. return hash;
  184. }
  185. Y_FORCE_INLINE ui64 AddBucketFromKey(i32 key, bool& isNew) {
  186. return AddBucketFromKeyImpl(key, Cells, isNew);
  187. }
  188. Y_FORCE_INLINE ui64 AddBucketFromKeyImpl(i32 key, std::vector<TCell>& cells, bool& isNew) {
  189. isNew = false;
  190. ui32 chainLen = 0;
  191. if constexpr (CalculateHashStats) {
  192. HashSearches++;
  193. }
  194. ui64 bucket = MakeHash(key) & (cells.size() - 1);
  195. i32 distance = 0;
  196. ui64 returnBucket;
  197. i64 oldState;
  198. for (;;) {
  199. if constexpr (CalculateHashStats) {
  200. HashProbes++;
  201. chainLen++;
  202. }
  203. if (cells[bucket].PSL < 0) {
  204. isNew = true;
  205. cells[bucket].Key = key;
  206. cells[bucket].PSL = distance;
  207. if constexpr (CalculateHashStats) {
  208. MaxHashChainLen = Max(MaxHashChainLen, chainLen);
  209. }
  210. return bucket;
  211. }
  212. if (cells[bucket].Key == key) {
  213. if constexpr (CalculateHashStats) {
  214. MaxHashChainLen = Max(MaxHashChainLen, chainLen);
  215. }
  216. return bucket;
  217. }
  218. if (distance > cells[bucket].PSL) {
  219. // swap keys & state
  220. returnBucket = bucket;
  221. oldState = cells[bucket].State;
  222. std::swap(key, cells[bucket].Key);
  223. std::swap(distance, cells[bucket].PSL);
  224. isNew = true;
  225. ++distance;
  226. bucket = (bucket + 1) & (cells.size() - 1);
  227. break;
  228. }
  229. ++distance;
  230. bucket = (bucket + 1) & (cells.size() - 1);
  231. }
  232. for (;;) {
  233. if constexpr (CalculateHashStats) {
  234. HashProbes++;
  235. chainLen++;
  236. }
  237. if (cells[bucket].PSL < 0) {
  238. if constexpr (CalculateHashStats) {
  239. MaxHashChainLen = Max(MaxHashChainLen, chainLen);
  240. }
  241. cells[bucket].Key = key;
  242. cells[bucket].State = oldState;
  243. cells[bucket].PSL = distance;
  244. return returnBucket; // for original key
  245. }
  246. Y_ENSURE(cells[bucket].Key != key);
  247. if (distance > cells[bucket].PSL) {
  248. // swap keys & state
  249. std::swap(key, cells[bucket].Key);
  250. std::swap(oldState, cells[bucket].State);
  251. std::swap(distance, cells[bucket].PSL);
  252. }
  253. ++distance;
  254. bucket = (bucket + 1) & (cells.size() - 1);
  255. }
  256. }
  257. void Grow() {
  258. std::vector<TCell> newCells;
  259. newCells.resize(Cells.size() * 2); // must be power of 2
  260. for (const auto& c : Cells) {
  261. if (c.PSL < 0) {
  262. continue;
  263. }
  264. bool isNew;
  265. auto newBucket = AddBucketFromKeyImpl(c.Key, newCells, isNew);
  266. auto& nc = newCells[newBucket];
  267. nc.State = c.State;
  268. }
  269. Cells.swap(newCells);
  270. }
  271. double GetAverageHashChainLen() {
  272. return 1.0*HashProbes/HashSearches;
  273. }
  274. ui32 GetMaxHashChainLen() {
  275. return MaxHashChainLen;
  276. }
  277. void GetResult(arrow::Datum& keys, arrow::Datum& sums) {
  278. arrow::Int32Builder keysBuilder;
  279. arrow::Int64Builder sumsBuilder;
  280. if (!MoreThanOne) {
  281. if (!One.IsEmpty) {
  282. ARROW_OK(keysBuilder.Reserve(1));
  283. ARROW_OK(sumsBuilder.Reserve(1));
  284. keysBuilder.UnsafeAppend(One.Key);
  285. sumsBuilder.UnsafeAppend(One.State);
  286. }
  287. } else {
  288. ui64 size;
  289. if constexpr (UseRH) {
  290. size = RH.GetSize();
  291. } else {
  292. size = Size;
  293. }
  294. ARROW_OK(keysBuilder.Reserve(size));
  295. ARROW_OK(sumsBuilder.Reserve(size));
  296. i32 maxPSL = 0;
  297. i64 sumPSL = 0;
  298. if constexpr (UseRH) {
  299. for (auto iter = RH.Begin(); iter != RH.End(); RH.Advance(iter)) {
  300. auto& psl = RH.GetPSL(iter);
  301. if (psl.Distance < 0) {
  302. continue;
  303. }
  304. keysBuilder.UnsafeAppend(RH.GetKey(iter));
  305. sumsBuilder.UnsafeAppend(*(i64*)RH.GetPayload(iter));
  306. maxPSL = Max(psl.Distance, maxPSL);
  307. sumPSL += psl.Distance;
  308. }
  309. } else {
  310. for (const auto& c : Cells) {
  311. if (c.PSL < 0) {
  312. continue;
  313. }
  314. keysBuilder.UnsafeAppend(c.Key);
  315. sumsBuilder.UnsafeAppend(c.State);
  316. maxPSL = Max(c.PSL, maxPSL);
  317. sumPSL += c.PSL;
  318. }
  319. }
  320. if constexpr (CalculateHashStats) {
  321. Cerr << "maxPSL = " << maxPSL << "\n";
  322. Cerr << "avgPSL = " << 1.0*sumPSL/size << "\n";
  323. }
  324. }
  325. std::shared_ptr<arrow::ArrayData> keysData;
  326. ARROW_OK(keysBuilder.FinishInternal(&keysData));
  327. keys = keysData;
  328. std::shared_ptr<arrow::ArrayData> sumsData;
  329. ARROW_OK(sumsBuilder.FinishInternal(&sumsData));
  330. sums = sumsData;
  331. }
  332. private:
  333. bool MoreThanOne = false;
  334. TOneCell One;
  335. std::vector<TCell> Cells;
  336. ui64 Size = 0;
  337. const std::vector<IAggregator*> Aggs;
  338. ui64 HashProbes = 0;
  339. ui64 HashSearches = 0;
  340. ui32 MaxHashChainLen = 0;
  341. NKikimr::NMiniKQL::TRobinHoodHashMap<i32> RH;
  342. NKikimr::NMiniKQL::TRobinHoodHashSet<i32> RHS;
  343. };
  344. int main(int argc, char** argv) {
  345. NLastGetopt::TOpts opts = NLastGetopt::TOpts::Default();
  346. TString keysDistributionStr;
  347. TString shapeStr="default";
  348. ui32 nIters = 100;
  349. ui32 nRows = 1000000;
  350. ui32 nBuckets = 16;
  351. ui32 nRepeats = 10;
  352. opts.AddLongOption('k', "keys", "distribution of keys (const, linear, random, few, randomfew)").StoreResult(&keysDistributionStr).Required();
  353. opts.AddLongOption('s', "shape", "shape of counter (default, sqrt, log)").StoreResult(&shapeStr);
  354. opts.AddLongOption('i', "iter", "# of iterations").StoreResult(&nIters);
  355. opts.AddLongOption('r', "rows", "# of rows").StoreResult(&nRows);
  356. opts.AddLongOption('b', "buckets", "modulo for few/randomfew").StoreResult(&nBuckets);
  357. opts.AddLongOption('t', "repeats", "# of repeats").StoreResult(&nRepeats);
  358. opts.SetFreeArgsMax(0);
  359. NLastGetopt::TOptsParseResult res(&opts, argc, argv);
  360. EDistribution keysDist;
  361. EShape shape = EShape::Default;
  362. if (keysDistributionStr == "const") {
  363. keysDist = EDistribution::Const;
  364. } else if (keysDistributionStr == "linear") {
  365. keysDist = EDistribution::Linear;
  366. } else if (keysDistributionStr == "random") {
  367. keysDist = EDistribution::Random;
  368. } else if (keysDistributionStr == "few") {
  369. keysDist = EDistribution::Few;
  370. } else if (keysDistributionStr == "randomfew") {
  371. keysDist = EDistribution::RandomFew;
  372. } else {
  373. ythrow yexception() << "Unsupported distribution: " << keysDistributionStr;
  374. }
  375. if (shapeStr == "default") {
  376. shape = EShape::Default;
  377. } else if (shapeStr == "sqrt") {
  378. shape = EShape::Sqrt;
  379. } else if (shapeStr == "log") {
  380. shape = EShape::Log;
  381. } else {
  382. ythrow yexception() << "Unsupported shape: " << shapeStr;
  383. }
  384. auto col1 = MakeIntColumn(nRows, keysDist, shape, nBuckets);
  385. auto col2 = MakeIntColumn(nRows, EDistribution::Linear, EShape::Default, nBuckets);
  386. Cerr << "col1.length: " << col1.length() << "\n";
  387. Cerr << "col2.length: " << col2.length() << "\n";
  388. TSumAggregator sum;
  389. std::vector<IAggregator*> aggs;
  390. aggs.push_back(&sum);
  391. TAggregate<true, true> agg(aggs);
  392. agg.AddBatch(col1, col2);
  393. arrow::Datum keys, sums;
  394. agg.GetResult(keys, sums);
  395. ui64 total1 = 0;
  396. for (ui32 i = 0; i < col2.length(); ++i) {
  397. total1 += col2.array()->GetValues<i32>(1)[i];
  398. }
  399. Cerr << "total1: " << total1 << "\n";
  400. ui64 total2 = 0;
  401. Cerr << "keys.length: " << keys.length() << "\n";
  402. Cerr << "sums.length: " << sums.length() << "\n";
  403. for (ui32 i = 0; i < sums.length(); ++i) {
  404. total2 += sums.array()->GetValues<i64>(1)[i];
  405. }
  406. Cerr << "total2: " << total2 << "\n";
  407. Y_ENSURE(total1 == total2);
  408. Cerr << "AverageHashChainLen: " << agg.GetAverageHashChainLen() << "\n";
  409. Cerr << "MaxHashChainLen: " << agg.GetMaxHashChainLen() << "\n";
  410. std::vector<double> durations;
  411. for (ui32 j = 0; j < nRepeats; ++j) {
  412. TSimpleTimer timer;
  413. for (ui32 i = 0; i < nIters; ++i) {
  414. TAggregate<false, true> agg(aggs);
  415. agg.AddBatch(col1, col2);
  416. arrow::Datum keys, sums;
  417. agg.GetResult(keys, sums);
  418. }
  419. auto duration = timer.Get();
  420. durations.push_back(1e-6*duration.MicroSeconds());
  421. }
  422. double sumDurations = 0.0, sumDurationsQ = 0.0;
  423. for (auto d : durations) {
  424. sumDurations += d;
  425. sumDurationsQ += d * d;
  426. }
  427. double avgDuration = sumDurations / nRepeats;
  428. double dispDuration = sqrt(sumDurationsQ / nRepeats - avgDuration * avgDuration);
  429. Cerr << "Elapsed: " << avgDuration << ", noise: " << 100*dispDuration/avgDuration << "%\n";
  430. Cerr << "Speed: " << 1e-6 * (ui64(nIters) * nRows / avgDuration) << " M rows/sec\n";
  431. Cerr << "Speed: " << 1e-6 * (2 * sizeof(i32) * ui64(nIters) * nRows / avgDuration) << " M bytes/sec\n";
  432. return 0;
  433. }