CoalescingBitVector.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. ///
  14. /// \file
  15. /// A bitvector that uses an IntervalMap to coalesce adjacent elements
  16. /// into intervals.
  17. ///
  18. //===----------------------------------------------------------------------===//
  19. #ifndef LLVM_ADT_COALESCINGBITVECTOR_H
  20. #define LLVM_ADT_COALESCINGBITVECTOR_H
  21. #include "llvm/ADT/IntervalMap.h"
  22. #include "llvm/ADT/STLExtras.h"
  23. #include "llvm/ADT/SmallVector.h"
  24. #include "llvm/ADT/iterator_range.h"
  25. #include "llvm/Support/Debug.h"
  26. #include "llvm/Support/raw_ostream.h"
  27. #include <initializer_list>
  28. namespace llvm {
  29. /// A bitvector that, under the hood, relies on an IntervalMap to coalesce
  30. /// elements into intervals. Good for representing sets which predominantly
  31. /// contain contiguous ranges. Bad for representing sets with lots of gaps
  32. /// between elements.
  33. ///
  34. /// Compared to SparseBitVector, CoalescingBitVector offers more predictable
  35. /// performance for non-sequential find() operations.
  36. ///
  37. /// \tparam IndexT - The type of the index into the bitvector.
  38. template <typename IndexT> class CoalescingBitVector {
  39. static_assert(std::is_unsigned<IndexT>::value,
  40. "Index must be an unsigned integer.");
  41. using ThisT = CoalescingBitVector<IndexT>;
  42. /// An interval map for closed integer ranges. The mapped values are unused.
  43. using MapT = IntervalMap<IndexT, char>;
  44. using UnderlyingIterator = typename MapT::const_iterator;
  45. using IntervalT = std::pair<IndexT, IndexT>;
  46. public:
  47. using Allocator = typename MapT::Allocator;
  48. /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator
  49. /// reference.
  50. CoalescingBitVector(Allocator &Alloc)
  51. : Alloc(&Alloc), Intervals(Alloc) {}
  52. /// \name Copy/move constructors and assignment operators.
  53. /// @{
  54. CoalescingBitVector(const ThisT &Other)
  55. : Alloc(Other.Alloc), Intervals(*Other.Alloc) {
  56. set(Other);
  57. }
  58. ThisT &operator=(const ThisT &Other) {
  59. clear();
  60. set(Other);
  61. return *this;
  62. }
  63. CoalescingBitVector(ThisT &&Other) = delete;
  64. ThisT &operator=(ThisT &&Other) = delete;
  65. /// @}
  66. /// Clear all the bits.
  67. void clear() { Intervals.clear(); }
  68. /// Check whether no bits are set.
  69. bool empty() const { return Intervals.empty(); }
  70. /// Count the number of set bits.
  71. unsigned count() const {
  72. unsigned Bits = 0;
  73. for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
  74. Bits += 1 + It.stop() - It.start();
  75. return Bits;
  76. }
  77. /// Set the bit at \p Index.
  78. ///
  79. /// This method does /not/ support setting a bit that has already been set,
  80. /// for efficiency reasons. If possible, restructure your code to not set the
  81. /// same bit multiple times, or use \ref test_and_set.
  82. void set(IndexT Index) {
  83. assert(!test(Index) && "Setting already-set bits not supported/efficient, "
  84. "IntervalMap will assert");
  85. insert(Index, Index);
  86. }
  87. /// Set the bits set in \p Other.
  88. ///
  89. /// This method does /not/ support setting already-set bits, see \ref set
  90. /// for the rationale. For a safe set union operation, use \ref operator|=.
  91. void set(const ThisT &Other) {
  92. for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
  93. It != End; ++It)
  94. insert(It.start(), It.stop());
  95. }
  96. /// Set the bits at \p Indices. Used for testing, primarily.
  97. void set(std::initializer_list<IndexT> Indices) {
  98. for (IndexT Index : Indices)
  99. set(Index);
  100. }
  101. /// Check whether the bit at \p Index is set.
  102. bool test(IndexT Index) const {
  103. const auto It = Intervals.find(Index);
  104. if (It == Intervals.end())
  105. return false;
  106. assert(It.stop() >= Index && "Interval must end after Index");
  107. return It.start() <= Index;
  108. }
  109. /// Set the bit at \p Index. Supports setting an already-set bit.
  110. void test_and_set(IndexT Index) {
  111. if (!test(Index))
  112. set(Index);
  113. }
  114. /// Reset the bit at \p Index. Supports resetting an already-unset bit.
  115. void reset(IndexT Index) {
  116. auto It = Intervals.find(Index);
  117. if (It == Intervals.end())
  118. return;
  119. // Split the interval containing Index into up to two parts: one from
  120. // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to
  121. // either Start or Stop, we create one new interval. If Index is equal to
  122. // both Start and Stop, we simply erase the existing interval.
  123. IndexT Start = It.start();
  124. if (Index < Start)
  125. // The index was not set.
  126. return;
  127. IndexT Stop = It.stop();
  128. assert(Index <= Stop && "Wrong interval for index");
  129. It.erase();
  130. if (Start < Index)
  131. insert(Start, Index - 1);
  132. if (Index < Stop)
  133. insert(Index + 1, Stop);
  134. }
  135. /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may
  136. /// be a faster alternative.
  137. void operator|=(const ThisT &RHS) {
  138. // Get the overlaps between the two interval maps.
  139. SmallVector<IntervalT, 8> Overlaps;
  140. getOverlaps(RHS, Overlaps);
  141. // Insert the non-overlapping parts of all the intervals from RHS.
  142. for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
  143. It != End; ++It) {
  144. IndexT Start = It.start();
  145. IndexT Stop = It.stop();
  146. SmallVector<IntervalT, 8> NonOverlappingParts;
  147. getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
  148. for (IntervalT AdditivePortion : NonOverlappingParts)
  149. insert(AdditivePortion.first, AdditivePortion.second);
  150. }
  151. }
  152. /// Set intersection.
  153. void operator&=(const ThisT &RHS) {
  154. // Get the overlaps between the two interval maps (i.e. the intersection).
  155. SmallVector<IntervalT, 8> Overlaps;
  156. getOverlaps(RHS, Overlaps);
  157. // Rebuild the interval map, including only the overlaps.
  158. clear();
  159. for (IntervalT Overlap : Overlaps)
  160. insert(Overlap.first, Overlap.second);
  161. }
  162. /// Reset all bits present in \p Other.
  163. void intersectWithComplement(const ThisT &Other) {
  164. SmallVector<IntervalT, 8> Overlaps;
  165. if (!getOverlaps(Other, Overlaps)) {
  166. // If there is no overlap with Other, the intersection is empty.
  167. return;
  168. }
  169. // Delete the overlapping intervals. Split up intervals that only partially
  170. // intersect an overlap.
  171. for (IntervalT Overlap : Overlaps) {
  172. IndexT OlapStart, OlapStop;
  173. std::tie(OlapStart, OlapStop) = Overlap;
  174. auto It = Intervals.find(OlapStart);
  175. IndexT CurrStart = It.start();
  176. IndexT CurrStop = It.stop();
  177. assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
  178. "Expected some intersection!");
  179. // Split the overlap interval into up to two parts: one from [CurrStart,
  180. // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is
  181. // equal to CurrStart, the first split interval is unnecessary. Ditto for
  182. // when OlapStop is equal to CurrStop, we omit the second split interval.
  183. It.erase();
  184. if (CurrStart < OlapStart)
  185. insert(CurrStart, OlapStart - 1);
  186. if (OlapStop < CurrStop)
  187. insert(OlapStop + 1, CurrStop);
  188. }
  189. }
  190. bool operator==(const ThisT &RHS) const {
  191. // We cannot just use std::equal because it checks the dereferenced values
  192. // of an iterator pair for equality, not the iterators themselves. In our
  193. // case that results in comparison of the (unused) IntervalMap values.
  194. auto ItL = Intervals.begin();
  195. auto ItR = RHS.Intervals.begin();
  196. while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
  197. ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
  198. ++ItL;
  199. ++ItR;
  200. }
  201. return ItL == Intervals.end() && ItR == RHS.Intervals.end();
  202. }
  203. bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
  204. class const_iterator {
  205. friend class CoalescingBitVector;
  206. public:
  207. using iterator_category = std::forward_iterator_tag;
  208. using value_type = IndexT;
  209. using difference_type = std::ptrdiff_t;
  210. using pointer = value_type *;
  211. using reference = value_type &;
  212. private:
  213. // For performance reasons, make the offset at the end different than the
  214. // one used in \ref begin, to optimize the common `It == end()` pattern.
  215. static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
  216. UnderlyingIterator MapIterator;
  217. unsigned OffsetIntoMapIterator = 0;
  218. // Querying the start/stop of an IntervalMap iterator can be very expensive.
  219. // Cache these values for performance reasons.
  220. IndexT CachedStart = IndexT();
  221. IndexT CachedStop = IndexT();
  222. void setToEnd() {
  223. OffsetIntoMapIterator = kIteratorAtTheEndOffset;
  224. CachedStart = IndexT();
  225. CachedStop = IndexT();
  226. }
  227. /// MapIterator has just changed, reset the cached state to point to the
  228. /// start of the new underlying iterator.
  229. void resetCache() {
  230. if (MapIterator.valid()) {
  231. OffsetIntoMapIterator = 0;
  232. CachedStart = MapIterator.start();
  233. CachedStop = MapIterator.stop();
  234. } else {
  235. setToEnd();
  236. }
  237. }
  238. /// Advance the iterator to \p Index, if it is contained within the current
  239. /// interval. The public-facing method which supports advancing past the
  240. /// current interval is \ref advanceToLowerBound.
  241. void advanceTo(IndexT Index) {
  242. assert(Index <= CachedStop && "Cannot advance to OOB index");
  243. if (Index < CachedStart)
  244. // We're already past this index.
  245. return;
  246. OffsetIntoMapIterator = Index - CachedStart;
  247. }
  248. const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
  249. resetCache();
  250. }
  251. public:
  252. const_iterator() { setToEnd(); }
  253. bool operator==(const const_iterator &RHS) const {
  254. // Do /not/ compare MapIterator for equality, as this is very expensive.
  255. // The cached start/stop values make that check unnecessary.
  256. return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
  257. std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
  258. RHS.CachedStop);
  259. }
  260. bool operator!=(const const_iterator &RHS) const {
  261. return !operator==(RHS);
  262. }
  263. IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
  264. const_iterator &operator++() { // Pre-increment (++It).
  265. if (CachedStart + OffsetIntoMapIterator < CachedStop) {
  266. // Keep going within the current interval.
  267. ++OffsetIntoMapIterator;
  268. } else {
  269. // We reached the end of the current interval: advance.
  270. ++MapIterator;
  271. resetCache();
  272. }
  273. return *this;
  274. }
  275. const_iterator operator++(int) { // Post-increment (It++).
  276. const_iterator tmp = *this;
  277. operator++();
  278. return tmp;
  279. }
  280. /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If
  281. /// no such set bit exists, advance to end(). This is like std::lower_bound.
  282. /// This is useful if \p Index is close to the current iterator position.
  283. /// However, unlike \ref find(), this has worst-case O(n) performance.
  284. void advanceToLowerBound(IndexT Index) {
  285. if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
  286. return;
  287. // Advance to the first interval containing (or past) Index, or to end().
  288. while (Index > CachedStop) {
  289. ++MapIterator;
  290. resetCache();
  291. if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
  292. return;
  293. }
  294. advanceTo(Index);
  295. }
  296. };
  297. const_iterator begin() const { return const_iterator(Intervals.begin()); }
  298. const_iterator end() const { return const_iterator(); }
  299. /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index.
  300. /// If no such set bit exists, return end(). This is like std::lower_bound.
  301. /// This has worst-case logarithmic performance (roughly O(log(gaps between
  302. /// contiguous ranges))).
  303. const_iterator find(IndexT Index) const {
  304. auto UnderlyingIt = Intervals.find(Index);
  305. if (UnderlyingIt == Intervals.end())
  306. return end();
  307. auto It = const_iterator(UnderlyingIt);
  308. It.advanceTo(Index);
  309. return It;
  310. }
  311. /// Return a range iterator which iterates over all of the set bits in the
  312. /// half-open range [Start, End).
  313. iterator_range<const_iterator> half_open_range(IndexT Start,
  314. IndexT End) const {
  315. assert(Start < End && "Not a valid range");
  316. auto StartIt = find(Start);
  317. if (StartIt == end() || *StartIt >= End)
  318. return {end(), end()};
  319. auto EndIt = StartIt;
  320. EndIt.advanceToLowerBound(End);
  321. return {StartIt, EndIt};
  322. }
  323. void print(raw_ostream &OS) const {
  324. OS << "{";
  325. for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
  326. ++It) {
  327. OS << "[" << It.start();
  328. if (It.start() != It.stop())
  329. OS << ", " << It.stop();
  330. OS << "]";
  331. }
  332. OS << "}";
  333. }
  334. #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
  335. LLVM_DUMP_METHOD void dump() const {
  336. // LLDB swallows the first line of output after callling dump(). Add
  337. // newlines before/after the braces to work around this.
  338. dbgs() << "\n";
  339. print(dbgs());
  340. dbgs() << "\n";
  341. }
  342. #endif
  343. private:
  344. void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
  345. /// Record the overlaps between \p this and \p Other in \p Overlaps. Return
  346. /// true if there is any overlap.
  347. bool getOverlaps(const ThisT &Other,
  348. SmallVectorImpl<IntervalT> &Overlaps) const {
  349. for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
  350. I.valid(); ++I)
  351. Overlaps.emplace_back(I.start(), I.stop());
  352. assert(llvm::is_sorted(Overlaps,
  353. [](IntervalT LHS, IntervalT RHS) {
  354. return LHS.second < RHS.first;
  355. }) &&
  356. "Overlaps must be sorted");
  357. return !Overlaps.empty();
  358. }
  359. /// Given the set of overlaps between this and some other bitvector, and an
  360. /// interval [Start, Stop] from that bitvector, determine the portions of the
  361. /// interval which do not overlap with this.
  362. void getNonOverlappingParts(IndexT Start, IndexT Stop,
  363. const SmallVectorImpl<IntervalT> &Overlaps,
  364. SmallVectorImpl<IntervalT> &NonOverlappingParts) {
  365. IndexT NextUncoveredBit = Start;
  366. for (IntervalT Overlap : Overlaps) {
  367. IndexT OlapStart, OlapStop;
  368. std::tie(OlapStart, OlapStop) = Overlap;
  369. // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop
  370. // and Start <= OlapStop.
  371. bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
  372. if (!DoesOverlap)
  373. continue;
  374. // Cover the range [NextUncoveredBit, OlapStart). This puts the start of
  375. // the next uncovered range at OlapStop+1.
  376. if (NextUncoveredBit < OlapStart)
  377. NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
  378. NextUncoveredBit = OlapStop + 1;
  379. if (NextUncoveredBit > Stop)
  380. break;
  381. }
  382. if (NextUncoveredBit <= Stop)
  383. NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
  384. }
  385. Allocator *Alloc;
  386. MapT Intervals;
  387. };
  388. } // namespace llvm
  389. #endif // LLVM_ADT_COALESCINGBITVECTOR_H
  390. #ifdef __GNUC__
  391. #pragma GCC diagnostic pop
  392. #endif