LoopExtractor.cpp 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // A pass wrapper around the ExtractLoop() scalar transformation to extract each
  10. // top-level loop into its own new function. If the loop is the ONLY loop in a
  11. // given function, it is not touched. This is a pass most useful for debugging
  12. // via bugpoint.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/Transforms/IPO/LoopExtractor.h"
  16. #include "llvm/ADT/Statistic.h"
  17. #include "llvm/Analysis/AssumptionCache.h"
  18. #include "llvm/Analysis/LoopInfo.h"
  19. #include "llvm/IR/Dominators.h"
  20. #include "llvm/IR/Instructions.h"
  21. #include "llvm/IR/Module.h"
  22. #include "llvm/IR/PassManager.h"
  23. #include "llvm/InitializePasses.h"
  24. #include "llvm/Pass.h"
  25. #include "llvm/Transforms/IPO.h"
  26. #include "llvm/Transforms/Utils.h"
  27. #include "llvm/Transforms/Utils/CodeExtractor.h"
  28. using namespace llvm;
  29. #define DEBUG_TYPE "loop-extract"
  30. STATISTIC(NumExtracted, "Number of loops extracted");
  31. namespace {
  32. struct LoopExtractorLegacyPass : public ModulePass {
  33. static char ID; // Pass identification, replacement for typeid
  34. unsigned NumLoops;
  35. explicit LoopExtractorLegacyPass(unsigned NumLoops = ~0)
  36. : ModulePass(ID), NumLoops(NumLoops) {
  37. initializeLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry());
  38. }
  39. bool runOnModule(Module &M) override;
  40. void getAnalysisUsage(AnalysisUsage &AU) const override {
  41. AU.addRequiredID(BreakCriticalEdgesID);
  42. AU.addRequired<DominatorTreeWrapperPass>();
  43. AU.addRequired<LoopInfoWrapperPass>();
  44. AU.addPreserved<LoopInfoWrapperPass>();
  45. AU.addRequiredID(LoopSimplifyID);
  46. AU.addUsedIfAvailable<AssumptionCacheTracker>();
  47. }
  48. };
  49. struct LoopExtractor {
  50. explicit LoopExtractor(
  51. unsigned NumLoops,
  52. function_ref<DominatorTree &(Function &)> LookupDomTree,
  53. function_ref<LoopInfo &(Function &)> LookupLoopInfo,
  54. function_ref<AssumptionCache *(Function &)> LookupAssumptionCache)
  55. : NumLoops(NumLoops), LookupDomTree(LookupDomTree),
  56. LookupLoopInfo(LookupLoopInfo),
  57. LookupAssumptionCache(LookupAssumptionCache) {}
  58. bool runOnModule(Module &M);
  59. private:
  60. // The number of natural loops to extract from the program into functions.
  61. unsigned NumLoops;
  62. function_ref<DominatorTree &(Function &)> LookupDomTree;
  63. function_ref<LoopInfo &(Function &)> LookupLoopInfo;
  64. function_ref<AssumptionCache *(Function &)> LookupAssumptionCache;
  65. bool runOnFunction(Function &F);
  66. bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI,
  67. DominatorTree &DT);
  68. bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT);
  69. };
  70. } // namespace
  71. char LoopExtractorLegacyPass::ID = 0;
  72. INITIALIZE_PASS_BEGIN(LoopExtractorLegacyPass, "loop-extract",
  73. "Extract loops into new functions", false, false)
  74. INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
  75. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  76. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  77. INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
  78. INITIALIZE_PASS_END(LoopExtractorLegacyPass, "loop-extract",
  79. "Extract loops into new functions", false, false)
  80. namespace {
  81. /// SingleLoopExtractor - For bugpoint.
  82. struct SingleLoopExtractor : public LoopExtractorLegacyPass {
  83. static char ID; // Pass identification, replacement for typeid
  84. SingleLoopExtractor() : LoopExtractorLegacyPass(1) {}
  85. };
  86. } // End anonymous namespace
  87. char SingleLoopExtractor::ID = 0;
  88. INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single",
  89. "Extract at most one loop into a new function", false, false)
  90. // createLoopExtractorPass - This pass extracts all natural loops from the
  91. // program into a function if it can.
  92. //
  93. Pass *llvm::createLoopExtractorPass() { return new LoopExtractorLegacyPass(); }
  94. bool LoopExtractorLegacyPass::runOnModule(Module &M) {
  95. if (skipModule(M))
  96. return false;
  97. bool Changed = false;
  98. auto LookupDomTree = [this](Function &F) -> DominatorTree & {
  99. return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
  100. };
  101. auto LookupLoopInfo = [this, &Changed](Function &F) -> LoopInfo & {
  102. return this->getAnalysis<LoopInfoWrapperPass>(F, &Changed).getLoopInfo();
  103. };
  104. auto LookupACT = [this](Function &F) -> AssumptionCache * {
  105. if (auto *ACT = this->getAnalysisIfAvailable<AssumptionCacheTracker>())
  106. return ACT->lookupAssumptionCache(F);
  107. return nullptr;
  108. };
  109. return LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT)
  110. .runOnModule(M) ||
  111. Changed;
  112. }
  113. bool LoopExtractor::runOnModule(Module &M) {
  114. if (M.empty())
  115. return false;
  116. if (!NumLoops)
  117. return false;
  118. bool Changed = false;
  119. // The end of the function list may change (new functions will be added at the
  120. // end), so we run from the first to the current last.
  121. auto I = M.begin(), E = --M.end();
  122. while (true) {
  123. Function &F = *I;
  124. Changed |= runOnFunction(F);
  125. if (!NumLoops)
  126. break;
  127. // If this is the last function.
  128. if (I == E)
  129. break;
  130. ++I;
  131. }
  132. return Changed;
  133. }
  134. bool LoopExtractor::runOnFunction(Function &F) {
  135. // Do not modify `optnone` functions.
  136. if (F.hasOptNone())
  137. return false;
  138. if (F.empty())
  139. return false;
  140. bool Changed = false;
  141. LoopInfo &LI = LookupLoopInfo(F);
  142. // If there are no loops in the function.
  143. if (LI.empty())
  144. return Changed;
  145. DominatorTree &DT = LookupDomTree(F);
  146. // If there is more than one top-level loop in this function, extract all of
  147. // the loops.
  148. if (std::next(LI.begin()) != LI.end())
  149. return Changed | extractLoops(LI.begin(), LI.end(), LI, DT);
  150. // Otherwise there is exactly one top-level loop.
  151. Loop *TLL = *LI.begin();
  152. // If the loop is in LoopSimplify form, then extract it only if this function
  153. // is more than a minimal wrapper around the loop.
  154. if (TLL->isLoopSimplifyForm()) {
  155. bool ShouldExtractLoop = false;
  156. // Extract the loop if the entry block doesn't branch to the loop header.
  157. Instruction *EntryTI = F.getEntryBlock().getTerminator();
  158. if (!isa<BranchInst>(EntryTI) ||
  159. !cast<BranchInst>(EntryTI)->isUnconditional() ||
  160. EntryTI->getSuccessor(0) != TLL->getHeader()) {
  161. ShouldExtractLoop = true;
  162. } else {
  163. // Check to see if any exits from the loop are more than just return
  164. // blocks.
  165. SmallVector<BasicBlock *, 8> ExitBlocks;
  166. TLL->getExitBlocks(ExitBlocks);
  167. for (auto *ExitBlock : ExitBlocks)
  168. if (!isa<ReturnInst>(ExitBlock->getTerminator())) {
  169. ShouldExtractLoop = true;
  170. break;
  171. }
  172. }
  173. if (ShouldExtractLoop)
  174. return Changed | extractLoop(TLL, LI, DT);
  175. }
  176. // Okay, this function is a minimal container around the specified loop.
  177. // If we extract the loop, we will continue to just keep extracting it
  178. // infinitely... so don't extract it. However, if the loop contains any
  179. // sub-loops, extract them.
  180. return Changed | extractLoops(TLL->begin(), TLL->end(), LI, DT);
  181. }
  182. bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To,
  183. LoopInfo &LI, DominatorTree &DT) {
  184. bool Changed = false;
  185. SmallVector<Loop *, 8> Loops;
  186. // Save the list of loops, as it may change.
  187. Loops.assign(From, To);
  188. for (Loop *L : Loops) {
  189. // If LoopSimplify form is not available, stay out of trouble.
  190. if (!L->isLoopSimplifyForm())
  191. continue;
  192. Changed |= extractLoop(L, LI, DT);
  193. if (!NumLoops)
  194. break;
  195. }
  196. return Changed;
  197. }
  198. bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) {
  199. assert(NumLoops != 0);
  200. Function &Func = *L->getHeader()->getParent();
  201. AssumptionCache *AC = LookupAssumptionCache(Func);
  202. CodeExtractorAnalysisCache CEAC(Func);
  203. CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
  204. if (Extractor.extractCodeRegion(CEAC)) {
  205. LI.erase(L);
  206. --NumLoops;
  207. ++NumExtracted;
  208. return true;
  209. }
  210. return false;
  211. }
  212. // createSingleLoopExtractorPass - This pass extracts one natural loop from the
  213. // program into a function if it can. This is used by bugpoint.
  214. //
  215. Pass *llvm::createSingleLoopExtractorPass() {
  216. return new SingleLoopExtractor();
  217. }
  218. PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) {
  219. auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
  220. auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
  221. return FAM.getResult<DominatorTreeAnalysis>(F);
  222. };
  223. auto LookupLoopInfo = [&FAM](Function &F) -> LoopInfo & {
  224. return FAM.getResult<LoopAnalysis>(F);
  225. };
  226. auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * {
  227. return FAM.getCachedResult<AssumptionAnalysis>(F);
  228. };
  229. if (!LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo,
  230. LookupAssumptionCache)
  231. .runOnModule(M))
  232. return PreservedAnalyses::all();
  233. PreservedAnalyses PA;
  234. PA.preserve<LoopAnalysis>();
  235. return PA;
  236. }
  237. void LoopExtractorPass::printPipeline(
  238. raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
  239. static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline(
  240. OS, MapClassName2PassName);
  241. OS << "<";
  242. if (NumLoops == 1)
  243. OS << "single";
  244. OS << ">";
  245. }