EquivalenceClasses.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/ADT/EquivalenceClasses.h - Generic Equiv. Classes ---*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. ///
  14. /// \file
  15. /// Generic implementation of equivalence classes through the use Tarjan's
  16. /// efficient union-find algorithm.
  17. ///
  18. //===----------------------------------------------------------------------===//
  19. #ifndef LLVM_ADT_EQUIVALENCECLASSES_H
  20. #define LLVM_ADT_EQUIVALENCECLASSES_H
  21. #include <cassert>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <iterator>
  25. #include <set>
  26. namespace llvm {
  27. /// EquivalenceClasses - This represents a collection of equivalence classes and
  28. /// supports three efficient operations: insert an element into a class of its
  29. /// own, union two classes, and find the class for a given element. In
  30. /// addition to these modification methods, it is possible to iterate over all
  31. /// of the equivalence classes and all of the elements in a class.
  32. ///
  33. /// This implementation is an efficient implementation that only stores one copy
  34. /// of the element being indexed per entry in the set, and allows any arbitrary
  35. /// type to be indexed (as long as it can be ordered with operator< or a
  36. /// comparator is provided).
  37. ///
  38. /// Here is a simple example using integers:
  39. ///
  40. /// \code
  41. /// EquivalenceClasses<int> EC;
  42. /// EC.unionSets(1, 2); // insert 1, 2 into the same set
  43. /// EC.insert(4); EC.insert(5); // insert 4, 5 into own sets
  44. /// EC.unionSets(5, 1); // merge the set for 1 with 5's set.
  45. ///
  46. /// for (EquivalenceClasses<int>::iterator I = EC.begin(), E = EC.end();
  47. /// I != E; ++I) { // Iterate over all of the equivalence sets.
  48. /// if (!I->isLeader()) continue; // Ignore non-leader sets.
  49. /// for (EquivalenceClasses<int>::member_iterator MI = EC.member_begin(I);
  50. /// MI != EC.member_end(); ++MI) // Loop over members in this set.
  51. /// cerr << *MI << " "; // Print member.
  52. /// cerr << "\n"; // Finish set.
  53. /// }
  54. /// \endcode
  55. ///
  56. /// This example prints:
  57. /// 4
  58. /// 5 1 2
  59. ///
  60. template <class ElemTy, class Compare = std::less<ElemTy>>
  61. class EquivalenceClasses {
  62. /// ECValue - The EquivalenceClasses data structure is just a set of these.
  63. /// Each of these represents a relation for a value. First it stores the
  64. /// value itself, which provides the ordering that the set queries. Next, it
  65. /// provides a "next pointer", which is used to enumerate all of the elements
  66. /// in the unioned set. Finally, it defines either a "end of list pointer" or
  67. /// "leader pointer" depending on whether the value itself is a leader. A
  68. /// "leader pointer" points to the node that is the leader for this element,
  69. /// if the node is not a leader. A "end of list pointer" points to the last
  70. /// node in the list of members of this list. Whether or not a node is a
  71. /// leader is determined by a bit stolen from one of the pointers.
  72. class ECValue {
  73. friend class EquivalenceClasses;
  74. mutable const ECValue *Leader, *Next;
  75. ElemTy Data;
  76. // ECValue ctor - Start out with EndOfList pointing to this node, Next is
  77. // Null, isLeader = true.
  78. ECValue(const ElemTy &Elt)
  79. : Leader(this), Next((ECValue*)(intptr_t)1), Data(Elt) {}
  80. const ECValue *getLeader() const {
  81. if (isLeader()) return this;
  82. if (Leader->isLeader()) return Leader;
  83. // Path compression.
  84. return Leader = Leader->getLeader();
  85. }
  86. const ECValue *getEndOfList() const {
  87. assert(isLeader() && "Cannot get the end of a list for a non-leader!");
  88. return Leader;
  89. }
  90. void setNext(const ECValue *NewNext) const {
  91. assert(getNext() == nullptr && "Already has a next pointer!");
  92. Next = (const ECValue*)((intptr_t)NewNext | (intptr_t)isLeader());
  93. }
  94. public:
  95. ECValue(const ECValue &RHS) : Leader(this), Next((ECValue*)(intptr_t)1),
  96. Data(RHS.Data) {
  97. // Only support copying of singleton nodes.
  98. assert(RHS.isLeader() && RHS.getNext() == nullptr && "Not a singleton!");
  99. }
  100. bool isLeader() const { return (intptr_t)Next & 1; }
  101. const ElemTy &getData() const { return Data; }
  102. const ECValue *getNext() const {
  103. return (ECValue*)((intptr_t)Next & ~(intptr_t)1);
  104. }
  105. };
  106. /// A wrapper of the comparator, to be passed to the set.
  107. struct ECValueComparator {
  108. using is_transparent = void;
  109. ECValueComparator() : compare(Compare()) {}
  110. bool operator()(const ECValue &lhs, const ECValue &rhs) const {
  111. return compare(lhs.Data, rhs.Data);
  112. }
  113. template <typename T>
  114. bool operator()(const T &lhs, const ECValue &rhs) const {
  115. return compare(lhs, rhs.Data);
  116. }
  117. template <typename T>
  118. bool operator()(const ECValue &lhs, const T &rhs) const {
  119. return compare(lhs.Data, rhs);
  120. }
  121. const Compare compare;
  122. };
  123. /// TheMapping - This implicitly provides a mapping from ElemTy values to the
  124. /// ECValues, it just keeps the key as part of the value.
  125. std::set<ECValue, ECValueComparator> TheMapping;
  126. public:
  127. EquivalenceClasses() = default;
  128. EquivalenceClasses(const EquivalenceClasses &RHS) {
  129. operator=(RHS);
  130. }
  131. const EquivalenceClasses &operator=(const EquivalenceClasses &RHS) {
  132. TheMapping.clear();
  133. for (iterator I = RHS.begin(), E = RHS.end(); I != E; ++I)
  134. if (I->isLeader()) {
  135. member_iterator MI = RHS.member_begin(I);
  136. member_iterator LeaderIt = member_begin(insert(*MI));
  137. for (++MI; MI != member_end(); ++MI)
  138. unionSets(LeaderIt, member_begin(insert(*MI)));
  139. }
  140. return *this;
  141. }
  142. //===--------------------------------------------------------------------===//
  143. // Inspection methods
  144. //
  145. /// iterator* - Provides a way to iterate over all values in the set.
  146. using iterator = typename std::set<ECValue>::const_iterator;
  147. iterator begin() const { return TheMapping.begin(); }
  148. iterator end() const { return TheMapping.end(); }
  149. bool empty() const { return TheMapping.empty(); }
  150. /// member_* Iterate over the members of an equivalence class.
  151. class member_iterator;
  152. member_iterator member_begin(iterator I) const {
  153. // Only leaders provide anything to iterate over.
  154. return member_iterator(I->isLeader() ? &*I : nullptr);
  155. }
  156. member_iterator member_end() const {
  157. return member_iterator(nullptr);
  158. }
  159. /// findValue - Return an iterator to the specified value. If it does not
  160. /// exist, end() is returned.
  161. iterator findValue(const ElemTy &V) const {
  162. return TheMapping.find(V);
  163. }
  164. /// getLeaderValue - Return the leader for the specified value that is in the
  165. /// set. It is an error to call this method for a value that is not yet in
  166. /// the set. For that, call getOrInsertLeaderValue(V).
  167. const ElemTy &getLeaderValue(const ElemTy &V) const {
  168. member_iterator MI = findLeader(V);
  169. assert(MI != member_end() && "Value is not in the set!");
  170. return *MI;
  171. }
  172. /// getOrInsertLeaderValue - Return the leader for the specified value that is
  173. /// in the set. If the member is not in the set, it is inserted, then
  174. /// returned.
  175. const ElemTy &getOrInsertLeaderValue(const ElemTy &V) {
  176. member_iterator MI = findLeader(insert(V));
  177. assert(MI != member_end() && "Value is not in the set!");
  178. return *MI;
  179. }
  180. /// getNumClasses - Return the number of equivalence classes in this set.
  181. /// Note that this is a linear time operation.
  182. unsigned getNumClasses() const {
  183. unsigned NC = 0;
  184. for (iterator I = begin(), E = end(); I != E; ++I)
  185. if (I->isLeader()) ++NC;
  186. return NC;
  187. }
  188. //===--------------------------------------------------------------------===//
  189. // Mutation methods
  190. /// insert - Insert a new value into the union/find set, ignoring the request
  191. /// if the value already exists.
  192. iterator insert(const ElemTy &Data) {
  193. return TheMapping.insert(ECValue(Data)).first;
  194. }
  195. /// findLeader - Given a value in the set, return a member iterator for the
  196. /// equivalence class it is in. This does the path-compression part that
  197. /// makes union-find "union findy". This returns an end iterator if the value
  198. /// is not in the equivalence class.
  199. member_iterator findLeader(iterator I) const {
  200. if (I == TheMapping.end()) return member_end();
  201. return member_iterator(I->getLeader());
  202. }
  203. member_iterator findLeader(const ElemTy &V) const {
  204. return findLeader(TheMapping.find(V));
  205. }
  206. /// union - Merge the two equivalence sets for the specified values, inserting
  207. /// them if they do not already exist in the equivalence set.
  208. member_iterator unionSets(const ElemTy &V1, const ElemTy &V2) {
  209. iterator V1I = insert(V1), V2I = insert(V2);
  210. return unionSets(findLeader(V1I), findLeader(V2I));
  211. }
  212. member_iterator unionSets(member_iterator L1, member_iterator L2) {
  213. assert(L1 != member_end() && L2 != member_end() && "Illegal inputs!");
  214. if (L1 == L2) return L1; // Unifying the same two sets, noop.
  215. // Otherwise, this is a real union operation. Set the end of the L1 list to
  216. // point to the L2 leader node.
  217. const ECValue &L1LV = *L1.Node, &L2LV = *L2.Node;
  218. L1LV.getEndOfList()->setNext(&L2LV);
  219. // Update L1LV's end of list pointer.
  220. L1LV.Leader = L2LV.getEndOfList();
  221. // Clear L2's leader flag:
  222. L2LV.Next = L2LV.getNext();
  223. // L2's leader is now L1.
  224. L2LV.Leader = &L1LV;
  225. return L1;
  226. }
  227. // isEquivalent - Return true if V1 is equivalent to V2. This can happen if
  228. // V1 is equal to V2 or if they belong to one equivalence class.
  229. bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const {
  230. // Fast path: any element is equivalent to itself.
  231. if (V1 == V2)
  232. return true;
  233. auto It = findLeader(V1);
  234. return It != member_end() && It == findLeader(V2);
  235. }
  236. class member_iterator {
  237. friend class EquivalenceClasses;
  238. const ECValue *Node;
  239. public:
  240. using iterator_category = std::forward_iterator_tag;
  241. using value_type = const ElemTy;
  242. using size_type = std::size_t;
  243. using difference_type = std::ptrdiff_t;
  244. using pointer = value_type *;
  245. using reference = value_type &;
  246. explicit member_iterator() = default;
  247. explicit member_iterator(const ECValue *N) : Node(N) {}
  248. reference operator*() const {
  249. assert(Node != nullptr && "Dereferencing end()!");
  250. return Node->getData();
  251. }
  252. pointer operator->() const { return &operator*(); }
  253. member_iterator &operator++() {
  254. assert(Node != nullptr && "++'d off the end of the list!");
  255. Node = Node->getNext();
  256. return *this;
  257. }
  258. member_iterator operator++(int) { // postincrement operators.
  259. member_iterator tmp = *this;
  260. ++*this;
  261. return tmp;
  262. }
  263. bool operator==(const member_iterator &RHS) const {
  264. return Node == RHS.Node;
  265. }
  266. bool operator!=(const member_iterator &RHS) const {
  267. return Node != RHS.Node;
  268. }
  269. };
  270. };
  271. } // end namespace llvm
  272. #endif // LLVM_ADT_EQUIVALENCECLASSES_H
  273. #ifdef __GNUC__
  274. #pragma GCC diagnostic pop
  275. #endif