BlockExtractor.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass extracts the specified basic blocks from the module into their
  10. // own functions.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/Transforms/IPO/BlockExtractor.h"
  14. #include "llvm/ADT/STLExtras.h"
  15. #include "llvm/ADT/Statistic.h"
  16. #include "llvm/IR/Instructions.h"
  17. #include "llvm/IR/Module.h"
  18. #include "llvm/IR/PassManager.h"
  19. #include "llvm/InitializePasses.h"
  20. #include "llvm/Pass.h"
  21. #include "llvm/Support/CommandLine.h"
  22. #include "llvm/Support/Debug.h"
  23. #include "llvm/Support/MemoryBuffer.h"
  24. #include "llvm/Transforms/IPO.h"
  25. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  26. #include "llvm/Transforms/Utils/CodeExtractor.h"
  27. using namespace llvm;
  28. #define DEBUG_TYPE "block-extractor"
  29. STATISTIC(NumExtracted, "Number of basic blocks extracted");
  30. static cl::opt<std::string> BlockExtractorFile(
  31. "extract-blocks-file", cl::value_desc("filename"),
  32. cl::desc("A file containing list of basic blocks to extract"), cl::Hidden);
  33. static cl::opt<bool>
  34. BlockExtractorEraseFuncs("extract-blocks-erase-funcs",
  35. cl::desc("Erase the existing functions"),
  36. cl::Hidden);
  37. namespace {
  38. class BlockExtractor {
  39. public:
  40. BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {}
  41. bool runOnModule(Module &M);
  42. void init(const SmallVectorImpl<SmallVector<BasicBlock *, 16>>
  43. &GroupsOfBlocksToExtract) {
  44. for (const SmallVectorImpl<BasicBlock *> &GroupOfBlocks :
  45. GroupsOfBlocksToExtract) {
  46. SmallVector<BasicBlock *, 16> NewGroup;
  47. NewGroup.append(GroupOfBlocks.begin(), GroupOfBlocks.end());
  48. GroupsOfBlocks.emplace_back(NewGroup);
  49. }
  50. if (!BlockExtractorFile.empty())
  51. loadFile();
  52. }
  53. private:
  54. SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupsOfBlocks;
  55. bool EraseFunctions;
  56. /// Map a function name to groups of blocks.
  57. SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4>
  58. BlocksByName;
  59. void loadFile();
  60. void splitLandingPadPreds(Function &F);
  61. };
  62. class BlockExtractorLegacyPass : public ModulePass {
  63. BlockExtractor BE;
  64. bool runOnModule(Module &M) override;
  65. public:
  66. static char ID;
  67. BlockExtractorLegacyPass(const SmallVectorImpl<BasicBlock *> &BlocksToExtract,
  68. bool EraseFunctions)
  69. : ModulePass(ID), BE(EraseFunctions) {
  70. // We want one group per element of the input list.
  71. SmallVector<SmallVector<BasicBlock *, 16>, 4> MassagedGroupsOfBlocks;
  72. for (BasicBlock *BB : BlocksToExtract) {
  73. SmallVector<BasicBlock *, 16> NewGroup;
  74. NewGroup.push_back(BB);
  75. MassagedGroupsOfBlocks.push_back(NewGroup);
  76. }
  77. BE.init(MassagedGroupsOfBlocks);
  78. }
  79. BlockExtractorLegacyPass(const SmallVectorImpl<SmallVector<BasicBlock *, 16>>
  80. &GroupsOfBlocksToExtract,
  81. bool EraseFunctions)
  82. : ModulePass(ID), BE(EraseFunctions) {
  83. BE.init(GroupsOfBlocksToExtract);
  84. }
  85. BlockExtractorLegacyPass()
  86. : BlockExtractorLegacyPass(SmallVector<BasicBlock *, 0>(), false) {}
  87. };
  88. } // end anonymous namespace
  89. char BlockExtractorLegacyPass::ID = 0;
  90. INITIALIZE_PASS(BlockExtractorLegacyPass, "extract-blocks",
  91. "Extract basic blocks from module", false, false)
  92. ModulePass *llvm::createBlockExtractorPass() {
  93. return new BlockExtractorLegacyPass();
  94. }
  95. ModulePass *llvm::createBlockExtractorPass(
  96. const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) {
  97. return new BlockExtractorLegacyPass(BlocksToExtract, EraseFunctions);
  98. }
  99. ModulePass *llvm::createBlockExtractorPass(
  100. const SmallVectorImpl<SmallVector<BasicBlock *, 16>>
  101. &GroupsOfBlocksToExtract,
  102. bool EraseFunctions) {
  103. return new BlockExtractorLegacyPass(GroupsOfBlocksToExtract, EraseFunctions);
  104. }
  105. /// Gets all of the blocks specified in the input file.
  106. void BlockExtractor::loadFile() {
  107. auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile);
  108. if (ErrOrBuf.getError())
  109. report_fatal_error("BlockExtractor couldn't load the file.");
  110. // Read the file.
  111. auto &Buf = *ErrOrBuf;
  112. SmallVector<StringRef, 16> Lines;
  113. Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1,
  114. /*KeepEmpty=*/false);
  115. for (const auto &Line : Lines) {
  116. SmallVector<StringRef, 4> LineSplit;
  117. Line.split(LineSplit, ' ', /*MaxSplit=*/-1,
  118. /*KeepEmpty=*/false);
  119. if (LineSplit.empty())
  120. continue;
  121. if (LineSplit.size()!=2)
  122. report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'");
  123. SmallVector<StringRef, 4> BBNames;
  124. LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1,
  125. /*KeepEmpty=*/false);
  126. if (BBNames.empty())
  127. report_fatal_error("Missing bbs name");
  128. BlocksByName.push_back(
  129. {std::string(LineSplit[0]), {BBNames.begin(), BBNames.end()}});
  130. }
  131. }
  132. /// Extracts the landing pads to make sure all of them have only one
  133. /// predecessor.
  134. void BlockExtractor::splitLandingPadPreds(Function &F) {
  135. for (BasicBlock &BB : F) {
  136. for (Instruction &I : BB) {
  137. if (!isa<InvokeInst>(&I))
  138. continue;
  139. InvokeInst *II = cast<InvokeInst>(&I);
  140. BasicBlock *Parent = II->getParent();
  141. BasicBlock *LPad = II->getUnwindDest();
  142. // Look through the landing pad's predecessors. If one of them ends in an
  143. // 'invoke', then we want to split the landing pad.
  144. bool Split = false;
  145. for (auto PredBB : predecessors(LPad)) {
  146. if (PredBB->isLandingPad() && PredBB != Parent &&
  147. isa<InvokeInst>(Parent->getTerminator())) {
  148. Split = true;
  149. break;
  150. }
  151. }
  152. if (!Split)
  153. continue;
  154. SmallVector<BasicBlock *, 2> NewBBs;
  155. SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs);
  156. }
  157. }
  158. }
  159. bool BlockExtractor::runOnModule(Module &M) {
  160. bool Changed = false;
  161. // Get all the functions.
  162. SmallVector<Function *, 4> Functions;
  163. for (Function &F : M) {
  164. splitLandingPadPreds(F);
  165. Functions.push_back(&F);
  166. }
  167. // Get all the blocks specified in the input file.
  168. unsigned NextGroupIdx = GroupsOfBlocks.size();
  169. GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size());
  170. for (const auto &BInfo : BlocksByName) {
  171. Function *F = M.getFunction(BInfo.first);
  172. if (!F)
  173. report_fatal_error("Invalid function name specified in the input file");
  174. for (const auto &BBInfo : BInfo.second) {
  175. auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) {
  176. return BB.getName().equals(BBInfo);
  177. });
  178. if (Res == F->end())
  179. report_fatal_error("Invalid block name specified in the input file");
  180. GroupsOfBlocks[NextGroupIdx].push_back(&*Res);
  181. }
  182. ++NextGroupIdx;
  183. }
  184. // Extract each group of basic blocks.
  185. for (auto &BBs : GroupsOfBlocks) {
  186. SmallVector<BasicBlock *, 32> BlocksToExtractVec;
  187. for (BasicBlock *BB : BBs) {
  188. // Check if the module contains BB.
  189. if (BB->getParent()->getParent() != &M)
  190. report_fatal_error("Invalid basic block");
  191. LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting "
  192. << BB->getParent()->getName() << ":" << BB->getName()
  193. << "\n");
  194. BlocksToExtractVec.push_back(BB);
  195. if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator()))
  196. BlocksToExtractVec.push_back(II->getUnwindDest());
  197. ++NumExtracted;
  198. Changed = true;
  199. }
  200. CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent());
  201. Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC);
  202. if (F)
  203. LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName()
  204. << "' in: " << F->getName() << '\n');
  205. else
  206. LLVM_DEBUG(dbgs() << "Failed to extract for group '"
  207. << (*BBs.begin())->getName() << "'\n");
  208. }
  209. // Erase the functions.
  210. if (EraseFunctions || BlockExtractorEraseFuncs) {
  211. for (Function *F : Functions) {
  212. LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName()
  213. << "\n");
  214. F->deleteBody();
  215. }
  216. // Set linkage as ExternalLinkage to avoid erasing unreachable functions.
  217. for (Function &F : M)
  218. F.setLinkage(GlobalValue::ExternalLinkage);
  219. Changed = true;
  220. }
  221. return Changed;
  222. }
  223. bool BlockExtractorLegacyPass::runOnModule(Module &M) {
  224. return BE.runOnModule(M);
  225. }
  226. PreservedAnalyses BlockExtractorPass::run(Module &M,
  227. ModuleAnalysisManager &AM) {
  228. BlockExtractor BE(false);
  229. BE.init(SmallVector<SmallVector<BasicBlock *, 16>, 0>());
  230. return BE.runOnModule(M) ? PreservedAnalyses::none()
  231. : PreservedAnalyses::all();
  232. }