Browse Source

IGNIETFERRO-1105 TAtomic -> std::atomic in util/generic/* and threadpool

ref:39a714b781c60dca9e3b946d870971076e14ab7c
eeight 2 years ago
parent
commit
05a6fea781

+ 1 - 1
library/cpp/messagebus/use_count_checker.cpp

@@ -7,7 +7,7 @@ TUseCountChecker::TUseCountChecker() {
 }
 
 TUseCountChecker::~TUseCountChecker() {
-    TAtomicBase count = Counter.Val();
+    auto count = Counter.Val();
     Y_VERIFY(count == 0, "must not release when count is not zero: %ld", (long)count);
 }
 

+ 9 - 8
util/generic/object_counter.h

@@ -1,6 +1,7 @@
 #pragma once
 
-#include <util/system/atomic.h>
+#include <cstddef>
+#include <atomic>
 
 /**
  * Simple thread-safe per-class counter that can be used to make sure you don't
@@ -20,19 +21,19 @@ template <class T>
 class TObjectCounter {
 public:
     inline TObjectCounter() noexcept {
-        AtomicIncrement(Count_);
+        ++Count_;
     }
 
     inline TObjectCounter(const TObjectCounter& /*item*/) noexcept {
-        AtomicIncrement(Count_);
+        ++Count_;
     }
 
     inline ~TObjectCounter() {
-        AtomicDecrement(Count_);
+        --Count_;
     }
 
     static inline long ObjectCount() noexcept {
-        return AtomicGet(Count_);
+        return Count_.load();
     }
 
     /**
@@ -42,12 +43,12 @@ public:
      * \returns                         Current object count.
      */
     static inline long ResetObjectCount() noexcept {
-        return AtomicSwap(&Count_, 0);
+        return Count_.exchange(0);
     }
 
 private:
-    static TAtomic Count_;
+    static std::atomic<intptr_t> Count_;
 };
 
 template <class T>
-TAtomic TObjectCounter<T>::Count_ = 0;
+std::atomic<intptr_t> TObjectCounter<T>::Count_ = 0;

+ 3 - 3
util/generic/ptr.h

@@ -365,7 +365,7 @@ public:
 
     inline ~TRefCounted() = default;
 
-    inline void Ref(TAtomicBase d) noexcept {
+    inline void Ref(intptr_t d) noexcept {
         auto resultCount = Counter_.Add(d);
         Y_ASSERT(resultCount >= d);
         (void)resultCount;
@@ -377,7 +377,7 @@ public:
         (void)resultCount;
     }
 
-    inline void UnRef(TAtomicBase d) noexcept {
+    inline void UnRef(intptr_t d) noexcept {
         auto resultCount = Counter_.Sub(d);
         Y_ASSERT(resultCount >= 0);
         if (resultCount == 0) {
@@ -389,7 +389,7 @@ public:
         UnRef(1);
     }
 
-    inline TAtomicBase RefCount() const noexcept {
+    inline intptr_t RefCount() const noexcept {
         return Counter_.Val();
     }
 

+ 21 - 19
util/generic/ptr_ut.cpp

@@ -779,29 +779,31 @@ void TPointerTest::TestRefCountedPtrsInHashSet() {
 class TRefCountedWithStatistics: public TNonCopyable {
 public:
     struct TExternalCounter {
-        TAtomic Counter{0};
-        TAtomic Increments{0};
+        std::atomic<size_t> Counter{0};
+        std::atomic<size_t> Increments{0};
     };
 
     TRefCountedWithStatistics(TExternalCounter& cnt)
         : ExternalCounter_(cnt)
     {
-        ExternalCounter_ = {}; // reset counters
+        // Reset counters
+        ExternalCounter_.Counter.store(0);
+        ExternalCounter_.Increments.store(0);
     }
 
     void Ref() noexcept {
-        AtomicIncrement(ExternalCounter_.Counter);
-        AtomicIncrement(ExternalCounter_.Increments);
+        ++ExternalCounter_.Counter;
+        ++ExternalCounter_.Increments;
     }
 
     void UnRef() noexcept {
-        if (AtomicDecrement(ExternalCounter_.Counter) == 0) {
+        if (--ExternalCounter_.Counter == 0) {
             TDelete::Destroy(this);
         }
     }
 
     void DecRef() noexcept {
-        Y_VERIFY(AtomicDecrement(ExternalCounter_.Counter) != 0);
+        Y_VERIFY(--ExternalCounter_.Counter != 0);
     }
 
 private:
@@ -811,24 +813,24 @@ private:
 void TPointerTest::TestIntrusiveConstConstruction() {
     {
         TRefCountedWithStatistics::TExternalCounter cnt;
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Counter), 0);
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Increments), 0);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Counter.load(), 0);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Increments.load(), 0);
         TIntrusivePtr<TRefCountedWithStatistics> i{MakeIntrusive<TRefCountedWithStatistics>(cnt)};
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Counter), 1);
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Increments), 1);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Counter.load(), 1);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Increments.load(), 1);
         i.Reset();
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Counter), 0);
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Increments), 1);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Counter.load(), 0);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Increments.load(), 1);
     }
     {
         TRefCountedWithStatistics::TExternalCounter cnt;
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Counter), 0);
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Increments), 0);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Counter.load(), 0);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Increments.load(), 0);
         TIntrusiveConstPtr<TRefCountedWithStatistics> c{MakeIntrusive<TRefCountedWithStatistics>(cnt)};
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Counter), 1);
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Increments), 1);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Counter.load(), 1);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Increments.load(), 1);
         c.Reset();
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Counter), 0);
-        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(cnt.Increments), 1);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Counter.load(), 0);
+        UNIT_ASSERT_VALUES_EQUAL(cnt.Increments.load(), 1);
     }
 }

