LoopVersioning.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines a utility class to perform loop versioning. The versioned
  10. // loop speculates that otherwise may-aliasing memory accesses don't overlap and
  11. // emits checks to prove this.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Transforms/Utils/LoopVersioning.h"
  15. #include "llvm/ADT/ArrayRef.h"
  16. #include "llvm/Analysis/AliasAnalysis.h"
  17. #include "llvm/Analysis/InstSimplifyFolder.h"
  18. #include "llvm/Analysis/LoopAccessAnalysis.h"
  19. #include "llvm/Analysis/LoopInfo.h"
  20. #include "llvm/Analysis/ScalarEvolution.h"
  21. #include "llvm/Analysis/TargetLibraryInfo.h"
  22. #include "llvm/IR/Dominators.h"
  23. #include "llvm/IR/MDBuilder.h"
  24. #include "llvm/IR/PassManager.h"
  25. #include "llvm/InitializePasses.h"
  26. #include "llvm/Support/CommandLine.h"
  27. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  28. #include "llvm/Transforms/Utils/Cloning.h"
  29. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  30. using namespace llvm;
  31. static cl::opt<bool>
  32. AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
  33. cl::Hidden,
  34. cl::desc("Add no-alias annotation for instructions that "
  35. "are disambiguated by memchecks"));
  36. LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
  37. ArrayRef<RuntimePointerCheck> Checks, Loop *L,
  38. LoopInfo *LI, DominatorTree *DT,
  39. ScalarEvolution *SE)
  40. : VersionedLoop(L), AliasChecks(Checks.begin(), Checks.end()),
  41. Preds(LAI.getPSE().getPredicate()), LAI(LAI), LI(LI), DT(DT),
  42. SE(SE) {
  43. }
  44. void LoopVersioning::versionLoop(
  45. const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
  46. assert(VersionedLoop->getUniqueExitBlock() && "No single exit block");
  47. assert(VersionedLoop->isLoopSimplifyForm() &&
  48. "Loop is not in loop-simplify form");
  49. Value *MemRuntimeCheck;
  50. Value *SCEVRuntimeCheck;
  51. Value *RuntimeCheck = nullptr;
  52. // Add the memcheck in the original preheader (this is empty initially).
  53. BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
  54. const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
  55. SCEVExpander Exp2(*RtPtrChecking.getSE(),
  56. VersionedLoop->getHeader()->getModule()->getDataLayout(),
  57. "induction");
  58. MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(),
  59. VersionedLoop, AliasChecks, Exp2);
  60. SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
  61. "scev.check");
  62. SCEVRuntimeCheck =
  63. Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator());
  64. IRBuilder<InstSimplifyFolder> Builder(
  65. RuntimeCheckBB->getContext(),
  66. InstSimplifyFolder(RuntimeCheckBB->getModule()->getDataLayout()));
  67. if (MemRuntimeCheck && SCEVRuntimeCheck) {
  68. Builder.SetInsertPoint(RuntimeCheckBB->getTerminator());
  69. RuntimeCheck =
  70. Builder.CreateOr(MemRuntimeCheck, SCEVRuntimeCheck, "lver.safe");
  71. } else
  72. RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
  73. assert(RuntimeCheck && "called even though we don't need "
  74. "any runtime checks");
  75. // Rename the block to make the IR more readable.
  76. RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
  77. ".lver.check");
  78. // Create empty preheader for the loop (and after cloning for the
  79. // non-versioned loop).
  80. BasicBlock *PH =
  81. SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
  82. nullptr, VersionedLoop->getHeader()->getName() + ".ph");
  83. // Clone the loop including the preheader.
  84. //
  85. // FIXME: This does not currently preserve SimplifyLoop because the exit
  86. // block is a join between the two loops.
  87. SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
  88. NonVersionedLoop =
  89. cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
  90. ".lver.orig", LI, DT, NonVersionedLoopBlocks);
  91. remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
  92. // Insert the conditional branch based on the result of the memchecks.
  93. Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
  94. Builder.SetInsertPoint(OrigTerm);
  95. Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(),
  96. VersionedLoop->getLoopPreheader());
  97. OrigTerm->eraseFromParent();
  98. // The loops merge in the original exit block. This is now dominated by the
  99. // memchecking block.
  100. DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
  101. // Adds the necessary PHI nodes for the versioned loops based on the
  102. // loop-defined values used outside of the loop.
  103. addPHINodes(DefsUsedOutside);
  104. formDedicatedExitBlocks(NonVersionedLoop, DT, LI, nullptr, true);
  105. formDedicatedExitBlocks(VersionedLoop, DT, LI, nullptr, true);
  106. assert(NonVersionedLoop->isLoopSimplifyForm() &&
  107. VersionedLoop->isLoopSimplifyForm() &&
  108. "The versioned loops should be in simplify form.");
  109. }
  110. void LoopVersioning::addPHINodes(
  111. const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
  112. BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
  113. assert(PHIBlock && "No single successor to loop exit block");
  114. PHINode *PN;
  115. // First add a single-operand PHI for each DefsUsedOutside if one does not
  116. // exists yet.
  117. for (auto *Inst : DefsUsedOutside) {
  118. // See if we have a single-operand PHI with the value defined by the
  119. // original loop.
  120. for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
  121. if (PN->getIncomingValue(0) == Inst) {
  122. SE->forgetValue(PN);
  123. break;
  124. }
  125. }
  126. // If not create it.
  127. if (!PN) {
  128. PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
  129. &PHIBlock->front());
  130. SmallVector<User*, 8> UsersToUpdate;
  131. for (User *U : Inst->users())
  132. if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
  133. UsersToUpdate.push_back(U);
  134. for (User *U : UsersToUpdate)
  135. U->replaceUsesOfWith(Inst, PN);
  136. PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
  137. }
  138. }
  139. // Then for each PHI add the operand for the edge from the cloned loop.
  140. for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
  141. assert(PN->getNumOperands() == 1 &&
  142. "Exit block should only have on predecessor");
  143. // If the definition was cloned used that otherwise use the same value.
  144. Value *ClonedValue = PN->getIncomingValue(0);
  145. auto Mapped = VMap.find(ClonedValue);
  146. if (Mapped != VMap.end())
  147. ClonedValue = Mapped->second;
  148. PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
  149. }
  150. }
  151. void LoopVersioning::prepareNoAliasMetadata() {
  152. // We need to turn the no-alias relation between pointer checking groups into
  153. // no-aliasing annotations between instructions.
  154. //
  155. // We accomplish this by mapping each pointer checking group (a set of
  156. // pointers memchecked together) to an alias scope and then also mapping each
  157. // group to the list of scopes it can't alias.
  158. const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
  159. LLVMContext &Context = VersionedLoop->getHeader()->getContext();
  160. // First allocate an aliasing scope for each pointer checking group.
  161. //
  162. // While traversing through the checking groups in the loop, also create a
  163. // reverse map from pointers to the pointer checking group they were assigned
  164. // to.
  165. MDBuilder MDB(Context);
  166. MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
  167. for (const auto &Group : RtPtrChecking->CheckingGroups) {
  168. GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
  169. for (unsigned PtrIdx : Group.Members)
  170. PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
  171. }
  172. // Go through the checks and for each pointer group, collect the scopes for
  173. // each non-aliasing pointer group.
  174. DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
  175. GroupToNonAliasingScopes;
  176. for (const auto &Check : AliasChecks)
  177. GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
  178. // Finally, transform the above to actually map to scope list which is what
  179. // the metadata uses.
  180. for (auto Pair : GroupToNonAliasingScopes)
  181. GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
  182. }
  183. void LoopVersioning::annotateLoopWithNoAlias() {
  184. if (!AnnotateNoAlias)
  185. return;
  186. // First prepare the maps.
  187. prepareNoAliasMetadata();
  188. // Add the scope and no-alias metadata to the instructions.
  189. for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
  190. annotateInstWithNoAlias(I);
  191. }
  192. }
  193. void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
  194. const Instruction *OrigInst) {
  195. if (!AnnotateNoAlias)
  196. return;
  197. LLVMContext &Context = VersionedLoop->getHeader()->getContext();
  198. const Value *Ptr = isa<LoadInst>(OrigInst)
  199. ? cast<LoadInst>(OrigInst)->getPointerOperand()
  200. : cast<StoreInst>(OrigInst)->getPointerOperand();
  201. // Find the group for the pointer and then add the scope metadata.
  202. auto Group = PtrToGroup.find(Ptr);
  203. if (Group != PtrToGroup.end()) {
  204. VersionedInst->setMetadata(
  205. LLVMContext::MD_alias_scope,
  206. MDNode::concatenate(
  207. VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
  208. MDNode::get(Context, GroupToScope[Group->second])));
  209. // Add the no-alias metadata.
  210. auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
  211. if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
  212. VersionedInst->setMetadata(
  213. LLVMContext::MD_noalias,
  214. MDNode::concatenate(
  215. VersionedInst->getMetadata(LLVMContext::MD_noalias),
  216. NonAliasingScopeList->second));
  217. }
  218. }
  219. namespace {
  220. bool runImpl(LoopInfo *LI, LoopAccessInfoManager &LAIs, DominatorTree *DT,
  221. ScalarEvolution *SE) {
  222. // Build up a worklist of inner-loops to version. This is necessary as the
  223. // act of versioning a loop creates new loops and can invalidate iterators
  224. // across the loops.
  225. SmallVector<Loop *, 8> Worklist;
  226. for (Loop *TopLevelLoop : *LI)
  227. for (Loop *L : depth_first(TopLevelLoop))
  228. // We only handle inner-most loops.
  229. if (L->isInnermost())
  230. Worklist.push_back(L);
  231. // Now walk the identified inner loops.
  232. bool Changed = false;
  233. for (Loop *L : Worklist) {
  234. if (!L->isLoopSimplifyForm() || !L->isRotatedForm() ||
  235. !L->getExitingBlock())
  236. continue;
  237. const LoopAccessInfo &LAI = LAIs.getInfo(*L);
  238. if (!LAI.hasConvergentOp() &&
  239. (LAI.getNumRuntimePointerChecks() ||
  240. !LAI.getPSE().getPredicate().isAlwaysTrue())) {
  241. LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
  242. LI, DT, SE);
  243. LVer.versionLoop();
  244. LVer.annotateLoopWithNoAlias();
  245. Changed = true;
  246. LAIs.clear();
  247. }
  248. }
  249. return Changed;
  250. }
  251. /// Also expose this is a pass. Currently this is only used for
  252. /// unit-testing. It adds all memchecks necessary to remove all may-aliasing
  253. /// array accesses from the loop.
  254. class LoopVersioningLegacyPass : public FunctionPass {
  255. public:
  256. LoopVersioningLegacyPass() : FunctionPass(ID) {
  257. initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry());
  258. }
  259. bool runOnFunction(Function &F) override {
  260. auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  261. auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs();
  262. auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  263. auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  264. return runImpl(LI, LAIs, DT, SE);
  265. }
  266. void getAnalysisUsage(AnalysisUsage &AU) const override {
  267. AU.addRequired<LoopInfoWrapperPass>();
  268. AU.addPreserved<LoopInfoWrapperPass>();
  269. AU.addRequired<LoopAccessLegacyAnalysis>();
  270. AU.addRequired<DominatorTreeWrapperPass>();
  271. AU.addPreserved<DominatorTreeWrapperPass>();
  272. AU.addRequired<ScalarEvolutionWrapperPass>();
  273. }
  274. static char ID;
  275. };
  276. }
  277. #define LVER_OPTION "loop-versioning"
  278. #define DEBUG_TYPE LVER_OPTION
  279. char LoopVersioningLegacyPass::ID;
  280. static const char LVer_name[] = "Loop Versioning";
  281. INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
  282. false)
  283. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  284. INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
  285. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  286. INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
  287. INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
  288. false)
  289. namespace llvm {
  290. FunctionPass *createLoopVersioningLegacyPass() {
  291. return new LoopVersioningLegacyPass();
  292. }
  293. PreservedAnalyses LoopVersioningPass::run(Function &F,
  294. FunctionAnalysisManager &AM) {
  295. auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
  296. auto &LI = AM.getResult<LoopAnalysis>(F);
  297. LoopAccessInfoManager &LAIs = AM.getResult<LoopAccessAnalysis>(F);
  298. auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
  299. if (runImpl(&LI, LAIs, &DT, &SE))
  300. return PreservedAnalyses::none();
  301. return PreservedAnalyses::all();
  302. }
  303. } // namespace llvm