LoopVersioning.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines a utility class to perform loop versioning. The versioned
  10. // loop speculates that otherwise may-aliasing memory accesses don't overlap and
  11. // emits checks to prove this.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Transforms/Utils/LoopVersioning.h"
  15. #include "llvm/ADT/ArrayRef.h"
  16. #include "llvm/Analysis/AliasAnalysis.h"
  17. #include "llvm/Analysis/InstSimplifyFolder.h"
  18. #include "llvm/Analysis/LoopAccessAnalysis.h"
  19. #include "llvm/Analysis/LoopInfo.h"
  20. #include "llvm/Analysis/ScalarEvolution.h"
  21. #include "llvm/Analysis/TargetLibraryInfo.h"
  22. #include "llvm/IR/Dominators.h"
  23. #include "llvm/IR/MDBuilder.h"
  24. #include "llvm/IR/PassManager.h"
  25. #include "llvm/InitializePasses.h"
  26. #include "llvm/Support/CommandLine.h"
  27. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  28. #include "llvm/Transforms/Utils/Cloning.h"
  29. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  30. using namespace llvm;
  31. static cl::opt<bool>
  32. AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
  33. cl::Hidden,
  34. cl::desc("Add no-alias annotation for instructions that "
  35. "are disambiguated by memchecks"));
  36. LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
  37. ArrayRef<RuntimePointerCheck> Checks, Loop *L,
  38. LoopInfo *LI, DominatorTree *DT,
  39. ScalarEvolution *SE)
  40. : VersionedLoop(L), NonVersionedLoop(nullptr),
  41. AliasChecks(Checks.begin(), Checks.end()),
  42. Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT),
  43. SE(SE) {
  44. }
  45. void LoopVersioning::versionLoop(
  46. const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
  47. assert(VersionedLoop->getUniqueExitBlock() && "No single exit block");
  48. assert(VersionedLoop->isLoopSimplifyForm() &&
  49. "Loop is not in loop-simplify form");
  50. Value *MemRuntimeCheck;
  51. Value *SCEVRuntimeCheck;
  52. Value *RuntimeCheck = nullptr;
  53. // Add the memcheck in the original preheader (this is empty initially).
  54. BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
  55. const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
  56. SCEVExpander Exp2(*RtPtrChecking.getSE(),
  57. VersionedLoop->getHeader()->getModule()->getDataLayout(),
  58. "induction");
  59. MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(),
  60. VersionedLoop, AliasChecks, Exp2);
  61. SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
  62. "scev.check");
  63. SCEVRuntimeCheck =
  64. Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator());
  65. IRBuilder<InstSimplifyFolder> Builder(
  66. RuntimeCheckBB->getContext(),
  67. InstSimplifyFolder(RuntimeCheckBB->getModule()->getDataLayout()));
  68. if (MemRuntimeCheck && SCEVRuntimeCheck) {
  69. Builder.SetInsertPoint(RuntimeCheckBB->getTerminator());
  70. RuntimeCheck =
  71. Builder.CreateOr(MemRuntimeCheck, SCEVRuntimeCheck, "lver.safe");
  72. } else
  73. RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
  74. assert(RuntimeCheck && "called even though we don't need "
  75. "any runtime checks");
  76. // Rename the block to make the IR more readable.
  77. RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
  78. ".lver.check");
  79. // Create empty preheader for the loop (and after cloning for the
  80. // non-versioned loop).
  81. BasicBlock *PH =
  82. SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
  83. nullptr, VersionedLoop->getHeader()->getName() + ".ph");
  84. // Clone the loop including the preheader.
  85. //
  86. // FIXME: This does not currently preserve SimplifyLoop because the exit
  87. // block is a join between the two loops.
  88. SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
  89. NonVersionedLoop =
  90. cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
  91. ".lver.orig", LI, DT, NonVersionedLoopBlocks);
  92. remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
  93. // Insert the conditional branch based on the result of the memchecks.
  94. Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
  95. Builder.SetInsertPoint(OrigTerm);
  96. Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(),
  97. VersionedLoop->getLoopPreheader());
  98. OrigTerm->eraseFromParent();
  99. // The loops merge in the original exit block. This is now dominated by the
  100. // memchecking block.
  101. DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
  102. // Adds the necessary PHI nodes for the versioned loops based on the
  103. // loop-defined values used outside of the loop.
  104. addPHINodes(DefsUsedOutside);
  105. formDedicatedExitBlocks(NonVersionedLoop, DT, LI, nullptr, true);
  106. formDedicatedExitBlocks(VersionedLoop, DT, LI, nullptr, true);
  107. assert(NonVersionedLoop->isLoopSimplifyForm() &&
  108. VersionedLoop->isLoopSimplifyForm() &&
  109. "The versioned loops should be in simplify form.");
  110. }
  111. void LoopVersioning::addPHINodes(
  112. const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
  113. BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
  114. assert(PHIBlock && "No single successor to loop exit block");
  115. PHINode *PN;
  116. // First add a single-operand PHI for each DefsUsedOutside if one does not
  117. // exists yet.
  118. for (auto *Inst : DefsUsedOutside) {
  119. // See if we have a single-operand PHI with the value defined by the
  120. // original loop.
  121. for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
  122. if (PN->getIncomingValue(0) == Inst)
  123. break;
  124. }
  125. // If not create it.
  126. if (!PN) {
  127. PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
  128. &PHIBlock->front());
  129. SmallVector<User*, 8> UsersToUpdate;
  130. for (User *U : Inst->users())
  131. if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
  132. UsersToUpdate.push_back(U);
  133. for (User *U : UsersToUpdate)
  134. U->replaceUsesOfWith(Inst, PN);
  135. PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
  136. }
  137. }
  138. // Then for each PHI add the operand for the edge from the cloned loop.
  139. for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
  140. assert(PN->getNumOperands() == 1 &&
  141. "Exit block should only have on predecessor");
  142. // If the definition was cloned used that otherwise use the same value.
  143. Value *ClonedValue = PN->getIncomingValue(0);
  144. auto Mapped = VMap.find(ClonedValue);
  145. if (Mapped != VMap.end())
  146. ClonedValue = Mapped->second;
  147. PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
  148. }
  149. }
  150. void LoopVersioning::prepareNoAliasMetadata() {
  151. // We need to turn the no-alias relation between pointer checking groups into
  152. // no-aliasing annotations between instructions.
  153. //
  154. // We accomplish this by mapping each pointer checking group (a set of
  155. // pointers memchecked together) to an alias scope and then also mapping each
  156. // group to the list of scopes it can't alias.
  157. const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
  158. LLVMContext &Context = VersionedLoop->getHeader()->getContext();
  159. // First allocate an aliasing scope for each pointer checking group.
  160. //
  161. // While traversing through the checking groups in the loop, also create a
  162. // reverse map from pointers to the pointer checking group they were assigned
  163. // to.
  164. MDBuilder MDB(Context);
  165. MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
  166. for (const auto &Group : RtPtrChecking->CheckingGroups) {
  167. GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
  168. for (unsigned PtrIdx : Group.Members)
  169. PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
  170. }
  171. // Go through the checks and for each pointer group, collect the scopes for
  172. // each non-aliasing pointer group.
  173. DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
  174. GroupToNonAliasingScopes;
  175. for (const auto &Check : AliasChecks)
  176. GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
  177. // Finally, transform the above to actually map to scope list which is what
  178. // the metadata uses.
  179. for (auto Pair : GroupToNonAliasingScopes)
  180. GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
  181. }
  182. void LoopVersioning::annotateLoopWithNoAlias() {
  183. if (!AnnotateNoAlias)
  184. return;
  185. // First prepare the maps.
  186. prepareNoAliasMetadata();
  187. // Add the scope and no-alias metadata to the instructions.
  188. for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
  189. annotateInstWithNoAlias(I);
  190. }
  191. }
  192. void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
  193. const Instruction *OrigInst) {
  194. if (!AnnotateNoAlias)
  195. return;
  196. LLVMContext &Context = VersionedLoop->getHeader()->getContext();
  197. const Value *Ptr = isa<LoadInst>(OrigInst)
  198. ? cast<LoadInst>(OrigInst)->getPointerOperand()
  199. : cast<StoreInst>(OrigInst)->getPointerOperand();
  200. // Find the group for the pointer and then add the scope metadata.
  201. auto Group = PtrToGroup.find(Ptr);
  202. if (Group != PtrToGroup.end()) {
  203. VersionedInst->setMetadata(
  204. LLVMContext::MD_alias_scope,
  205. MDNode::concatenate(
  206. VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
  207. MDNode::get(Context, GroupToScope[Group->second])));
  208. // Add the no-alias metadata.
  209. auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
  210. if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
  211. VersionedInst->setMetadata(
  212. LLVMContext::MD_noalias,
  213. MDNode::concatenate(
  214. VersionedInst->getMetadata(LLVMContext::MD_noalias),
  215. NonAliasingScopeList->second));
  216. }
  217. }
  218. namespace {
  219. bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA,
  220. DominatorTree *DT, ScalarEvolution *SE) {
  221. // Build up a worklist of inner-loops to version. This is necessary as the
  222. // act of versioning a loop creates new loops and can invalidate iterators
  223. // across the loops.
  224. SmallVector<Loop *, 8> Worklist;
  225. for (Loop *TopLevelLoop : *LI)
  226. for (Loop *L : depth_first(TopLevelLoop))
  227. // We only handle inner-most loops.
  228. if (L->isInnermost())
  229. Worklist.push_back(L);
  230. // Now walk the identified inner loops.
  231. bool Changed = false;
  232. for (Loop *L : Worklist) {
  233. if (!L->isLoopSimplifyForm() || !L->isRotatedForm() ||
  234. !L->getExitingBlock())
  235. continue;
  236. const LoopAccessInfo &LAI = GetLAA(*L);
  237. if (!LAI.hasConvergentOp() &&
  238. (LAI.getNumRuntimePointerChecks() ||
  239. !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) {
  240. LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
  241. LI, DT, SE);
  242. LVer.versionLoop();
  243. LVer.annotateLoopWithNoAlias();
  244. Changed = true;
  245. }
  246. }
  247. return Changed;
  248. }
  249. /// Also expose this is a pass. Currently this is only used for
  250. /// unit-testing. It adds all memchecks necessary to remove all may-aliasing
  251. /// array accesses from the loop.
  252. class LoopVersioningLegacyPass : public FunctionPass {
  253. public:
  254. LoopVersioningLegacyPass() : FunctionPass(ID) {
  255. initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry());
  256. }
  257. bool runOnFunction(Function &F) override {
  258. auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  259. auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
  260. return getAnalysis<LoopAccessLegacyAnalysis>().getInfo(&L);
  261. };
  262. auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  263. auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  264. return runImpl(LI, GetLAA, DT, SE);
  265. }
  266. void getAnalysisUsage(AnalysisUsage &AU) const override {
  267. AU.addRequired<LoopInfoWrapperPass>();
  268. AU.addPreserved<LoopInfoWrapperPass>();
  269. AU.addRequired<LoopAccessLegacyAnalysis>();
  270. AU.addRequired<DominatorTreeWrapperPass>();
  271. AU.addPreserved<DominatorTreeWrapperPass>();
  272. AU.addRequired<ScalarEvolutionWrapperPass>();
  273. }
  274. static char ID;
  275. };
  276. }
  277. #define LVER_OPTION "loop-versioning"
  278. #define DEBUG_TYPE LVER_OPTION
  279. char LoopVersioningLegacyPass::ID;
  280. static const char LVer_name[] = "Loop Versioning";
  281. INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
  282. false)
  283. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  284. INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
  285. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  286. INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
  287. INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
  288. false)
  289. namespace llvm {
  290. FunctionPass *createLoopVersioningLegacyPass() {
  291. return new LoopVersioningLegacyPass();
  292. }
  293. PreservedAnalyses LoopVersioningPass::run(Function &F,
  294. FunctionAnalysisManager &AM) {
  295. auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
  296. auto &LI = AM.getResult<LoopAnalysis>(F);
  297. auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  298. auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
  299. auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
  300. auto &AA = AM.getResult<AAManager>(F);
  301. auto &AC = AM.getResult<AssumptionAnalysis>(F);
  302. auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager();
  303. auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
  304. LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE,
  305. TLI, TTI, nullptr, nullptr, nullptr};
  306. return LAM.getResult<LoopAccessAnalysis>(L, AR);
  307. };
  308. if (runImpl(&LI, GetLAA, &DT, &SE))
  309. return PreservedAnalyses::none();
  310. return PreservedAnalyses::all();
  311. }
  312. } // namespace llvm