BlockExtractor.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass extracts the specified basic blocks from the module into their
  10. // own functions.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/Transforms/IPO/BlockExtractor.h"
  14. #include "llvm/ADT/STLExtras.h"
  15. #include "llvm/ADT/Statistic.h"
  16. #include "llvm/IR/Instructions.h"
  17. #include "llvm/IR/Module.h"
  18. #include "llvm/IR/PassManager.h"
  19. #include "llvm/InitializePasses.h"
  20. #include "llvm/Pass.h"
  21. #include "llvm/Support/CommandLine.h"
  22. #include "llvm/Support/Debug.h"
  23. #include "llvm/Support/MemoryBuffer.h"
  24. #include "llvm/Transforms/IPO.h"
  25. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  26. #include "llvm/Transforms/Utils/CodeExtractor.h"
  27. using namespace llvm;
  28. #define DEBUG_TYPE "block-extractor"
  29. STATISTIC(NumExtracted, "Number of basic blocks extracted");
  30. static cl::opt<std::string> BlockExtractorFile(
  31. "extract-blocks-file", cl::value_desc("filename"),
  32. cl::desc("A file containing list of basic blocks to extract"), cl::Hidden);
  33. static cl::opt<bool>
  34. BlockExtractorEraseFuncs("extract-blocks-erase-funcs",
  35. cl::desc("Erase the existing functions"),
  36. cl::Hidden);
  37. namespace {
  38. class BlockExtractor {
  39. public:
  40. BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {}
  41. bool runOnModule(Module &M);
  42. void
  43. init(const std::vector<std::vector<BasicBlock *>> &GroupsOfBlocksToExtract) {
  44. GroupsOfBlocks = GroupsOfBlocksToExtract;
  45. if (!BlockExtractorFile.empty())
  46. loadFile();
  47. }
  48. private:
  49. std::vector<std::vector<BasicBlock *>> GroupsOfBlocks;
  50. bool EraseFunctions;
  51. /// Map a function name to groups of blocks.
  52. SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4>
  53. BlocksByName;
  54. void loadFile();
  55. void splitLandingPadPreds(Function &F);
  56. };
  57. } // end anonymous namespace
  58. /// Gets all of the blocks specified in the input file.
  59. void BlockExtractor::loadFile() {
  60. auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile);
  61. if (ErrOrBuf.getError())
  62. report_fatal_error("BlockExtractor couldn't load the file.");
  63. // Read the file.
  64. auto &Buf = *ErrOrBuf;
  65. SmallVector<StringRef, 16> Lines;
  66. Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1,
  67. /*KeepEmpty=*/false);
  68. for (const auto &Line : Lines) {
  69. SmallVector<StringRef, 4> LineSplit;
  70. Line.split(LineSplit, ' ', /*MaxSplit=*/-1,
  71. /*KeepEmpty=*/false);
  72. if (LineSplit.empty())
  73. continue;
  74. if (LineSplit.size()!=2)
  75. report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'",
  76. /*GenCrashDiag=*/false);
  77. SmallVector<StringRef, 4> BBNames;
  78. LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1,
  79. /*KeepEmpty=*/false);
  80. if (BBNames.empty())
  81. report_fatal_error("Missing bbs name");
  82. BlocksByName.push_back(
  83. {std::string(LineSplit[0]), {BBNames.begin(), BBNames.end()}});
  84. }
  85. }
  86. /// Extracts the landing pads to make sure all of them have only one
  87. /// predecessor.
  88. void BlockExtractor::splitLandingPadPreds(Function &F) {
  89. for (BasicBlock &BB : F) {
  90. for (Instruction &I : BB) {
  91. if (!isa<InvokeInst>(&I))
  92. continue;
  93. InvokeInst *II = cast<InvokeInst>(&I);
  94. BasicBlock *Parent = II->getParent();
  95. BasicBlock *LPad = II->getUnwindDest();
  96. // Look through the landing pad's predecessors. If one of them ends in an
  97. // 'invoke', then we want to split the landing pad.
  98. bool Split = false;
  99. for (auto *PredBB : predecessors(LPad)) {
  100. if (PredBB->isLandingPad() && PredBB != Parent &&
  101. isa<InvokeInst>(Parent->getTerminator())) {
  102. Split = true;
  103. break;
  104. }
  105. }
  106. if (!Split)
  107. continue;
  108. SmallVector<BasicBlock *, 2> NewBBs;
  109. SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs);
  110. }
  111. }
  112. }
  113. bool BlockExtractor::runOnModule(Module &M) {
  114. bool Changed = false;
  115. // Get all the functions.
  116. SmallVector<Function *, 4> Functions;
  117. for (Function &F : M) {
  118. splitLandingPadPreds(F);
  119. Functions.push_back(&F);
  120. }
  121. // Get all the blocks specified in the input file.
  122. unsigned NextGroupIdx = GroupsOfBlocks.size();
  123. GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size());
  124. for (const auto &BInfo : BlocksByName) {
  125. Function *F = M.getFunction(BInfo.first);
  126. if (!F)
  127. report_fatal_error("Invalid function name specified in the input file",
  128. /*GenCrashDiag=*/false);
  129. for (const auto &BBInfo : BInfo.second) {
  130. auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) {
  131. return BB.getName().equals(BBInfo);
  132. });
  133. if (Res == F->end())
  134. report_fatal_error("Invalid block name specified in the input file",
  135. /*GenCrashDiag=*/false);
  136. GroupsOfBlocks[NextGroupIdx].push_back(&*Res);
  137. }
  138. ++NextGroupIdx;
  139. }
  140. // Extract each group of basic blocks.
  141. for (auto &BBs : GroupsOfBlocks) {
  142. SmallVector<BasicBlock *, 32> BlocksToExtractVec;
  143. for (BasicBlock *BB : BBs) {
  144. // Check if the module contains BB.
  145. if (BB->getParent()->getParent() != &M)
  146. report_fatal_error("Invalid basic block", /*GenCrashDiag=*/false);
  147. LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting "
  148. << BB->getParent()->getName() << ":" << BB->getName()
  149. << "\n");
  150. BlocksToExtractVec.push_back(BB);
  151. if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator()))
  152. BlocksToExtractVec.push_back(II->getUnwindDest());
  153. ++NumExtracted;
  154. Changed = true;
  155. }
  156. CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent());
  157. Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC);
  158. if (F)
  159. LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName()
  160. << "' in: " << F->getName() << '\n');
  161. else
  162. LLVM_DEBUG(dbgs() << "Failed to extract for group '"
  163. << (*BBs.begin())->getName() << "'\n");
  164. }
  165. // Erase the functions.
  166. if (EraseFunctions || BlockExtractorEraseFuncs) {
  167. for (Function *F : Functions) {
  168. LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName()
  169. << "\n");
  170. F->deleteBody();
  171. }
  172. // Set linkage as ExternalLinkage to avoid erasing unreachable functions.
  173. for (Function &F : M)
  174. F.setLinkage(GlobalValue::ExternalLinkage);
  175. Changed = true;
  176. }
  177. return Changed;
  178. }
  179. BlockExtractorPass::BlockExtractorPass(
  180. std::vector<std::vector<BasicBlock *>> &&GroupsOfBlocks,
  181. bool EraseFunctions)
  182. : GroupsOfBlocks(GroupsOfBlocks), EraseFunctions(EraseFunctions) {}
  183. PreservedAnalyses BlockExtractorPass::run(Module &M,
  184. ModuleAnalysisManager &AM) {
  185. BlockExtractor BE(EraseFunctions);
  186. BE.init(GroupsOfBlocks);
  187. return BE.runOnModule(M) ? PreservedAnalyses::none()
  188. : PreservedAnalyses::all();
  189. }