LoopBoundSplit.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
  9. #include "llvm/ADT/Sequence.h"
  10. #include "llvm/Analysis/LoopAccessAnalysis.h"
  11. #include "llvm/Analysis/LoopAnalysisManager.h"
  12. #include "llvm/Analysis/LoopInfo.h"
  13. #include "llvm/Analysis/LoopIterator.h"
  14. #include "llvm/Analysis/LoopPass.h"
  15. #include "llvm/Analysis/MemorySSA.h"
  16. #include "llvm/Analysis/MemorySSAUpdater.h"
  17. #include "llvm/Analysis/ScalarEvolution.h"
  18. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  19. #include "llvm/IR/PatternMatch.h"
  20. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  21. #include "llvm/Transforms/Utils/Cloning.h"
  22. #include "llvm/Transforms/Utils/LoopSimplify.h"
  23. #include "llvm/Transforms/Utils/LoopUtils.h"
  24. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  25. #define DEBUG_TYPE "loop-bound-split"
  26. namespace llvm {
  27. using namespace PatternMatch;
  28. namespace {
  29. struct ConditionInfo {
  30. /// Branch instruction with this condition
  31. BranchInst *BI;
  32. /// ICmp instruction with this condition
  33. ICmpInst *ICmp;
  34. /// Preciate info
  35. ICmpInst::Predicate Pred;
  36. /// AddRec llvm value
  37. Value *AddRecValue;
  38. /// Non PHI AddRec llvm value
  39. Value *NonPHIAddRecValue;
  40. /// Bound llvm value
  41. Value *BoundValue;
  42. /// AddRec SCEV
  43. const SCEVAddRecExpr *AddRecSCEV;
  44. /// Bound SCEV
  45. const SCEV *BoundSCEV;
  46. ConditionInfo()
  47. : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE),
  48. AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr),
  49. BoundSCEV(nullptr) {}
  50. };
  51. } // namespace
  52. static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
  53. ConditionInfo &Cond, const Loop &L) {
  54. Cond.ICmp = ICmp;
  55. if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
  56. m_Value(Cond.BoundValue)))) {
  57. const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
  58. const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);
  59. const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
  60. const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);
  61. // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
  62. if (!LHSAddRecSCEV && RHSAddRecSCEV) {
  63. std::swap(Cond.AddRecValue, Cond.BoundValue);
  64. std::swap(AddRecSCEV, BoundSCEV);
  65. Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
  66. }
  67. Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
  68. Cond.BoundSCEV = BoundSCEV;
  69. Cond.NonPHIAddRecValue = Cond.AddRecValue;
  70. // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with
  71. // value from backedge.
  72. if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) {
  73. PHINode *PN = cast<PHINode>(Cond.AddRecValue);
  74. Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch());
  75. }
  76. }
  77. }
  78. static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
  79. ConditionInfo &Cond, bool IsExitCond) {
  80. if (IsExitCond) {
  81. const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
  82. if (isa<SCEVCouldNotCompute>(ExitCount))
  83. return false;
  84. Cond.BoundSCEV = ExitCount;
  85. return true;
  86. }
  87. // For non-exit condtion, if pred is LT, keep existing bound.
  88. if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
  89. return true;
  90. // For non-exit condition, if pre is LE, try to convert it to LT.
  91. // Range Range
  92. // AddRec <= Bound --> AddRec < Bound + 1
  93. if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
  94. return false;
  95. if (IntegerType *BoundSCEVIntType =
  96. dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
  97. unsigned BitWidth = BoundSCEVIntType->getBitWidth();
  98. APInt Max = ICmpInst::isSigned(Cond.Pred)
  99. ? APInt::getSignedMaxValue(BitWidth)
  100. : APInt::getMaxValue(BitWidth);
  101. const SCEV *MaxSCEV = SE.getConstant(Max);
  102. // Check Bound < INT_MAX
  103. ICmpInst::Predicate Pred =
  104. ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
  105. if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
  106. const SCEV *BoundPlusOneSCEV =
  107. SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
  108. Cond.BoundSCEV = BoundPlusOneSCEV;
  109. Cond.Pred = Pred;
  110. return true;
  111. }
  112. }
  113. // ToDo: Support ICMP_NE/EQ.
  114. return false;
  115. }
  116. static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
  117. ICmpInst *ICmp, ConditionInfo &Cond,
  118. bool IsExitCond) {
  119. analyzeICmp(SE, ICmp, Cond, L);
  120. // The BoundSCEV should be evaluated at loop entry.
  121. if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
  122. return false;
  123. // Allowed AddRec as induction variable.
  124. if (!Cond.AddRecSCEV)
  125. return false;
  126. if (!Cond.AddRecSCEV->isAffine())
  127. return false;
  128. const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
  129. // Allowed constant step.
  130. if (!isa<SCEVConstant>(StepRecSCEV))
  131. return false;
  132. ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
  133. // Allowed positive step for now.
  134. // TODO: Support negative step.
  135. if (StepCI->isNegative() || StepCI->isZero())
  136. return false;
  137. // Calculate upper bound.
  138. if (!calculateUpperBound(L, SE, Cond, IsExitCond))
  139. return false;
  140. return true;
  141. }
  142. static bool isProcessableCondBI(const ScalarEvolution &SE,
  143. const BranchInst *BI) {
  144. BasicBlock *TrueSucc = nullptr;
  145. BasicBlock *FalseSucc = nullptr;
  146. ICmpInst::Predicate Pred;
  147. Value *LHS, *RHS;
  148. if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
  149. m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
  150. return false;
  151. if (!SE.isSCEVable(LHS->getType()))
  152. return false;
  153. assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
  154. if (TrueSucc == FalseSucc)
  155. return false;
  156. return true;
  157. }
  158. static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
  159. ScalarEvolution &SE, ConditionInfo &Cond) {
  160. // Skip function with optsize.
  161. if (L.getHeader()->getParent()->hasOptSize())
  162. return false;
  163. // Split only innermost loop.
  164. if (!L.isInnermost())
  165. return false;
  166. // Check loop is in simplified form.
  167. if (!L.isLoopSimplifyForm())
  168. return false;
  169. // Check loop is in LCSSA form.
  170. if (!L.isLCSSAForm(DT))
  171. return false;
  172. // Skip loop that cannot be cloned.
  173. if (!L.isSafeToClone())
  174. return false;
  175. BasicBlock *ExitingBB = L.getExitingBlock();
  176. // Assumed only one exiting block.
  177. if (!ExitingBB)
  178. return false;
  179. BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
  180. if (!ExitingBI)
  181. return false;
  182. // Allowed only conditional branch with ICmp.
  183. if (!isProcessableCondBI(SE, ExitingBI))
  184. return false;
  185. // Check the condition is processable.
  186. ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
  187. if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
  188. return false;
  189. Cond.BI = ExitingBI;
  190. return true;
  191. }
  192. static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
  193. // If the conditional branch splits a loop into two halves, we could
  194. // generally say it is profitable.
  195. //
  196. // ToDo: Add more profitable cases here.
  197. // Check this branch causes diamond CFG.
  198. BasicBlock *Succ0 = BI->getSuccessor(0);
  199. BasicBlock *Succ1 = BI->getSuccessor(1);
  200. BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
  201. BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
  202. if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
  203. return false;
  204. // ToDo: Calculate each successor's instruction cost.
  205. return true;
  206. }
  207. static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
  208. ConditionInfo &ExitingCond,
  209. ConditionInfo &SplitCandidateCond) {
  210. for (auto *BB : L.blocks()) {
  211. // Skip condition of backedge.
  212. if (L.getLoopLatch() == BB)
  213. continue;
  214. auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
  215. if (!BI)
  216. continue;
  217. // Check conditional branch with ICmp.
  218. if (!isProcessableCondBI(SE, BI))
  219. continue;
  220. // Skip loop invariant condition.
  221. if (L.isLoopInvariant(BI->getCondition()))
  222. continue;
  223. // Check the condition is processable.
  224. ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
  225. if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
  226. /*IsExitCond*/ false))
  227. continue;
  228. if (ExitingCond.BoundSCEV->getType() !=
  229. SplitCandidateCond.BoundSCEV->getType())
  230. continue;
  231. // After transformation, we assume the split condition of the pre-loop is
  232. // always true. In order to guarantee it, we need to check the start value
  233. // of the split cond AddRec satisfies the split condition.
  234. if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,
  235. SplitCandidateCond.AddRecSCEV->getStart(),
  236. SplitCandidateCond.BoundSCEV))
  237. continue;
  238. SplitCandidateCond.BI = BI;
  239. return BI;
  240. }
  241. return nullptr;
  242. }
  243. static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
  244. ScalarEvolution &SE, LPMUpdater &U) {
  245. ConditionInfo SplitCandidateCond;
  246. ConditionInfo ExitingCond;
  247. // Check we can split this loop's bound.
  248. if (!canSplitLoopBound(L, DT, SE, ExitingCond))
  249. return false;
  250. if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
  251. return false;
  252. if (!isProfitableToTransform(L, SplitCandidateCond.BI))
  253. return false;
  254. // Now, we have a split candidate. Let's build a form as below.
  255. // +--------------------+
  256. // | preheader |
  257. // | set up newbound |
  258. // +--------------------+
  259. // | /----------------\
  260. // +--------v----v------+ |
  261. // | header |---\ |
  262. // | with true condition| | |
  263. // +--------------------+ | |
  264. // | | |
  265. // +--------v-----------+ | |
  266. // | if.then.BB | | |
  267. // +--------------------+ | |
  268. // | | |
  269. // +--------v-----------<---/ |
  270. // | latch >----------/
  271. // | with newbound |
  272. // +--------------------+
  273. // |
  274. // +--------v-----------+
  275. // | preheader2 |--------------\
  276. // | if (AddRec i != | |
  277. // | org bound) | |
  278. // +--------------------+ |
  279. // | /----------------\ |
  280. // +--------v----v------+ | |
  281. // | header2 |---\ | |
  282. // | conditional branch | | | |
  283. // |with false condition| | | |
  284. // +--------------------+ | | |
  285. // | | | |
  286. // +--------v-----------+ | | |
  287. // | if.then.BB2 | | | |
  288. // +--------------------+ | | |
  289. // | | | |
  290. // +--------v-----------<---/ | |
  291. // | latch2 >----------/ |
  292. // | with org bound | |
  293. // +--------v-----------+ |
  294. // | |
  295. // | +---------------+ |
  296. // +--> exit <-------/
  297. // +---------------+
  298. // Let's create post loop.
  299. SmallVector<BasicBlock *, 8> PostLoopBlocks;
  300. Loop *PostLoop;
  301. ValueToValueMapTy VMap;
  302. BasicBlock *PreHeader = L.getLoopPreheader();
  303. BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
  304. PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
  305. ".split", &LI, &DT, PostLoopBlocks);
  306. remapInstructionsInBlocks(PostLoopBlocks, VMap);
  307. BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
  308. IRBuilder<> Builder(&PostLoopPreHeader->front());
  309. // Update phi nodes in header of post-loop.
  310. bool isExitingLatch =
  311. (L.getExitingBlock() == L.getLoopLatch()) ? true : false;
  312. Value *ExitingCondLCSSAPhi = nullptr;
  313. for (PHINode &PN : L.getHeader()->phis()) {
  314. // Create LCSSA phi node in preheader of post-loop.
  315. PHINode *LCSSAPhi =
  316. Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
  317. LCSSAPhi->setDebugLoc(PN.getDebugLoc());
  318. // If the exiting block is loop latch, the phi does not have the update at
  319. // last iteration. In this case, update lcssa phi with value from backedge.
  320. LCSSAPhi->addIncoming(
  321. isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN,
  322. L.getExitingBlock());
  323. // Update the start value of phi node in post-loop with the LCSSA phi node.
  324. PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
  325. PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi);
  326. // Find PHI with exiting condition from pre-loop. The PHI should be
  327. // SCEVAddRecExpr and have same incoming value from backedge with
  328. // ExitingCond.
  329. if (!SE.isSCEVable(PN.getType()))
  330. continue;
  331. const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
  332. if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
  333. PN.getIncomingValueForBlock(L.getLoopLatch()))
  334. ExitingCondLCSSAPhi = LCSSAPhi;
  335. }
  336. // Add conditional branch to check we can skip post-loop in its preheader.
  337. Instruction *OrigBI = PostLoopPreHeader->getTerminator();
  338. ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
  339. Value *Cond =
  340. Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue);
  341. Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
  342. OrigBI->eraseFromParent();
  343. // Create new loop bound and add it into preheader of pre-loop.
  344. const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
  345. const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
  346. NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
  347. ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
  348. : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
  349. SCEVExpander Expander(
  350. SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
  351. Instruction *InsertPt = SplitLoopPH->getTerminator();
  352. Value *NewBoundValue =
  353. Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
  354. NewBoundValue->setName("new.bound");
  355. // Replace exiting bound value of pre-loop NewBound.
  356. ExitingCond.ICmp->setOperand(1, NewBoundValue);
  357. // Replace SplitCandidateCond.BI's condition of pre-loop by True.
  358. LLVMContext &Context = PreHeader->getContext();
  359. SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
  360. // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
  361. BranchInst *ClonedSplitCandidateBI =
  362. cast<BranchInst>(VMap[SplitCandidateCond.BI]);
  363. ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
  364. // Replace exit branch target of pre-loop by post-loop's preheader.
  365. if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
  366. ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
  367. else
  368. ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
  369. // Update phi node in exit block of post-loop.
  370. Builder.SetInsertPoint(&PostLoopPreHeader->front());
  371. for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
  372. for (auto i : seq<int>(0, PN.getNumOperands())) {
  373. // Check incoming block is pre-loop's exiting block.
  374. if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
  375. Value *IncomingValue = PN.getIncomingValue(i);
  376. // Create LCSSA phi node for incoming value.
  377. PHINode *LCSSAPhi =
  378. Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
  379. LCSSAPhi->setDebugLoc(PN.getDebugLoc());
  380. LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i));
  381. // Replace pre-loop's exiting block by post-loop's preheader.
  382. PN.setIncomingBlock(i, PostLoopPreHeader);
  383. // Replace incoming value by LCSSAPhi.
  384. PN.setIncomingValue(i, LCSSAPhi);
  385. // Add a new incoming value with post-loop's exiting block.
  386. PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock());
  387. }
  388. }
  389. }
  390. // Update dominator tree.
  391. DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
  392. DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
  393. // Invalidate cached SE information.
  394. SE.forgetLoop(&L);
  395. // Canonicalize loops.
  396. simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
  397. simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
  398. // Add new post-loop to loop pass manager.
  399. U.addSiblingLoops(PostLoop);
  400. return true;
  401. }
  402. PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
  403. LoopStandardAnalysisResults &AR,
  404. LPMUpdater &U) {
  405. Function &F = *L.getHeader()->getParent();
  406. (void)F;
  407. LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
  408. << "\n");
  409. if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
  410. return PreservedAnalyses::all();
  411. assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
  412. AR.LI.verify(AR.DT);
  413. return getLoopPassPreservedAnalyses();
  414. }
  415. } // end namespace llvm