concurrent_hash.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #pragma once
  2. #include <util/generic/hash.h>
  3. #include <util/system/spinlock.h>
  4. #include <array>
  5. template <typename K, typename V, size_t BucketCount = 64, typename L = TAdaptiveLock>
  6. class TConcurrentHashMap {
  7. public:
  8. using TActualMap = THashMap<K, V>;
  9. using TLock = L;
  10. struct TBucket {
  11. friend class TConcurrentHashMap;
  12. private:
  13. TActualMap Map;
  14. mutable TLock Mutex;
  15. public:
  16. TLock& GetMutex() const {
  17. return Mutex;
  18. }
  19. TActualMap& GetMap() {
  20. return Map;
  21. }
  22. const TActualMap& GetMap() const {
  23. return Map;
  24. }
  25. const V& GetUnsafe(const K& key) const {
  26. typename TActualMap::const_iterator it = Map.find(key);
  27. Y_VERIFY(it != Map.end(), "not found by key");
  28. return it->second;
  29. }
  30. V& GetUnsafe(const K& key) {
  31. typename TActualMap::iterator it = Map.find(key);
  32. Y_VERIFY(it != Map.end(), "not found by key");
  33. return it->second;
  34. }
  35. V RemoveUnsafe(const K& key) {
  36. typename TActualMap::iterator it = Map.find(key);
  37. Y_VERIFY(it != Map.end(), "removing non-existent key");
  38. V r = std::move(it->second);
  39. Map.erase(it);
  40. return r;
  41. }
  42. bool HasUnsafe(const K& key) const {
  43. typename TActualMap::const_iterator it = Map.find(key);
  44. return (it != Map.end());
  45. }
  46. const V* TryGetUnsafe(const K& key) const {
  47. typename TActualMap::const_iterator it = Map.find(key);
  48. return it == Map.end() ? nullptr : &it->second;
  49. }
  50. V* TryGetUnsafe(const K& key) {
  51. typename TActualMap::iterator it = Map.find(key);
  52. return it == Map.end() ? nullptr : &it->second;
  53. }
  54. };
  55. std::array<TBucket, BucketCount> Buckets;
  56. public:
  57. TBucket& GetBucketForKey(const K& key) {
  58. return Buckets[THash<K>()(key) % BucketCount];
  59. }
  60. const TBucket& GetBucketForKey(const K& key) const {
  61. return Buckets[THash<K>()(key) % BucketCount];
  62. }
  63. void Insert(const K& key, const V& value) {
  64. TBucket& bucket = GetBucketForKey(key);
  65. TGuard<TLock> guard(bucket.Mutex);
  66. bucket.Map[key] = value;
  67. }
  68. void InsertUnique(const K& key, const V& value) {
  69. TBucket& bucket = GetBucketForKey(key);
  70. TGuard<TLock> guard(bucket.Mutex);
  71. if (!bucket.Map.insert(std::make_pair(key, value)).second) {
  72. Y_FAIL("non-unique key");
  73. }
  74. }
  75. V& InsertIfAbsent(const K& key, const V& value) {
  76. TBucket& bucket = GetBucketForKey(key);
  77. TGuard<TLock> guard(bucket.Mutex);
  78. return bucket.Map.insert(std::make_pair(key, value)).first->second;
  79. }
  80. template <typename TKey, typename... Args>
  81. V& EmplaceIfAbsent(TKey&& key, Args&&... args) {
  82. TBucket& bucket = GetBucketForKey(key);
  83. TGuard<TLock> guard(bucket.Mutex);
  84. if (V* value = bucket.TryGetUnsafe(key)) {
  85. return *value;
  86. }
  87. return bucket.Map.emplace(
  88. std::piecewise_construct,
  89. std::forward_as_tuple(std::forward<TKey>(key)),
  90. std::forward_as_tuple(std::forward<Args>(args)...)
  91. ).first->second;
  92. }
  93. template <typename Callable>
  94. V& InsertIfAbsentWithInit(const K& key, Callable initFunc) {
  95. TBucket& bucket = GetBucketForKey(key);
  96. TGuard<TLock> guard(bucket.Mutex);
  97. if (V* value = bucket.TryGetUnsafe(key)) {
  98. return *value;
  99. }
  100. return bucket.Map.insert(std::make_pair(key, initFunc())).first->second;
  101. }
  102. V Get(const K& key) const {
  103. const TBucket& bucket = GetBucketForKey(key);
  104. TGuard<TLock> guard(bucket.Mutex);
  105. return bucket.GetUnsafe(key);
  106. }
  107. bool Get(const K& key, V& result) const {
  108. const TBucket& bucket = GetBucketForKey(key);
  109. TGuard<TLock> guard(bucket.Mutex);
  110. if (const V* value = bucket.TryGetUnsafe(key)) {
  111. result = *value;
  112. return true;
  113. }
  114. return false;
  115. }
  116. V Remove(const K& key) {
  117. TBucket& bucket = GetBucketForKey(key);
  118. TGuard<TLock> guard(bucket.Mutex);
  119. return bucket.RemoveUnsafe(key);
  120. }
  121. bool Has(const K& key) const {
  122. const TBucket& bucket = GetBucketForKey(key);
  123. TGuard<TLock> guard(bucket.Mutex);
  124. return bucket.HasUnsafe(key);
  125. }
  126. };