LoopExtractor.cpp 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // A pass wrapper around the ExtractLoop() scalar transformation to extract each
  10. // top-level loop into its own new function. If the loop is the ONLY loop in a
  11. // given function, it is not touched. This is a pass most useful for debugging
  12. // via bugpoint.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/Transforms/IPO/LoopExtractor.h"
  16. #include "llvm/ADT/Statistic.h"
  17. #include "llvm/Analysis/AssumptionCache.h"
  18. #include "llvm/Analysis/LoopInfo.h"
  19. #include "llvm/IR/Dominators.h"
  20. #include "llvm/IR/Instructions.h"
  21. #include "llvm/IR/Module.h"
  22. #include "llvm/IR/PassManager.h"
  23. #include "llvm/InitializePasses.h"
  24. #include "llvm/Pass.h"
  25. #include "llvm/Support/CommandLine.h"
  26. #include "llvm/Transforms/IPO.h"
  27. #include "llvm/Transforms/Scalar.h"
  28. #include "llvm/Transforms/Utils.h"
  29. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  30. #include "llvm/Transforms/Utils/CodeExtractor.h"
  31. #include <fstream>
  32. #include <set>
  33. using namespace llvm;
  34. #define DEBUG_TYPE "loop-extract"
  35. STATISTIC(NumExtracted, "Number of loops extracted");
  36. namespace {
  37. struct LoopExtractorLegacyPass : public ModulePass {
  38. static char ID; // Pass identification, replacement for typeid
  39. unsigned NumLoops;
  40. explicit LoopExtractorLegacyPass(unsigned NumLoops = ~0)
  41. : ModulePass(ID), NumLoops(NumLoops) {
  42. initializeLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry());
  43. }
  44. bool runOnModule(Module &M) override;
  45. void getAnalysisUsage(AnalysisUsage &AU) const override {
  46. AU.addRequiredID(BreakCriticalEdgesID);
  47. AU.addRequired<DominatorTreeWrapperPass>();
  48. AU.addRequired<LoopInfoWrapperPass>();
  49. AU.addPreserved<LoopInfoWrapperPass>();
  50. AU.addRequiredID(LoopSimplifyID);
  51. AU.addUsedIfAvailable<AssumptionCacheTracker>();
  52. }
  53. };
  54. struct LoopExtractor {
  55. explicit LoopExtractor(
  56. unsigned NumLoops,
  57. function_ref<DominatorTree &(Function &)> LookupDomTree,
  58. function_ref<LoopInfo &(Function &)> LookupLoopInfo,
  59. function_ref<AssumptionCache *(Function &)> LookupAssumptionCache)
  60. : NumLoops(NumLoops), LookupDomTree(LookupDomTree),
  61. LookupLoopInfo(LookupLoopInfo),
  62. LookupAssumptionCache(LookupAssumptionCache) {}
  63. bool runOnModule(Module &M);
  64. private:
  65. // The number of natural loops to extract from the program into functions.
  66. unsigned NumLoops;
  67. function_ref<DominatorTree &(Function &)> LookupDomTree;
  68. function_ref<LoopInfo &(Function &)> LookupLoopInfo;
  69. function_ref<AssumptionCache *(Function &)> LookupAssumptionCache;
  70. bool runOnFunction(Function &F);
  71. bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI,
  72. DominatorTree &DT);
  73. bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT);
  74. };
  75. } // namespace
  76. char LoopExtractorLegacyPass::ID = 0;
  77. INITIALIZE_PASS_BEGIN(LoopExtractorLegacyPass, "loop-extract",
  78. "Extract loops into new functions", false, false)
  79. INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
  80. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  81. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  82. INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
  83. INITIALIZE_PASS_END(LoopExtractorLegacyPass, "loop-extract",
  84. "Extract loops into new functions", false, false)
  85. namespace {
  86. /// SingleLoopExtractor - For bugpoint.
  87. struct SingleLoopExtractor : public LoopExtractorLegacyPass {
  88. static char ID; // Pass identification, replacement for typeid
  89. SingleLoopExtractor() : LoopExtractorLegacyPass(1) {}
  90. };
  91. } // End anonymous namespace
  92. char SingleLoopExtractor::ID = 0;
  93. INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single",
  94. "Extract at most one loop into a new function", false, false)
  95. // createLoopExtractorPass - This pass extracts all natural loops from the
  96. // program into a function if it can.
  97. //
  98. Pass *llvm::createLoopExtractorPass() { return new LoopExtractorLegacyPass(); }
  99. bool LoopExtractorLegacyPass::runOnModule(Module &M) {
  100. if (skipModule(M))
  101. return false;
  102. bool Changed = false;
  103. auto LookupDomTree = [this](Function &F) -> DominatorTree & {
  104. return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
  105. };
  106. auto LookupLoopInfo = [this, &Changed](Function &F) -> LoopInfo & {
  107. return this->getAnalysis<LoopInfoWrapperPass>(F, &Changed).getLoopInfo();
  108. };
  109. auto LookupACT = [this](Function &F) -> AssumptionCache * {
  110. if (auto *ACT = this->getAnalysisIfAvailable<AssumptionCacheTracker>())
  111. return ACT->lookupAssumptionCache(F);
  112. return nullptr;
  113. };
  114. return LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT)
  115. .runOnModule(M) ||
  116. Changed;
  117. }
  118. bool LoopExtractor::runOnModule(Module &M) {
  119. if (M.empty())
  120. return false;
  121. if (!NumLoops)
  122. return false;
  123. bool Changed = false;
  124. // The end of the function list may change (new functions will be added at the
  125. // end), so we run from the first to the current last.
  126. auto I = M.begin(), E = --M.end();
  127. while (true) {
  128. Function &F = *I;
  129. Changed |= runOnFunction(F);
  130. if (!NumLoops)
  131. break;
  132. // If this is the last function.
  133. if (I == E)
  134. break;
  135. ++I;
  136. }
  137. return Changed;
  138. }
  139. bool LoopExtractor::runOnFunction(Function &F) {
  140. // Do not modify `optnone` functions.
  141. if (F.hasOptNone())
  142. return false;
  143. if (F.empty())
  144. return false;
  145. bool Changed = false;
  146. LoopInfo &LI = LookupLoopInfo(F);
  147. // If there are no loops in the function.
  148. if (LI.empty())
  149. return Changed;
  150. DominatorTree &DT = LookupDomTree(F);
  151. // If there is more than one top-level loop in this function, extract all of
  152. // the loops.
  153. if (std::next(LI.begin()) != LI.end())
  154. return Changed | extractLoops(LI.begin(), LI.end(), LI, DT);
  155. // Otherwise there is exactly one top-level loop.
  156. Loop *TLL = *LI.begin();
  157. // If the loop is in LoopSimplify form, then extract it only if this function
  158. // is more than a minimal wrapper around the loop.
  159. if (TLL->isLoopSimplifyForm()) {
  160. bool ShouldExtractLoop = false;
  161. // Extract the loop if the entry block doesn't branch to the loop header.
  162. Instruction *EntryTI = F.getEntryBlock().getTerminator();
  163. if (!isa<BranchInst>(EntryTI) ||
  164. !cast<BranchInst>(EntryTI)->isUnconditional() ||
  165. EntryTI->getSuccessor(0) != TLL->getHeader()) {
  166. ShouldExtractLoop = true;
  167. } else {
  168. // Check to see if any exits from the loop are more than just return
  169. // blocks.
  170. SmallVector<BasicBlock *, 8> ExitBlocks;
  171. TLL->getExitBlocks(ExitBlocks);
  172. for (auto *ExitBlock : ExitBlocks)
  173. if (!isa<ReturnInst>(ExitBlock->getTerminator())) {
  174. ShouldExtractLoop = true;
  175. break;
  176. }
  177. }
  178. if (ShouldExtractLoop)
  179. return Changed | extractLoop(TLL, LI, DT);
  180. }
  181. // Okay, this function is a minimal container around the specified loop.
  182. // If we extract the loop, we will continue to just keep extracting it
  183. // infinitely... so don't extract it. However, if the loop contains any
  184. // sub-loops, extract them.
  185. return Changed | extractLoops(TLL->begin(), TLL->end(), LI, DT);
  186. }
  187. bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To,
  188. LoopInfo &LI, DominatorTree &DT) {
  189. bool Changed = false;
  190. SmallVector<Loop *, 8> Loops;
  191. // Save the list of loops, as it may change.
  192. Loops.assign(From, To);
  193. for (Loop *L : Loops) {
  194. // If LoopSimplify form is not available, stay out of trouble.
  195. if (!L->isLoopSimplifyForm())
  196. continue;
  197. Changed |= extractLoop(L, LI, DT);
  198. if (!NumLoops)
  199. break;
  200. }
  201. return Changed;
  202. }
  203. bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) {
  204. assert(NumLoops != 0);
  205. Function &Func = *L->getHeader()->getParent();
  206. AssumptionCache *AC = LookupAssumptionCache(Func);
  207. CodeExtractorAnalysisCache CEAC(Func);
  208. CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
  209. if (Extractor.extractCodeRegion(CEAC)) {
  210. LI.erase(L);
  211. --NumLoops;
  212. ++NumExtracted;
  213. return true;
  214. }
  215. return false;
  216. }
  217. // createSingleLoopExtractorPass - This pass extracts one natural loop from the
  218. // program into a function if it can. This is used by bugpoint.
  219. //
  220. Pass *llvm::createSingleLoopExtractorPass() {
  221. return new SingleLoopExtractor();
  222. }
  223. PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) {
  224. auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
  225. auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
  226. return FAM.getResult<DominatorTreeAnalysis>(F);
  227. };
  228. auto LookupLoopInfo = [&FAM](Function &F) -> LoopInfo & {
  229. return FAM.getResult<LoopAnalysis>(F);
  230. };
  231. auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * {
  232. return FAM.getCachedResult<AssumptionAnalysis>(F);
  233. };
  234. if (!LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo,
  235. LookupAssumptionCache)
  236. .runOnModule(M))
  237. return PreservedAnalyses::all();
  238. PreservedAnalyses PA;
  239. PA.preserve<LoopAnalysis>();
  240. return PA;
  241. }
  242. void LoopExtractorPass::printPipeline(
  243. raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
  244. static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline(
  245. OS, MapClassName2PassName);
  246. OS << "<";
  247. if (NumLoops == 1)
  248. OS << "single";
  249. OS << ">";
  250. }