LoopPredication.cpp 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324
  1. //===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // The LoopPredication pass tries to convert loop variant range checks to loop
  10. // invariant by widening checks across loop iterations. For example, it will
  11. // convert
  12. //
  13. // for (i = 0; i < n; i++) {
  14. // guard(i < len);
  15. // ...
  16. // }
  17. //
  18. // to
  19. //
  20. // for (i = 0; i < n; i++) {
  21. // guard(n - 1 < len);
  22. // ...
  23. // }
  24. //
  25. // After this transformation the condition of the guard is loop invariant, so
  26. // loop-unswitch can later unswitch the loop by this condition which basically
  27. // predicates the loop by the widened condition:
  28. //
  29. // if (n - 1 < len)
  30. // for (i = 0; i < n; i++) {
  31. // ...
  32. // }
  33. // else
  34. // deoptimize
  35. //
  36. // It's tempting to rely on SCEV here, but it has proven to be problematic.
  37. // Generally the facts SCEV provides about the increment step of add
  38. // recurrences are true if the backedge of the loop is taken, which implicitly
  39. // assumes that the guard doesn't fail. Using these facts to optimize the
  40. // guard results in a circular logic where the guard is optimized under the
  41. // assumption that it never fails.
  42. //
  43. // For example, in the loop below the induction variable will be marked as nuw
  44. // basing on the guard. Basing on nuw the guard predicate will be considered
  45. // monotonic. Given a monotonic condition it's tempting to replace the induction
  46. // variable in the condition with its value on the last iteration. But this
  47. // transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
  48. //
  49. // for (int i = b; i != e; i++)
  50. // guard(i u< len)
  51. //
  52. // One of the ways to reason about this problem is to use an inductive proof
  53. // approach. Given the loop:
  54. //
  55. // if (B(0)) {
  56. // do {
  57. // I = PHI(0, I.INC)
  58. // I.INC = I + Step
  59. // guard(G(I));
  60. // } while (B(I));
  61. // }
  62. //
  63. // where B(x) and G(x) are predicates that map integers to booleans, we want a
  64. // loop invariant expression M such the following program has the same semantics
  65. // as the above:
  66. //
  67. // if (B(0)) {
  68. // do {
  69. // I = PHI(0, I.INC)
  70. // I.INC = I + Step
  71. // guard(G(0) && M);
  72. // } while (B(I));
  73. // }
  74. //
  75. // One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
  76. //
  77. // Informal proof that the transformation above is correct:
  78. //
  79. // By the definition of guards we can rewrite the guard condition to:
  80. // G(I) && G(0) && M
  81. //
  82. // Let's prove that for each iteration of the loop:
  83. // G(0) && M => G(I)
  84. // And the condition above can be simplified to G(Start) && M.
  85. //
  86. // Induction base.
  87. // G(0) && M => G(0)
  88. //
  89. // Induction step. Assuming G(0) && M => G(I) on the subsequent
  90. // iteration:
  91. //
  92. // B(I) is true because it's the backedge condition.
  93. // G(I) is true because the backedge is guarded by this condition.
  94. //
  95. // So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
  96. //
  97. // Note that we can use anything stronger than M, i.e. any condition which
  98. // implies M.
  99. //
  100. // When S = 1 (i.e. forward iterating loop), the transformation is supported
  101. // when:
  102. // * The loop has a single latch with the condition of the form:
  103. // B(X) = latchStart + X <pred> latchLimit,
  104. // where <pred> is u<, u<=, s<, or s<=.
  105. // * The guard condition is of the form
  106. // G(X) = guardStart + X u< guardLimit
  107. //
  108. // For the ult latch comparison case M is:
  109. // forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
  110. // guardStart + X + 1 u< guardLimit
  111. //
  112. // The only way the antecedent can be true and the consequent can be false is
  113. // if
  114. // X == guardLimit - 1 - guardStart
  115. // (and guardLimit is non-zero, but we won't use this latter fact).
  116. // If X == guardLimit - 1 - guardStart then the second half of the antecedent is
  117. // latchStart + guardLimit - 1 - guardStart u< latchLimit
  118. // and its negation is
  119. // latchStart + guardLimit - 1 - guardStart u>= latchLimit
  120. //
  121. // In other words, if
  122. // latchLimit u<= latchStart + guardLimit - 1 - guardStart
  123. // then:
  124. // (the ranges below are written in ConstantRange notation, where [A, B) is the
  125. // set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
  126. //
  127. // forall X . guardStart + X u< guardLimit &&
  128. // latchStart + X u< latchLimit =>
  129. // guardStart + X + 1 u< guardLimit
  130. // == forall X . guardStart + X u< guardLimit &&
  131. // latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
  132. // guardStart + X + 1 u< guardLimit
  133. // == forall X . (guardStart + X) in [0, guardLimit) &&
  134. // (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
  135. // (guardStart + X + 1) in [0, guardLimit)
  136. // == forall X . X in [-guardStart, guardLimit - guardStart) &&
  137. // X in [-latchStart, guardLimit - 1 - guardStart) =>
  138. // X in [-guardStart - 1, guardLimit - guardStart - 1)
  139. // == true
  140. //
  141. // So the widened condition is:
  142. // guardStart u< guardLimit &&
  143. // latchStart + guardLimit - 1 - guardStart u>= latchLimit
  144. // Similarly for ule condition the widened condition is:
  145. // guardStart u< guardLimit &&
  146. // latchStart + guardLimit - 1 - guardStart u> latchLimit
  147. // For slt condition the widened condition is:
  148. // guardStart u< guardLimit &&
  149. // latchStart + guardLimit - 1 - guardStart s>= latchLimit
  150. // For sle condition the widened condition is:
  151. // guardStart u< guardLimit &&
  152. // latchStart + guardLimit - 1 - guardStart s> latchLimit
  153. //
  154. // When S = -1 (i.e. reverse iterating loop), the transformation is supported
  155. // when:
  156. // * The loop has a single latch with the condition of the form:
  157. // B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
  158. // * The guard condition is of the form
  159. // G(X) = X - 1 u< guardLimit
  160. //
  161. // For the ugt latch comparison case M is:
  162. // forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
  163. //
  164. // The only way the antecedent can be true and the consequent can be false is if
  165. // X == 1.
  166. // If X == 1 then the second half of the antecedent is
  167. // 1 u> latchLimit, and its negation is latchLimit u>= 1.
  168. //
  169. // So the widened condition is:
  170. // guardStart u< guardLimit && latchLimit u>= 1.
  171. // Similarly for sgt condition the widened condition is:
  172. // guardStart u< guardLimit && latchLimit s>= 1.
  173. // For uge condition the widened condition is:
  174. // guardStart u< guardLimit && latchLimit u> 1.
  175. // For sge condition the widened condition is:
  176. // guardStart u< guardLimit && latchLimit s> 1.
  177. //===----------------------------------------------------------------------===//
  178. #include "llvm/Transforms/Scalar/LoopPredication.h"
  179. #include "llvm/ADT/Statistic.h"
  180. #include "llvm/Analysis/AliasAnalysis.h"
  181. #include "llvm/Analysis/BranchProbabilityInfo.h"
  182. #include "llvm/Analysis/GuardUtils.h"
  183. #include "llvm/Analysis/LoopInfo.h"
  184. #include "llvm/Analysis/LoopPass.h"
  185. #include "llvm/Analysis/MemorySSA.h"
  186. #include "llvm/Analysis/MemorySSAUpdater.h"
  187. #include "llvm/Analysis/ScalarEvolution.h"
  188. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  189. #include "llvm/IR/Function.h"
  190. #include "llvm/IR/IntrinsicInst.h"
  191. #include "llvm/IR/Module.h"
  192. #include "llvm/IR/PatternMatch.h"
  193. #include "llvm/IR/ProfDataUtils.h"
  194. #include "llvm/InitializePasses.h"
  195. #include "llvm/Pass.h"
  196. #include "llvm/Support/CommandLine.h"
  197. #include "llvm/Support/Debug.h"
  198. #include "llvm/Transforms/Scalar.h"
  199. #include "llvm/Transforms/Utils/GuardUtils.h"
  200. #include "llvm/Transforms/Utils/Local.h"
  201. #include "llvm/Transforms/Utils/LoopUtils.h"
  202. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  203. #include <optional>
  204. #define DEBUG_TYPE "loop-predication"
  205. STATISTIC(TotalConsidered, "Number of guards considered");
  206. STATISTIC(TotalWidened, "Number of checks widened");
  207. using namespace llvm;
  208. static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
  209. cl::Hidden, cl::init(true));
  210. static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
  211. cl::Hidden, cl::init(true));
  212. static cl::opt<bool>
  213. SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
  214. cl::Hidden, cl::init(false));
  215. // This is the scale factor for the latch probability. We use this during
  216. // profitability analysis to find other exiting blocks that have a much higher
  217. // probability of exiting the loop instead of loop exiting via latch.
  218. // This value should be greater than 1 for a sane profitability check.
  219. static cl::opt<float> LatchExitProbabilityScale(
  220. "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
  221. cl::desc("scale factor for the latch probability. Value should be greater "
  222. "than 1. Lower values are ignored"));
  223. static cl::opt<bool> PredicateWidenableBranchGuards(
  224. "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden,
  225. cl::desc("Whether or not we should predicate guards "
  226. "expressed as widenable branches to deoptimize blocks"),
  227. cl::init(true));
  228. static cl::opt<bool> InsertAssumesOfPredicatedGuardsConditions(
  229. "loop-predication-insert-assumes-of-predicated-guards-conditions",
  230. cl::Hidden,
  231. cl::desc("Whether or not we should insert assumes of conditions of "
  232. "predicated guards"),
  233. cl::init(true));
  234. namespace {
  235. /// Represents an induction variable check:
  236. /// icmp Pred, <induction variable>, <loop invariant limit>
  237. struct LoopICmp {
  238. ICmpInst::Predicate Pred;
  239. const SCEVAddRecExpr *IV;
  240. const SCEV *Limit;
  241. LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
  242. const SCEV *Limit)
  243. : Pred(Pred), IV(IV), Limit(Limit) {}
  244. LoopICmp() = default;
  245. void dump() {
  246. dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
  247. << ", Limit = " << *Limit << "\n";
  248. }
  249. };
  250. class LoopPredication {
  251. AliasAnalysis *AA;
  252. DominatorTree *DT;
  253. ScalarEvolution *SE;
  254. LoopInfo *LI;
  255. MemorySSAUpdater *MSSAU;
  256. Loop *L;
  257. const DataLayout *DL;
  258. BasicBlock *Preheader;
  259. LoopICmp LatchCheck;
  260. bool isSupportedStep(const SCEV* Step);
  261. std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
  262. std::optional<LoopICmp> parseLoopLatchICmp();
  263. /// Return an insertion point suitable for inserting a safe to speculate
  264. /// instruction whose only user will be 'User' which has operands 'Ops'. A
  265. /// trivial result would be the at the User itself, but we try to return a
  266. /// loop invariant location if possible.
  267. Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
  268. /// Same as above, *except* that this uses the SCEV definition of invariant
  269. /// which is that an expression *can be made* invariant via SCEVExpander.
  270. /// Thus, this version is only suitable for finding an insert point to be be
  271. /// passed to SCEVExpander!
  272. Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User,
  273. ArrayRef<const SCEV *> Ops);
  274. /// Return true if the value is known to produce a single fixed value across
  275. /// all iterations on which it executes. Note that this does not imply
  276. /// speculation safety. That must be established separately.
  277. bool isLoopInvariantValue(const SCEV* S);
  278. Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
  279. ICmpInst::Predicate Pred, const SCEV *LHS,
  280. const SCEV *RHS);
  281. std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI,
  282. SCEVExpander &Expander,
  283. Instruction *Guard);
  284. std::optional<Value *>
  285. widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
  286. SCEVExpander &Expander,
  287. Instruction *Guard);
  288. std::optional<Value *>
  289. widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
  290. SCEVExpander &Expander,
  291. Instruction *Guard);
  292. unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition,
  293. SCEVExpander &Expander, Instruction *Guard);
  294. bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
  295. bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
  296. // If the loop always exits through another block in the loop, we should not
  297. // predicate based on the latch check. For example, the latch check can be a
  298. // very coarse grained check and there can be more fine grained exit checks
  299. // within the loop.
  300. bool isLoopProfitableToPredicate();
  301. bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
  302. public:
  303. LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE,
  304. LoopInfo *LI, MemorySSAUpdater *MSSAU)
  305. : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
  306. bool runOnLoop(Loop *L);
  307. };
  308. class LoopPredicationLegacyPass : public LoopPass {
  309. public:
  310. static char ID;
  311. LoopPredicationLegacyPass() : LoopPass(ID) {
  312. initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry());
  313. }
  314. void getAnalysisUsage(AnalysisUsage &AU) const override {
  315. AU.addRequired<BranchProbabilityInfoWrapperPass>();
  316. getLoopAnalysisUsage(AU);
  317. AU.addPreserved<MemorySSAWrapperPass>();
  318. }
  319. bool runOnLoop(Loop *L, LPPassManager &LPM) override {
  320. if (skipLoop(L))
  321. return false;
  322. auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  323. auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  324. auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  325. auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
  326. std::unique_ptr<MemorySSAUpdater> MSSAU;
  327. if (MSSAWP)
  328. MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
  329. auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
  330. LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr);
  331. return LP.runOnLoop(L);
  332. }
  333. };
  334. char LoopPredicationLegacyPass::ID = 0;
  335. } // end namespace
  336. INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
  337. "Loop predication", false, false)
  338. INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
  339. INITIALIZE_PASS_DEPENDENCY(LoopPass)
  340. INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
  341. "Loop predication", false, false)
  342. Pass *llvm::createLoopPredicationPass() {
  343. return new LoopPredicationLegacyPass();
  344. }
  345. PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
  346. LoopStandardAnalysisResults &AR,
  347. LPMUpdater &U) {
  348. std::unique_ptr<MemorySSAUpdater> MSSAU;
  349. if (AR.MSSA)
  350. MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA);
  351. LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI,
  352. MSSAU ? MSSAU.get() : nullptr);
  353. if (!LP.runOnLoop(&L))
  354. return PreservedAnalyses::all();
  355. auto PA = getLoopPassPreservedAnalyses();
  356. if (AR.MSSA)
  357. PA.preserve<MemorySSAAnalysis>();
  358. return PA;
  359. }
  360. std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) {
  361. auto Pred = ICI->getPredicate();
  362. auto *LHS = ICI->getOperand(0);
  363. auto *RHS = ICI->getOperand(1);
  364. const SCEV *LHSS = SE->getSCEV(LHS);
  365. if (isa<SCEVCouldNotCompute>(LHSS))
  366. return std::nullopt;
  367. const SCEV *RHSS = SE->getSCEV(RHS);
  368. if (isa<SCEVCouldNotCompute>(RHSS))
  369. return std::nullopt;
  370. // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
  371. if (SE->isLoopInvariant(LHSS, L)) {
  372. std::swap(LHS, RHS);
  373. std::swap(LHSS, RHSS);
  374. Pred = ICmpInst::getSwappedPredicate(Pred);
  375. }
  376. const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
  377. if (!AR || AR->getLoop() != L)
  378. return std::nullopt;
  379. return LoopICmp(Pred, AR, RHSS);
  380. }
  381. Value *LoopPredication::expandCheck(SCEVExpander &Expander,
  382. Instruction *Guard,
  383. ICmpInst::Predicate Pred, const SCEV *LHS,
  384. const SCEV *RHS) {
  385. Type *Ty = LHS->getType();
  386. assert(Ty == RHS->getType() && "expandCheck operands have different types?");
  387. if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) {
  388. IRBuilder<> Builder(Guard);
  389. if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
  390. return Builder.getTrue();
  391. if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
  392. LHS, RHS))
  393. return Builder.getFalse();
  394. }
  395. Value *LHSV =
  396. Expander.expandCodeFor(LHS, Ty, findInsertPt(Expander, Guard, {LHS}));
  397. Value *RHSV =
  398. Expander.expandCodeFor(RHS, Ty, findInsertPt(Expander, Guard, {RHS}));
  399. IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV}));
  400. return Builder.CreateICmp(Pred, LHSV, RHSV);
  401. }
  402. // Returns true if its safe to truncate the IV to RangeCheckType.
  403. // When the IV type is wider than the range operand type, we can still do loop
  404. // predication, by generating SCEVs for the range and latch that are of the
  405. // same type. We achieve this by generating a SCEV truncate expression for the
  406. // latch IV. This is done iff truncation of the IV is a safe operation,
  407. // without loss of information.
  408. // Another way to achieve this is by generating a wider type SCEV for the
  409. // range check operand, however, this needs a more involved check that
  410. // operands do not overflow. This can lead to loss of information when the
  411. // range operand is of the form: add i32 %offset, %iv. We need to prove that
  412. // sext(x + y) is same as sext(x) + sext(y).
  413. // This function returns true if we can safely represent the IV type in
  414. // the RangeCheckType without loss of information.
  415. static bool isSafeToTruncateWideIVType(const DataLayout &DL,
  416. ScalarEvolution &SE,
  417. const LoopICmp LatchCheck,
  418. Type *RangeCheckType) {
  419. if (!EnableIVTruncation)
  420. return false;
  421. assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() >
  422. DL.getTypeSizeInBits(RangeCheckType).getFixedValue() &&
  423. "Expected latch check IV type to be larger than range check operand "
  424. "type!");
  425. // The start and end values of the IV should be known. This is to guarantee
  426. // that truncating the wide type will not lose information.
  427. auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
  428. auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
  429. if (!Limit || !Start)
  430. return false;
  431. // This check makes sure that the IV does not change sign during loop
  432. // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
  433. // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
  434. // IV wraps around, and the truncation of the IV would lose the range of
  435. // iterations between 2^32 and 2^64.
  436. if (!SE.getMonotonicPredicateType(LatchCheck.IV, LatchCheck.Pred))
  437. return false;
  438. // The active bits should be less than the bits in the RangeCheckType. This
  439. // guarantees that truncating the latch check to RangeCheckType is a safe
  440. // operation.
  441. auto RangeCheckTypeBitSize =
  442. DL.getTypeSizeInBits(RangeCheckType).getFixedValue();
  443. return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
  444. Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
  445. }
  446. // Return an LoopICmp describing a latch check equivlent to LatchCheck but with
  447. // the requested type if safe to do so. May involve the use of a new IV.
  448. static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
  449. ScalarEvolution &SE,
  450. const LoopICmp LatchCheck,
  451. Type *RangeCheckType) {
  452. auto *LatchType = LatchCheck.IV->getType();
  453. if (RangeCheckType == LatchType)
  454. return LatchCheck;
  455. // For now, bail out if latch type is narrower than range type.
  456. if (DL.getTypeSizeInBits(LatchType).getFixedValue() <
  457. DL.getTypeSizeInBits(RangeCheckType).getFixedValue())
  458. return std::nullopt;
  459. if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType))
  460. return std::nullopt;
  461. // We can now safely identify the truncated version of the IV and limit for
  462. // RangeCheckType.
  463. LoopICmp NewLatchCheck;
  464. NewLatchCheck.Pred = LatchCheck.Pred;
  465. NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
  466. SE.getTruncateExpr(LatchCheck.IV, RangeCheckType));
  467. if (!NewLatchCheck.IV)
  468. return std::nullopt;
  469. NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType);
  470. LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
  471. << "can be represented as range check type:"
  472. << *RangeCheckType << "\n");
  473. LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
  474. LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
  475. return NewLatchCheck;
  476. }
  477. bool LoopPredication::isSupportedStep(const SCEV* Step) {
  478. return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
  479. }
  480. Instruction *LoopPredication::findInsertPt(Instruction *Use,
  481. ArrayRef<Value*> Ops) {
  482. for (Value *Op : Ops)
  483. if (!L->isLoopInvariant(Op))
  484. return Use;
  485. return Preheader->getTerminator();
  486. }
  487. Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander,
  488. Instruction *Use,
  489. ArrayRef<const SCEV *> Ops) {
  490. // Subtlety: SCEV considers things to be invariant if the value produced is
  491. // the same across iterations. This is not the same as being able to
  492. // evaluate outside the loop, which is what we actually need here.
  493. for (const SCEV *Op : Ops)
  494. if (!SE->isLoopInvariant(Op, L) ||
  495. !Expander.isSafeToExpandAt(Op, Preheader->getTerminator()))
  496. return Use;
  497. return Preheader->getTerminator();
  498. }
  499. bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
  500. // Handling expressions which produce invariant results, but *haven't* yet
  501. // been removed from the loop serves two important purposes.
  502. // 1) Most importantly, it resolves a pass ordering cycle which would
  503. // otherwise need us to iteration licm, loop-predication, and either
  504. // loop-unswitch or loop-peeling to make progress on examples with lots of
  505. // predicable range checks in a row. (Since, in the general case, we can't
  506. // hoist the length checks until the dominating checks have been discharged
  507. // as we can't prove doing so is safe.)
  508. // 2) As a nice side effect, this exposes the value of peeling or unswitching
  509. // much more obviously in the IR. Otherwise, the cost modeling for other
  510. // transforms would end up needing to duplicate all of this logic to model a
  511. // check which becomes predictable based on a modeled peel or unswitch.
  512. //
  513. // The cost of doing so in the worst case is an extra fill from the stack in
  514. // the loop to materialize the loop invariant test value instead of checking
  515. // against the original IV which is presumable in a register inside the loop.
  516. // Such cases are presumably rare, and hint at missing oppurtunities for
  517. // other passes.
  518. if (SE->isLoopInvariant(S, L))
  519. // Note: This the SCEV variant, so the original Value* may be within the
  520. // loop even though SCEV has proven it is loop invariant.
  521. return true;
  522. // Handle a particular important case which SCEV doesn't yet know about which
  523. // shows up in range checks on arrays with immutable lengths.
  524. // TODO: This should be sunk inside SCEV.
  525. if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
  526. if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
  527. if (LI->isUnordered() && L->hasLoopInvariantOperands(LI))
  528. if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) ||
  529. LI->hasMetadata(LLVMContext::MD_invariant_load))
  530. return true;
  531. return false;
  532. }
  533. std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
  534. LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
  535. Instruction *Guard) {
  536. auto *Ty = RangeCheck.IV->getType();
  537. // Generate the widened condition for the forward loop:
  538. // guardStart u< guardLimit &&
  539. // latchLimit <pred> guardLimit - 1 - guardStart + latchStart
  540. // where <pred> depends on the latch condition predicate. See the file
  541. // header comment for the reasoning.
  542. // guardLimit - guardStart + latchStart - 1
  543. const SCEV *GuardStart = RangeCheck.IV->getStart();
  544. const SCEV *GuardLimit = RangeCheck.Limit;
  545. const SCEV *LatchStart = LatchCheck.IV->getStart();
  546. const SCEV *LatchLimit = LatchCheck.Limit;
  547. // Subtlety: We need all the values to be *invariant* across all iterations,
  548. // but we only need to check expansion safety for those which *aren't*
  549. // already guaranteed to dominate the guard.
  550. if (!isLoopInvariantValue(GuardStart) ||
  551. !isLoopInvariantValue(GuardLimit) ||
  552. !isLoopInvariantValue(LatchStart) ||
  553. !isLoopInvariantValue(LatchLimit)) {
  554. LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
  555. return std::nullopt;
  556. }
  557. if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
  558. !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
  559. LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
  560. return std::nullopt;
  561. }
  562. // guardLimit - guardStart + latchStart - 1
  563. const SCEV *RHS =
  564. SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
  565. SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
  566. auto LimitCheckPred =
  567. ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred);
  568. LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
  569. LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
  570. LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
  571. auto *LimitCheck =
  572. expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
  573. auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
  574. GuardStart, GuardLimit);
  575. IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
  576. return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
  577. }
  578. std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
  579. LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
  580. Instruction *Guard) {
  581. auto *Ty = RangeCheck.IV->getType();
  582. const SCEV *GuardStart = RangeCheck.IV->getStart();
  583. const SCEV *GuardLimit = RangeCheck.Limit;
  584. const SCEV *LatchStart = LatchCheck.IV->getStart();
  585. const SCEV *LatchLimit = LatchCheck.Limit;
  586. // Subtlety: We need all the values to be *invariant* across all iterations,
  587. // but we only need to check expansion safety for those which *aren't*
  588. // already guaranteed to dominate the guard.
  589. if (!isLoopInvariantValue(GuardStart) ||
  590. !isLoopInvariantValue(GuardLimit) ||
  591. !isLoopInvariantValue(LatchStart) ||
  592. !isLoopInvariantValue(LatchLimit)) {
  593. LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
  594. return std::nullopt;
  595. }
  596. if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
  597. !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
  598. LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
  599. return std::nullopt;
  600. }
  601. // The decrement of the latch check IV should be the same as the
  602. // rangeCheckIV.
  603. auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
  604. if (RangeCheck.IV != PostDecLatchCheckIV) {
  605. LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
  606. << *PostDecLatchCheckIV
  607. << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
  608. return std::nullopt;
  609. }
  610. // Generate the widened condition for CountDownLoop:
  611. // guardStart u< guardLimit &&
  612. // latchLimit <pred> 1.
  613. // See the header comment for reasoning of the checks.
  614. auto LimitCheckPred =
  615. ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred);
  616. auto *FirstIterationCheck = expandCheck(Expander, Guard,
  617. ICmpInst::ICMP_ULT,
  618. GuardStart, GuardLimit);
  619. auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
  620. SE->getOne(Ty));
  621. IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
  622. return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
  623. }
  624. static void normalizePredicate(ScalarEvolution *SE, Loop *L,
  625. LoopICmp& RC) {
  626. // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the
  627. // ULT/UGE form for ease of handling by our caller.
  628. if (ICmpInst::isEquality(RC.Pred) &&
  629. RC.IV->getStepRecurrence(*SE)->isOne() &&
  630. SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit))
  631. RC.Pred = RC.Pred == ICmpInst::ICMP_NE ?
  632. ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE;
  633. }
  634. /// If ICI can be widened to a loop invariant condition emits the loop
  635. /// invariant condition in the loop preheader and return it, otherwise
  636. /// returns std::nullopt.
  637. std::optional<Value *>
  638. LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
  639. Instruction *Guard) {
  640. LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
  641. LLVM_DEBUG(ICI->dump());
  642. // parseLoopStructure guarantees that the latch condition is:
  643. // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
  644. // We are looking for the range checks of the form:
  645. // i u< guardLimit
  646. auto RangeCheck = parseLoopICmp(ICI);
  647. if (!RangeCheck) {
  648. LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
  649. return std::nullopt;
  650. }
  651. LLVM_DEBUG(dbgs() << "Guard check:\n");
  652. LLVM_DEBUG(RangeCheck->dump());
  653. if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
  654. LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
  655. << RangeCheck->Pred << ")!\n");
  656. return std::nullopt;
  657. }
  658. auto *RangeCheckIV = RangeCheck->IV;
  659. if (!RangeCheckIV->isAffine()) {
  660. LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
  661. return std::nullopt;
  662. }
  663. auto *Step = RangeCheckIV->getStepRecurrence(*SE);
  664. // We cannot just compare with latch IV step because the latch and range IVs
  665. // may have different types.
  666. if (!isSupportedStep(Step)) {
  667. LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
  668. return std::nullopt;
  669. }
  670. auto *Ty = RangeCheckIV->getType();
  671. auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty);
  672. if (!CurrLatchCheckOpt) {
  673. LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
  674. "corresponding to range type: "
  675. << *Ty << "\n");
  676. return std::nullopt;
  677. }
  678. LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
  679. // At this point, the range and latch step should have the same type, but need
  680. // not have the same value (we support both 1 and -1 steps).
  681. assert(Step->getType() ==
  682. CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
  683. "Range and latch steps should be of same type!");
  684. if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
  685. LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
  686. return std::nullopt;
  687. }
  688. if (Step->isOne())
  689. return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
  690. Expander, Guard);
  691. else {
  692. assert(Step->isAllOnesValue() && "Step should be -1!");
  693. return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
  694. Expander, Guard);
  695. }
  696. }
  697. unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks,
  698. Value *Condition,
  699. SCEVExpander &Expander,
  700. Instruction *Guard) {
  701. unsigned NumWidened = 0;
  702. // The guard condition is expected to be in form of:
  703. // cond1 && cond2 && cond3 ...
  704. // Iterate over subconditions looking for icmp conditions which can be
  705. // widened across loop iterations. Widening these conditions remember the
  706. // resulting list of subconditions in Checks vector.
  707. SmallVector<Value *, 4> Worklist(1, Condition);
  708. SmallPtrSet<Value *, 4> Visited;
  709. Visited.insert(Condition);
  710. Value *WideableCond = nullptr;
  711. do {
  712. Value *Condition = Worklist.pop_back_val();
  713. Value *LHS, *RHS;
  714. using namespace llvm::PatternMatch;
  715. if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
  716. if (Visited.insert(LHS).second)
  717. Worklist.push_back(LHS);
  718. if (Visited.insert(RHS).second)
  719. Worklist.push_back(RHS);
  720. continue;
  721. }
  722. if (match(Condition,
  723. m_Intrinsic<Intrinsic::experimental_widenable_condition>())) {
  724. // Pick any, we don't care which
  725. WideableCond = Condition;
  726. continue;
  727. }
  728. if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
  729. if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander,
  730. Guard)) {
  731. Checks.push_back(*NewRangeCheck);
  732. NumWidened++;
  733. continue;
  734. }
  735. }
  736. // Save the condition as is if we can't widen it
  737. Checks.push_back(Condition);
  738. } while (!Worklist.empty());
  739. // At the moment, our matching logic for wideable conditions implicitly
  740. // assumes we preserve the form: (br (and Cond, WC())). FIXME
  741. // Note that if there were multiple calls to wideable condition in the
  742. // traversal, we only need to keep one, and which one is arbitrary.
  743. if (WideableCond)
  744. Checks.push_back(WideableCond);
  745. return NumWidened;
  746. }
  747. bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
  748. SCEVExpander &Expander) {
  749. LLVM_DEBUG(dbgs() << "Processing guard:\n");
  750. LLVM_DEBUG(Guard->dump());
  751. TotalConsidered++;
  752. SmallVector<Value *, 4> Checks;
  753. unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander,
  754. Guard);
  755. if (NumWidened == 0)
  756. return false;
  757. TotalWidened += NumWidened;
  758. // Emit the new guard condition
  759. IRBuilder<> Builder(findInsertPt(Guard, Checks));
  760. Value *AllChecks = Builder.CreateAnd(Checks);
  761. auto *OldCond = Guard->getOperand(0);
  762. Guard->setOperand(0, AllChecks);
  763. if (InsertAssumesOfPredicatedGuardsConditions) {
  764. Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard));
  765. Builder.CreateAssumption(OldCond);
  766. }
  767. RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
  768. LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
  769. return true;
  770. }
  771. bool LoopPredication::widenWidenableBranchGuardConditions(
  772. BranchInst *BI, SCEVExpander &Expander) {
  773. assert(isGuardAsWidenableBranch(BI) && "Must be!");
  774. LLVM_DEBUG(dbgs() << "Processing guard:\n");
  775. LLVM_DEBUG(BI->dump());
  776. Value *Cond, *WC;
  777. BasicBlock *IfTrueBB, *IfFalseBB;
  778. bool Parsed = parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB);
  779. assert(Parsed && "Must be able to parse widenable branch");
  780. (void)Parsed;
  781. TotalConsidered++;
  782. SmallVector<Value *, 4> Checks;
  783. unsigned NumWidened = collectChecks(Checks, BI->getCondition(),
  784. Expander, BI);
  785. if (NumWidened == 0)
  786. return false;
  787. TotalWidened += NumWidened;
  788. // Emit the new guard condition
  789. IRBuilder<> Builder(findInsertPt(BI, Checks));
  790. Value *AllChecks = Builder.CreateAnd(Checks);
  791. auto *OldCond = BI->getCondition();
  792. BI->setCondition(AllChecks);
  793. if (InsertAssumesOfPredicatedGuardsConditions) {
  794. Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
  795. Builder.CreateAssumption(Cond);
  796. }
  797. RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
  798. assert(isGuardAsWidenableBranch(BI) &&
  799. "Stopped being a guard after transform?");
  800. LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
  801. return true;
  802. }
  803. std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
  804. using namespace PatternMatch;
  805. BasicBlock *LoopLatch = L->getLoopLatch();
  806. if (!LoopLatch) {
  807. LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
  808. return std::nullopt;
  809. }
  810. auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
  811. if (!BI || !BI->isConditional()) {
  812. LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
  813. return std::nullopt;
  814. }
  815. BasicBlock *TrueDest = BI->getSuccessor(0);
  816. assert(
  817. (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) &&
  818. "One of the latch's destinations must be the header");
  819. auto *ICI = dyn_cast<ICmpInst>(BI->getCondition());
  820. if (!ICI) {
  821. LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n");
  822. return std::nullopt;
  823. }
  824. auto Result = parseLoopICmp(ICI);
  825. if (!Result) {
  826. LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
  827. return std::nullopt;
  828. }
  829. if (TrueDest != L->getHeader())
  830. Result->Pred = ICmpInst::getInversePredicate(Result->Pred);
  831. // Check affine first, so if it's not we don't try to compute the step
  832. // recurrence.
  833. if (!Result->IV->isAffine()) {
  834. LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
  835. return std::nullopt;
  836. }
  837. auto *Step = Result->IV->getStepRecurrence(*SE);
  838. if (!isSupportedStep(Step)) {
  839. LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
  840. return std::nullopt;
  841. }
  842. auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
  843. if (Step->isOne()) {
  844. return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
  845. Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
  846. } else {
  847. assert(Step->isAllOnesValue() && "Step should be -1!");
  848. return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
  849. Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
  850. }
  851. };
  852. normalizePredicate(SE, L, *Result);
  853. if (IsUnsupportedPredicate(Step, Result->Pred)) {
  854. LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
  855. << ")!\n");
  856. return std::nullopt;
  857. }
  858. return Result;
  859. }
  860. bool LoopPredication::isLoopProfitableToPredicate() {
  861. if (SkipProfitabilityChecks)
  862. return true;
  863. SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges;
  864. L->getExitEdges(ExitEdges);
  865. // If there is only one exiting edge in the loop, it is always profitable to
  866. // predicate the loop.
  867. if (ExitEdges.size() == 1)
  868. return true;
  869. // Calculate the exiting probabilities of all exiting edges from the loop,
  870. // starting with the LatchExitProbability.
  871. // Heuristic for profitability: If any of the exiting blocks' probability of
  872. // exiting the loop is larger than exiting through the latch block, it's not
  873. // profitable to predicate the loop.
  874. auto *LatchBlock = L->getLoopLatch();
  875. assert(LatchBlock && "Should have a single latch at this point!");
  876. auto *LatchTerm = LatchBlock->getTerminator();
  877. assert(LatchTerm->getNumSuccessors() == 2 &&
  878. "expected to be an exiting block with 2 succs!");
  879. unsigned LatchBrExitIdx =
  880. LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
  881. // We compute branch probabilities without BPI. We do not rely on BPI since
  882. // Loop predication is usually run in an LPM and BPI is only preserved
  883. // lossily within loop pass managers, while BPI has an inherent notion of
  884. // being complete for an entire function.
  885. // If the latch exits into a deoptimize or an unreachable block, do not
  886. // predicate on that latch check.
  887. auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
  888. if (isa<UnreachableInst>(LatchTerm) ||
  889. LatchExitBlock->getTerminatingDeoptimizeCall())
  890. return false;
  891. // Latch terminator has no valid profile data, so nothing to check
  892. // profitability on.
  893. if (!hasValidBranchWeightMD(*LatchTerm))
  894. return true;
  895. auto ComputeBranchProbability =
  896. [&](const BasicBlock *ExitingBlock,
  897. const BasicBlock *ExitBlock) -> BranchProbability {
  898. auto *Term = ExitingBlock->getTerminator();
  899. unsigned NumSucc = Term->getNumSuccessors();
  900. if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
  901. SmallVector<uint32_t> Weights;
  902. extractBranchWeights(ProfileData, Weights);
  903. uint64_t Numerator = 0, Denominator = 0;
  904. for (auto [i, Weight] : llvm::enumerate(Weights)) {
  905. if (Term->getSuccessor(i) == ExitBlock)
  906. Numerator += Weight;
  907. Denominator += Weight;
  908. }
  909. return BranchProbability::getBranchProbability(Numerator, Denominator);
  910. } else {
  911. assert(LatchBlock != ExitingBlock &&
  912. "Latch term should always have profile data!");
  913. // No profile data, so we choose the weight as 1/num_of_succ(Src)
  914. return BranchProbability::getBranchProbability(1, NumSucc);
  915. }
  916. };
  917. BranchProbability LatchExitProbability =
  918. ComputeBranchProbability(LatchBlock, LatchExitBlock);
  919. // Protect against degenerate inputs provided by the user. Providing a value
  920. // less than one, can invert the definition of profitable loop predication.
  921. float ScaleFactor = LatchExitProbabilityScale;
  922. if (ScaleFactor < 1) {
  923. LLVM_DEBUG(
  924. dbgs()
  925. << "Ignored user setting for loop-predication-latch-probability-scale: "
  926. << LatchExitProbabilityScale << "\n");
  927. LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
  928. ScaleFactor = 1.0;
  929. }
  930. const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
  931. for (const auto &ExitEdge : ExitEdges) {
  932. BranchProbability ExitingBlockProbability =
  933. ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
  934. // Some exiting edge has higher probability than the latch exiting edge.
  935. // No longer profitable to predicate.
  936. if (ExitingBlockProbability > LatchProbabilityThreshold)
  937. return false;
  938. }
  939. // We have concluded that the most probable way to exit from the
  940. // loop is through the latch (or there's no profile information and all
  941. // exits are equally likely).
  942. return true;
  943. }
  944. /// If we can (cheaply) find a widenable branch which controls entry into the
  945. /// loop, return it.
  946. static BranchInst *FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI) {
  947. // Walk back through any unconditional executed blocks and see if we can find
  948. // a widenable condition which seems to control execution of this loop. Note
  949. // that we predict that maythrow calls are likely untaken and thus that it's
  950. // profitable to widen a branch before a maythrow call with a condition
  951. // afterwards even though that may cause the slow path to run in a case where
  952. // it wouldn't have otherwise.
  953. BasicBlock *BB = L->getLoopPreheader();
  954. if (!BB)
  955. return nullptr;
  956. do {
  957. if (BasicBlock *Pred = BB->getSinglePredecessor())
  958. if (BB == Pred->getSingleSuccessor()) {
  959. BB = Pred;
  960. continue;
  961. }
  962. break;
  963. } while (true);
  964. if (BasicBlock *Pred = BB->getSinglePredecessor()) {
  965. auto *Term = Pred->getTerminator();
  966. Value *Cond, *WC;
  967. BasicBlock *IfTrueBB, *IfFalseBB;
  968. if (parseWidenableBranch(Term, Cond, WC, IfTrueBB, IfFalseBB) &&
  969. IfTrueBB == BB)
  970. return cast<BranchInst>(Term);
  971. }
  972. return nullptr;
  973. }
  974. /// Return the minimum of all analyzeable exit counts. This is an upper bound
  975. /// on the actual exit count. If there are not at least two analyzeable exits,
  976. /// returns SCEVCouldNotCompute.
  977. static const SCEV *getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE,
  978. DominatorTree &DT,
  979. Loop *L) {
  980. SmallVector<BasicBlock *, 16> ExitingBlocks;
  981. L->getExitingBlocks(ExitingBlocks);
  982. SmallVector<const SCEV *, 4> ExitCounts;
  983. for (BasicBlock *ExitingBB : ExitingBlocks) {
  984. const SCEV *ExitCount = SE.getExitCount(L, ExitingBB);
  985. if (isa<SCEVCouldNotCompute>(ExitCount))
  986. continue;
  987. assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
  988. "We should only have known counts for exiting blocks that "
  989. "dominate latch!");
  990. ExitCounts.push_back(ExitCount);
  991. }
  992. if (ExitCounts.size() < 2)
  993. return SE.getCouldNotCompute();
  994. return SE.getUMinFromMismatchedTypes(ExitCounts);
  995. }
  996. /// This implements an analogous, but entirely distinct transform from the main
  997. /// loop predication transform. This one is phrased in terms of using a
  998. /// widenable branch *outside* the loop to allow us to simplify loop exits in a
  999. /// following loop. This is close in spirit to the IndVarSimplify transform
  1000. /// of the same name, but is materially different widening loosens legality
  1001. /// sharply.
  1002. bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
  1003. // The transformation performed here aims to widen a widenable condition
  1004. // above the loop such that all analyzeable exit leading to deopt are dead.
  1005. // It assumes that the latch is the dominant exit for profitability and that
  1006. // exits branching to deoptimizing blocks are rarely taken. It relies on the
  1007. // semantics of widenable expressions for legality. (i.e. being able to fall
  1008. // down the widenable path spuriously allows us to ignore exit order,
  1009. // unanalyzeable exits, side effects, exceptional exits, and other challenges
  1010. // which restrict the applicability of the non-WC based version of this
  1011. // transform in IndVarSimplify.)
  1012. //
  1013. // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may
  1014. // imply flags on the expression being hoisted and inserting new uses (flags
  1015. // are only correct for current uses). The result is that we may be
  1016. // inserting a branch on the value which can be either poison or undef. In
  1017. // this case, the branch can legally go either way; we just need to avoid
  1018. // introducing UB. This is achieved through the use of the freeze
  1019. // instruction.
  1020. SmallVector<BasicBlock *, 16> ExitingBlocks;
  1021. L->getExitingBlocks(ExitingBlocks);
  1022. if (ExitingBlocks.empty())
  1023. return false; // Nothing to do.
  1024. auto *Latch = L->getLoopLatch();
  1025. if (!Latch)
  1026. return false;
  1027. auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, *LI);
  1028. if (!WidenableBR)
  1029. return false;
  1030. const SCEV *LatchEC = SE->getExitCount(L, Latch);
  1031. if (isa<SCEVCouldNotCompute>(LatchEC))
  1032. return false; // profitability - want hot exit in analyzeable set
  1033. // At this point, we have found an analyzeable latch, and a widenable
  1034. // condition above the loop. If we have a widenable exit within the loop
  1035. // (for which we can't compute exit counts), drop the ability to further
  1036. // widen so that we gain ability to analyze it's exit count and perform this
  1037. // transform. TODO: It'd be nice to know for sure the exit became
  1038. // analyzeable after dropping widenability.
  1039. bool ChangedLoop = false;
  1040. for (auto *ExitingBB : ExitingBlocks) {
  1041. if (LI->getLoopFor(ExitingBB) != L)
  1042. continue;
  1043. auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
  1044. if (!BI)
  1045. continue;
  1046. Use *Cond, *WC;
  1047. BasicBlock *IfTrueBB, *IfFalseBB;
  1048. if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) &&
  1049. L->contains(IfTrueBB)) {
  1050. WC->set(ConstantInt::getTrue(IfTrueBB->getContext()));
  1051. ChangedLoop = true;
  1052. }
  1053. }
  1054. if (ChangedLoop)
  1055. SE->forgetLoop(L);
  1056. // The use of umin(all analyzeable exits) instead of latch is subtle, but
  1057. // important for profitability. We may have a loop which hasn't been fully
  1058. // canonicalized just yet. If the exit we chose to widen is provably never
  1059. // taken, we want the widened form to *also* be provably never taken. We
  1060. // can't guarantee this as a current unanalyzeable exit may later become
  1061. // analyzeable, but we can at least avoid the obvious cases.
  1062. const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L);
  1063. if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
  1064. !SE->isLoopInvariant(MinEC, L) ||
  1065. !Rewriter.isSafeToExpandAt(MinEC, WidenableBR))
  1066. return ChangedLoop;
  1067. // Subtlety: We need to avoid inserting additional uses of the WC. We know
  1068. // that it can only have one transitive use at the moment, and thus moving
  1069. // that use to just before the branch and inserting code before it and then
  1070. // modifying the operand is legal.
  1071. auto *IP = cast<Instruction>(WidenableBR->getCondition());
  1072. // Here we unconditionally modify the IR, so after this point we should return
  1073. // only `true`!
  1074. IP->moveBefore(WidenableBR);
  1075. if (MSSAU)
  1076. if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(IP))
  1077. MSSAU->moveToPlace(MUD, WidenableBR->getParent(),
  1078. MemorySSA::BeforeTerminator);
  1079. Rewriter.setInsertPoint(IP);
  1080. IRBuilder<> B(IP);
  1081. bool InvalidateLoop = false;
  1082. Value *MinECV = nullptr; // lazily generated if needed
  1083. for (BasicBlock *ExitingBB : ExitingBlocks) {
  1084. // If our exiting block exits multiple loops, we can only rewrite the
  1085. // innermost one. Otherwise, we're changing how many times the innermost
  1086. // loop runs before it exits.
  1087. if (LI->getLoopFor(ExitingBB) != L)
  1088. continue;
  1089. // Can't rewrite non-branch yet.
  1090. auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
  1091. if (!BI)
  1092. continue;
  1093. // If already constant, nothing to do.
  1094. if (isa<Constant>(BI->getCondition()))
  1095. continue;
  1096. const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
  1097. if (isa<SCEVCouldNotCompute>(ExitCount) ||
  1098. ExitCount->getType()->isPointerTy() ||
  1099. !Rewriter.isSafeToExpandAt(ExitCount, WidenableBR))
  1100. continue;
  1101. const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
  1102. BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1);
  1103. if (!ExitBB->getPostdominatingDeoptimizeCall())
  1104. continue;
  1105. /// Here we can be fairly sure that executing this exit will most likely
  1106. /// lead to executing llvm.experimental.deoptimize.
  1107. /// This is a profitability heuristic, not a legality constraint.
  1108. // If we found a widenable exit condition, do two things:
  1109. // 1) fold the widened exit test into the widenable condition
  1110. // 2) fold the branch to untaken - avoids infinite looping
  1111. Value *ECV = Rewriter.expandCodeFor(ExitCount);
  1112. if (!MinECV)
  1113. MinECV = Rewriter.expandCodeFor(MinEC);
  1114. Value *RHS = MinECV;
  1115. if (ECV->getType() != RHS->getType()) {
  1116. Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
  1117. ECV = B.CreateZExt(ECV, WiderTy);
  1118. RHS = B.CreateZExt(RHS, WiderTy);
  1119. }
  1120. assert(!Latch || DT->dominates(ExitingBB, Latch));
  1121. Value *NewCond = B.CreateICmp(ICmpInst::ICMP_UGT, ECV, RHS);
  1122. // Freeze poison or undef to an arbitrary bit pattern to ensure we can
  1123. // branch without introducing UB. See NOTE ON POISON/UNDEF above for
  1124. // context.
  1125. NewCond = B.CreateFreeze(NewCond);
  1126. widenWidenableBranch(WidenableBR, NewCond);
  1127. Value *OldCond = BI->getCondition();
  1128. BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue));
  1129. InvalidateLoop = true;
  1130. }
  1131. if (InvalidateLoop)
  1132. // We just mutated a bunch of loop exits changing there exit counts
  1133. // widely. We need to force recomputation of the exit counts given these
  1134. // changes. Note that all of the inserted exits are never taken, and
  1135. // should be removed next time the CFG is modified.
  1136. SE->forgetLoop(L);
  1137. // Always return `true` since we have moved the WidenableBR's condition.
  1138. return true;
  1139. }
  1140. bool LoopPredication::runOnLoop(Loop *Loop) {
  1141. L = Loop;
  1142. LLVM_DEBUG(dbgs() << "Analyzing ");
  1143. LLVM_DEBUG(L->dump());
  1144. Module *M = L->getHeader()->getModule();
  1145. // There is nothing to do if the module doesn't use guards
  1146. auto *GuardDecl =
  1147. M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
  1148. bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty();
  1149. auto *WCDecl = M->getFunction(
  1150. Intrinsic::getName(Intrinsic::experimental_widenable_condition));
  1151. bool HasWidenableConditions =
  1152. PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty();
  1153. if (!HasIntrinsicGuards && !HasWidenableConditions)
  1154. return false;
  1155. DL = &M->getDataLayout();
  1156. Preheader = L->getLoopPreheader();
  1157. if (!Preheader)
  1158. return false;
  1159. auto LatchCheckOpt = parseLoopLatchICmp();
  1160. if (!LatchCheckOpt)
  1161. return false;
  1162. LatchCheck = *LatchCheckOpt;
  1163. LLVM_DEBUG(dbgs() << "Latch check:\n");
  1164. LLVM_DEBUG(LatchCheck.dump());
  1165. if (!isLoopProfitableToPredicate()) {
  1166. LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
  1167. return false;
  1168. }
  1169. // Collect all the guards into a vector and process later, so as not
  1170. // to invalidate the instruction iterator.
  1171. SmallVector<IntrinsicInst *, 4> Guards;
  1172. SmallVector<BranchInst *, 4> GuardsAsWidenableBranches;
  1173. for (const auto BB : L->blocks()) {
  1174. for (auto &I : *BB)
  1175. if (isGuard(&I))
  1176. Guards.push_back(cast<IntrinsicInst>(&I));
  1177. if (PredicateWidenableBranchGuards &&
  1178. isGuardAsWidenableBranch(BB->getTerminator()))
  1179. GuardsAsWidenableBranches.push_back(
  1180. cast<BranchInst>(BB->getTerminator()));
  1181. }
  1182. SCEVExpander Expander(*SE, *DL, "loop-predication");
  1183. bool Changed = false;
  1184. for (auto *Guard : Guards)
  1185. Changed |= widenGuardConditions(Guard, Expander);
  1186. for (auto *Guard : GuardsAsWidenableBranches)
  1187. Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
  1188. Changed |= predicateLoopExits(L, Expander);
  1189. if (MSSAU && VerifyMemorySSA)
  1190. MSSAU->getMemorySSA()->verifyMemorySSA();
  1191. return Changed;
  1192. }