skiplist.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. #pragma once
  2. #include "compare.h"
  3. #include <util/generic/algorithm.h>
  4. #include <util/generic/noncopyable.h>
  5. #include <util/generic/typetraits.h>
  6. #include <util/memory/pool.h>
  7. #include <util/random/random.h>
  8. #include <library/cpp/deprecated/atomic/atomic.h>
  9. namespace NThreading {
  10. ////////////////////////////////////////////////////////////////////////////////
  11. class TNopCounter {
  12. protected:
  13. template <typename T>
  14. void OnInsert(const T&) {
  15. }
  16. template <typename T>
  17. void OnUpdate(const T&) {
  18. }
  19. void Reset() {
  20. }
  21. };
  22. ////////////////////////////////////////////////////////////////////////////////
  23. class TSizeCounter {
  24. private:
  25. size_t Size;
  26. public:
  27. TSizeCounter()
  28. : Size(0)
  29. {
  30. }
  31. size_t GetSize() const {
  32. return Size;
  33. }
  34. protected:
  35. template <typename T>
  36. void OnInsert(const T&) {
  37. ++Size;
  38. }
  39. template <typename T>
  40. void OnUpdate(const T&) {
  41. }
  42. void Reset() {
  43. Size = 0;
  44. }
  45. };
  46. ////////////////////////////////////////////////////////////////////////////////
  47. // Append-only concurrent skip-list
  48. //
  49. // Readers do not require any synchronization.
  50. // Writers should be externally synchronized.
  51. // Nodes will be allocated using TMemoryPool instance.
  52. template <
  53. typename T,
  54. typename TComparer = TCompare<T>,
  55. typename TAllocator = TMemoryPool,
  56. typename TCounter = TSizeCounter,
  57. int MaxHeight = 12,
  58. int Branching = 4>
  59. class TSkipList: public TCounter, private TNonCopyable {
  60. class TNode {
  61. private:
  62. T Value; // should be immutable after insert
  63. TNode* Next[]; // variable-size array maximum of MaxHeight values
  64. public:
  65. TNode(T&& value)
  66. : Value(std::move(value))
  67. {
  68. Y_UNUSED(Next);
  69. }
  70. const T& GetValue() const {
  71. return Value;
  72. }
  73. T& GetValue() {
  74. return Value;
  75. }
  76. TNode* GetNext(int height) const {
  77. return AtomicGet(Next[height]);
  78. }
  79. void Link(int height, TNode** prev) {
  80. for (int i = 0; i < height; ++i) {
  81. Next[i] = prev[i]->Next[i];
  82. AtomicSet(prev[i]->Next[i], this);
  83. }
  84. }
  85. };
  86. public:
  87. class TIterator {
  88. private:
  89. const TSkipList* List;
  90. const TNode* Node;
  91. public:
  92. TIterator()
  93. : List(nullptr)
  94. , Node(nullptr)
  95. {
  96. }
  97. TIterator(const TSkipList* list, const TNode* node)
  98. : List(list)
  99. , Node(node)
  100. {
  101. }
  102. TIterator(const TIterator& other)
  103. : List(other.List)
  104. , Node(other.Node)
  105. {
  106. }
  107. TIterator& operator=(const TIterator& other) {
  108. List = other.List;
  109. Node = other.Node;
  110. return *this;
  111. }
  112. void Next() {
  113. Node = Node ? Node->GetNext(0) : nullptr;
  114. }
  115. // much less efficient than Next as our list is single-linked
  116. void Prev() {
  117. if (Node) {
  118. TNode* node = List->FindLessThan(Node->GetValue(), nullptr);
  119. Node = (node != List->Head ? node : nullptr);
  120. }
  121. }
  122. void Reset() {
  123. Node = nullptr;
  124. }
  125. bool IsValid() const {
  126. return Node != nullptr;
  127. }
  128. const T& GetValue() const {
  129. Y_ASSERT(IsValid());
  130. return Node->GetValue();
  131. }
  132. };
  133. private:
  134. TAllocator& Allocator;
  135. TComparer Comparer;
  136. TNode* Head;
  137. TAtomic Height;
  138. TCounter Counter;
  139. TNode* Prev[MaxHeight];
  140. template <typename TValue>
  141. using TComparerReturnType = std::invoke_result_t<TComparer, const T&, const TValue&>;
  142. public:
  143. TSkipList(TAllocator& allocator, const TComparer& comparer = TComparer())
  144. : Allocator(allocator)
  145. , Comparer(comparer)
  146. {
  147. Init();
  148. }
  149. ~TSkipList() {
  150. CallDtors();
  151. }
  152. void Clear() {
  153. CallDtors();
  154. Allocator.ClearKeepFirstChunk();
  155. Init();
  156. }
  157. bool Insert(T value) {
  158. TNode* node = PrepareInsert(value);
  159. if (Y_UNLIKELY(node && Compare(node, value) == 0)) {
  160. // we do not allow duplicates
  161. return false;
  162. }
  163. node = DoInsert(std::move(value));
  164. TCounter::OnInsert(node->GetValue());
  165. return true;
  166. }
  167. template <typename TInsertAction, typename TUpdateAction>
  168. bool Insert(const T& value, TInsertAction insert, TUpdateAction update) {
  169. TNode* node = PrepareInsert(value);
  170. if (Y_UNLIKELY(node && Compare(node, value) == 0)) {
  171. if (update(node->GetValue())) {
  172. TCounter::OnUpdate(node->GetValue());
  173. return true;
  174. }
  175. // we do not allow duplicates
  176. return false;
  177. }
  178. node = DoInsert(insert(value));
  179. TCounter::OnInsert(node->GetValue());
  180. return true;
  181. }
  182. template <typename TValue>
  183. bool Contains(const TValue& value) const {
  184. TNode* node = FindGreaterThanOrEqual(value);
  185. return node && Compare(node, value) == 0;
  186. }
  187. TIterator SeekToFirst() const {
  188. return TIterator(this, FindFirst());
  189. }
  190. TIterator SeekToLast() const {
  191. TNode* last = FindLast();
  192. return TIterator(this, last != Head ? last : nullptr);
  193. }
  194. template <typename TValue>
  195. TIterator SeekTo(const TValue& value) const {
  196. return TIterator(this, FindGreaterThanOrEqual(value));
  197. }
  198. private:
  199. static int RandomHeight() {
  200. int height = 1;
  201. while (height < MaxHeight && (RandomNumber<unsigned int>() % Branching) == 0) {
  202. ++height;
  203. }
  204. return height;
  205. }
  206. void Init() {
  207. Head = AllocateRootNode();
  208. Height = 1;
  209. TCounter::Reset();
  210. for (int i = 0; i < MaxHeight; ++i) {
  211. Prev[i] = Head;
  212. }
  213. }
  214. void CallDtors() {
  215. if (!TTypeTraits<T>::IsPod) {
  216. // we should explicitly call destructors for our nodes
  217. TNode* node = Head->GetNext(0);
  218. while (node) {
  219. TNode* next = node->GetNext(0);
  220. node->~TNode();
  221. node = next;
  222. }
  223. }
  224. }
  225. TNode* AllocateRootNode() {
  226. size_t size = sizeof(TNode) + sizeof(TNode*) * MaxHeight;
  227. void* buffer = Allocator.Allocate(size);
  228. memset(buffer, 0, size);
  229. return static_cast<TNode*>(buffer);
  230. }
  231. TNode* AllocateNode(T&& value, int height) {
  232. size_t size = sizeof(TNode) + sizeof(TNode*) * height;
  233. void* buffer = Allocator.Allocate(size);
  234. memset(buffer, 0, size);
  235. return new (buffer) TNode(std::move(value));
  236. }
  237. TNode* FindFirst() const {
  238. return Head->GetNext(0);
  239. }
  240. TNode* FindLast() const {
  241. TNode* node = Head;
  242. int height = AtomicGet(Height) - 1;
  243. while (true) {
  244. TNode* next = node->GetNext(height);
  245. if (next) {
  246. node = next;
  247. continue;
  248. }
  249. if (height) {
  250. --height;
  251. } else {
  252. return node;
  253. }
  254. }
  255. }
  256. template <typename TValue>
  257. TComparerReturnType<TValue> Compare(const TNode* node, const TValue& value) const {
  258. return Comparer(node->GetValue(), value);
  259. }
  260. template <typename TValue>
  261. TNode* FindLessThan(const TValue& value, TNode** links) const {
  262. TNode* node = Head;
  263. int height = AtomicGet(Height) - 1;
  264. TNode* prev = nullptr;
  265. while (true) {
  266. TNode* next = node->GetNext(height);
  267. if (next && next != prev) {
  268. TComparerReturnType<TValue> cmp = Compare(next, value);
  269. if (cmp < 0) {
  270. node = next;
  271. continue;
  272. }
  273. }
  274. if (links) {
  275. // collect links from upper levels
  276. links[height] = node;
  277. }
  278. if (height) {
  279. prev = next;
  280. --height;
  281. } else {
  282. return node;
  283. }
  284. }
  285. }
  286. template <typename TValue>
  287. TNode* FindGreaterThanOrEqual(const TValue& value) const {
  288. TNode* node = Head;
  289. int height = AtomicGet(Height) - 1;
  290. TNode* prev = nullptr;
  291. while (true) {
  292. TNode* next = node->GetNext(height);
  293. if (next && next != prev) {
  294. TComparerReturnType<TValue> cmp = Compare(next, value);
  295. if (cmp < 0) {
  296. node = next;
  297. continue;
  298. }
  299. if (cmp == 0) {
  300. return next;
  301. }
  302. }
  303. if (height) {
  304. prev = next;
  305. --height;
  306. } else {
  307. return next;
  308. }
  309. }
  310. }
  311. TNode* PrepareInsert(const T& value) {
  312. TNode* prev = Prev[0];
  313. TNode* next = prev->GetNext(0);
  314. if ((prev == Head || Compare(prev, value) < 0) && (next == nullptr || Compare(next, value) >= 0)) {
  315. // avoid seek in case of sequential insert
  316. } else {
  317. prev = FindLessThan(value, Prev);
  318. next = prev->GetNext(0);
  319. }
  320. return next;
  321. }
  322. TNode* DoInsert(T&& value) {
  323. // choose level to place new node
  324. int currentHeight = AtomicGet(Height);
  325. int height = RandomHeight();
  326. if (height > currentHeight) {
  327. for (int i = currentHeight; i < height; ++i) {
  328. // head should link to all levels
  329. Prev[i] = Head;
  330. }
  331. AtomicSet(Height, height);
  332. }
  333. TNode* node = AllocateNode(std::move(value), height);
  334. node->Link(height, Prev);
  335. // keep last inserted node to optimize sequential inserts
  336. for (int i = 0; i < height; i++) {
  337. Prev[i] = node;
  338. }
  339. return node;
  340. }
  341. };
  342. }