LoopDataPrefetch.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a Loop Data Prefetching Pass.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
  13. #include "llvm/InitializePasses.h"
  14. #include "llvm/ADT/DepthFirstIterator.h"
  15. #include "llvm/ADT/Statistic.h"
  16. #include "llvm/Analysis/AssumptionCache.h"
  17. #include "llvm/Analysis/CodeMetrics.h"
  18. #include "llvm/Analysis/LoopInfo.h"
  19. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  20. #include "llvm/Analysis/ScalarEvolution.h"
  21. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  22. #include "llvm/Analysis/TargetTransformInfo.h"
  23. #include "llvm/IR/Dominators.h"
  24. #include "llvm/IR/Function.h"
  25. #include "llvm/IR/Module.h"
  26. #include "llvm/Support/CommandLine.h"
  27. #include "llvm/Support/Debug.h"
  28. #include "llvm/Transforms/Scalar.h"
  29. #include "llvm/Transforms/Utils.h"
  30. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  31. #define DEBUG_TYPE "loop-data-prefetch"
  32. using namespace llvm;
  33. // By default, we limit this to creating 16 PHIs (which is a little over half
  34. // of the allocatable register set).
  35. static cl::opt<bool>
  36. PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
  37. cl::desc("Prefetch write addresses"));
  38. static cl::opt<unsigned>
  39. PrefetchDistance("prefetch-distance",
  40. cl::desc("Number of instructions to prefetch ahead"),
  41. cl::Hidden);
  42. static cl::opt<unsigned>
  43. MinPrefetchStride("min-prefetch-stride",
  44. cl::desc("Min stride to add prefetches"), cl::Hidden);
  45. static cl::opt<unsigned> MaxPrefetchIterationsAhead(
  46. "max-prefetch-iters-ahead",
  47. cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
  48. STATISTIC(NumPrefetches, "Number of prefetches inserted");
  49. namespace {
  50. /// Loop prefetch implementation class.
  51. class LoopDataPrefetch {
  52. public:
  53. LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
  54. ScalarEvolution *SE, const TargetTransformInfo *TTI,
  55. OptimizationRemarkEmitter *ORE)
  56. : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
  57. bool run();
  58. private:
  59. bool runOnLoop(Loop *L);
  60. /// Check if the stride of the accesses is large enough to
  61. /// warrant a prefetch.
  62. bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);
  63. unsigned getMinPrefetchStride(unsigned NumMemAccesses,
  64. unsigned NumStridedMemAccesses,
  65. unsigned NumPrefetches,
  66. bool HasCall) {
  67. if (MinPrefetchStride.getNumOccurrences() > 0)
  68. return MinPrefetchStride;
  69. return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
  70. NumPrefetches, HasCall);
  71. }
  72. unsigned getPrefetchDistance() {
  73. if (PrefetchDistance.getNumOccurrences() > 0)
  74. return PrefetchDistance;
  75. return TTI->getPrefetchDistance();
  76. }
  77. unsigned getMaxPrefetchIterationsAhead() {
  78. if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
  79. return MaxPrefetchIterationsAhead;
  80. return TTI->getMaxPrefetchIterationsAhead();
  81. }
  82. bool doPrefetchWrites() {
  83. if (PrefetchWrites.getNumOccurrences() > 0)
  84. return PrefetchWrites;
  85. return TTI->enableWritePrefetching();
  86. }
  87. AssumptionCache *AC;
  88. DominatorTree *DT;
  89. LoopInfo *LI;
  90. ScalarEvolution *SE;
  91. const TargetTransformInfo *TTI;
  92. OptimizationRemarkEmitter *ORE;
  93. };
  94. /// Legacy class for inserting loop data prefetches.
  95. class LoopDataPrefetchLegacyPass : public FunctionPass {
  96. public:
  97. static char ID; // Pass ID, replacement for typeid
  98. LoopDataPrefetchLegacyPass() : FunctionPass(ID) {
  99. initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
  100. }
  101. void getAnalysisUsage(AnalysisUsage &AU) const override {
  102. AU.addRequired<AssumptionCacheTracker>();
  103. AU.addRequired<DominatorTreeWrapperPass>();
  104. AU.addPreserved<DominatorTreeWrapperPass>();
  105. AU.addRequired<LoopInfoWrapperPass>();
  106. AU.addPreserved<LoopInfoWrapperPass>();
  107. AU.addRequiredID(LoopSimplifyID);
  108. AU.addPreservedID(LoopSimplifyID);
  109. AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
  110. AU.addRequired<ScalarEvolutionWrapperPass>();
  111. AU.addPreserved<ScalarEvolutionWrapperPass>();
  112. AU.addRequired<TargetTransformInfoWrapperPass>();
  113. }
  114. bool runOnFunction(Function &F) override;
  115. };
  116. }
  117. char LoopDataPrefetchLegacyPass::ID = 0;
  118. INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
  119. "Loop Data Prefetch", false, false)
  120. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  121. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  122. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  123. INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
  124. INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
  125. INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
  126. INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
  127. "Loop Data Prefetch", false, false)
  128. FunctionPass *llvm::createLoopDataPrefetchPass() {
  129. return new LoopDataPrefetchLegacyPass();
  130. }
  131. bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,
  132. unsigned TargetMinStride) {
  133. // No need to check if any stride goes.
  134. if (TargetMinStride <= 1)
  135. return true;
  136. const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
  137. // If MinStride is set, don't prefetch unless we can ensure that stride is
  138. // larger.
  139. if (!ConstStride)
  140. return false;
  141. unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
  142. return TargetMinStride <= AbsStride;
  143. }
  144. PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
  145. FunctionAnalysisManager &AM) {
  146. DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
  147. LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
  148. ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
  149. AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
  150. OptimizationRemarkEmitter *ORE =
  151. &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
  152. const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
  153. LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
  154. bool Changed = LDP.run();
  155. if (Changed) {
  156. PreservedAnalyses PA;
  157. PA.preserve<DominatorTreeAnalysis>();
  158. PA.preserve<LoopAnalysis>();
  159. return PA;
  160. }
  161. return PreservedAnalyses::all();
  162. }
  163. bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
  164. if (skipFunction(F))
  165. return false;
  166. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  167. LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  168. ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  169. AssumptionCache *AC =
  170. &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  171. OptimizationRemarkEmitter *ORE =
  172. &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
  173. const TargetTransformInfo *TTI =
  174. &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  175. LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
  176. return LDP.run();
  177. }
  178. bool LoopDataPrefetch::run() {
  179. // If PrefetchDistance is not set, don't run the pass. This gives an
  180. // opportunity for targets to run this pass for selected subtargets only
  181. // (whose TTI sets PrefetchDistance and CacheLineSize).
  182. if (getPrefetchDistance() == 0 || TTI->getCacheLineSize() == 0) {
  183. LLVM_DEBUG(dbgs() << "Please set both PrefetchDistance and CacheLineSize "
  184. "for loop data prefetch.\n");
  185. return false;
  186. }
  187. bool MadeChange = false;
  188. for (Loop *I : *LI)
  189. for (Loop *L : depth_first(I))
  190. MadeChange |= runOnLoop(L);
  191. return MadeChange;
  192. }
  193. /// A record for a potential prefetch made during the initial scan of the
  194. /// loop. This is used to let a single prefetch target multiple memory accesses.
  195. struct Prefetch {
  196. /// The address formula for this prefetch as returned by ScalarEvolution.
  197. const SCEVAddRecExpr *LSCEVAddRec;
  198. /// The point of insertion for the prefetch instruction.
  199. Instruction *InsertPt = nullptr;
  200. /// True if targeting a write memory access.
  201. bool Writes = false;
  202. /// The (first seen) prefetched instruction.
  203. Instruction *MemI = nullptr;
  204. /// Constructor to create a new Prefetch for \p I.
  205. Prefetch(const SCEVAddRecExpr *L, Instruction *I) : LSCEVAddRec(L) {
  206. addInstruction(I);
  207. };
  208. /// Add the instruction \param I to this prefetch. If it's not the first
  209. /// one, 'InsertPt' and 'Writes' will be updated as required.
  210. /// \param PtrDiff the known constant address difference to the first added
  211. /// instruction.
  212. void addInstruction(Instruction *I, DominatorTree *DT = nullptr,
  213. int64_t PtrDiff = 0) {
  214. if (!InsertPt) {
  215. MemI = I;
  216. InsertPt = I;
  217. Writes = isa<StoreInst>(I);
  218. } else {
  219. BasicBlock *PrefBB = InsertPt->getParent();
  220. BasicBlock *InsBB = I->getParent();
  221. if (PrefBB != InsBB) {
  222. BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);
  223. if (DomBB != PrefBB)
  224. InsertPt = DomBB->getTerminator();
  225. }
  226. if (isa<StoreInst>(I) && PtrDiff == 0)
  227. Writes = true;
  228. }
  229. }
  230. };
  231. bool LoopDataPrefetch::runOnLoop(Loop *L) {
  232. bool MadeChange = false;
  233. // Only prefetch in the inner-most loop
  234. if (!L->isInnermost())
  235. return MadeChange;
  236. SmallPtrSet<const Value *, 32> EphValues;
  237. CodeMetrics::collectEphemeralValues(L, AC, EphValues);
  238. // Calculate the number of iterations ahead to prefetch
  239. CodeMetrics Metrics;
  240. bool HasCall = false;
  241. for (const auto BB : L->blocks()) {
  242. // If the loop already has prefetches, then assume that the user knows
  243. // what they are doing and don't add any more.
  244. for (auto &I : *BB) {
  245. if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {
  246. if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
  247. if (F->getIntrinsicID() == Intrinsic::prefetch)
  248. return MadeChange;
  249. if (TTI->isLoweredToCall(F))
  250. HasCall = true;
  251. } else { // indirect call.
  252. HasCall = true;
  253. }
  254. }
  255. }
  256. Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
  257. }
  258. if (!Metrics.NumInsts.isValid())
  259. return MadeChange;
  260. unsigned LoopSize = *Metrics.NumInsts.getValue();
  261. if (!LoopSize)
  262. LoopSize = 1;
  263. unsigned ItersAhead = getPrefetchDistance() / LoopSize;
  264. if (!ItersAhead)
  265. ItersAhead = 1;
  266. if (ItersAhead > getMaxPrefetchIterationsAhead())
  267. return MadeChange;
  268. unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);
  269. if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)
  270. return MadeChange;
  271. unsigned NumMemAccesses = 0;
  272. unsigned NumStridedMemAccesses = 0;
  273. SmallVector<Prefetch, 16> Prefetches;
  274. for (const auto BB : L->blocks())
  275. for (auto &I : *BB) {
  276. Value *PtrValue;
  277. Instruction *MemI;
  278. if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {
  279. MemI = LMemI;
  280. PtrValue = LMemI->getPointerOperand();
  281. } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
  282. if (!doPrefetchWrites()) continue;
  283. MemI = SMemI;
  284. PtrValue = SMemI->getPointerOperand();
  285. } else continue;
  286. unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
  287. if (!TTI->shouldPrefetchAddressSpace(PtrAddrSpace))
  288. continue;
  289. NumMemAccesses++;
  290. if (L->isLoopInvariant(PtrValue))
  291. continue;
  292. const SCEV *LSCEV = SE->getSCEV(PtrValue);
  293. const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
  294. if (!LSCEVAddRec)
  295. continue;
  296. NumStridedMemAccesses++;
  297. // We don't want to double prefetch individual cache lines. If this
  298. // access is known to be within one cache line of some other one that
  299. // has already been prefetched, then don't prefetch this one as well.
  300. bool DupPref = false;
  301. for (auto &Pref : Prefetches) {
  302. const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);
  303. if (const SCEVConstant *ConstPtrDiff =
  304. dyn_cast<SCEVConstant>(PtrDiff)) {
  305. int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
  306. if (PD < (int64_t) TTI->getCacheLineSize()) {
  307. Pref.addInstruction(MemI, DT, PD);
  308. DupPref = true;
  309. break;
  310. }
  311. }
  312. }
  313. if (!DupPref)
  314. Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));
  315. }
  316. unsigned TargetMinStride =
  317. getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
  318. Prefetches.size(), HasCall);
  319. LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
  320. << " iterations ahead (loop size: " << LoopSize << ") in "
  321. << L->getHeader()->getParent()->getName() << ": " << *L);
  322. LLVM_DEBUG(dbgs() << "Loop has: "
  323. << NumMemAccesses << " memory accesses, "
  324. << NumStridedMemAccesses << " strided memory accesses, "
  325. << Prefetches.size() << " potential prefetch(es), "
  326. << "a minimum stride of " << TargetMinStride << ", "
  327. << (HasCall ? "calls" : "no calls") << ".\n");
  328. for (auto &P : Prefetches) {
  329. // Check if the stride of the accesses is large enough to warrant a
  330. // prefetch.
  331. if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))
  332. continue;
  333. BasicBlock *BB = P.InsertPt->getParent();
  334. SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr");
  335. const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(
  336. SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),
  337. P.LSCEVAddRec->getStepRecurrence(*SE)));
  338. if (!SCEVE.isSafeToExpand(NextLSCEV))
  339. continue;
  340. unsigned PtrAddrSpace = NextLSCEV->getType()->getPointerAddressSpace();
  341. Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace);
  342. Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
  343. IRBuilder<> Builder(P.InsertPt);
  344. Module *M = BB->getParent()->getParent();
  345. Type *I32 = Type::getInt32Ty(BB->getContext());
  346. Function *PrefetchFunc = Intrinsic::getDeclaration(
  347. M, Intrinsic::prefetch, PrefPtrValue->getType());
  348. Builder.CreateCall(
  349. PrefetchFunc,
  350. {PrefPtrValue,
  351. ConstantInt::get(I32, P.Writes),
  352. ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
  353. ++NumPrefetches;
  354. LLVM_DEBUG(dbgs() << " Access: "
  355. << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)
  356. << ", SCEV: " << *P.LSCEVAddRec << "\n");
  357. ORE->emit([&]() {
  358. return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)
  359. << "prefetched memory access";
  360. });
  361. MadeChange = true;
  362. }
  363. return MadeChange;
  364. }