+ 29 - 23
util/generic/refcount.h

@@ -1,10 +1,11 @@
 #pragma once
 
 #include <util/system/guard.h>
-#include <util/system/atomic.h>
 #include <util/system/defaults.h>
 #include <util/system/yassert.h>
 
+#include <atomic>
+
 template <class TCounterCheckPolicy>
 class TSimpleCounterTemplate: public TCounterCheckPolicy {
     using TCounterCheckPolicy::Check;
@@ -19,21 +20,21 @@ public:
         Check();
     }
 
-    inline TAtomicBase Add(TAtomicBase d) noexcept {
+    inline intptr_t Add(intptr_t d) noexcept {
         Check();
         return Counter_ += d;
     }
 
-    inline TAtomicBase Inc() noexcept {
+    inline intptr_t Inc() noexcept {
         return Add(1);
     }
 
-    inline TAtomicBase Sub(TAtomicBase d) noexcept {
+    inline intptr_t Sub(intptr_t d) noexcept {
         Check();
         return Counter_ -= d;
     }
 
-    inline TAtomicBase Dec() noexcept {
+    inline intptr_t Dec() noexcept {
         return Sub(1);
     }
 
@@ -48,12 +49,12 @@ public:
         return true;
     }
 
-    inline TAtomicBase Val() const noexcept {
+    inline intptr_t Val() const noexcept {
         return Counter_;
     }
 
 private:
-    TAtomicBase Counter_;
+    intptr_t Counter_;
 };
 
 class TNoCheckPolicy {
@@ -107,47 +108,52 @@ public:
     {
     }
 
+    TAtomicCounter(const TAtomicCounter& other)
+        : Counter_(other.Counter_.load())
+    {
+    }
+
+    TAtomicCounter& operator=(const TAtomicCounter& other) {
+        Counter_.store(other.Counter_.load());
+        return *this;
+    }
+
     inline ~TAtomicCounter() = default;
 
-    inline TAtomicBase Add(TAtomicBase d) noexcept {
-        return AtomicAdd(Counter_, d);
+    inline intptr_t Add(intptr_t d) noexcept {
+        return Counter_ += d;
     }
 
-    inline TAtomicBase Inc() noexcept {
+    inline intptr_t Inc() noexcept {
         return Add(1);
     }
 
-    inline TAtomicBase Sub(TAtomicBase d) noexcept {
-        return AtomicSub(Counter_, d);
+    inline intptr_t Sub(intptr_t d) noexcept {
+        return Counter_ -= d;
     }
 
-    inline TAtomicBase Dec() noexcept {
+    inline intptr_t Dec() noexcept {
         return Sub(1);
     }
 
-    inline TAtomicBase Val() const noexcept {
-        return AtomicGet(Counter_);
+    inline intptr_t Val() const noexcept {
+        return Counter_.load();
     }
 
     inline bool TryWeakInc() noexcept {
-        while (true) {
-            intptr_t curValue = Counter_;
-
+        for (auto curValue = Counter_.load(std::memory_order_acquire);;) {
             if (!curValue) {
                 return false;
             }
 
-            intptr_t newValue = curValue + 1;
-            Y_ASSERT(newValue != 0);
-
-            if (AtomicCas(&Counter_, newValue, curValue)) {
+            if (Counter_.compare_exchange_weak(curValue, curValue + 1)) {
                 return true;
             }
         }
     }
 
 private:
-    TAtomic Counter_;
+    std::atomic<intptr_t> Counter_;
 };
 
 template <>

+ 13 - 12
util/generic/singleton.cpp

@@ -7,16 +7,17 @@
 #include <cstring>
 
 namespace {
-    static inline bool MyAtomicTryLock(TAtomic& a, TAtomicBase v) noexcept {
-        return AtomicCas(&a, v, 0);
+    static inline bool MyAtomicTryLock(std::atomic<size_t>& a, size_t v) noexcept {
+        size_t zero = 0;
+        return a.compare_exchange_strong(zero, v);
     }
 
-    static inline bool MyAtomicTryAndTryLock(TAtomic& a, TAtomicBase v) noexcept {
-        return (AtomicGet(a) == 0) && MyAtomicTryLock(a, v);
+    static inline bool MyAtomicTryAndTryLock(std::atomic<size_t>& a, size_t v) noexcept {
+        return a.load(std::memory_order_acquire) == 0 && MyAtomicTryLock(a, v);
     }
 
-    static inline TAtomicBase MyThreadId() noexcept {
-        const TAtomicBase ret = TThread::CurrentThreadId();
+    static inline size_t MyThreadId() noexcept {
+        const size_t ret = TThread::CurrentThreadId();
 
         if (ret) {
             return ret;
@@ -41,10 +42,10 @@ void NPrivate::FillWithTrash(void* ptr, size_t len) {
 #endif
 }
 
-void NPrivate::LockRecursive(TAtomic& lock) noexcept {
-    const TAtomicBase id = MyThreadId();
+void NPrivate::LockRecursive(std::atomic<size_t>& lock) noexcept {
+    const size_t id = MyThreadId();
 
-    Y_VERIFY(AtomicGet(lock) != id, "recursive singleton initialization");
+    Y_VERIFY(lock.load(std::memory_order_acquire) != id, "recursive singleton initialization");
 
     if (!MyAtomicTryLock(lock, id)) {
         TSpinWait sw;
@@ -55,7 +56,7 @@ void NPrivate::LockRecursive(TAtomic& lock) noexcept {
     }
 }
 
-void NPrivate::UnlockRecursive(TAtomic& lock) noexcept {
-    Y_VERIFY(AtomicGet(lock) == MyThreadId(), "unlock from another thread?!?!");
-    AtomicUnlock(&lock);
+void NPrivate::UnlockRecursive(std::atomic<size_t>& lock) noexcept {
+    Y_VERIFY(lock.load(std::memory_order_acquire) == MyThreadId(), "unlock from another thread?!?!");
+    lock.store(0);
 }

+ 10 - 10
util/generic/singleton.h

@@ -1,8 +1,8 @@
 #pragma once
 
 #include <util/system/atexit.h>
-#include <util/system/atomic.h>
 
+#include <atomic>
 #include <new>
 #include <utility>
 
@@ -14,8 +14,8 @@ struct TSingletonTraits {
 namespace NPrivate {
     void FillWithTrash(void* ptr, size_t len);
 
-    void LockRecursive(TAtomic& lock) noexcept;
-    void UnlockRecursive(TAtomic& lock) noexcept;
+    void LockRecursive(std::atomic<size_t>& lock) noexcept;
+    void UnlockRecursive(std::atomic<size_t>& lock) noexcept;
 
     template <class T>
     void Destroyer(void* ptr) {
@@ -24,13 +24,13 @@ namespace NPrivate {
     }
 
     template <class T, size_t P, class... TArgs>
-    Y_NO_INLINE T* SingletonBase(T*& ptr, TArgs&&... args) {
+    Y_NO_INLINE T* SingletonBase(std::atomic<T*>& ptr, TArgs&&... args) {
         alignas(T) static char buf[sizeof(T)];
-        static TAtomic lock;
+        static std::atomic<size_t> lock;
 
         LockRecursive(lock);
 
-        auto ret = AtomicGet(ptr);
+        auto ret = ptr.load();
 
         try {
             if (!ret) {
@@ -44,7 +44,7 @@ namespace NPrivate {
                     throw;
                 }
 
-                AtomicSet(ptr, ret);
+                ptr.store(ret);
             }
         } catch (...) {
             UnlockRecursive(lock);
@@ -61,8 +61,8 @@ namespace NPrivate {
     T* SingletonInt(TArgs&&... args) {
         static_assert(sizeof(T) < 32000, "use HugeSingleton instead");
 
-        static T* ptr;
-        auto ret = AtomicGet(ptr);
+        static std::atomic<T*> ptr;
+        auto ret = ptr.load();
 
         if (Y_UNLIKELY(!ret)) {
             ret = SingletonBase<T, P>(ptr, std::forward<TArgs>(args)...);
@@ -108,7 +108,7 @@ namespace NPrivate {
     template <class T, size_t P, class... TArgs>    \
     friend T* ::NPrivate::SingletonInt(TArgs&&...); \
     template <class T, size_t P, class... TArgs>    \
-    friend T* ::NPrivate::SingletonBase(T*&, TArgs&&...);
+    friend T* ::NPrivate::SingletonBase(std::atomic<T*>&, TArgs&&...);
 
 template <class T, class... TArgs>
 T* Singleton(TArgs&&... args) {

+ 9 - 10
util/thread/pool.cpp

@@ -18,7 +18,6 @@
 
 #include <util/system/event.h>
 #include <util/system/mutex.h>
-#include <util/system/atomic.h>
 #include <util/system/condvar.h>
 #include <util/system/thread.h>
 
@@ -76,7 +75,7 @@ public:
         , Blocking(params.Blocking_)
         , Catching(params.Catching_)
         , Namer(params)
-        , ShouldTerminate(1)
+        , ShouldTerminate(true)
         , MaxQueueSize(0)
         , ThreadCountExpected(0)
         , ThreadCountReal(0)
@@ -98,7 +97,7 @@ public:
     }
 
     inline bool Add(IObjectInQueue* obj) {
-        if (AtomicGet(ShouldTerminate)) {
+        if (ShouldTerminate.load()) {
             return false;
         }
 
@@ -110,14 +109,14 @@ public:
         }
 
         with_lock (QueueMutex) {
-            while (MaxQueueSize > 0 && Queue.Size() >= MaxQueueSize && !AtomicGet(ShouldTerminate)) {
+            while (MaxQueueSize > 0 && Queue.Size() >= MaxQueueSize && !ShouldTerminate.load()) {
                 if (!Blocking) {
                     return false;
                 }
                 QueuePopCond.Wait(QueueMutex);
             }
 
-            if (AtomicGet(ShouldTerminate)) {
+            if (ShouldTerminate.load()) {
                 return false;
             }
 
@@ -157,7 +156,7 @@ public:
 
 private:
     inline void Start(size_t num, size_t maxque) {
-        AtomicSet(ShouldTerminate, 0);
+        ShouldTerminate.store(false);
         MaxQueueSize = maxque;
         ThreadCountExpected = num;
 
@@ -174,7 +173,7 @@ private:
     }
 
     inline void Stop() {
-        AtomicSet(ShouldTerminate, 1);
+        ShouldTerminate.store(true);
 
         with_lock (QueueMutex) {
             QueuePopCond.BroadCast();
@@ -212,11 +211,11 @@ private:
             IObjectInQueue* job = nullptr;
 
             with_lock (QueueMutex) {
-                while (Queue.Empty() && !AtomicGet(ShouldTerminate)) {
+                while (Queue.Empty() && !ShouldTerminate.load()) {
                     QueuePushCond.Wait(QueueMutex);
                 }
 
-                if (AtomicGet(ShouldTerminate) && Queue.Empty()) {
+                if (ShouldTerminate.load() && Queue.Empty()) {
                     tsr.Destroy();
 
                     break;
@@ -264,7 +263,7 @@ private:
     TCondVar StopCond;
     TJobQueue Queue;
     TVector<TThreadRef> Tharr;
-    TAtomic ShouldTerminate;
+    std::atomic<bool> ShouldTerminate;
     size_t MaxQueueSize;
     size_t ThreadCountExpected;
     size_t ThreadCountReal;

+ 1 - 0
ydb/library/yql/utils/rand_guid.h

@@ -2,6 +2,7 @@
 
 #include <util/random/mersenne.h>
 #include <util/generic/ptr.h>
+#include <util/system/atomic.h>
 
 namespace NYql {
 class TRandGuid {