mkql_rh_hash.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. #pragma once
  2. #include <util/system/compiler.h>
  3. #include <util/system/types.h>
  4. #include <util/generic/bitops.h>
  5. #include <util/generic/yexception.h>
  6. #include <vector>
  7. #include <span>
  8. #include <yql/essentials/minikql/mkql_rh_hash_utils.h>
  9. #include <yql/essentials/utils/prefetch.h>
  10. #include <util/digest/city.h>
  11. #include <util/generic/scope.h>
  12. namespace NKikimr {
  13. namespace NMiniKQL {
  14. template <class TKey>
  15. struct TRobinHoodDefaultSettings {
  16. static constexpr bool CacheHash = !std::is_arithmetic<TKey>::value;
  17. };
  18. template <typename TKey>
  19. struct TRobinHoodBatchRequestItem {
  20. // input
  21. alignas(TKey) char KeyStorage[sizeof(TKey)];
  22. const TKey& GetKey() const {
  23. return *reinterpret_cast<const TKey*>(KeyStorage);
  24. }
  25. void ConstructKey(const TKey& key) {
  26. new (KeyStorage) TKey(key);
  27. }
  28. // intermediate data
  29. ui64 Hash;
  30. char* InitialIterator;
  31. };
  32. constexpr ui32 PrefetchBatchSize = 64;
  33. //TODO: only POD key & payloads are now supported
  34. template <typename TKey, typename TEqual, typename THash, typename TAllocator, typename TDeriv, bool CacheHash>
  35. class TRobinHoodHashBase {
  36. public:
  37. using iterator = char*;
  38. using const_iterator = const char*;
  39. protected:
  40. THash HashLocal;
  41. TEqual EqualLocal;
  42. template <bool CacheHashForPSL>
  43. struct TPSLStorageImpl;
  44. template <>
  45. struct TPSLStorageImpl<true> {
  46. i32 Distance = -1;
  47. ui64 Hash = 0;
  48. TPSLStorageImpl() = default;
  49. TPSLStorageImpl(const ui64 hash)
  50. : Distance(0)
  51. , Hash(hash) {
  52. }
  53. };
  54. template <>
  55. struct TPSLStorageImpl<false> {
  56. i32 Distance = -1;
  57. TPSLStorageImpl() = default;
  58. TPSLStorageImpl(const ui64 /*hash*/)
  59. : Distance(0) {
  60. }
  61. };
  62. using TPSLStorage = TPSLStorageImpl<CacheHash>;
  63. explicit TRobinHoodHashBase(const ui64 initialCapacity, THash hash, TEqual equal)
  64. : HashLocal(std::move(hash))
  65. , EqualLocal(std::move(equal))
  66. , Capacity(initialCapacity)
  67. , CapacityShift(64 - MostSignificantBit(initialCapacity))
  68. , Allocator()
  69. , SelfHash(GetSelfHash(this))
  70. {
  71. Y_ENSURE((Capacity & (Capacity - 1)) == 0);
  72. }
  73. ~TRobinHoodHashBase() {
  74. if (Data) {
  75. Allocator.deallocate(Data, DataEnd - Data);
  76. }
  77. }
  78. TRobinHoodHashBase(const TRobinHoodHashBase&) = delete;
  79. TRobinHoodHashBase(TRobinHoodHashBase&&) = delete;
  80. void operator=(const TRobinHoodHashBase&) = delete;
  81. void operator=(TRobinHoodHashBase&&) = delete;
  82. public:
  83. // returns iterator
  84. Y_FORCE_INLINE char* Insert(TKey key, bool& isNew) {
  85. auto hash = HashLocal(key);
  86. auto ptr = MakeIterator(hash, Data, CapacityShift);
  87. auto ret = InsertImpl(key, hash, isNew, Data, DataEnd, ptr);
  88. Size += isNew ? 1 : 0;
  89. return ret;
  90. }
  91. // should be called after Insert if isNew is true
  92. Y_FORCE_INLINE void CheckGrow() {
  93. if (RHHashTableNeedsGrow(Size, Capacity)) {
  94. Grow();
  95. }
  96. }
  97. // returns iterator or nullptr if key is not present
  98. Y_FORCE_INLINE char* Lookup(TKey key) {
  99. auto hash = HashLocal(key);
  100. auto ptr = MakeIterator(hash, Data, CapacityShift);
  101. auto ret = LookupImpl(key, hash, Data, DataEnd, ptr);
  102. return ret;
  103. }
  104. template <typename TSink>
  105. Y_NO_INLINE void BatchInsert(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) {
  106. while (RHHashTableNeedsGrow(Size + batchRequest.size(), Capacity)) {
  107. Grow();
  108. }
  109. for (size_t i = 0; i < batchRequest.size(); ++i) {
  110. auto& r = batchRequest[i];
  111. r.Hash = HashLocal(r.GetKey());
  112. r.InitialIterator = MakeIterator(r.Hash, Data, CapacityShift);
  113. NYql::PrefetchForWrite(r.InitialIterator);
  114. }
  115. for (size_t i = 0; i < batchRequest.size(); ++i) {
  116. auto& r = batchRequest[i];
  117. bool isNew;
  118. auto iter = InsertImpl(r.GetKey(), r.Hash, isNew, Data, DataEnd, r.InitialIterator);
  119. Size += isNew ? 1 : 0;
  120. sink(i, iter, isNew);
  121. }
  122. }
  123. template <typename TSink>
  124. Y_NO_INLINE void BatchLookup(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) {
  125. for (size_t i = 0; i < batchRequest.size(); ++i) {
  126. auto& r = batchRequest[i];
  127. r.Hash = HashLocal(r.GetKey());
  128. r.InitialIterator = MakeIterator(r.Hash, Data, CapacityShift);
  129. NYql::PrefetchForRead(r.InitialIterator);
  130. }
  131. for (size_t i = 0; i < batchRequest.size(); ++i) {
  132. auto& r = batchRequest[i];
  133. auto iter = LookupImpl(r.GetKey(), r.Hash, Data, DataEnd, r.InitialIterator);
  134. sink(i, iter);
  135. }
  136. }
  137. ui64 GetCapacity() const {
  138. return Capacity;
  139. }
  140. void Clear() {
  141. char* ptr = Data;
  142. for (ui64 i = 0; i < Capacity; ++i) {
  143. GetPSL(ptr).Distance = -1;
  144. ptr += AsDeriv().GetCellSize();
  145. }
  146. Size = 0;
  147. }
  148. bool Empty() const {
  149. return !Size;
  150. }
  151. ui64 GetSize() const {
  152. return Size;
  153. }
  154. const char* Begin() const {
  155. return Data;
  156. }
  157. const char* End() const {
  158. return DataEnd;
  159. }
  160. char* Begin() {
  161. return Data;
  162. }
  163. char* End() {
  164. return DataEnd;
  165. }
  166. void Advance(char*& ptr) const {
  167. ptr += AsDeriv().GetCellSize();
  168. }
  169. void Advance(const char*& ptr) const {
  170. ptr += AsDeriv().GetCellSize();
  171. }
  172. bool IsValid(const char* ptr) {
  173. return GetPSL(ptr).Distance >= 0;
  174. }
  175. static const TPSLStorage& GetPSL(const char* ptr) {
  176. return *(const TPSLStorage*)ptr;
  177. }
  178. static const TKey& GetKey(const char* ptr) {
  179. return *(const TKey*)(ptr + sizeof(TPSLStorage));
  180. }
  181. static TKey& GetKey(char* ptr) {
  182. return *(TKey*)(ptr + sizeof(TPSLStorage));
  183. }
  184. const void* GetPayload(const char* ptr) {
  185. return AsDeriv().GetPayloadImpl(ptr);
  186. }
  187. static TPSLStorage& GetPSL(char* ptr) {
  188. return *(TPSLStorage*)ptr;
  189. }
  190. void* GetMutablePayload(char* ptr) {
  191. return AsDeriv().GetPayloadImpl(ptr);
  192. }
  193. private:
  194. struct TInternalBatchRequestItem : TRobinHoodBatchRequestItem<TKey> {
  195. char* OriginalIterator;
  196. };
  197. Y_FORCE_INLINE char* MakeIterator(const ui64 hash, char* data, ui64 capacityShift) {
  198. // https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/
  199. ui64 bucket = ((SelfHash ^ hash) * 11400714819323198485llu) >> capacityShift;
  200. char* ptr = data + AsDeriv().GetCellSize() * bucket;
  201. return ptr;
  202. }
  203. Y_FORCE_INLINE char* InsertImpl(TKey key, const ui64 hash, bool& isNew, char* data, char* dataEnd, char* ptr) {
  204. isNew = false;
  205. TPSLStorage psl(hash);
  206. char* returnPtr;
  207. typename TDeriv::TPayloadStore tmpPayload;
  208. for (;;) {
  209. auto& pslPtr = GetPSL(ptr);
  210. if (pslPtr.Distance < 0) {
  211. isNew = true;
  212. pslPtr = psl;
  213. GetKey(ptr) = key;
  214. return ptr;
  215. }
  216. if constexpr (CacheHash) {
  217. if (pslPtr.Hash == psl.Hash && EqualLocal(GetKey(ptr), key)) {
  218. return ptr;
  219. }
  220. } else {
  221. if (EqualLocal(GetKey(ptr), key)) {
  222. return ptr;
  223. }
  224. }
  225. if (psl.Distance > pslPtr.Distance) {
  226. // swap keys & state
  227. returnPtr = ptr;
  228. std::swap(psl, pslPtr);
  229. std::swap(key, GetKey(ptr));
  230. AsDeriv().SavePayload(GetPayload(ptr), tmpPayload);
  231. isNew = true;
  232. ++psl.Distance;
  233. AdvancePointer(ptr, data, dataEnd);
  234. break;
  235. }
  236. ++psl.Distance;
  237. AdvancePointer(ptr, data, dataEnd);
  238. }
  239. for (;;) {
  240. auto& pslPtr = GetPSL(ptr);
  241. if (pslPtr.Distance < 0) {
  242. pslPtr = psl;
  243. GetKey(ptr) = key;
  244. AsDeriv().RestorePayload(GetMutablePayload(ptr), tmpPayload);
  245. return returnPtr; // for original key
  246. }
  247. if (psl.Distance > pslPtr.Distance) {
  248. // swap keys & state
  249. std::swap(psl, pslPtr);
  250. std::swap(key, GetKey(ptr));
  251. AsDeriv().SwapPayload(GetMutablePayload(ptr), tmpPayload);
  252. }
  253. ++psl.Distance;
  254. AdvancePointer(ptr, data, dataEnd);
  255. }
  256. }
  257. Y_FORCE_INLINE char* LookupImpl(TKey key, const ui64 hash, char* data, char* dataEnd, char* ptr) {
  258. i32 currDistance = 0;
  259. for (;;) {
  260. auto& pslPtr = GetPSL(ptr);
  261. if (pslPtr.Distance < 0 || currDistance > pslPtr.Distance) {
  262. return nullptr;
  263. }
  264. if constexpr (CacheHash) {
  265. if (pslPtr.Hash == hash && EqualLocal(GetKey(ptr), key)) {
  266. return ptr;
  267. }
  268. } else {
  269. if (EqualLocal(GetKey(ptr), key)) {
  270. return ptr;
  271. }
  272. }
  273. ++currDistance;
  274. AdvancePointer(ptr, data, dataEnd);
  275. }
  276. }
  277. Y_NO_INLINE void Grow() {
  278. auto newCapacity = Capacity * CalculateRHHashTableGrowFactor(Capacity);
  279. auto newCapacityShift = 64 - MostSignificantBit(newCapacity);
  280. char *newData, *newDataEnd;
  281. Allocate(newCapacity, newData, newDataEnd);
  282. Y_DEFER {
  283. Allocator.deallocate(newData, newDataEnd - newData);
  284. };
  285. std::array<TInternalBatchRequestItem, PrefetchBatchSize> batch;
  286. ui32 batchLen = 0;
  287. for (auto iter = Begin(); iter != End(); Advance(iter)) {
  288. if (GetPSL(iter).Distance < 0) {
  289. continue;
  290. }
  291. if (batchLen == batch.size()) {
  292. CopyBatch({batch.data(), batchLen}, newData, newDataEnd);
  293. batchLen = 0;
  294. }
  295. auto& r = batch[batchLen++];
  296. r.ConstructKey(GetKey(iter));
  297. r.OriginalIterator = iter;
  298. if constexpr (CacheHash) {
  299. r.Hash = GetPSL(iter).Hash;
  300. } else {
  301. r.Hash = HashLocal(r.GetKey());
  302. }
  303. r.InitialIterator = MakeIterator(r.Hash, newData, newCapacityShift);
  304. NYql::PrefetchForWrite(r.InitialIterator);
  305. }
  306. CopyBatch({batch.data(), batchLen}, newData, newDataEnd);
  307. Capacity = newCapacity;
  308. CapacityShift = newCapacityShift;
  309. std::swap(Data, newData);
  310. std::swap(DataEnd, newDataEnd);
  311. }
  312. Y_NO_INLINE void CopyBatch(std::span<TInternalBatchRequestItem> batch, char* newData, char* newDataEnd) {
  313. for (auto& r : batch) {
  314. bool isNew;
  315. auto iter = InsertImpl(r.GetKey(), r.Hash, isNew, newData, newDataEnd, r.InitialIterator);
  316. Y_ASSERT(isNew);
  317. AsDeriv().CopyPayload(GetMutablePayload(iter), GetPayload(r.OriginalIterator));
  318. }
  319. }
  320. void AdvancePointer(char*& ptr, char* begin, char* end) const {
  321. ptr += AsDeriv().GetCellSize();
  322. ptr = (ptr == end) ? begin : ptr;
  323. }
  324. static ui64 GetSelfHash(void* self) {
  325. char buf[sizeof(void*)];
  326. *(void**)buf = self;
  327. return CityHash64(buf, sizeof(buf));
  328. }
  329. protected:
  330. void Init() {
  331. Allocate(Capacity, Data, DataEnd);
  332. }
  333. private:
  334. void Allocate(ui64 capacity, char*& data, char*& dataEnd) {
  335. ui64 bytes = capacity * AsDeriv().GetCellSize();
  336. data = Allocator.allocate(bytes);
  337. dataEnd = data + bytes;
  338. char* ptr = data;
  339. for (ui64 i = 0; i < capacity; ++i) {
  340. GetPSL(ptr).Distance = -1;
  341. ptr += AsDeriv().GetCellSize();
  342. }
  343. }
  344. const TDeriv& AsDeriv() const {
  345. return static_cast<const TDeriv&>(*this);
  346. }
  347. TDeriv& AsDeriv() {
  348. return static_cast<TDeriv&>(*this);
  349. }
  350. private:
  351. ui64 Size = 0;
  352. ui64 Capacity;
  353. ui64 CapacityShift;
  354. TAllocator Allocator;
  355. const ui64 SelfHash;
  356. char* Data = nullptr;
  357. char* DataEnd = nullptr;
  358. };
  359. template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = TRobinHoodDefaultSettings<TKey>>
  360. class TRobinHoodHashMap : public TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TRobinHoodHashMap<TKey, TEqual, THash, TAllocator, TSettings>, TSettings::CacheHash> {
  361. public:
  362. using TSelf = TRobinHoodHashMap<TKey, TEqual, THash, TAllocator, TSettings>;
  363. using TBase = TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TSelf, TSettings::CacheHash>;
  364. using TPayloadStore = int;
  365. explicit TRobinHoodHashMap(ui32 payloadSize, ui64 initialCapacity = 1u << 8)
  366. : TBase(initialCapacity, THash(), TEqual())
  367. , CellSize(sizeof(typename TBase::TPSLStorage) + sizeof(TKey) + payloadSize)
  368. , PayloadSize(payloadSize)
  369. {
  370. TmpPayload.resize(PayloadSize);
  371. TmpPayload2.resize(PayloadSize);
  372. TBase::Init();
  373. }
  374. explicit TRobinHoodHashMap(ui32 payloadSize, const THash& hash, const TEqual& equal, ui64 initialCapacity = 1u << 8)
  375. : TBase(initialCapacity, hash, equal)
  376. , CellSize(sizeof(typename TBase::TPSLStorage) + sizeof(TKey) + payloadSize)
  377. , PayloadSize(payloadSize)
  378. {
  379. TmpPayload.resize(PayloadSize);
  380. TmpPayload2.resize(PayloadSize);
  381. TBase::Init();
  382. }
  383. ui32 GetCellSize() const {
  384. return CellSize;
  385. }
  386. void* GetPayloadImpl(char* ptr) {
  387. return ptr + sizeof(typename TBase::TPSLStorage) + sizeof(TKey);
  388. }
  389. const void* GetPayloadImpl(const char* ptr) {
  390. return ptr + sizeof(typename TBase::TPSLStorage) + sizeof(TKey);
  391. }
  392. void CopyPayload(void* dst, const void* src) {
  393. memcpy(dst, src, PayloadSize);
  394. }
  395. void SavePayload(const void* p, int& store) {
  396. Y_UNUSED(store);
  397. memcpy(TmpPayload.data(), p, PayloadSize);
  398. }
  399. void RestorePayload(void* p, const int& store) {
  400. Y_UNUSED(store);
  401. memcpy(p, TmpPayload.data(), PayloadSize);
  402. }
  403. void SwapPayload(void* p, int& store) {
  404. Y_UNUSED(store);
  405. memcpy(TmpPayload2.data(), p, PayloadSize);
  406. memcpy(p, TmpPayload.data(), PayloadSize);
  407. TmpPayload2.swap(TmpPayload);
  408. }
  409. private:
  410. const ui32 CellSize;
  411. const ui32 PayloadSize;
  412. using TVec = std::vector<char, TAllocator>;
  413. TVec TmpPayload, TmpPayload2;
  414. };
  415. template <typename TKey, typename TPayload, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = TRobinHoodDefaultSettings<TKey>>
  416. class TRobinHoodHashFixedMap : public TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TRobinHoodHashFixedMap<TKey, TPayload, TEqual, THash, TAllocator, TSettings>, TSettings::CacheHash> {
  417. public:
  418. using TSelf = TRobinHoodHashFixedMap<TKey, TPayload, TEqual, THash, TAllocator, TSettings>;
  419. using TBase = TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TSelf, TSettings::CacheHash>;
  420. using TPayloadStore = TPayload;
  421. explicit TRobinHoodHashFixedMap(ui64 initialCapacity = 1u << 8)
  422. : TBase(initialCapacity, THash(), TEqual())
  423. {
  424. TBase::Init();
  425. }
  426. explicit TRobinHoodHashFixedMap(const THash& hash, const TEqual& equal, ui64 initialCapacity = 1u << 8)
  427. : TBase(initialCapacity, hash, equal)
  428. {
  429. TBase::Init();
  430. }
  431. static constexpr ui32 GetCellSize() {
  432. return sizeof(typename TBase::TPSLStorage) + sizeof(TKey) + sizeof(TPayload);
  433. }
  434. void* GetPayloadImpl(char* ptr) {
  435. return ptr + sizeof(typename TBase::TPSLStorage) + sizeof(TKey);
  436. }
  437. const void* GetPayloadImpl(const char* ptr) {
  438. return ptr + sizeof(typename TBase::TPSLStorage) + sizeof(TKey);
  439. }
  440. void CopyPayload(void* dst, const void* src) {
  441. *(TPayload*)dst = *(const TPayload*)src;
  442. }
  443. void SavePayload(const void* p, TPayload& store) {
  444. store = *(const TPayload*)p;
  445. }
  446. void RestorePayload(void* p, const TPayload& store) {
  447. *(TPayload*)p = store;
  448. }
  449. void SwapPayload(void* p, TPayload& store) {
  450. std::swap(*(TPayload*)p, store);
  451. }
  452. };
  453. template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = TRobinHoodDefaultSettings<TKey>>
  454. class TRobinHoodHashSet : public TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TRobinHoodHashSet<TKey, TEqual, THash, TAllocator, TSettings>, TSettings::CacheHash> {
  455. public:
  456. using TSelf = TRobinHoodHashSet<TKey, TEqual, THash, TAllocator, TSettings>;
  457. using TBase = TRobinHoodHashBase<TKey, TEqual, THash, TAllocator, TSelf, TSettings::CacheHash>;
  458. using TPayloadStore = int;
  459. explicit TRobinHoodHashSet(THash hash, TEqual equal, ui64 initialCapacity = 1u << 8)
  460. : TBase(initialCapacity, hash, equal) {
  461. TBase::Init();
  462. }
  463. explicit TRobinHoodHashSet(ui64 initialCapacity = 1u << 8)
  464. : TBase(initialCapacity, THash(), TEqual()) {
  465. TBase::Init();
  466. }
  467. static constexpr ui32 GetCellSize() {
  468. return sizeof(typename TBase::TPSLStorage) + sizeof(TKey);
  469. }
  470. void* GetPayloadImpl(char* ptr) {
  471. Y_UNUSED(ptr);
  472. return nullptr;
  473. }
  474. const void* GetPayloadImpl(const char* ptr) {
  475. Y_UNUSED(ptr);
  476. return nullptr;
  477. }
  478. void CopyPayload(void* dst, const void* src) {
  479. Y_UNUSED(dst);
  480. Y_UNUSED(src);
  481. }
  482. void SavePayload(const void* p, int& store) {
  483. Y_UNUSED(p);
  484. Y_UNUSED(store);
  485. }
  486. void RestorePayload(void* p, const int& store) {
  487. Y_UNUSED(p);
  488. Y_UNUSED(store);
  489. }
  490. void SwapPayload(void* p, int& store) {
  491. Y_UNUSED(p);
  492. Y_UNUSED(store);
  493. }
  494. };
  495. }
  496. }