LoopVersioning.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines a utility class to perform loop versioning. The versioned
  10. // loop speculates that otherwise may-aliasing memory accesses don't overlap and
  11. // emits checks to prove this.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Transforms/Utils/LoopVersioning.h"
  15. #include "llvm/ADT/ArrayRef.h"
  16. #include "llvm/Analysis/LoopAccessAnalysis.h"
  17. #include "llvm/Analysis/LoopInfo.h"
  18. #include "llvm/Analysis/MemorySSA.h"
  19. #include "llvm/Analysis/ScalarEvolution.h"
  20. #include "llvm/Analysis/TargetLibraryInfo.h"
  21. #include "llvm/IR/Dominators.h"
  22. #include "llvm/IR/MDBuilder.h"
  23. #include "llvm/IR/PassManager.h"
  24. #include "llvm/InitializePasses.h"
  25. #include "llvm/Support/CommandLine.h"
  26. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  27. #include "llvm/Transforms/Utils/Cloning.h"
  28. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  29. using namespace llvm;
  30. static cl::opt<bool>
  31. AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
  32. cl::Hidden,
  33. cl::desc("Add no-alias annotation for instructions that "
  34. "are disambiguated by memchecks"));
  35. LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
  36. ArrayRef<RuntimePointerCheck> Checks, Loop *L,
  37. LoopInfo *LI, DominatorTree *DT,
  38. ScalarEvolution *SE)
  39. : VersionedLoop(L), NonVersionedLoop(nullptr),
  40. AliasChecks(Checks.begin(), Checks.end()),
  41. Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT),
  42. SE(SE) {
  43. assert(L->getUniqueExitBlock() && "No single exit block");
  44. }
  45. void LoopVersioning::versionLoop(
  46. const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
  47. assert(VersionedLoop->isLoopSimplifyForm() &&
  48. "Loop is not in loop-simplify form");
  49. Instruction *FirstCheckInst;
  50. Instruction *MemRuntimeCheck;
  51. Value *SCEVRuntimeCheck;
  52. Value *RuntimeCheck = nullptr;
  53. // Add the memcheck in the original preheader (this is empty initially).
  54. BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
  55. const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
  56. std::tie(FirstCheckInst, MemRuntimeCheck) =
  57. addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop,
  58. AliasChecks, RtPtrChecking.getSE());
  59. SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
  60. "scev.check");
  61. SCEVRuntimeCheck =
  62. Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator());
  63. auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
  64. // Discard the SCEV runtime check if it is always true.
  65. if (CI && CI->isZero())
  66. SCEVRuntimeCheck = nullptr;
  67. if (MemRuntimeCheck && SCEVRuntimeCheck) {
  68. RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
  69. SCEVRuntimeCheck, "lver.safe");
  70. if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
  71. I->insertBefore(RuntimeCheckBB->getTerminator());
  72. } else
  73. RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
  74. assert(RuntimeCheck && "called even though we don't need "
  75. "any runtime checks");
  76. // Rename the block to make the IR more readable.
  77. RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
  78. ".lver.check");
  79. // Create empty preheader for the loop (and after cloning for the
  80. // non-versioned loop).
  81. BasicBlock *PH =
  82. SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
  83. nullptr, VersionedLoop->getHeader()->getName() + ".ph");
  84. // Clone the loop including the preheader.
  85. //
  86. // FIXME: This does not currently preserve SimplifyLoop because the exit
  87. // block is a join between the two loops.
  88. SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
  89. NonVersionedLoop =
  90. cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
  91. ".lver.orig", LI, DT, NonVersionedLoopBlocks);
  92. remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
  93. // Insert the conditional branch based on the result of the memchecks.
  94. Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
  95. BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
  96. VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
  97. OrigTerm->eraseFromParent();
  98. // The loops merge in the original exit block. This is now dominated by the
  99. // memchecking block.
  100. DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
  101. // Adds the necessary PHI nodes for the versioned loops based on the
  102. // loop-defined values used outside of the loop.
  103. addPHINodes(DefsUsedOutside);
  104. formDedicatedExitBlocks(NonVersionedLoop, DT, LI, nullptr, true);
  105. formDedicatedExitBlocks(VersionedLoop, DT, LI, nullptr, true);
  106. assert(NonVersionedLoop->isLoopSimplifyForm() &&
  107. VersionedLoop->isLoopSimplifyForm() &&
  108. "The versioned loops should be in simplify form.");
  109. }
  110. void LoopVersioning::addPHINodes(
  111. const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
  112. BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
  113. assert(PHIBlock && "No single successor to loop exit block");
  114. PHINode *PN;
  115. // First add a single-operand PHI for each DefsUsedOutside if one does not
  116. // exists yet.
  117. for (auto *Inst : DefsUsedOutside) {
  118. // See if we have a single-operand PHI with the value defined by the
  119. // original loop.
  120. for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
  121. if (PN->getIncomingValue(0) == Inst)
  122. break;
  123. }
  124. // If not create it.
  125. if (!PN) {
  126. PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
  127. &PHIBlock->front());
  128. SmallVector<User*, 8> UsersToUpdate;
  129. for (User *U : Inst->users())
  130. if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
  131. UsersToUpdate.push_back(U);
  132. for (User *U : UsersToUpdate)
  133. U->replaceUsesOfWith(Inst, PN);
  134. PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
  135. }
  136. }
  137. // Then for each PHI add the operand for the edge from the cloned loop.
  138. for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
  139. assert(PN->getNumOperands() == 1 &&
  140. "Exit block should only have on predecessor");
  141. // If the definition was cloned used that otherwise use the same value.
  142. Value *ClonedValue = PN->getIncomingValue(0);
  143. auto Mapped = VMap.find(ClonedValue);
  144. if (Mapped != VMap.end())
  145. ClonedValue = Mapped->second;
  146. PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
  147. }
  148. }
  149. void LoopVersioning::prepareNoAliasMetadata() {
  150. // We need to turn the no-alias relation between pointer checking groups into
  151. // no-aliasing annotations between instructions.
  152. //
  153. // We accomplish this by mapping each pointer checking group (a set of
  154. // pointers memchecked together) to an alias scope and then also mapping each
  155. // group to the list of scopes it can't alias.
  156. const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
  157. LLVMContext &Context = VersionedLoop->getHeader()->getContext();
  158. // First allocate an aliasing scope for each pointer checking group.
  159. //
  160. // While traversing through the checking groups in the loop, also create a
  161. // reverse map from pointers to the pointer checking group they were assigned
  162. // to.
  163. MDBuilder MDB(Context);
  164. MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
  165. for (const auto &Group : RtPtrChecking->CheckingGroups) {
  166. GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
  167. for (unsigned PtrIdx : Group.Members)
  168. PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
  169. }
  170. // Go through the checks and for each pointer group, collect the scopes for
  171. // each non-aliasing pointer group.
  172. DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
  173. GroupToNonAliasingScopes;
  174. for (const auto &Check : AliasChecks)
  175. GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
  176. // Finally, transform the above to actually map to scope list which is what
  177. // the metadata uses.
  178. for (auto Pair : GroupToNonAliasingScopes)
  179. GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
  180. }
  181. void LoopVersioning::annotateLoopWithNoAlias() {
  182. if (!AnnotateNoAlias)
  183. return;
  184. // First prepare the maps.
  185. prepareNoAliasMetadata();
  186. // Add the scope and no-alias metadata to the instructions.
  187. for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
  188. annotateInstWithNoAlias(I);
  189. }
  190. }
  191. void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
  192. const Instruction *OrigInst) {
  193. if (!AnnotateNoAlias)
  194. return;
  195. LLVMContext &Context = VersionedLoop->getHeader()->getContext();
  196. const Value *Ptr = isa<LoadInst>(OrigInst)
  197. ? cast<LoadInst>(OrigInst)->getPointerOperand()
  198. : cast<StoreInst>(OrigInst)->getPointerOperand();
  199. // Find the group for the pointer and then add the scope metadata.
  200. auto Group = PtrToGroup.find(Ptr);
  201. if (Group != PtrToGroup.end()) {
  202. VersionedInst->setMetadata(
  203. LLVMContext::MD_alias_scope,
  204. MDNode::concatenate(
  205. VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
  206. MDNode::get(Context, GroupToScope[Group->second])));
  207. // Add the no-alias metadata.
  208. auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
  209. if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
  210. VersionedInst->setMetadata(
  211. LLVMContext::MD_noalias,
  212. MDNode::concatenate(
  213. VersionedInst->getMetadata(LLVMContext::MD_noalias),
  214. NonAliasingScopeList->second));
  215. }
  216. }
  217. namespace {
  218. bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA,
  219. DominatorTree *DT, ScalarEvolution *SE) {
  220. // Build up a worklist of inner-loops to version. This is necessary as the
  221. // act of versioning a loop creates new loops and can invalidate iterators
  222. // across the loops.
  223. SmallVector<Loop *, 8> Worklist;
  224. for (Loop *TopLevelLoop : *LI)
  225. for (Loop *L : depth_first(TopLevelLoop))
  226. // We only handle inner-most loops.
  227. if (L->isInnermost())
  228. Worklist.push_back(L);
  229. // Now walk the identified inner loops.
  230. bool Changed = false;
  231. for (Loop *L : Worklist) {
  232. if (!L->isLoopSimplifyForm() || !L->isRotatedForm() ||
  233. !L->getExitingBlock())
  234. continue;
  235. const LoopAccessInfo &LAI = GetLAA(*L);
  236. if (!LAI.hasConvergentOp() &&
  237. (LAI.getNumRuntimePointerChecks() ||
  238. !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) {
  239. LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
  240. LI, DT, SE);
  241. LVer.versionLoop();
  242. LVer.annotateLoopWithNoAlias();
  243. Changed = true;
  244. }
  245. }
  246. return Changed;
  247. }
  248. /// Also expose this is a pass. Currently this is only used for
  249. /// unit-testing. It adds all memchecks necessary to remove all may-aliasing
  250. /// array accesses from the loop.
  251. class LoopVersioningLegacyPass : public FunctionPass {
  252. public:
  253. LoopVersioningLegacyPass() : FunctionPass(ID) {
  254. initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry());
  255. }
  256. bool runOnFunction(Function &F) override {
  257. auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  258. auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
  259. return getAnalysis<LoopAccessLegacyAnalysis>().getInfo(&L);
  260. };
  261. auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  262. auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  263. return runImpl(LI, GetLAA, DT, SE);
  264. }
  265. void getAnalysisUsage(AnalysisUsage &AU) const override {
  266. AU.addRequired<LoopInfoWrapperPass>();
  267. AU.addPreserved<LoopInfoWrapperPass>();
  268. AU.addRequired<LoopAccessLegacyAnalysis>();
  269. AU.addRequired<DominatorTreeWrapperPass>();
  270. AU.addPreserved<DominatorTreeWrapperPass>();
  271. AU.addRequired<ScalarEvolutionWrapperPass>();
  272. }
  273. static char ID;
  274. };
  275. }
  276. #define LVER_OPTION "loop-versioning"
  277. #define DEBUG_TYPE LVER_OPTION
  278. char LoopVersioningLegacyPass::ID;
  279. static const char LVer_name[] = "Loop Versioning";
  280. INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
  281. false)
  282. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  283. INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
  284. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  285. INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
  286. INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
  287. false)
  288. namespace llvm {
  289. FunctionPass *createLoopVersioningLegacyPass() {
  290. return new LoopVersioningLegacyPass();
  291. }
  292. PreservedAnalyses LoopVersioningPass::run(Function &F,
  293. FunctionAnalysisManager &AM) {
  294. auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
  295. auto &LI = AM.getResult<LoopAnalysis>(F);
  296. auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  297. auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
  298. auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
  299. auto &AA = AM.getResult<AAManager>(F);
  300. auto &AC = AM.getResult<AssumptionAnalysis>(F);
  301. MemorySSA *MSSA = EnableMSSALoopDependency
  302. ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA()
  303. : nullptr;
  304. auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager();
  305. auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
  306. LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE,
  307. TLI, TTI, nullptr, MSSA};
  308. return LAM.getResult<LoopAccessAnalysis>(L, AR);
  309. };
  310. if (runImpl(&LI, GetLAA, &DT, &SE))
  311. return PreservedAnalyses::none();
  312. return PreservedAnalyses::all();
  313. }
  314. } // namespace llvm