thread_local.h 6.9 KB


  1. #pragma once
  2. #include <util/generic/hash.h>
  3. #include <util/generic/maybe.h>
  4. #include <util/generic/ptr.h>
  5. #include <util/generic/vector.h>
  6. #include <util/memory/pool.h>
  7. #include <util/system/mutex.h>
  8. #include <util/system/thread.h>
  9. #include <library/cpp/threading/hot_swap/hot_swap.h>
  10. #include <library/cpp/threading/skip_list/skiplist.h>
  11. #include <array>
  12. #include <atomic>
  13. #include <thread>
  14. namespace NThreading {
  15. // TThreadLocalValue
  16. //
  17. // Safe RAII-friendly thread local storage without dirty hacks from util/system/tls
  18. //
  19. // Example 1:
  20. //
  21. // THolder<IThreadPool> pool = CreateThreadPool(threads);
  22. // TThreadLocalValue<ui32> tls;
  23. // for (ui32 i : xrange(threads)) {
  24. // pool->SafeAddFunc([&]) {
  25. // *tls->Get() = 1337;
  26. // }
  27. // }
  28. //
  29. // Example 2:
  30. //
  31. // class TNoisy {
  32. // public:
  33. // TNoisy(const char* name = "TNoisy")
  34. // : Name_{name} {
  35. // printf("%s::%s\n", Name_, Name_);
  36. // }
  37. //
  38. // ~TNoisy() {
  39. // printf("%s::~%s\n", Name_, Name_);
  40. // }
  41. // private:
  42. // const char* Name_;
  43. // };
  44. //
  45. // class TWrapper {
  46. // public:
  47. // TWrapper() {
  48. // Println(__PRETTY_FUNCTION__);
  49. // }
  50. //
  51. // ~TWrapper() {
  52. // Println(__PRETTY_FUNCTION__);
  53. // }
  54. //
  55. // void DoWork() {
  56. // ThreadLocal_->Get();
  57. // }
  58. //
  59. // private:
  60. // TNoisy Noisy_{"TWrapper"};
  61. // TThreadLocalValue<TNoisy> ThreadLocal_;
  62. // };
  63. //
  64. // THolder<IThreadPool> pool = CreateThreadPool(3);
  65. // {
  66. // TWrapper wrapper;
  67. // for (ui32 i : xrange(3)) {
  68. // pool->SafeAddFunc([&] {
  69. // wrapper.DoWork();
  70. // });
  71. // }
  72. // }
  73. //
  74. // Will always print:
  75. // TWrapper::TWrapper()
  76. // TNoisy::TNoisy()
  77. // TNoisy::TNoisy()
  78. // TNoisy::TNoisy()
  79. // TNoisy::~TNoisy()
  80. // TNoisy::~TNoisy()
  81. // TNoisy::~TNoisy()
  82. // TWrapper::~TWrapper()
  83. //
  84. enum class EThreadLocalImpl {
  85. HotSwap,
  86. SkipList,
  87. ForwardList,
  88. };
  89. namespace NDetail {
  90. template <typename T, EThreadLocalImpl Impl, size_t NumShards>
  91. class TThreadLocalValueImpl;
  92. } // namespace NDetail
  93. inline constexpr size_t DefaultNumShards = 3;
  94. template <typename T, EThreadLocalImpl Impl = EThreadLocalImpl::SkipList, size_t NumShards = DefaultNumShards>
  95. class TThreadLocalValue : private TNonCopyable {
  96. public:
  97. template <typename ...ConstructArgs>
  98. T& GetRef(ConstructArgs&& ...args) const {
  99. return *Get(std::forward<ConstructArgs>(args)...);
  100. }
  101. template <typename ...ConstructArgs>
  102. T* Get(ConstructArgs&& ...args) const {
  103. TThread::TId tid = TThread::CurrentThreadId();
  104. return Shards_[tid % NumShards].Get(tid, std::forward<ConstructArgs>(args)...);
  105. }
  106. private:
  107. using TStorage = NDetail::TThreadLocalValueImpl<T, Impl, NumShards>;
  108. mutable std::array<TStorage, NumShards> Shards_;
  109. };
  110. namespace NDetail {
  111. template <typename T, size_t NumShards>
  112. class TThreadLocalValueImpl<T, EThreadLocalImpl::HotSwap, NumShards> {
  113. private:
  114. class TStorage: public THashMap<TThread::TId, TAtomicSharedPtr<T>>, public TAtomicRefCount<TStorage> {
  115. };
  116. public:
  117. TThreadLocalValueImpl() {
  118. Registered_.AtomicStore(new TStorage());
  119. }
  120. template <typename ...ConstructArgs>
  121. T* Get(TThread::TId tid, ConstructArgs&& ...args) {
  122. if (TIntrusivePtr<TStorage> state = Registered_.AtomicLoad(); TAtomicSharedPtr<T>* result = state->FindPtr(tid)) {
  123. return result->Get();
  124. } else {
  125. TAtomicSharedPtr<T> value = MakeAtomicShared<T>(std::forward<ConstructArgs>(args)...);
  126. with_lock(RegisterLock_) {
  127. TIntrusivePtr<TStorage> oldState = Registered_.AtomicLoad();
  128. THolder<TStorage> newState = MakeHolder<TStorage>(*oldState);
  129. (*newState)[tid] = value;
  130. Registered_.AtomicStore(newState.Release());
  131. }
  132. return value.Get();
  133. }
  134. }
  135. private:
  136. THotSwap<TStorage> Registered_;
  137. TMutex RegisterLock_;
  138. };
  139. template <typename T, size_t NumShards>
  140. class TThreadLocalValueImpl<T, EThreadLocalImpl::SkipList, NumShards> {
  141. private:
  142. struct TNode {
  143. TThread::TId Key;
  144. THolder<T> Value;
  145. };
  146. struct TCompare {
  147. int operator()(const TNode& lhs, const TNode& rhs) const {
  148. return ::NThreading::TCompare<TThread::TId>{}(lhs.Key, rhs.Key);
  149. }
  150. };
  151. public:
  152. TThreadLocalValueImpl()
  153. : ListPool_{InitialPoolSize()}
  154. , SkipList_{ListPool_}
  155. {}
  156. template <typename ...ConstructArgs>
  157. T* Get(TThread::TId tid, ConstructArgs&& ...args) {
  158. TNode key{tid, {}};
  159. auto iterator = SkipList_.SeekTo(key);
  160. if (iterator.IsValid() && iterator.GetValue().Key == key.Key) {
  161. return iterator.GetValue().Value.Get();
  162. }
  163. with_lock (RegisterLock_) {
  164. SkipList_.Insert({tid, MakeHolder<T>(std::forward<ConstructArgs>(args)...)});
  165. }
  166. iterator = SkipList_.SeekTo(key);
  167. return iterator.GetValue().Value.Get();
  168. }
  169. private:
  170. static size_t InitialPoolSize() {
  171. return std::thread::hardware_concurrency() * (sizeof(T) + sizeof(TThread::TId) + sizeof(void*)) / NumShards;
  172. }
  173. private:
  174. static inline constexpr size_t MaxHeight = 6;
  175. using TCustomSkipList = TSkipList<TNode, TCompare, TMemoryPool, TSizeCounter, MaxHeight>;
  176. TMemoryPool ListPool_;
  177. TCustomSkipList SkipList_;
  178. TAdaptiveLock RegisterLock_;
  179. };
  180. template <typename T, size_t NumShards>
  181. class TThreadLocalValueImpl<T, EThreadLocalImpl::ForwardList, NumShards> {
  182. private:
  183. struct TNode {
  184. TThread::TId Key = 0;
  185. T Value;
  186. TNode* Next = nullptr;
  187. };
  188. public:
  189. TThreadLocalValueImpl()
  190. : Head_{nullptr}
  191. , Pool_{0}
  192. {}
  193. template <typename ...ConstructArgs>
  194. T* Get(TThread::TId tid, ConstructArgs&& ...args) {
  195. TNode* head = Head_.load(std::memory_order_acquire);
  196. for (TNode* node = head; node; node = node->Next) {
  197. if (node->Key == tid) {
  198. return &node->Value;
  199. }
  200. }
  201. TNode* newNode = AllocateNode(tid, head, std::forward<ConstructArgs>(args)...);
  202. while (!Head_.compare_exchange_weak(head, newNode, std::memory_order_release, std::memory_order_relaxed)) {
  203. newNode->Next = head;
  204. }
  205. return &newNode->Value;
  206. }
  207. template <typename ...ConstructArgs>
  208. TNode* AllocateNode(TThread::TId tid, TNode* next, ConstructArgs&& ...args) {
  209. TNode* storage = nullptr;
  210. with_lock(PoolMutex_) {
  211. storage = Pool_.Allocate<TNode>();
  212. }
  213. new (storage) TNode{tid, T{std::forward<ConstructArgs>(args)...}, next};
  214. return storage;
  215. }
  216. ~TThreadLocalValueImpl() {
  217. if constexpr (!std::is_trivially_destructible_v<T>) {
  218. TNode* next = nullptr;
  219. for (TNode* node = Head_.load(); node; node = next) {
  220. next = node->Next;
  221. node->~TNode();
  222. }
  223. }
  224. }
  225. private:
  226. std::atomic<TNode*> Head_;
  227. TMemoryPool Pool_;
  228. TMutex PoolMutex_;
  229. };
  230. } // namespace NDetail
  231. } // namespace NThreading