LoopBoundSplit.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
  9. #include "llvm/ADT/Sequence.h"
  10. #include "llvm/Analysis/LoopAnalysisManager.h"
  11. #include "llvm/Analysis/LoopInfo.h"
  12. #include "llvm/Analysis/ScalarEvolution.h"
  13. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  14. #include "llvm/IR/PatternMatch.h"
  15. #include "llvm/Transforms/Scalar/LoopPassManager.h"
  16. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  17. #include "llvm/Transforms/Utils/Cloning.h"
  18. #include "llvm/Transforms/Utils/LoopSimplify.h"
  19. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  20. #define DEBUG_TYPE "loop-bound-split"
  21. namespace llvm {
  22. using namespace PatternMatch;
  23. namespace {
  24. struct ConditionInfo {
  25. /// Branch instruction with this condition
  26. BranchInst *BI = nullptr;
  27. /// ICmp instruction with this condition
  28. ICmpInst *ICmp = nullptr;
  29. /// Preciate info
  30. ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
  31. /// AddRec llvm value
  32. Value *AddRecValue = nullptr;
  33. /// Non PHI AddRec llvm value
  34. Value *NonPHIAddRecValue;
  35. /// Bound llvm value
  36. Value *BoundValue = nullptr;
  37. /// AddRec SCEV
  38. const SCEVAddRecExpr *AddRecSCEV = nullptr;
  39. /// Bound SCEV
  40. const SCEV *BoundSCEV = nullptr;
  41. ConditionInfo() = default;
  42. };
  43. } // namespace
  44. static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
  45. ConditionInfo &Cond, const Loop &L) {
  46. Cond.ICmp = ICmp;
  47. if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
  48. m_Value(Cond.BoundValue)))) {
  49. const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
  50. const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);
  51. const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
  52. const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);
  53. // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
  54. if (!LHSAddRecSCEV && RHSAddRecSCEV) {
  55. std::swap(Cond.AddRecValue, Cond.BoundValue);
  56. std::swap(AddRecSCEV, BoundSCEV);
  57. Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
  58. }
  59. Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
  60. Cond.BoundSCEV = BoundSCEV;
  61. Cond.NonPHIAddRecValue = Cond.AddRecValue;
  62. // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with
  63. // value from backedge.
  64. if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) {
  65. PHINode *PN = cast<PHINode>(Cond.AddRecValue);
  66. Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch());
  67. }
  68. }
  69. }
  70. static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
  71. ConditionInfo &Cond, bool IsExitCond) {
  72. if (IsExitCond) {
  73. const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
  74. if (isa<SCEVCouldNotCompute>(ExitCount))
  75. return false;
  76. Cond.BoundSCEV = ExitCount;
  77. return true;
  78. }
  79. // For non-exit condtion, if pred is LT, keep existing bound.
  80. if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
  81. return true;
  82. // For non-exit condition, if pre is LE, try to convert it to LT.
  83. // Range Range
  84. // AddRec <= Bound --> AddRec < Bound + 1
  85. if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
  86. return false;
  87. if (IntegerType *BoundSCEVIntType =
  88. dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
  89. unsigned BitWidth = BoundSCEVIntType->getBitWidth();
  90. APInt Max = ICmpInst::isSigned(Cond.Pred)
  91. ? APInt::getSignedMaxValue(BitWidth)
  92. : APInt::getMaxValue(BitWidth);
  93. const SCEV *MaxSCEV = SE.getConstant(Max);
  94. // Check Bound < INT_MAX
  95. ICmpInst::Predicate Pred =
  96. ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
  97. if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
  98. const SCEV *BoundPlusOneSCEV =
  99. SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
  100. Cond.BoundSCEV = BoundPlusOneSCEV;
  101. Cond.Pred = Pred;
  102. return true;
  103. }
  104. }
  105. // ToDo: Support ICMP_NE/EQ.
  106. return false;
  107. }
  108. static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
  109. ICmpInst *ICmp, ConditionInfo &Cond,
  110. bool IsExitCond) {
  111. analyzeICmp(SE, ICmp, Cond, L);
  112. // The BoundSCEV should be evaluated at loop entry.
  113. if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
  114. return false;
  115. // Allowed AddRec as induction variable.
  116. if (!Cond.AddRecSCEV)
  117. return false;
  118. if (!Cond.AddRecSCEV->isAffine())
  119. return false;
  120. const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
  121. // Allowed constant step.
  122. if (!isa<SCEVConstant>(StepRecSCEV))
  123. return false;
  124. ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
  125. // Allowed positive step for now.
  126. // TODO: Support negative step.
  127. if (StepCI->isNegative() || StepCI->isZero())
  128. return false;
  129. // Calculate upper bound.
  130. if (!calculateUpperBound(L, SE, Cond, IsExitCond))
  131. return false;
  132. return true;
  133. }
  134. static bool isProcessableCondBI(const ScalarEvolution &SE,
  135. const BranchInst *BI) {
  136. BasicBlock *TrueSucc = nullptr;
  137. BasicBlock *FalseSucc = nullptr;
  138. ICmpInst::Predicate Pred;
  139. Value *LHS, *RHS;
  140. if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
  141. m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
  142. return false;
  143. if (!SE.isSCEVable(LHS->getType()))
  144. return false;
  145. assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
  146. if (TrueSucc == FalseSucc)
  147. return false;
  148. return true;
  149. }
  150. static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
  151. ScalarEvolution &SE, ConditionInfo &Cond) {
  152. // Skip function with optsize.
  153. if (L.getHeader()->getParent()->hasOptSize())
  154. return false;
  155. // Split only innermost loop.
  156. if (!L.isInnermost())
  157. return false;
  158. // Check loop is in simplified form.
  159. if (!L.isLoopSimplifyForm())
  160. return false;
  161. // Check loop is in LCSSA form.
  162. if (!L.isLCSSAForm(DT))
  163. return false;
  164. // Skip loop that cannot be cloned.
  165. if (!L.isSafeToClone())
  166. return false;
  167. BasicBlock *ExitingBB = L.getExitingBlock();
  168. // Assumed only one exiting block.
  169. if (!ExitingBB)
  170. return false;
  171. BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
  172. if (!ExitingBI)
  173. return false;
  174. // Allowed only conditional branch with ICmp.
  175. if (!isProcessableCondBI(SE, ExitingBI))
  176. return false;
  177. // Check the condition is processable.
  178. ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
  179. if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
  180. return false;
  181. Cond.BI = ExitingBI;
  182. return true;
  183. }
  184. static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
  185. // If the conditional branch splits a loop into two halves, we could
  186. // generally say it is profitable.
  187. //
  188. // ToDo: Add more profitable cases here.
  189. // Check this branch causes diamond CFG.
  190. BasicBlock *Succ0 = BI->getSuccessor(0);
  191. BasicBlock *Succ1 = BI->getSuccessor(1);
  192. BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
  193. BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
  194. if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
  195. return false;
  196. // ToDo: Calculate each successor's instruction cost.
  197. return true;
  198. }
  199. static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
  200. ConditionInfo &ExitingCond,
  201. ConditionInfo &SplitCandidateCond) {
  202. for (auto *BB : L.blocks()) {
  203. // Skip condition of backedge.
  204. if (L.getLoopLatch() == BB)
  205. continue;
  206. auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
  207. if (!BI)
  208. continue;
  209. // Check conditional branch with ICmp.
  210. if (!isProcessableCondBI(SE, BI))
  211. continue;
  212. // Skip loop invariant condition.
  213. if (L.isLoopInvariant(BI->getCondition()))
  214. continue;
  215. // Check the condition is processable.
  216. ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
  217. if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
  218. /*IsExitCond*/ false))
  219. continue;
  220. if (ExitingCond.BoundSCEV->getType() !=
  221. SplitCandidateCond.BoundSCEV->getType())
  222. continue;
  223. // After transformation, we assume the split condition of the pre-loop is
  224. // always true. In order to guarantee it, we need to check the start value
  225. // of the split cond AddRec satisfies the split condition.
  226. if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,
  227. SplitCandidateCond.AddRecSCEV->getStart(),
  228. SplitCandidateCond.BoundSCEV))
  229. continue;
  230. SplitCandidateCond.BI = BI;
  231. return BI;
  232. }
  233. return nullptr;
  234. }
  235. static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
  236. ScalarEvolution &SE, LPMUpdater &U) {
  237. ConditionInfo SplitCandidateCond;
  238. ConditionInfo ExitingCond;
  239. // Check we can split this loop's bound.
  240. if (!canSplitLoopBound(L, DT, SE, ExitingCond))
  241. return false;
  242. if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
  243. return false;
  244. if (!isProfitableToTransform(L, SplitCandidateCond.BI))
  245. return false;
  246. // Now, we have a split candidate. Let's build a form as below.
  247. // +--------------------+
  248. // | preheader |
  249. // | set up newbound |
  250. // +--------------------+
  251. // | /----------------\
  252. // +--------v----v------+ |
  253. // | header |---\ |
  254. // | with true condition| | |
  255. // +--------------------+ | |
  256. // | | |
  257. // +--------v-----------+ | |
  258. // | if.then.BB | | |
  259. // +--------------------+ | |
  260. // | | |
  261. // +--------v-----------<---/ |
  262. // | latch >----------/
  263. // | with newbound |
  264. // +--------------------+
  265. // |
  266. // +--------v-----------+
  267. // | preheader2 |--------------\
  268. // | if (AddRec i != | |
  269. // | org bound) | |
  270. // +--------------------+ |
  271. // | /----------------\ |
  272. // +--------v----v------+ | |
  273. // | header2 |---\ | |
  274. // | conditional branch | | | |
  275. // |with false condition| | | |
  276. // +--------------------+ | | |
  277. // | | | |
  278. // +--------v-----------+ | | |
  279. // | if.then.BB2 | | | |
  280. // +--------------------+ | | |
  281. // | | | |
  282. // +--------v-----------<---/ | |
  283. // | latch2 >----------/ |
  284. // | with org bound | |
  285. // +--------v-----------+ |
  286. // | |
  287. // | +---------------+ |
  288. // +--> exit <-------/
  289. // +---------------+
  290. // Let's create post loop.
  291. SmallVector<BasicBlock *, 8> PostLoopBlocks;
  292. Loop *PostLoop;
  293. ValueToValueMapTy VMap;
  294. BasicBlock *PreHeader = L.getLoopPreheader();
  295. BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
  296. PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
  297. ".split", &LI, &DT, PostLoopBlocks);
  298. remapInstructionsInBlocks(PostLoopBlocks, VMap);
  299. BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
  300. IRBuilder<> Builder(&PostLoopPreHeader->front());
  301. // Update phi nodes in header of post-loop.
  302. bool isExitingLatch =
  303. (L.getExitingBlock() == L.getLoopLatch()) ? true : false;
  304. Value *ExitingCondLCSSAPhi = nullptr;
  305. for (PHINode &PN : L.getHeader()->phis()) {
  306. // Create LCSSA phi node in preheader of post-loop.
  307. PHINode *LCSSAPhi =
  308. Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
  309. LCSSAPhi->setDebugLoc(PN.getDebugLoc());
  310. // If the exiting block is loop latch, the phi does not have the update at
  311. // last iteration. In this case, update lcssa phi with value from backedge.
  312. LCSSAPhi->addIncoming(
  313. isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN,
  314. L.getExitingBlock());
  315. // Update the start value of phi node in post-loop with the LCSSA phi node.
  316. PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
  317. PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi);
  318. // Find PHI with exiting condition from pre-loop. The PHI should be
  319. // SCEVAddRecExpr and have same incoming value from backedge with
  320. // ExitingCond.
  321. if (!SE.isSCEVable(PN.getType()))
  322. continue;
  323. const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
  324. if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
  325. PN.getIncomingValueForBlock(L.getLoopLatch()))
  326. ExitingCondLCSSAPhi = LCSSAPhi;
  327. }
  328. // Add conditional branch to check we can skip post-loop in its preheader.
  329. Instruction *OrigBI = PostLoopPreHeader->getTerminator();
  330. ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
  331. Value *Cond =
  332. Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue);
  333. Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
  334. OrigBI->eraseFromParent();
  335. // Create new loop bound and add it into preheader of pre-loop.
  336. const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
  337. const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
  338. NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
  339. ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
  340. : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
  341. SCEVExpander Expander(
  342. SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
  343. Instruction *InsertPt = SplitLoopPH->getTerminator();
  344. Value *NewBoundValue =
  345. Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
  346. NewBoundValue->setName("new.bound");
  347. // Replace exiting bound value of pre-loop NewBound.
  348. ExitingCond.ICmp->setOperand(1, NewBoundValue);
  349. // Replace SplitCandidateCond.BI's condition of pre-loop by True.
  350. LLVMContext &Context = PreHeader->getContext();
  351. SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
  352. // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
  353. BranchInst *ClonedSplitCandidateBI =
  354. cast<BranchInst>(VMap[SplitCandidateCond.BI]);
  355. ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
  356. // Replace exit branch target of pre-loop by post-loop's preheader.
  357. if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
  358. ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
  359. else
  360. ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
  361. // Update phi node in exit block of post-loop.
  362. Builder.SetInsertPoint(&PostLoopPreHeader->front());
  363. for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
  364. for (auto i : seq<int>(0, PN.getNumOperands())) {
  365. // Check incoming block is pre-loop's exiting block.
  366. if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
  367. Value *IncomingValue = PN.getIncomingValue(i);
  368. // Create LCSSA phi node for incoming value.
  369. PHINode *LCSSAPhi =
  370. Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
  371. LCSSAPhi->setDebugLoc(PN.getDebugLoc());
  372. LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i));
  373. // Replace pre-loop's exiting block by post-loop's preheader.
  374. PN.setIncomingBlock(i, PostLoopPreHeader);
  375. // Replace incoming value by LCSSAPhi.
  376. PN.setIncomingValue(i, LCSSAPhi);
  377. // Add a new incoming value with post-loop's exiting block.
  378. PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock());
  379. }
  380. }
  381. }
  382. // Update dominator tree.
  383. DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
  384. DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
  385. // Invalidate cached SE information.
  386. SE.forgetLoop(&L);
  387. // Canonicalize loops.
  388. simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
  389. simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
  390. // Add new post-loop to loop pass manager.
  391. U.addSiblingLoops(PostLoop);
  392. return true;
  393. }
  394. PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
  395. LoopStandardAnalysisResults &AR,
  396. LPMUpdater &U) {
  397. Function &F = *L.getHeader()->getParent();
  398. (void)F;
  399. LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
  400. << "\n");
  401. if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
  402. return PreservedAnalyses::all();
  403. assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
  404. AR.LI.verify(AR.DT);
  405. return getLoopPassPreservedAnalyses();
  406. }
  407. } // end namespace llvm