concurrent_hash.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #pragma once
  2. #include <util/generic/hash.h>
  3. #include <util/system/spinlock.h>
  4. #include <array>
  5. namespace NPrivate {
  6. template <typename T, typename THash>
  7. concept CHashableBy = requires (const THash& hash, const T& t) {
  8. hash(t);
  9. };
  10. }
  11. template <typename K, typename V, size_t BucketCount = 64, typename L = TAdaptiveLock>
  12. class TConcurrentHashMap {
  13. public:
  14. using TActualMap = THashMap<K, V>;
  15. using TLock = L;
  16. struct TBucket {
  17. friend class TConcurrentHashMap;
  18. private:
  19. TActualMap Map;
  20. mutable TLock Mutex;
  21. public:
  22. TLock& GetMutex() const {
  23. return Mutex;
  24. }
  25. TActualMap& GetMap() {
  26. return Map;
  27. }
  28. const TActualMap& GetMap() const {
  29. return Map;
  30. }
  31. const V& GetUnsafe(const K& key) const {
  32. typename TActualMap::const_iterator it = Map.find(key);
  33. Y_ABORT_UNLESS(it != Map.end(), "not found by key");
  34. return it->second;
  35. }
  36. V& GetUnsafe(const K& key) {
  37. typename TActualMap::iterator it = Map.find(key);
  38. Y_ABORT_UNLESS(it != Map.end(), "not found by key");
  39. return it->second;
  40. }
  41. V RemoveUnsafe(const K& key) {
  42. typename TActualMap::iterator it = Map.find(key);
  43. Y_ABORT_UNLESS(it != Map.end(), "removing non-existent key");
  44. V r = std::move(it->second);
  45. Map.erase(it);
  46. return r;
  47. }
  48. bool HasUnsafe(const K& key) const {
  49. typename TActualMap::const_iterator it = Map.find(key);
  50. return (it != Map.end());
  51. }
  52. const V* TryGetUnsafe(const K& key) const {
  53. typename TActualMap::const_iterator it = Map.find(key);
  54. return it == Map.end() ? nullptr : &it->second;
  55. }
  56. V* TryGetUnsafe(const K& key) {
  57. typename TActualMap::iterator it = Map.find(key);
  58. return it == Map.end() ? nullptr : &it->second;
  59. }
  60. };
  61. std::array<TBucket, BucketCount> Buckets;
  62. public:
  63. template <NPrivate::CHashableBy<THash<K>> TKey>
  64. TBucket& GetBucketForKey(const TKey& key) {
  65. return Buckets[THash<K>()(key) % BucketCount];
  66. }
  67. template <NPrivate::CHashableBy<THash<K>> TKey>
  68. const TBucket& GetBucketForKey(const TKey& key) const {
  69. return Buckets[THash<K>()(key) % BucketCount];
  70. }
  71. void Insert(const K& key, const V& value) {
  72. TBucket& bucket = GetBucketForKey(key);
  73. TGuard<TLock> guard(bucket.Mutex);
  74. bucket.Map[key] = value;
  75. }
  76. void InsertUnique(const K& key, const V& value) {
  77. TBucket& bucket = GetBucketForKey(key);
  78. TGuard<TLock> guard(bucket.Mutex);
  79. if (!bucket.Map.insert(std::make_pair(key, value)).second) {
  80. Y_ABORT("non-unique key");
  81. }
  82. }
  83. V& InsertIfAbsent(const K& key, const V& value) {
  84. TBucket& bucket = GetBucketForKey(key);
  85. TGuard<TLock> guard(bucket.Mutex);
  86. return bucket.Map.insert(std::make_pair(key, value)).first->second;
  87. }
  88. template <typename TKey, typename... Args>
  89. V& EmplaceIfAbsent(TKey&& key, Args&&... args) {
  90. TBucket& bucket = GetBucketForKey(key);
  91. TGuard<TLock> guard(bucket.Mutex);
  92. if (V* value = bucket.TryGetUnsafe(key)) {
  93. return *value;
  94. }
  95. return bucket.Map.emplace(
  96. std::piecewise_construct,
  97. std::forward_as_tuple(std::forward<TKey>(key)),
  98. std::forward_as_tuple(std::forward<Args>(args)...)
  99. ).first->second;
  100. }
  101. template <typename Callable>
  102. V& InsertIfAbsentWithInit(const K& key, Callable initFunc) {
  103. TBucket& bucket = GetBucketForKey(key);
  104. TGuard<TLock> guard(bucket.Mutex);
  105. if (V* value = bucket.TryGetUnsafe(key)) {
  106. return *value;
  107. }
  108. return bucket.Map.insert(std::make_pair(key, initFunc())).first->second;
  109. }
  110. V Get(const K& key) const {
  111. const TBucket& bucket = GetBucketForKey(key);
  112. TGuard<TLock> guard(bucket.Mutex);
  113. return bucket.GetUnsafe(key);
  114. }
  115. bool Get(const K& key, V& result) const {
  116. const TBucket& bucket = GetBucketForKey(key);
  117. TGuard<TLock> guard(bucket.Mutex);
  118. if (const V* value = bucket.TryGetUnsafe(key)) {
  119. result = *value;
  120. return true;
  121. }
  122. return false;
  123. }
  124. V Remove(const K& key) {
  125. TBucket& bucket = GetBucketForKey(key);
  126. TGuard<TLock> guard(bucket.Mutex);
  127. return bucket.RemoveUnsafe(key);
  128. }
  129. bool Has(const K& key) const {
  130. const TBucket& bucket = GetBucketForKey(key);
  131. TGuard<TLock> guard(bucket.Mutex);
  132. return bucket.HasUnsafe(key);
  133. }
  134. };