Browse Source

IGNIETFERRO-1105 TAtomic -> std::atomic in util/thread/lfqueue.h

ref:8cf44e7b3fecd13c3a0c699a8c1c7abe780eab0b
eeight 2 years ago
parent
commit
1af8bb8789
2 changed files with 71 additions and 80 deletions
  1. 68 75
      util/thread/lfqueue.h
  2. 3 5
      util/thread/lfqueue_ut.cpp

+ 68 - 75
util/thread/lfqueue.h

@@ -1,11 +1,12 @@
 #pragma once
 #pragma once
 
 
 #include "fwd.h"
 #include "fwd.h"
+#include "lfstack.h"
 
 
 #include <util/generic/ptr.h>
 #include <util/generic/ptr.h>
-#include <util/system/atomic.h>
 #include <util/system/yassert.h>
 #include <util/system/yassert.h>
-#include "lfstack.h"
+
+#include <atomic>
 
 
 struct TDefaultLFCounter {
 struct TDefaultLFCounter {
     template <class T>
     template <class T>
@@ -39,24 +40,17 @@ class TLockFreeQueue: public TNonCopyable {
         {
         {
         }
         }
 
 
-        TListNode* volatile Next;
+        std::atomic<TListNode*> Next;
         T Data;
         T Data;
     };
     };
 
 
     // using inheritance to be able to use 0 bytes for TCounter when we don't need one
     // using inheritance to be able to use 0 bytes for TCounter when we don't need one
     struct TRootNode: public TCounter {
     struct TRootNode: public TCounter {
-        TListNode* volatile PushQueue;
-        TListNode* volatile PopQueue;
-        TListNode* volatile ToDelete;
-        TRootNode* volatile NextFree;
-
-        TRootNode()
-            : PushQueue(nullptr)
-            , PopQueue(nullptr)
-            , ToDelete(nullptr)
-            , NextFree(nullptr)
-        {
-        }
+        std::atomic<TListNode*> PushQueue = nullptr;
+        std::atomic<TListNode*> PopQueue = nullptr;
+        std::atomic<TListNode*> ToDelete = nullptr;
+        std::atomic<TRootNode*> NextFree = nullptr;
+
         void CopyCounter(TRootNode* x) {
         void CopyCounter(TRootNode* x) {
             *(TCounter*)this = *(TCounter*)x;
             *(TCounter*)this = *(TCounter*)x;
         }
         }
@@ -64,58 +58,58 @@ class TLockFreeQueue: public TNonCopyable {
 
 
     static void EraseList(TListNode* n) {
     static void EraseList(TListNode* n) {
         while (n) {
         while (n) {
-            TListNode* keepNext = AtomicGet(n->Next);
+            TListNode* keepNext = n->Next.load(std::memory_order_acquire);
             delete n;
             delete n;
             n = keepNext;
             n = keepNext;
         }
         }
     }
     }
 
 
-    alignas(64) TRootNode* volatile JobQueue;
-    alignas(64) volatile TAtomic FreememCounter;
-    alignas(64) volatile TAtomic FreeingTaskCounter;
-    alignas(64) TRootNode* volatile FreePtr;
+    alignas(64) std::atomic<TRootNode*> JobQueue;
+    alignas(64) std::atomic<size_t> FreememCounter;
+    alignas(64) std::atomic<size_t> FreeingTaskCounter;
+    alignas(64) std::atomic<TRootNode*> FreePtr;
 
 
     void TryToFreeAsyncMemory() {
     void TryToFreeAsyncMemory() {
-        TAtomic keepCounter = AtomicAdd(FreeingTaskCounter, 0);
-        TRootNode* current = AtomicGet(FreePtr);
+        const auto keepCounter = FreeingTaskCounter.load();
+        TRootNode* current = FreePtr.load(std::memory_order_acquire);
         if (current == nullptr)
         if (current == nullptr)
             return;
             return;
-        if (AtomicAdd(FreememCounter, 0) == 1) {
+        if (FreememCounter.load() == 1) {
             // we are the last thread, try to cleanup
             // we are the last thread, try to cleanup
             // check if another thread have cleaned up
             // check if another thread have cleaned up
-            if (keepCounter != AtomicAdd(FreeingTaskCounter, 0)) {
+            if (keepCounter != FreeingTaskCounter.load()) {
                 return;
                 return;
             }
             }
-            if (AtomicCas(&FreePtr, (TRootNode*)nullptr, current)) {
+            if (FreePtr.compare_exchange_strong(current, nullptr)) {
                 // free list
                 // free list
                 while (current) {
                 while (current) {
-                    TRootNode* p = AtomicGet(current->NextFree);
-                    EraseList(AtomicGet(current->ToDelete));
+                    TRootNode* p = current->NextFree.load(std::memory_order_acquire);
+                    EraseList(current->ToDelete.load(std::memory_order_acquire));
                     delete current;
                     delete current;
                     current = p;
                     current = p;
                 }
                 }
-                AtomicAdd(FreeingTaskCounter, 1);
+                ++FreeingTaskCounter;
             }
             }
         }
         }
     }
     }
     void AsyncRef() {
     void AsyncRef() {
-        AtomicAdd(FreememCounter, 1);
+        ++FreememCounter;
     }
     }
     void AsyncUnref() {
     void AsyncUnref() {
         TryToFreeAsyncMemory();
         TryToFreeAsyncMemory();
-        AtomicAdd(FreememCounter, -1);
+        --FreememCounter;
     }
     }
     void AsyncDel(TRootNode* toDelete, TListNode* lst) {
     void AsyncDel(TRootNode* toDelete, TListNode* lst) {
-        AtomicSet(toDelete->ToDelete, lst);
-        for (;;) {
-            AtomicSet(toDelete->NextFree, AtomicGet(FreePtr));
-            if (AtomicCas(&FreePtr, toDelete, AtomicGet(toDelete->NextFree)))
+        toDelete->ToDelete.store(lst, std::memory_order_release);
+        for (auto freePtr = FreePtr.load();;) {
+            toDelete->NextFree.store(freePtr, std::memory_order_release);
+            if (FreePtr.compare_exchange_weak(freePtr, toDelete))
                 break;
                 break;
         }
         }
     }
     }
     void AsyncUnref(TRootNode* toDelete, TListNode* lst) {
     void AsyncUnref(TRootNode* toDelete, TListNode* lst) {
         TryToFreeAsyncMemory();
         TryToFreeAsyncMemory();
-        if (AtomicAdd(FreememCounter, -1) == 0) {
+        if (--FreememCounter == 0) {
             // no other operations in progress, can safely reclaim memory
             // no other operations in progress, can safely reclaim memory
             EraseList(lst);
             EraseList(lst);
             delete toDelete;
             delete toDelete;
@@ -151,7 +145,7 @@ class TLockFreeQueue: public TNonCopyable {
             while (ptr) {
             while (ptr) {
                 if (ptr == PrevFirst) {
                 if (ptr == PrevFirst) {
                     // short cut, we have copied this part already
                     // short cut, we have copied this part already
-                    AtomicSet(Tail->Next, newCopy);
+                    Tail->Next.store(newCopy, std::memory_order_release);
                     newCopy = Copy;
                     newCopy = Copy;
                     Copy = nullptr; // do not destroy prev try
                     Copy = nullptr; // do not destroy prev try
                     if (!newTail)
                     if (!newTail)
@@ -160,7 +154,7 @@ class TLockFreeQueue: public TNonCopyable {
                 }
                 }
                 TListNode* newElem = new TListNode(ptr->Data, newCopy);
                 TListNode* newElem = new TListNode(ptr->Data, newCopy);
                 newCopy = newElem;
                 newCopy = newElem;
-                ptr = AtomicGet(ptr->Next);
+                ptr = ptr->Next.load(std::memory_order_acquire);
                 if (!newTail)
                 if (!newTail)
                     newTail = newElem;
                     newTail = newElem;
             }
             }
@@ -174,20 +168,19 @@ class TLockFreeQueue: public TNonCopyable {
     void EnqueueImpl(TListNode* head, TListNode* tail) {
     void EnqueueImpl(TListNode* head, TListNode* tail) {
         TRootNode* newRoot = new TRootNode;
         TRootNode* newRoot = new TRootNode;
         AsyncRef();
         AsyncRef();
-        AtomicSet(newRoot->PushQueue, head);
-        for (;;) {
-            TRootNode* curRoot = AtomicGet(JobQueue);
-            AtomicSet(tail->Next, AtomicGet(curRoot->PushQueue));
-            AtomicSet(newRoot->PopQueue, AtomicGet(curRoot->PopQueue));
+        newRoot->PushQueue.store(head, std::memory_order_release);
+        for (TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);;) {
+            tail->Next.store(curRoot->PushQueue.load(std::memory_order_acquire), std::memory_order_release);
+            newRoot->PopQueue.store(curRoot->PopQueue.load(std::memory_order_acquire), std::memory_order_release);
             newRoot->CopyCounter(curRoot);
             newRoot->CopyCounter(curRoot);
 
 
-            for (TListNode* node = head;; node = AtomicGet(node->Next)) {
+            for (TListNode* node = head;; node = node->Next.load(std::memory_order_acquire)) {
                 newRoot->IncCount(node->Data);
                 newRoot->IncCount(node->Data);
                 if (node == tail)
                 if (node == tail)
                     break;
                     break;
             }
             }
 
 
-            if (AtomicCas(&JobQueue, newRoot, curRoot)) {
+            if (JobQueue.compare_exchange_weak(curRoot, newRoot)) {
                 AsyncUnref(curRoot, nullptr);
                 AsyncUnref(curRoot, nullptr);
                 break;
                 break;
             }
             }
@@ -198,7 +191,7 @@ class TLockFreeQueue: public TNonCopyable {
     static void FillCollection(TListNode* lst, TCollection* res) {
     static void FillCollection(TListNode* lst, TCollection* res) {
         while (lst) {
         while (lst) {
             res->emplace_back(std::move(lst->Data));
             res->emplace_back(std::move(lst->Data));
-            lst = AtomicGet(lst->Next);
+            lst = lst->Next.load(std::memory_order_acquire);
         }
         }
     }
     }
 
 
@@ -215,7 +208,7 @@ class TLockFreeQueue: public TNonCopyable {
         do {
         do {
             TListNode* newElem = new TListNode(std::move(lst->Data), newCopy);
             TListNode* newElem = new TListNode(std::move(lst->Data), newCopy);
             newCopy = newElem;
             newCopy = newElem;
-            lst = AtomicGet(lst->Next);
+            lst = lst->Next.load(std::memory_order_acquire);
         } while (lst);
         } while (lst);
 
 
         FillCollection(newCopy, res);
         FillCollection(newCopy, res);
@@ -235,8 +228,8 @@ public:
     ~TLockFreeQueue() {
     ~TLockFreeQueue() {
         AsyncRef();
         AsyncRef();
         AsyncUnref(); // should free FreeList
         AsyncUnref(); // should free FreeList
-        EraseList(JobQueue->PushQueue);
-        EraseList(JobQueue->PopQueue);
+        EraseList(JobQueue.load(std::memory_order_relaxed)->PushQueue.load(std::memory_order_relaxed));
+        EraseList(JobQueue.load(std::memory_order_relaxed)->PopQueue.load(std::memory_order_relaxed));
         delete JobQueue;
         delete JobQueue;
     }
     }
     template <typename U>
     template <typename U>
@@ -262,8 +255,8 @@ public:
             return;
             return;
 
 
         TIter i = dataBegin;
         TIter i = dataBegin;
-        TListNode* volatile node = new TListNode(*i);
-        TListNode* volatile tail = node;
+        TListNode* node = new TListNode(*i);
+        TListNode* tail = node;
 
 
         for (++i; i != dataEnd; ++i) {
         for (++i; i != dataEnd; ++i) {
             TListNode* nextNode = node;
             TListNode* nextNode = node;
@@ -275,28 +268,27 @@ public:
         TRootNode* newRoot = nullptr;
         TRootNode* newRoot = nullptr;
         TListInvertor listInvertor;
         TListInvertor listInvertor;
         AsyncRef();
         AsyncRef();
-        for (;;) {
-            TRootNode* curRoot = AtomicGet(JobQueue);
-            TListNode* tail = AtomicGet(curRoot->PopQueue);
+        for (TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);;) {
+            TListNode* tail = curRoot->PopQueue.load(std::memory_order_acquire);
             if (tail) {
             if (tail) {
                 // has elems to pop
                 // has elems to pop
                 if (!newRoot)
                 if (!newRoot)
                     newRoot = new TRootNode;
                     newRoot = new TRootNode;
 
 
-                AtomicSet(newRoot->PushQueue, AtomicGet(curRoot->PushQueue));
-                AtomicSet(newRoot->PopQueue, AtomicGet(tail->Next));
+                newRoot->PushQueue.store(curRoot->PushQueue.load(std::memory_order_acquire), std::memory_order_release);
+                newRoot->PopQueue.store(tail->Next.load(std::memory_order_acquire), std::memory_order_release);
                 newRoot->CopyCounter(curRoot);
                 newRoot->CopyCounter(curRoot);
                 newRoot->DecCount(tail->Data);
                 newRoot->DecCount(tail->Data);
-                Y_ASSERT(AtomicGet(curRoot->PopQueue) == tail);
-                if (AtomicCas(&JobQueue, newRoot, curRoot)) {
+                Y_ASSERT(curRoot->PopQueue.load() == tail);
+                if (JobQueue.compare_exchange_weak(curRoot, newRoot)) {
                     *data = std::move(tail->Data);
                     *data = std::move(tail->Data);
-                    AtomicSet(tail->Next, nullptr);
+                    tail->Next.store(nullptr, std::memory_order_release);
                     AsyncUnref(curRoot, tail);
                     AsyncUnref(curRoot, tail);
                     return true;
                     return true;
                 }
                 }
                 continue;
                 continue;
             }
             }
-            if (AtomicGet(curRoot->PushQueue) == nullptr) {
+            if (curRoot->PushQueue.load(std::memory_order_acquire) == nullptr) {
                 delete newRoot;
                 delete newRoot;
                 AsyncUnref();
                 AsyncUnref();
                 return false; // no elems to pop
                 return false; // no elems to pop
@@ -304,17 +296,18 @@ public:
 
 
             if (!newRoot)
             if (!newRoot)
                 newRoot = new TRootNode;
                 newRoot = new TRootNode;
-            AtomicSet(newRoot->PushQueue, nullptr);
-            listInvertor.DoCopy(AtomicGet(curRoot->PushQueue));
-            AtomicSet(newRoot->PopQueue, listInvertor.Copy);
+            newRoot->PushQueue.store(nullptr, std::memory_order_release);
+            listInvertor.DoCopy(curRoot->PushQueue.load(std::memory_order_acquire));
+            newRoot->PopQueue.store(listInvertor.Copy, std::memory_order_release);
             newRoot->CopyCounter(curRoot);
             newRoot->CopyCounter(curRoot);
-            Y_ASSERT(AtomicGet(curRoot->PopQueue) == nullptr);
-            if (AtomicCas(&JobQueue, newRoot, curRoot)) {
+            Y_ASSERT(curRoot->PopQueue.load() == nullptr);
+            if (JobQueue.compare_exchange_weak(curRoot, newRoot)) {
+                AsyncDel(curRoot, curRoot->PushQueue.load(std::memory_order_acquire));
+                curRoot = newRoot;
                 newRoot = nullptr;
                 newRoot = nullptr;
                 listInvertor.CopyWasUsed();
                 listInvertor.CopyWasUsed();
-                AsyncDel(curRoot, AtomicGet(curRoot->PushQueue));
             } else {
             } else {
-                AtomicSet(newRoot->PopQueue, nullptr);
+                newRoot->PopQueue.store(nullptr, std::memory_order_release);
             }
             }
         }
         }
     }
     }
@@ -323,36 +316,36 @@ public:
         AsyncRef();
         AsyncRef();
 
 
         TRootNode* newRoot = new TRootNode;
         TRootNode* newRoot = new TRootNode;
-        TRootNode* curRoot;
+        TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);
         do {
         do {
-            curRoot = AtomicGet(JobQueue);
-        } while (!AtomicCas(&JobQueue, newRoot, curRoot));
+        } while (!JobQueue.compare_exchange_weak(curRoot, newRoot));
 
 
         FillCollection(curRoot->PopQueue, res);
         FillCollection(curRoot->PopQueue, res);
 
 
         TListNode* toDeleteHead = curRoot->PushQueue;
         TListNode* toDeleteHead = curRoot->PushQueue;
         TListNode* toDeleteTail = FillCollectionReverse(curRoot->PushQueue, res);
         TListNode* toDeleteTail = FillCollectionReverse(curRoot->PushQueue, res);
-        AtomicSet(curRoot->PushQueue, nullptr);
+        curRoot->PushQueue.store(nullptr, std::memory_order_release);
 
 
         if (toDeleteTail) {
         if (toDeleteTail) {
-            toDeleteTail->Next = curRoot->PopQueue;
+            toDeleteTail->Next.store(curRoot->PopQueue.load());
         } else {
         } else {
             toDeleteTail = curRoot->PopQueue;
             toDeleteTail = curRoot->PopQueue;
         }
         }
-        AtomicSet(curRoot->PopQueue, nullptr);
+        curRoot->PopQueue.store(nullptr, std::memory_order_release);
 
 
         AsyncUnref(curRoot, toDeleteHead);
         AsyncUnref(curRoot, toDeleteHead);
     }
     }
     bool IsEmpty() {
     bool IsEmpty() {
         AsyncRef();
         AsyncRef();
-        TRootNode* curRoot = AtomicGet(JobQueue);
-        bool res = AtomicGet(curRoot->PushQueue) == nullptr && AtomicGet(curRoot->PopQueue) == nullptr;
+        TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);
+        bool res = curRoot->PushQueue.load(std::memory_order_acquire) == nullptr &&
+                   curRoot->PopQueue.load(std::memory_order_acquire) == nullptr;
         AsyncUnref();
         AsyncUnref();
         return res;
         return res;
     }
     }
     TCounter GetCounter() {
     TCounter GetCounter() {
         AsyncRef();
         AsyncRef();
-        TRootNode* curRoot = AtomicGet(JobQueue);
+        TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);
         TCounter res = *(TCounter*)curRoot;
         TCounter res = *(TCounter*)curRoot;
         AsyncUnref();
         AsyncUnref();
         return res;
         return res;

+ 3 - 5
util/thread/lfqueue_ut.cpp

@@ -4,7 +4,6 @@
 #include <util/generic/algorithm.h>
 #include <util/generic/algorithm.h>
 #include <util/generic/vector.h>
 #include <util/generic/vector.h>
 #include <util/generic/ptr.h>
 #include <util/generic/ptr.h>
-#include <util/system/atomic.h>
 #include <util/thread/pool.h>
 #include <util/thread/pool.h>
 
 
 #include "lfqueue.h"
 #include "lfqueue.h"
@@ -211,8 +210,7 @@ Y_UNIT_TEST_SUITE(TLockFreeQueueTests) {
             });
             });
         }
         }
 
 
-        TAtomic elementsLeft;
-        AtomicSet(elementsLeft, threadsNum * enqueuesPerThread);
+        std::atomic<size_t> elementsLeft = threadsNum * enqueuesPerThread;
 
 
         ui64 numOfConsumers = singleConsumer ? 1 : threadsNum;
         ui64 numOfConsumers = singleConsumer ? 1 : threadsNum;
 
 
@@ -224,12 +222,12 @@ Y_UNIT_TEST_SUITE(TLockFreeQueueTests) {
 
 
             p.SafeAddFunc([&queue, &elementsLeft, promise, consumerData{&dataBuckets[i]}]() mutable {
             p.SafeAddFunc([&queue, &elementsLeft, promise, consumerData{&dataBuckets[i]}]() mutable {
                 TVector<int> vec;
                 TVector<int> vec;
-                while (static_cast<i64>(AtomicGet(elementsLeft)) > 0) {
+                while (static_cast<i64>(elementsLeft.load()) > 0) {
                     for (size_t i = 0; i != 100; ++i) {
                     for (size_t i = 0; i != 100; ++i) {
                         vec.clear();
                         vec.clear();
                         queue.DequeueAll(&vec);
                         queue.DequeueAll(&vec);
 
 
-                        AtomicSub(elementsLeft, vec.size());
+                        elementsLeft -= vec.size();
                         consumerData->insert(consumerData->end(), vec.begin(), vec.end());
                         consumerData->insert(consumerData->end(), vec.begin(), vec.end());
                     }
                     }
                 }
                 }