lfqueue.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. #pragma once
  2. #include "fwd.h"
  3. #include "lfstack.h"
  4. #include <util/generic/ptr.h>
  5. #include <util/system/yassert.h>
  6. #include <atomic>
  7. struct TDefaultLFCounter {
  8. template <class T>
  9. void IncCount(const T& data) {
  10. (void)data;
  11. }
  12. template <class T>
  13. void DecCount(const T& data) {
  14. (void)data;
  15. }
  16. };
  17. // @brief lockfree queue
  18. // @tparam T - the queue element, should be movable
  19. // @tparam TCounter, a observer class to count number of items in queue
  20. // be careful, IncCount and DecCount can be called on a moved object and
  21. // it is TCounter class responsibility to check validity of passed object
  22. template <class T, class TCounter>
  23. class TLockFreeQueue: public TNonCopyable {
  24. struct TListNode {
  25. template <typename U>
  26. TListNode(U&& u, TListNode* next)
  27. : Next(next)
  28. , Data(std::forward<U>(u))
  29. {
  30. }
  31. template <typename U>
  32. explicit TListNode(U&& u)
  33. : Data(std::forward<U>(u))
  34. {
  35. }
  36. std::atomic<TListNode*> Next;
  37. T Data;
  38. };
  39. // using inheritance to be able to use 0 bytes for TCounter when we don't need one
  40. struct TRootNode: public TCounter {
  41. std::atomic<TListNode*> PushQueue = nullptr;
  42. std::atomic<TListNode*> PopQueue = nullptr;
  43. std::atomic<TListNode*> ToDelete = nullptr;
  44. std::atomic<TRootNode*> NextFree = nullptr;
  45. void CopyCounter(TRootNode* x) {
  46. *(TCounter*)this = *(TCounter*)x;
  47. }
  48. };
  49. static void EraseList(TListNode* n) {
  50. while (n) {
  51. TListNode* keepNext = n->Next.load(std::memory_order_acquire);
  52. delete n;
  53. n = keepNext;
  54. }
  55. }
  56. alignas(64) std::atomic<TRootNode*> JobQueue;
  57. alignas(64) std::atomic<size_t> FreememCounter;
  58. alignas(64) std::atomic<size_t> FreeingTaskCounter;
  59. alignas(64) std::atomic<TRootNode*> FreePtr;
  60. void TryToFreeAsyncMemory() {
  61. const auto keepCounter = FreeingTaskCounter.load();
  62. TRootNode* current = FreePtr.load(std::memory_order_acquire);
  63. if (current == nullptr) {
  64. return;
  65. }
  66. if (FreememCounter.load() == 1) {
  67. // we are the last thread, try to cleanup
  68. // check if another thread have cleaned up
  69. if (keepCounter != FreeingTaskCounter.load()) {
  70. return;
  71. }
  72. if (FreePtr.compare_exchange_strong(current, nullptr)) {
  73. // free list
  74. while (current) {
  75. TRootNode* p = current->NextFree.load(std::memory_order_acquire);
  76. EraseList(current->ToDelete.load(std::memory_order_acquire));
  77. delete current;
  78. current = p;
  79. }
  80. ++FreeingTaskCounter;
  81. }
  82. }
  83. }
  84. void AsyncRef() {
  85. ++FreememCounter;
  86. }
  87. void AsyncUnref() {
  88. TryToFreeAsyncMemory();
  89. --FreememCounter;
  90. }
  91. void AsyncDel(TRootNode* toDelete, TListNode* lst) {
  92. toDelete->ToDelete.store(lst, std::memory_order_release);
  93. for (auto freePtr = FreePtr.load();;) {
  94. toDelete->NextFree.store(freePtr, std::memory_order_release);
  95. if (FreePtr.compare_exchange_weak(freePtr, toDelete)) {
  96. break;
  97. }
  98. }
  99. }
  100. void AsyncUnref(TRootNode* toDelete, TListNode* lst) {
  101. TryToFreeAsyncMemory();
  102. if (--FreememCounter == 0) {
  103. // no other operations in progress, can safely reclaim memory
  104. EraseList(lst);
  105. delete toDelete;
  106. } else {
  107. // Dequeue()s in progress, put node to free list
  108. AsyncDel(toDelete, lst);
  109. }
  110. }
  111. struct TListInvertor {
  112. TListNode* Copy;
  113. TListNode* Tail;
  114. TListNode* PrevFirst;
  115. TListInvertor()
  116. : Copy(nullptr)
  117. , Tail(nullptr)
  118. , PrevFirst(nullptr)
  119. {
  120. }
  121. ~TListInvertor() {
  122. EraseList(Copy);
  123. }
  124. void CopyWasUsed() {
  125. Copy = nullptr;
  126. Tail = nullptr;
  127. PrevFirst = nullptr;
  128. }
  129. void DoCopy(TListNode* ptr) {
  130. TListNode* newFirst = ptr;
  131. TListNode* newCopy = nullptr;
  132. TListNode* newTail = nullptr;
  133. while (ptr) {
  134. if (ptr == PrevFirst) {
  135. // short cut, we have copied this part already
  136. Tail->Next.store(newCopy, std::memory_order_release);
  137. newCopy = Copy;
  138. Copy = nullptr; // do not destroy prev try
  139. if (!newTail) {
  140. newTail = Tail; // tried to invert same list
  141. }
  142. break;
  143. }
  144. TListNode* newElem = new TListNode(ptr->Data, newCopy);
  145. newCopy = newElem;
  146. ptr = ptr->Next.load(std::memory_order_acquire);
  147. if (!newTail) {
  148. newTail = newElem;
  149. }
  150. }
  151. EraseList(Copy); // copy was useless
  152. Copy = newCopy;
  153. PrevFirst = newFirst;
  154. Tail = newTail;
  155. }
  156. };
  157. void EnqueueImpl(TListNode* head, TListNode* tail) {
  158. TRootNode* newRoot = new TRootNode;
  159. AsyncRef();
  160. newRoot->PushQueue.store(head, std::memory_order_release);
  161. for (TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);;) {
  162. tail->Next.store(curRoot->PushQueue.load(std::memory_order_acquire), std::memory_order_release);
  163. newRoot->PopQueue.store(curRoot->PopQueue.load(std::memory_order_acquire), std::memory_order_release);
  164. newRoot->CopyCounter(curRoot);
  165. for (TListNode* node = head;; node = node->Next.load(std::memory_order_acquire)) {
  166. newRoot->IncCount(node->Data);
  167. if (node == tail) {
  168. break;
  169. }
  170. }
  171. if (JobQueue.compare_exchange_weak(curRoot, newRoot)) {
  172. AsyncUnref(curRoot, nullptr);
  173. break;
  174. }
  175. }
  176. }
  177. template <typename TCollection>
  178. static void FillCollection(TListNode* lst, TCollection* res) {
  179. while (lst) {
  180. res->emplace_back(std::move(lst->Data));
  181. lst = lst->Next.load(std::memory_order_acquire);
  182. }
  183. }
  184. /** Traverses a given list simultaneously creating its inversed version.
  185. * After that, fills a collection with a reversed version and returns the last visited lst's node.
  186. */
  187. template <typename TCollection>
  188. static TListNode* FillCollectionReverse(TListNode* lst, TCollection* res) {
  189. if (!lst) {
  190. return nullptr;
  191. }
  192. TListNode* newCopy = nullptr;
  193. do {
  194. TListNode* newElem = new TListNode(std::move(lst->Data), newCopy);
  195. newCopy = newElem;
  196. lst = lst->Next.load(std::memory_order_acquire);
  197. } while (lst);
  198. FillCollection(newCopy, res);
  199. EraseList(newCopy);
  200. return lst;
  201. }
  202. public:
  203. TLockFreeQueue()
  204. : JobQueue(new TRootNode)
  205. , FreememCounter(0)
  206. , FreeingTaskCounter(0)
  207. , FreePtr(nullptr)
  208. {
  209. }
  210. ~TLockFreeQueue() {
  211. AsyncRef();
  212. AsyncUnref(); // should free FreeList
  213. EraseList(JobQueue.load(std::memory_order_relaxed)->PushQueue.load(std::memory_order_relaxed));
  214. EraseList(JobQueue.load(std::memory_order_relaxed)->PopQueue.load(std::memory_order_relaxed));
  215. delete JobQueue;
  216. }
  217. template <typename U>
  218. void Enqueue(U&& data) {
  219. TListNode* newNode = new TListNode(std::forward<U>(data));
  220. EnqueueImpl(newNode, newNode);
  221. }
  222. void Enqueue(T&& data) {
  223. TListNode* newNode = new TListNode(std::move(data));
  224. EnqueueImpl(newNode, newNode);
  225. }
  226. void Enqueue(const T& data) {
  227. TListNode* newNode = new TListNode(data);
  228. EnqueueImpl(newNode, newNode);
  229. }
  230. template <typename TCollection>
  231. void EnqueueAll(const TCollection& data) {
  232. EnqueueAll(data.begin(), data.end());
  233. }
  234. template <typename TIter>
  235. void EnqueueAll(TIter dataBegin, TIter dataEnd) {
  236. if (dataBegin == dataEnd) {
  237. return;
  238. }
  239. TIter i = dataBegin;
  240. TListNode* node = new TListNode(*i);
  241. TListNode* tail = node;
  242. for (++i; i != dataEnd; ++i) {
  243. TListNode* nextNode = node;
  244. node = new TListNode(*i, nextNode);
  245. }
  246. EnqueueImpl(node, tail);
  247. }
  248. bool Dequeue(T* data) {
  249. TRootNode* newRoot = nullptr;
  250. TListInvertor listInvertor;
  251. AsyncRef();
  252. for (TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);;) {
  253. TListNode* tail = curRoot->PopQueue.load(std::memory_order_acquire);
  254. if (tail) {
  255. // has elems to pop
  256. if (!newRoot) {
  257. newRoot = new TRootNode;
  258. }
  259. newRoot->PushQueue.store(curRoot->PushQueue.load(std::memory_order_acquire), std::memory_order_release);
  260. newRoot->PopQueue.store(tail->Next.load(std::memory_order_acquire), std::memory_order_release);
  261. newRoot->CopyCounter(curRoot);
  262. newRoot->DecCount(tail->Data);
  263. Y_ASSERT(curRoot->PopQueue.load() == tail);
  264. if (JobQueue.compare_exchange_weak(curRoot, newRoot)) {
  265. *data = std::move(tail->Data);
  266. tail->Next.store(nullptr, std::memory_order_release);
  267. AsyncUnref(curRoot, tail);
  268. return true;
  269. }
  270. continue;
  271. }
  272. if (curRoot->PushQueue.load(std::memory_order_acquire) == nullptr) {
  273. delete newRoot;
  274. AsyncUnref();
  275. return false; // no elems to pop
  276. }
  277. if (!newRoot) {
  278. newRoot = new TRootNode;
  279. }
  280. newRoot->PushQueue.store(nullptr, std::memory_order_release);
  281. listInvertor.DoCopy(curRoot->PushQueue.load(std::memory_order_acquire));
  282. newRoot->PopQueue.store(listInvertor.Copy, std::memory_order_release);
  283. newRoot->CopyCounter(curRoot);
  284. Y_ASSERT(curRoot->PopQueue.load() == nullptr);
  285. if (JobQueue.compare_exchange_weak(curRoot, newRoot)) {
  286. AsyncDel(curRoot, curRoot->PushQueue.load(std::memory_order_acquire));
  287. curRoot = newRoot;
  288. newRoot = nullptr;
  289. listInvertor.CopyWasUsed();
  290. } else {
  291. newRoot->PopQueue.store(nullptr, std::memory_order_release);
  292. }
  293. }
  294. }
  295. template <typename TCollection>
  296. void DequeueAll(TCollection* res) {
  297. AsyncRef();
  298. TRootNode* newRoot = new TRootNode;
  299. TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);
  300. do {
  301. } while (!JobQueue.compare_exchange_weak(curRoot, newRoot));
  302. FillCollection(curRoot->PopQueue, res);
  303. TListNode* toDeleteHead = curRoot->PushQueue;
  304. TListNode* toDeleteTail = FillCollectionReverse(curRoot->PushQueue, res);
  305. curRoot->PushQueue.store(nullptr, std::memory_order_release);
  306. if (toDeleteTail) {
  307. toDeleteTail->Next.store(curRoot->PopQueue.load());
  308. } else {
  309. toDeleteTail = curRoot->PopQueue;
  310. }
  311. curRoot->PopQueue.store(nullptr, std::memory_order_release);
  312. AsyncUnref(curRoot, toDeleteHead);
  313. }
  314. bool IsEmpty() {
  315. AsyncRef();
  316. TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);
  317. bool res = curRoot->PushQueue.load(std::memory_order_acquire) == nullptr &&
  318. curRoot->PopQueue.load(std::memory_order_acquire) == nullptr;
  319. AsyncUnref();
  320. return res;
  321. }
  322. TCounter GetCounter() {
  323. AsyncRef();
  324. TRootNode* curRoot = JobQueue.load(std::memory_order_acquire);
  325. TCounter res = *(TCounter*)curRoot;
  326. AsyncUnref();
  327. return res;
  328. }
  329. };
  330. template <class T, class TCounter>
  331. class TAutoLockFreeQueue {
  332. public:
  333. using TRef = THolder<T>;
  334. inline ~TAutoLockFreeQueue() {
  335. TRef tmp;
  336. while (Dequeue(&tmp)) {
  337. }
  338. }
  339. inline bool Dequeue(TRef* t) {
  340. T* res = nullptr;
  341. if (Queue.Dequeue(&res)) {
  342. t->Reset(res);
  343. return true;
  344. }
  345. return false;
  346. }
  347. inline void Enqueue(TRef& t) {
  348. Queue.Enqueue(t.Get());
  349. Y_UNUSED(t.Release());
  350. }
  351. inline void Enqueue(TRef&& t) {
  352. Queue.Enqueue(t.Get());
  353. Y_UNUSED(t.Release());
  354. }
  355. inline bool IsEmpty() {
  356. return Queue.IsEmpty();
  357. }
  358. inline TCounter GetCounter() {
  359. return Queue.GetCounter();
  360. }
  361. private:
  362. TLockFreeQueue<T*, TCounter> Queue;
  363. };