#pragma once #include #include #include #include #include #include #include #include #include #include #include #include namespace NThreading { // TThreadLocalValue // // Safe RAII-friendly thread local storage without dirty hacks from util/system/tls // // Example 1: // // THolder pool = CreateThreadPool(threads); // TThreadLocalValue tls; // for (ui32 i : xrange(threads)) { // pool->SafeAddFunc([&]) { // *tls->Get() = 1337; // } // } // // Example 2: // // class TNoisy { // public: // TNoisy(const char* name = "TNoisy") // : Name_{name} { // printf("%s::%s\n", Name_, Name_); // } // // ~TNoisy() { // printf("%s::~%s\n", Name_, Name_); // } // private: // const char* Name_; // }; // // class TWrapper { // public: // TWrapper() { // Println(__PRETTY_FUNCTION__); // } // // ~TWrapper() { // Println(__PRETTY_FUNCTION__); // } // // void DoWork() { // ThreadLocal_->Get(); // } // // private: // TNoisy Noisy_{"TWrapper"}; // TThreadLocalValue ThreadLocal_; // }; // // THolder pool = CreateThreadPool(3); // { // TWrapper wrapper; // for (ui32 i : xrange(3)) { // pool->SafeAddFunc([&] { // wrapper.DoWork(); // }); // } // } // // Will always print: // TWrapper::TWrapper() // TNoisy::TNoisy() // TNoisy::TNoisy() // TNoisy::TNoisy() // TNoisy::~TNoisy() // TNoisy::~TNoisy() // TNoisy::~TNoisy() // TWrapper::~TWrapper() // enum class EThreadLocalImpl { HotSwap, SkipList, ForwardList, }; namespace NDetail { template class TThreadLocalValueImpl; } // namespace NDetail inline constexpr size_t DefaultNumShards = 3; template class TThreadLocalValue : private TNonCopyable { public: template T& GetRef(ConstructArgs&& ...args) const { return *Get(std::forward(args)...); } template T* Get(ConstructArgs&& ...args) const { TThread::TId tid = TThread::CurrentThreadId(); return Shards_[tid % NumShards].Get(tid, std::forward(args)...); } private: using TStorage = NDetail::TThreadLocalValueImpl; mutable std::array Shards_; }; namespace NDetail { template class TThreadLocalValueImpl { private: class TStorage: public THashMap>, public TAtomicRefCount { }; public: TThreadLocalValueImpl() { Registered_.AtomicStore(new TStorage()); } template T* Get(TThread::TId tid, ConstructArgs&& ...args) { if (TIntrusivePtr state = Registered_.AtomicLoad(); TAtomicSharedPtr* result = state->FindPtr(tid)) { return result->Get(); } else { TAtomicSharedPtr value = MakeAtomicShared(std::forward(args)...); with_lock(RegisterLock_) { TIntrusivePtr oldState = Registered_.AtomicLoad(); THolder newState = MakeHolder(*oldState); (*newState)[tid] = value; Registered_.AtomicStore(newState.Release()); } return value.Get(); } } private: THotSwap Registered_; TMutex RegisterLock_; }; template class TThreadLocalValueImpl { private: struct TNode { TThread::TId Key; THolder Value; }; struct TCompare { int operator()(const TNode& lhs, const TNode& rhs) const { return ::NThreading::TCompare{}(lhs.Key, rhs.Key); } }; public: TThreadLocalValueImpl() : ListPool_{InitialPoolSize()} , SkipList_{ListPool_} {} template T* Get(TThread::TId tid, ConstructArgs&& ...args) { TNode key{tid, {}}; auto iterator = SkipList_.SeekTo(key); if (iterator.IsValid() && iterator.GetValue().Key == key.Key) { return iterator.GetValue().Value.Get(); } with_lock (RegisterLock_) { SkipList_.Insert({tid, MakeHolder(std::forward(args)...)}); } iterator = SkipList_.SeekTo(key); return iterator.GetValue().Value.Get(); } private: static size_t InitialPoolSize() { return std::thread::hardware_concurrency() * (sizeof(T) + sizeof(TThread::TId) + sizeof(void*)) / NumShards; } private: static inline constexpr size_t MaxHeight = 6; using TCustomSkipList = TSkipList; TMemoryPool ListPool_; TCustomSkipList SkipList_; TAdaptiveLock RegisterLock_; }; template class TThreadLocalValueImpl { private: struct TNode { TThread::TId Key = 0; T Value; TNode* Next = nullptr; }; public: TThreadLocalValueImpl() : Head_{nullptr} , Pool_{0} {} template T* Get(TThread::TId tid, ConsturctArgs&& ...args) { TNode* node = Head_.load(std::memory_order_relaxed); for (; node; node = node->Next) { if (node->Key == tid) { return &node->Value; } } TNode* newNode = AllocateNode(tid, node, std::forward(args)...); while (!Head_.compare_exchange_weak(node, newNode, std::memory_order_release, std::memory_order_relaxed)) { newNode->Next = node; } return &newNode->Value; } template TNode* AllocateNode(TThread::TId tid, TNode* next, ConstructArgs&& ...args) { TNode* storage = nullptr; with_lock(PoolMutex_) { storage = Pool_.Allocate(); } new (storage) TNode{tid, T{std::forward(args)...}, next}; return storage; } ~TThreadLocalValueImpl() { if constexpr (!std::is_trivially_destructible_v) { TNode* next = nullptr; for (TNode* node = Head_.load(); node; node = next) { next = node->Next; node->~TNode(); } } } private: std::atomic Head_; TMemoryPool Pool_; TMutex PoolMutex_; }; } // namespace NDetail } // namespace NThreading