TLSVariableHoist.cpp 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass identifies/eliminate Redundant TLS Loads if related option is set.
  10. // The example: Please refer to the comment at the head of TLSVariableHoist.h.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/ADT/SmallVector.h"
  14. #include "llvm/IR/BasicBlock.h"
  15. #include "llvm/IR/Dominators.h"
  16. #include "llvm/IR/Function.h"
  17. #include "llvm/IR/InstrTypes.h"
  18. #include "llvm/IR/Instruction.h"
  19. #include "llvm/IR/Instructions.h"
  20. #include "llvm/IR/IntrinsicInst.h"
  21. #include "llvm/IR/Module.h"
  22. #include "llvm/IR/Value.h"
  23. #include "llvm/InitializePasses.h"
  24. #include "llvm/Pass.h"
  25. #include "llvm/Support/Casting.h"
  26. #include "llvm/Support/Debug.h"
  27. #include "llvm/Support/raw_ostream.h"
  28. #include "llvm/Transforms/Scalar.h"
  29. #include "llvm/Transforms/Scalar/TLSVariableHoist.h"
  30. #include <algorithm>
  31. #include <cassert>
  32. #include <cstdint>
  33. #include <iterator>
  34. #include <tuple>
  35. #include <utility>
  36. using namespace llvm;
  37. using namespace tlshoist;
  38. #define DEBUG_TYPE "tlshoist"
  39. static cl::opt<bool> TLSLoadHoist(
  40. "tls-load-hoist", cl::init(false), cl::Hidden,
  41. cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
  42. "TLS address calculation."));
  43. namespace {
  44. /// The TLS Variable hoist pass.
  45. class TLSVariableHoistLegacyPass : public FunctionPass {
  46. public:
  47. static char ID; // Pass identification, replacement for typeid
  48. TLSVariableHoistLegacyPass() : FunctionPass(ID) {
  49. initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
  50. }
  51. bool runOnFunction(Function &Fn) override;
  52. StringRef getPassName() const override { return "TLS Variable Hoist"; }
  53. void getAnalysisUsage(AnalysisUsage &AU) const override {
  54. AU.setPreservesCFG();
  55. AU.addRequired<DominatorTreeWrapperPass>();
  56. AU.addRequired<LoopInfoWrapperPass>();
  57. }
  58. private:
  59. TLSVariableHoistPass Impl;
  60. };
  61. } // end anonymous namespace
  62. char TLSVariableHoistLegacyPass::ID = 0;
  63. INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
  64. "TLS Variable Hoist", false, false)
  65. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  66. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  67. INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
  68. "TLS Variable Hoist", false, false)
  69. FunctionPass *llvm::createTLSVariableHoistPass() {
  70. return new TLSVariableHoistLegacyPass();
  71. }
  72. /// Perform the TLS Variable Hoist optimization for the given function.
  73. bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
  74. if (skipFunction(Fn))
  75. return false;
  76. LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
  77. LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
  78. bool MadeChange =
  79. Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
  80. getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
  81. if (MadeChange) {
  82. LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
  83. << Fn.getName() << '\n');
  84. LLVM_DEBUG(dbgs() << Fn);
  85. }
  86. LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
  87. return MadeChange;
  88. }
  89. void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
  90. // Skip all cast instructions. They are visited indirectly later on.
  91. if (Inst->isCast())
  92. return;
  93. // Scan all operands.
  94. for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
  95. auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
  96. if (!GV || !GV->isThreadLocal())
  97. continue;
  98. // Add Candidate to TLSCandMap (GV --> Candidate).
  99. TLSCandMap[GV].addUser(Inst, Idx);
  100. }
  101. }
  102. void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
  103. // First, quickly check if there is TLS Variable.
  104. Module *M = Fn.getParent();
  105. bool HasTLS = llvm::any_of(
  106. M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
  107. // If non, directly return.
  108. if (!HasTLS)
  109. return;
  110. TLSCandMap.clear();
  111. // Then, collect TLS Variable info.
  112. for (BasicBlock &BB : Fn) {
  113. // Ignore unreachable basic blocks.
  114. if (!DT->isReachableFromEntry(&BB))
  115. continue;
  116. for (Instruction &Inst : BB)
  117. collectTLSCandidate(&Inst);
  118. }
  119. }
  120. static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
  121. if (Cand.Users.size() != 1)
  122. return false;
  123. BasicBlock *BB = Cand.Users[0].Inst->getParent();
  124. if (LI->getLoopFor(BB))
  125. return false;
  126. return true;
  127. }
  128. Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
  129. Loop *L) {
  130. assert(L && "Unexcepted Loop status!");
  131. // Get the outermost loop.
  132. while (Loop *Parent = L->getParentLoop())
  133. L = Parent;
  134. BasicBlock *PreHeader = L->getLoopPreheader();
  135. // There is unique predecessor outside the loop.
  136. if (PreHeader)
  137. return PreHeader->getTerminator();
  138. BasicBlock *Header = L->getHeader();
  139. BasicBlock *Dom = Header;
  140. for (BasicBlock *PredBB : predecessors(Header))
  141. Dom = DT->findNearestCommonDominator(Dom, PredBB);
  142. assert(Dom && "Not find dominator BB!");
  143. Instruction *Term = Dom->getTerminator();
  144. return Term;
  145. }
  146. Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
  147. Instruction *I2) {
  148. if (!I1)
  149. return I2;
  150. return DT->findNearestCommonDominator(I1, I2);
  151. }
  152. BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
  153. GlobalVariable *GV,
  154. BasicBlock *&PosBB) {
  155. tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
  156. // We should hoist the TLS use out of loop, so choose its nearest instruction
  157. // which dominate the loop and the outside loops (if exist).
  158. Instruction *LastPos = nullptr;
  159. for (auto &User : Cand.Users) {
  160. BasicBlock *BB = User.Inst->getParent();
  161. Instruction *Pos = User.Inst;
  162. if (Loop *L = LI->getLoopFor(BB)) {
  163. Pos = getNearestLoopDomInst(BB, L);
  164. assert(Pos && "Not find insert position out of loop!");
  165. }
  166. Pos = getDomInst(LastPos, Pos);
  167. LastPos = Pos;
  168. }
  169. assert(LastPos && "Unexpected insert position!");
  170. BasicBlock *Parent = LastPos->getParent();
  171. PosBB = Parent;
  172. return LastPos->getIterator();
  173. }
  174. // Generate a bitcast (no type change) to replace the uses of TLS Candidate.
  175. Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
  176. GlobalVariable *GV) {
  177. BasicBlock *PosBB = &Fn.getEntryBlock();
  178. BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
  179. Type *Ty = GV->getType();
  180. auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
  181. CastInst->insertInto(PosBB, Iter);
  182. return CastInst;
  183. }
  184. bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
  185. GlobalVariable *GV) {
  186. tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
  187. // If only used 1 time and not in loops, we no need to replace it.
  188. if (oneUseOutsideLoop(Cand, LI))
  189. return false;
  190. // Generate a bitcast (no type change)
  191. auto *CastInst = genBitCastInst(Fn, GV);
  192. // to replace the uses of TLS Candidate
  193. for (auto &User : Cand.Users)
  194. User.Inst->setOperand(User.OpndIdx, CastInst);
  195. return true;
  196. }
  197. bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
  198. if (TLSCandMap.empty())
  199. return false;
  200. bool Replaced = false;
  201. for (auto &GV2Cand : TLSCandMap) {
  202. GlobalVariable *GV = GV2Cand.first;
  203. Replaced |= tryReplaceTLSCandidate(Fn, GV);
  204. }
  205. return Replaced;
  206. }
  207. /// Optimize expensive TLS variables in the given function.
  208. bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
  209. LoopInfo &LI) {
  210. if (Fn.hasOptNone())
  211. return false;
  212. if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
  213. return false;
  214. this->LI = &LI;
  215. this->DT = &DT;
  216. assert(this->LI && this->DT && "Unexcepted requirement!");
  217. // Collect all TLS variable candidates.
  218. collectTLSCandidates(Fn);
  219. bool MadeChange = tryReplaceTLSCandidates(Fn);
  220. return MadeChange;
  221. }
  222. PreservedAnalyses TLSVariableHoistPass::run(Function &F,
  223. FunctionAnalysisManager &AM) {
  224. auto &LI = AM.getResult<LoopAnalysis>(F);
  225. auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
  226. if (!runImpl(F, DT, LI))
  227. return PreservedAnalyses::all();
  228. PreservedAnalyses PA;
  229. PA.preserveSet<CFGAnalyses>();
  230. return PA;
  231. }