concurrent_hash.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. #pragma once
  2. #include <util/generic/hash.h>
  3. #include <util/system/spinlock.h>
  4. #include <array>
  5. namespace NPrivate {
  6. template <typename T, typename THash>
  7. concept CHashableBy = requires (const THash& hash, const T& t) {
  8. hash(t);
  9. };
  10. }
  11. template <typename K, typename V, size_t BucketCount = 64, typename L = TAdaptiveLock, class TLockOps = TCommonLockOps<L>>
  12. class TConcurrentHashMap {
  13. public:
  14. using TActualMap = THashMap<K, V>;
  15. using TLock = L;
  16. using TBucketGuard = TGuard<TLock, TLockOps>;
  17. struct TBucket {
  18. friend class TConcurrentHashMap;
  19. private:
  20. TActualMap Map;
  21. mutable TLock Mutex;
  22. public:
  23. TLock& GetMutex() const {
  24. return Mutex;
  25. }
  26. TActualMap& GetMap() {
  27. return Map;
  28. }
  29. const TActualMap& GetMap() const {
  30. return Map;
  31. }
  32. const V& GetUnsafe(const K& key) const {
  33. typename TActualMap::const_iterator it = Map.find(key);
  34. Y_ABORT_UNLESS(it != Map.end(), "not found by key");
  35. return it->second;
  36. }
  37. V& GetUnsafe(const K& key) {
  38. typename TActualMap::iterator it = Map.find(key);
  39. Y_ABORT_UNLESS(it != Map.end(), "not found by key");
  40. return it->second;
  41. }
  42. V RemoveUnsafe(const K& key) {
  43. typename TActualMap::iterator it = Map.find(key);
  44. Y_ABORT_UNLESS(it != Map.end(), "removing non-existent key");
  45. V r = std::move(it->second);
  46. Map.erase(it);
  47. return r;
  48. }
  49. bool HasUnsafe(const K& key) const {
  50. typename TActualMap::const_iterator it = Map.find(key);
  51. return (it != Map.end());
  52. }
  53. const V* TryGetUnsafe(const K& key) const {
  54. typename TActualMap::const_iterator it = Map.find(key);
  55. return it == Map.end() ? nullptr : &it->second;
  56. }
  57. V* TryGetUnsafe(const K& key) {
  58. typename TActualMap::iterator it = Map.find(key);
  59. return it == Map.end() ? nullptr : &it->second;
  60. }
  61. };
  62. std::array<TBucket, BucketCount> Buckets;
  63. public:
  64. template <NPrivate::CHashableBy<THash<K>> TKey>
  65. TBucket& GetBucketForKey(const TKey& key) {
  66. return Buckets[THash<K>()(key) % BucketCount];
  67. }
  68. template <NPrivate::CHashableBy<THash<K>> TKey>
  69. const TBucket& GetBucketForKey(const TKey& key) const {
  70. return Buckets[THash<K>()(key) % BucketCount];
  71. }
  72. void Insert(const K& key, const V& value) {
  73. TBucket& bucket = GetBucketForKey(key);
  74. TBucketGuard guard(bucket.Mutex);
  75. bucket.Map[key] = value;
  76. }
  77. void InsertUnique(const K& key, const V& value) {
  78. TBucket& bucket = GetBucketForKey(key);
  79. TBucketGuard guard(bucket.Mutex);
  80. if (!bucket.Map.insert(std::make_pair(key, value)).second) {
  81. Y_ABORT("non-unique key");
  82. }
  83. }
  84. V& InsertIfAbsent(const K& key, const V& value) {
  85. TBucket& bucket = GetBucketForKey(key);
  86. TBucketGuard guard(bucket.Mutex);
  87. return bucket.Map.insert(std::make_pair(key, value)).first->second;
  88. }
  89. template <typename TKey, typename... Args>
  90. V& EmplaceIfAbsent(TKey&& key, Args&&... args) {
  91. TBucket& bucket = GetBucketForKey(key);
  92. TBucketGuard guard(bucket.Mutex);
  93. if (V* value = bucket.TryGetUnsafe(key)) {
  94. return *value;
  95. }
  96. return bucket.Map.emplace(
  97. std::piecewise_construct,
  98. std::forward_as_tuple(std::forward<TKey>(key)),
  99. std::forward_as_tuple(std::forward<Args>(args)...)
  100. ).first->second;
  101. }
  102. template <typename Callable>
  103. V& InsertIfAbsentWithInit(const K& key, Callable initFunc) {
  104. TBucket& bucket = GetBucketForKey(key);
  105. TBucketGuard guard(bucket.Mutex);
  106. if (V* value = bucket.TryGetUnsafe(key)) {
  107. return *value;
  108. }
  109. return bucket.Map.insert(std::make_pair(key, initFunc())).first->second;
  110. }
  111. V Get(const K& key) const {
  112. const TBucket& bucket = GetBucketForKey(key);
  113. TBucketGuard guard(bucket.Mutex);
  114. return bucket.GetUnsafe(key);
  115. }
  116. bool Get(const K& key, V& result) const {
  117. const TBucket& bucket = GetBucketForKey(key);
  118. TBucketGuard guard(bucket.Mutex);
  119. if (const V* value = bucket.TryGetUnsafe(key)) {
  120. result = *value;
  121. return true;
  122. }
  123. return false;
  124. }
  125. V Remove(const K& key) {
  126. TBucket& bucket = GetBucketForKey(key);
  127. TBucketGuard guard(bucket.Mutex);
  128. return bucket.RemoveUnsafe(key);
  129. }
  130. bool Has(const K& key) const {
  131. const TBucket& bucket = GetBucketForKey(key);
  132. TBucketGuard guard(bucket.Mutex);
  133. return bucket.HasUnsafe(key);
  134. }
  135. };