LoopDataPrefetch.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a Loop Data Prefetching Pass.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
  13. #include "llvm/InitializePasses.h"
  14. #include "llvm/ADT/DepthFirstIterator.h"
  15. #include "llvm/ADT/Statistic.h"
  16. #include "llvm/Analysis/AssumptionCache.h"
  17. #include "llvm/Analysis/CodeMetrics.h"
  18. #include "llvm/Analysis/LoopInfo.h"
  19. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  20. #include "llvm/Analysis/ScalarEvolution.h"
  21. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  22. #include "llvm/Analysis/TargetTransformInfo.h"
  23. #include "llvm/IR/CFG.h"
  24. #include "llvm/IR/Dominators.h"
  25. #include "llvm/IR/Function.h"
  26. #include "llvm/IR/Module.h"
  27. #include "llvm/Support/CommandLine.h"
  28. #include "llvm/Support/Debug.h"
  29. #include "llvm/Transforms/Scalar.h"
  30. #include "llvm/Transforms/Utils.h"
  31. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  32. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  33. #include "llvm/Transforms/Utils/ValueMapper.h"
  34. #define DEBUG_TYPE "loop-data-prefetch"
  35. using namespace llvm;
  36. // By default, we limit this to creating 16 PHIs (which is a little over half
  37. // of the allocatable register set).
  38. static cl::opt<bool>
  39. PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
  40. cl::desc("Prefetch write addresses"));
  41. static cl::opt<unsigned>
  42. PrefetchDistance("prefetch-distance",
  43. cl::desc("Number of instructions to prefetch ahead"),
  44. cl::Hidden);
  45. static cl::opt<unsigned>
  46. MinPrefetchStride("min-prefetch-stride",
  47. cl::desc("Min stride to add prefetches"), cl::Hidden);
  48. static cl::opt<unsigned> MaxPrefetchIterationsAhead(
  49. "max-prefetch-iters-ahead",
  50. cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
  51. STATISTIC(NumPrefetches, "Number of prefetches inserted");
  52. namespace {
  53. /// Loop prefetch implementation class.
  54. class LoopDataPrefetch {
  55. public:
  56. LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
  57. ScalarEvolution *SE, const TargetTransformInfo *TTI,
  58. OptimizationRemarkEmitter *ORE)
  59. : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
  60. bool run();
  61. private:
  62. bool runOnLoop(Loop *L);
  63. /// Check if the stride of the accesses is large enough to
  64. /// warrant a prefetch.
  65. bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);
  66. unsigned getMinPrefetchStride(unsigned NumMemAccesses,
  67. unsigned NumStridedMemAccesses,
  68. unsigned NumPrefetches,
  69. bool HasCall) {
  70. if (MinPrefetchStride.getNumOccurrences() > 0)
  71. return MinPrefetchStride;
  72. return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
  73. NumPrefetches, HasCall);
  74. }
  75. unsigned getPrefetchDistance() {
  76. if (PrefetchDistance.getNumOccurrences() > 0)
  77. return PrefetchDistance;
  78. return TTI->getPrefetchDistance();
  79. }
  80. unsigned getMaxPrefetchIterationsAhead() {
  81. if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
  82. return MaxPrefetchIterationsAhead;
  83. return TTI->getMaxPrefetchIterationsAhead();
  84. }
  85. bool doPrefetchWrites() {
  86. if (PrefetchWrites.getNumOccurrences() > 0)
  87. return PrefetchWrites;
  88. return TTI->enableWritePrefetching();
  89. }
  90. AssumptionCache *AC;
  91. DominatorTree *DT;
  92. LoopInfo *LI;
  93. ScalarEvolution *SE;
  94. const TargetTransformInfo *TTI;
  95. OptimizationRemarkEmitter *ORE;
  96. };
  97. /// Legacy class for inserting loop data prefetches.
  98. class LoopDataPrefetchLegacyPass : public FunctionPass {
  99. public:
  100. static char ID; // Pass ID, replacement for typeid
  101. LoopDataPrefetchLegacyPass() : FunctionPass(ID) {
  102. initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
  103. }
  104. void getAnalysisUsage(AnalysisUsage &AU) const override {
  105. AU.addRequired<AssumptionCacheTracker>();
  106. AU.addRequired<DominatorTreeWrapperPass>();
  107. AU.addPreserved<DominatorTreeWrapperPass>();
  108. AU.addRequired<LoopInfoWrapperPass>();
  109. AU.addPreserved<LoopInfoWrapperPass>();
  110. AU.addRequiredID(LoopSimplifyID);
  111. AU.addPreservedID(LoopSimplifyID);
  112. AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
  113. AU.addRequired<ScalarEvolutionWrapperPass>();
  114. AU.addPreserved<ScalarEvolutionWrapperPass>();
  115. AU.addRequired<TargetTransformInfoWrapperPass>();
  116. }
  117. bool runOnFunction(Function &F) override;
  118. };
  119. }
  120. char LoopDataPrefetchLegacyPass::ID = 0;
  121. INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
  122. "Loop Data Prefetch", false, false)
  123. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  124. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  125. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  126. INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
  127. INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
  128. INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
  129. INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
  130. "Loop Data Prefetch", false, false)
  131. FunctionPass *llvm::createLoopDataPrefetchPass() {
  132. return new LoopDataPrefetchLegacyPass();
  133. }
  134. bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,
  135. unsigned TargetMinStride) {
  136. // No need to check if any stride goes.
  137. if (TargetMinStride <= 1)
  138. return true;
  139. const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
  140. // If MinStride is set, don't prefetch unless we can ensure that stride is
  141. // larger.
  142. if (!ConstStride)
  143. return false;
  144. unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
  145. return TargetMinStride <= AbsStride;
  146. }
  147. PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
  148. FunctionAnalysisManager &AM) {
  149. DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
  150. LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
  151. ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
  152. AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
  153. OptimizationRemarkEmitter *ORE =
  154. &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
  155. const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
  156. LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
  157. bool Changed = LDP.run();
  158. if (Changed) {
  159. PreservedAnalyses PA;
  160. PA.preserve<DominatorTreeAnalysis>();
  161. PA.preserve<LoopAnalysis>();
  162. return PA;
  163. }
  164. return PreservedAnalyses::all();
  165. }
  166. bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
  167. if (skipFunction(F))
  168. return false;
  169. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  170. LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  171. ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  172. AssumptionCache *AC =
  173. &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  174. OptimizationRemarkEmitter *ORE =
  175. &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
  176. const TargetTransformInfo *TTI =
  177. &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  178. LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
  179. return LDP.run();
  180. }
  181. bool LoopDataPrefetch::run() {
  182. // If PrefetchDistance is not set, don't run the pass. This gives an
  183. // opportunity for targets to run this pass for selected subtargets only
  184. // (whose TTI sets PrefetchDistance).
  185. if (getPrefetchDistance() == 0)
  186. return false;
  187. assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
  188. bool MadeChange = false;
  189. for (Loop *I : *LI)
  190. for (Loop *L : depth_first(I))
  191. MadeChange |= runOnLoop(L);
  192. return MadeChange;
  193. }
  194. /// A record for a potential prefetch made during the initial scan of the
  195. /// loop. This is used to let a single prefetch target multiple memory accesses.
  196. struct Prefetch {
  197. /// The address formula for this prefetch as returned by ScalarEvolution.
  198. const SCEVAddRecExpr *LSCEVAddRec;
  199. /// The point of insertion for the prefetch instruction.
  200. Instruction *InsertPt;
  201. /// True if targeting a write memory access.
  202. bool Writes;
  203. /// The (first seen) prefetched instruction.
  204. Instruction *MemI;
  205. /// Constructor to create a new Prefetch for \p I.
  206. Prefetch(const SCEVAddRecExpr *L, Instruction *I)
  207. : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) {
  208. addInstruction(I);
  209. };
  210. /// Add the instruction \param I to this prefetch. If it's not the first
  211. /// one, 'InsertPt' and 'Writes' will be updated as required.
  212. /// \param PtrDiff the known constant address difference to the first added
  213. /// instruction.
  214. void addInstruction(Instruction *I, DominatorTree *DT = nullptr,
  215. int64_t PtrDiff = 0) {
  216. if (!InsertPt) {
  217. MemI = I;
  218. InsertPt = I;
  219. Writes = isa<StoreInst>(I);
  220. } else {
  221. BasicBlock *PrefBB = InsertPt->getParent();
  222. BasicBlock *InsBB = I->getParent();
  223. if (PrefBB != InsBB) {
  224. BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);
  225. if (DomBB != PrefBB)
  226. InsertPt = DomBB->getTerminator();
  227. }
  228. if (isa<StoreInst>(I) && PtrDiff == 0)
  229. Writes = true;
  230. }
  231. }
  232. };
  233. bool LoopDataPrefetch::runOnLoop(Loop *L) {
  234. bool MadeChange = false;
  235. // Only prefetch in the inner-most loop
  236. if (!L->isInnermost())
  237. return MadeChange;
  238. SmallPtrSet<const Value *, 32> EphValues;
  239. CodeMetrics::collectEphemeralValues(L, AC, EphValues);
  240. // Calculate the number of iterations ahead to prefetch
  241. CodeMetrics Metrics;
  242. bool HasCall = false;
  243. for (const auto BB : L->blocks()) {
  244. // If the loop already has prefetches, then assume that the user knows
  245. // what they are doing and don't add any more.
  246. for (auto &I : *BB) {
  247. if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {
  248. if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
  249. if (F->getIntrinsicID() == Intrinsic::prefetch)
  250. return MadeChange;
  251. if (TTI->isLoweredToCall(F))
  252. HasCall = true;
  253. } else { // indirect call.
  254. HasCall = true;
  255. }
  256. }
  257. }
  258. Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
  259. }
  260. unsigned LoopSize = Metrics.NumInsts;
  261. if (!LoopSize)
  262. LoopSize = 1;
  263. unsigned ItersAhead = getPrefetchDistance() / LoopSize;
  264. if (!ItersAhead)
  265. ItersAhead = 1;
  266. if (ItersAhead > getMaxPrefetchIterationsAhead())
  267. return MadeChange;
  268. unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);
  269. if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)
  270. return MadeChange;
  271. unsigned NumMemAccesses = 0;
  272. unsigned NumStridedMemAccesses = 0;
  273. SmallVector<Prefetch, 16> Prefetches;
  274. for (const auto BB : L->blocks())
  275. for (auto &I : *BB) {
  276. Value *PtrValue;
  277. Instruction *MemI;
  278. if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {
  279. MemI = LMemI;
  280. PtrValue = LMemI->getPointerOperand();
  281. } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
  282. if (!doPrefetchWrites()) continue;
  283. MemI = SMemI;
  284. PtrValue = SMemI->getPointerOperand();
  285. } else continue;
  286. unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
  287. if (PtrAddrSpace)
  288. continue;
  289. NumMemAccesses++;
  290. if (L->isLoopInvariant(PtrValue))
  291. continue;
  292. const SCEV *LSCEV = SE->getSCEV(PtrValue);
  293. const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
  294. if (!LSCEVAddRec)
  295. continue;
  296. NumStridedMemAccesses++;
  297. // We don't want to double prefetch individual cache lines. If this
  298. // access is known to be within one cache line of some other one that
  299. // has already been prefetched, then don't prefetch this one as well.
  300. bool DupPref = false;
  301. for (auto &Pref : Prefetches) {
  302. const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);
  303. if (const SCEVConstant *ConstPtrDiff =
  304. dyn_cast<SCEVConstant>(PtrDiff)) {
  305. int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
  306. if (PD < (int64_t) TTI->getCacheLineSize()) {
  307. Pref.addInstruction(MemI, DT, PD);
  308. DupPref = true;
  309. break;
  310. }
  311. }
  312. }
  313. if (!DupPref)
  314. Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));
  315. }
  316. unsigned TargetMinStride =
  317. getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
  318. Prefetches.size(), HasCall);
  319. LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
  320. << " iterations ahead (loop size: " << LoopSize << ") in "
  321. << L->getHeader()->getParent()->getName() << ": " << *L);
  322. LLVM_DEBUG(dbgs() << "Loop has: "
  323. << NumMemAccesses << " memory accesses, "
  324. << NumStridedMemAccesses << " strided memory accesses, "
  325. << Prefetches.size() << " potential prefetch(es), "
  326. << "a minimum stride of " << TargetMinStride << ", "
  327. << (HasCall ? "calls" : "no calls") << ".\n");
  328. for (auto &P : Prefetches) {
  329. // Check if the stride of the accesses is large enough to warrant a
  330. // prefetch.
  331. if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))
  332. continue;
  333. const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(
  334. SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),
  335. P.LSCEVAddRec->getStepRecurrence(*SE)));
  336. if (!isSafeToExpand(NextLSCEV, *SE))
  337. continue;
  338. BasicBlock *BB = P.InsertPt->getParent();
  339. Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/);
  340. SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr");
  341. Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
  342. IRBuilder<> Builder(P.InsertPt);
  343. Module *M = BB->getParent()->getParent();
  344. Type *I32 = Type::getInt32Ty(BB->getContext());
  345. Function *PrefetchFunc = Intrinsic::getDeclaration(
  346. M, Intrinsic::prefetch, PrefPtrValue->getType());
  347. Builder.CreateCall(
  348. PrefetchFunc,
  349. {PrefPtrValue,
  350. ConstantInt::get(I32, P.Writes),
  351. ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
  352. ++NumPrefetches;
  353. LLVM_DEBUG(dbgs() << " Access: "
  354. << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)
  355. << ", SCEV: " << *P.LSCEVAddRec << "\n");
  356. ORE->emit([&]() {
  357. return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)
  358. << "prefetched memory access";
  359. });
  360. MadeChange = true;
  361. }
  362. return MadeChange;
  363. }