CFGMST.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a Union-find algorithm to compute Minimum Spanning Tree
  10. // for a given CFG.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
  14. #define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
  15. #include "llvm/ADT/DenseMap.h"
  16. #include "llvm/ADT/STLExtras.h"
  17. #include "llvm/Analysis/BlockFrequencyInfo.h"
  18. #include "llvm/Analysis/BranchProbabilityInfo.h"
  19. #include "llvm/Analysis/CFG.h"
  20. #include "llvm/Support/BranchProbability.h"
  21. #include "llvm/Support/Debug.h"
  22. #include "llvm/Support/raw_ostream.h"
  23. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  24. #include <utility>
  25. #include <vector>
  26. #define DEBUG_TYPE "cfgmst"
  27. using namespace llvm;
  28. namespace llvm {
  29. /// An union-find based Minimum Spanning Tree for CFG
  30. ///
  31. /// Implements a Union-find algorithm to compute Minimum Spanning Tree
  32. /// for a given CFG.
  33. template <class Edge, class BBInfo> class CFGMST {
  34. public:
  35. Function &F;
  36. // Store all the edges in CFG. It may contain some stale edges
  37. // when Removed is set.
  38. std::vector<std::unique_ptr<Edge>> AllEdges;
  39. // This map records the auxiliary information for each BB.
  40. DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos;
  41. // Whehter the function has an exit block with no successors.
  42. // (For function with an infinite loop, this block may be absent)
  43. bool ExitBlockFound = false;
  44. // Find the root group of the G and compress the path from G to the root.
  45. BBInfo *findAndCompressGroup(BBInfo *G) {
  46. if (G->Group != G)
  47. G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group));
  48. return static_cast<BBInfo *>(G->Group);
  49. }
  50. // Union BB1 and BB2 into the same group and return true.
  51. // Returns false if BB1 and BB2 are already in the same group.
  52. bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) {
  53. BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1));
  54. BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2));
  55. if (BB1G == BB2G)
  56. return false;
  57. // Make the smaller rank tree a direct child or the root of high rank tree.
  58. if (BB1G->Rank < BB2G->Rank)
  59. BB1G->Group = BB2G;
  60. else {
  61. BB2G->Group = BB1G;
  62. // If the ranks are the same, increment root of one tree by one.
  63. if (BB1G->Rank == BB2G->Rank)
  64. BB1G->Rank++;
  65. }
  66. return true;
  67. }
  68. // Give BB, return the auxiliary information.
  69. BBInfo &getBBInfo(const BasicBlock *BB) const {
  70. auto It = BBInfos.find(BB);
  71. assert(It->second.get() != nullptr);
  72. return *It->second.get();
  73. }
  74. // Give BB, return the auxiliary information if it's available.
  75. BBInfo *findBBInfo(const BasicBlock *BB) const {
  76. auto It = BBInfos.find(BB);
  77. if (It == BBInfos.end())
  78. return nullptr;
  79. return It->second.get();
  80. }
  81. // Traverse the CFG using a stack. Find all the edges and assign the weight.
  82. // Edges with large weight will be put into MST first so they are less likely
  83. // to be instrumented.
  84. void buildEdges() {
  85. LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
  86. const BasicBlock *Entry = &(F.getEntryBlock());
  87. uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
  88. // If we want to instrument the entry count, lower the weight to 0.
  89. if (InstrumentFuncEntry)
  90. EntryWeight = 0;
  91. Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr,
  92. *ExitOutgoing = nullptr, *ExitIncoming = nullptr;
  93. uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0;
  94. // Add a fake edge to the entry.
  95. EntryIncoming = &addEdge(nullptr, Entry, EntryWeight);
  96. LLVM_DEBUG(dbgs() << " Edge: from fake node to " << Entry->getName()
  97. << " w = " << EntryWeight << "\n");
  98. // Special handling for single BB functions.
  99. if (succ_empty(Entry)) {
  100. addEdge(Entry, nullptr, EntryWeight);
  101. return;
  102. }
  103. static const uint32_t CriticalEdgeMultiplier = 1000;
  104. for (BasicBlock &BB : F) {
  105. Instruction *TI = BB.getTerminator();
  106. uint64_t BBWeight =
  107. (BFI != nullptr ? BFI->getBlockFreq(&BB).getFrequency() : 2);
  108. uint64_t Weight = 2;
  109. if (int successors = TI->getNumSuccessors()) {
  110. for (int i = 0; i != successors; ++i) {
  111. BasicBlock *TargetBB = TI->getSuccessor(i);
  112. bool Critical = isCriticalEdge(TI, i);
  113. uint64_t scaleFactor = BBWeight;
  114. if (Critical) {
  115. if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier)
  116. scaleFactor *= CriticalEdgeMultiplier;
  117. else
  118. scaleFactor = UINT64_MAX;
  119. }
  120. if (BPI != nullptr)
  121. Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor);
  122. if (Weight == 0)
  123. Weight++;
  124. auto *E = &addEdge(&BB, TargetBB, Weight);
  125. E->IsCritical = Critical;
  126. LLVM_DEBUG(dbgs() << " Edge: from " << BB.getName() << " to "
  127. << TargetBB->getName() << " w=" << Weight << "\n");
  128. // Keep track of entry/exit edges:
  129. if (&BB == Entry) {
  130. if (Weight > MaxEntryOutWeight) {
  131. MaxEntryOutWeight = Weight;
  132. EntryOutgoing = E;
  133. }
  134. }
  135. auto *TargetTI = TargetBB->getTerminator();
  136. if (TargetTI && !TargetTI->getNumSuccessors()) {
  137. if (Weight > MaxExitInWeight) {
  138. MaxExitInWeight = Weight;
  139. ExitIncoming = E;
  140. }
  141. }
  142. }
  143. } else {
  144. ExitBlockFound = true;
  145. Edge *ExitO = &addEdge(&BB, nullptr, BBWeight);
  146. if (BBWeight > MaxExitOutWeight) {
  147. MaxExitOutWeight = BBWeight;
  148. ExitOutgoing = ExitO;
  149. }
  150. LLVM_DEBUG(dbgs() << " Edge: from " << BB.getName() << " to fake exit"
  151. << " w = " << BBWeight << "\n");
  152. }
  153. }
  154. // Entry/exit edge adjustment heurisitic:
  155. // prefer instrumenting entry edge over exit edge
  156. // if possible. Those exit edges may never have a chance to be
  157. // executed (for instance the program is an event handling loop)
  158. // before the profile is asynchronously dumped.
  159. //
  160. // If EntryIncoming and ExitOutgoing has similar weight, make sure
  161. // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing
  162. // and ExitIncoming has similar weight, make sure ExitIncoming becomes
  163. // the min-edge.
  164. uint64_t EntryInWeight = EntryWeight;
  165. if (EntryInWeight >= MaxExitOutWeight &&
  166. EntryInWeight * 2 < MaxExitOutWeight * 3) {
  167. EntryIncoming->Weight = MaxExitOutWeight;
  168. ExitOutgoing->Weight = EntryInWeight + 1;
  169. }
  170. if (MaxEntryOutWeight >= MaxExitInWeight &&
  171. MaxEntryOutWeight * 2 < MaxExitInWeight * 3) {
  172. EntryOutgoing->Weight = MaxExitInWeight;
  173. ExitIncoming->Weight = MaxEntryOutWeight + 1;
  174. }
  175. }
  176. // Sort CFG edges based on its weight.
  177. void sortEdgesByWeight() {
  178. llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1,
  179. const std::unique_ptr<Edge> &Edge2) {
  180. return Edge1->Weight > Edge2->Weight;
  181. });
  182. }
  183. // Traverse all the edges and compute the Minimum Weight Spanning Tree
  184. // using union-find algorithm.
  185. void computeMinimumSpanningTree() {
  186. // First, put all the critical edge with landing-pad as the Dest to MST.
  187. // This works around the insufficient support of critical edges split
  188. // when destination BB is a landing pad.
  189. for (auto &Ei : AllEdges) {
  190. if (Ei->Removed)
  191. continue;
  192. if (Ei->IsCritical) {
  193. if (Ei->DestBB && Ei->DestBB->isLandingPad()) {
  194. if (unionGroups(Ei->SrcBB, Ei->DestBB))
  195. Ei->InMST = true;
  196. }
  197. }
  198. }
  199. for (auto &Ei : AllEdges) {
  200. if (Ei->Removed)
  201. continue;
  202. // If we detect infinite loops, force
  203. // instrumenting the entry edge:
  204. if (!ExitBlockFound && Ei->SrcBB == nullptr)
  205. continue;
  206. if (unionGroups(Ei->SrcBB, Ei->DestBB))
  207. Ei->InMST = true;
  208. }
  209. }
  210. // Dump the Debug information about the instrumentation.
  211. void dumpEdges(raw_ostream &OS, const Twine &Message) const {
  212. if (!Message.str().empty())
  213. OS << Message << "\n";
  214. OS << " Number of Basic Blocks: " << BBInfos.size() << "\n";
  215. for (auto &BI : BBInfos) {
  216. const BasicBlock *BB = BI.first;
  217. OS << " BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << " "
  218. << BI.second->infoString() << "\n";
  219. }
  220. OS << " Number of Edges: " << AllEdges.size()
  221. << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
  222. uint32_t Count = 0;
  223. for (auto &EI : AllEdges)
  224. OS << " Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->"
  225. << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n";
  226. }
  227. // Add an edge to AllEdges with weight W.
  228. Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) {
  229. uint32_t Index = BBInfos.size();
  230. auto Iter = BBInfos.end();
  231. bool Inserted;
  232. std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
  233. if (Inserted) {
  234. // Newly inserted, update the real info.
  235. Iter->second = std::move(std::make_unique<BBInfo>(Index));
  236. Index++;
  237. }
  238. std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
  239. if (Inserted)
  240. // Newly inserted, update the real info.
  241. Iter->second = std::move(std::make_unique<BBInfo>(Index));
  242. AllEdges.emplace_back(new Edge(Src, Dest, W));
  243. return *AllEdges.back();
  244. }
  245. BranchProbabilityInfo *BPI;
  246. BlockFrequencyInfo *BFI;
  247. // If function entry will be always instrumented.
  248. bool InstrumentFuncEntry;
  249. public:
  250. CFGMST(Function &Func, bool InstrumentFuncEntry_,
  251. BranchProbabilityInfo *BPI_ = nullptr,
  252. BlockFrequencyInfo *BFI_ = nullptr)
  253. : F(Func), BPI(BPI_), BFI(BFI_),
  254. InstrumentFuncEntry(InstrumentFuncEntry_) {
  255. buildEdges();
  256. sortEdgesByWeight();
  257. computeMinimumSpanningTree();
  258. if (AllEdges.size() > 1 && InstrumentFuncEntry)
  259. std::iter_swap(std::move(AllEdges.begin()),
  260. std::move(AllEdges.begin() + AllEdges.size() - 1));
  261. }
  262. };
  263. } // end namespace llvm
  264. #undef DEBUG_TYPE // "cfgmst"
  265. #endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H