LoopFlatten.cpp 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022
  1. //===- LoopFlatten.cpp - Loop flattening pass------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass flattens pairs nested loops into a single loop.
  10. //
  11. // The intention is to optimise loop nests like this, which together access an
  12. // array linearly:
  13. //
  14. // for (int i = 0; i < N; ++i)
  15. // for (int j = 0; j < M; ++j)
  16. // f(A[i*M+j]);
  17. //
  18. // into one loop:
  19. //
  20. // for (int i = 0; i < (N*M); ++i)
  21. // f(A[i]);
  22. //
  23. // It can also flatten loops where the induction variables are not used in the
  24. // loop. This is only worth doing if the induction variables are only used in an
  25. // expression like i*M+j. If they had any other uses, we would have to insert a
  26. // div/mod to reconstruct the original values, so this wouldn't be profitable.
  27. //
  28. // We also need to prove that N*M will not overflow. The preferred solution is
  29. // to widen the IV, which avoids overflow checks, so that is tried first. If
  30. // the IV cannot be widened, then we try to determine that this new tripcount
  31. // expression won't overflow.
  32. //
  33. // Q: Does LoopFlatten use SCEV?
  34. // Short answer: Yes and no.
  35. //
  36. // Long answer:
  37. // For this transformation to be valid, we require all uses of the induction
  38. // variables to be linear expressions of the form i*M+j. The different Loop
  39. // APIs are used to get some loop components like the induction variable,
  40. // compare statement, etc. In addition, we do some pattern matching to find the
  41. // linear expressions and other loop components like the loop increment. The
  42. // latter are examples of expressions that do use the induction variable, but
  43. // are safe to ignore when we check all uses to be of the form i*M+j. We keep
  44. // track of all of this in bookkeeping struct FlattenInfo.
  45. // We assume the loops to be canonical, i.e. starting at 0 and increment with
  46. // 1. This makes RHS of the compare the loop tripcount (with the right
  47. // predicate). We use SCEV to then sanity check that this tripcount matches
  48. // with the tripcount as computed by SCEV.
  49. //
  50. //===----------------------------------------------------------------------===//
  51. #include "llvm/Transforms/Scalar/LoopFlatten.h"
  52. #include "llvm/ADT/Statistic.h"
  53. #include "llvm/Analysis/AssumptionCache.h"
  54. #include "llvm/Analysis/LoopInfo.h"
  55. #include "llvm/Analysis/LoopNestAnalysis.h"
  56. #include "llvm/Analysis/MemorySSAUpdater.h"
  57. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  58. #include "llvm/Analysis/ScalarEvolution.h"
  59. #include "llvm/Analysis/TargetTransformInfo.h"
  60. #include "llvm/Analysis/ValueTracking.h"
  61. #include "llvm/IR/Dominators.h"
  62. #include "llvm/IR/Function.h"
  63. #include "llvm/IR/IRBuilder.h"
  64. #include "llvm/IR/Module.h"
  65. #include "llvm/IR/PatternMatch.h"
  66. #include "llvm/InitializePasses.h"
  67. #include "llvm/Pass.h"
  68. #include "llvm/Support/Debug.h"
  69. #include "llvm/Support/raw_ostream.h"
  70. #include "llvm/Transforms/Scalar.h"
  71. #include "llvm/Transforms/Scalar/LoopPassManager.h"
  72. #include "llvm/Transforms/Utils/Local.h"
  73. #include "llvm/Transforms/Utils/LoopUtils.h"
  74. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  75. #include "llvm/Transforms/Utils/SimplifyIndVar.h"
  76. #include <optional>
  77. using namespace llvm;
  78. using namespace llvm::PatternMatch;
  79. #define DEBUG_TYPE "loop-flatten"
  80. STATISTIC(NumFlattened, "Number of loops flattened");
  81. static cl::opt<unsigned> RepeatedInstructionThreshold(
  82. "loop-flatten-cost-threshold", cl::Hidden, cl::init(2),
  83. cl::desc("Limit on the cost of instructions that can be repeated due to "
  84. "loop flattening"));
  85. static cl::opt<bool>
  86. AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,
  87. cl::init(false),
  88. cl::desc("Assume that the product of the two iteration "
  89. "trip counts will never overflow"));
  90. static cl::opt<bool>
  91. WidenIV("loop-flatten-widen-iv", cl::Hidden, cl::init(true),
  92. cl::desc("Widen the loop induction variables, if possible, so "
  93. "overflow checks won't reject flattening"));
  94. namespace {
  95. // We require all uses of both induction variables to match this pattern:
  96. //
  97. // (OuterPHI * InnerTripCount) + InnerPHI
  98. //
  99. // I.e., it needs to be a linear expression of the induction variables and the
  100. // inner loop trip count. We keep track of all different expressions on which
  101. // checks will be performed in this bookkeeping struct.
  102. //
  103. struct FlattenInfo {
  104. Loop *OuterLoop = nullptr; // The loop pair to be flattened.
  105. Loop *InnerLoop = nullptr;
  106. PHINode *InnerInductionPHI = nullptr; // These PHINodes correspond to loop
  107. PHINode *OuterInductionPHI = nullptr; // induction variables, which are
  108. // expected to start at zero and
  109. // increment by one on each loop.
  110. Value *InnerTripCount = nullptr; // The product of these two tripcounts
  111. Value *OuterTripCount = nullptr; // will be the new flattened loop
  112. // tripcount. Also used to recognise a
  113. // linear expression that will be replaced.
  114. SmallPtrSet<Value *, 4> LinearIVUses; // Contains the linear expressions
  115. // of the form i*M+j that will be
  116. // replaced.
  117. BinaryOperator *InnerIncrement = nullptr; // Uses of induction variables in
  118. BinaryOperator *OuterIncrement = nullptr; // loop control statements that
  119. BranchInst *InnerBranch = nullptr; // are safe to ignore.
  120. BranchInst *OuterBranch = nullptr; // The instruction that needs to be
  121. // updated with new tripcount.
  122. SmallPtrSet<PHINode *, 4> InnerPHIsToTransform;
  123. bool Widened = false; // Whether this holds the flatten info before or after
  124. // widening.
  125. PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction
  126. PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV
  127. // has been applied. Used to skip
  128. // checks on phi nodes.
  129. FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
  130. bool isNarrowInductionPhi(PHINode *Phi) {
  131. // This can't be the narrow phi if we haven't widened the IV first.
  132. if (!Widened)
  133. return false;
  134. return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;
  135. }
  136. bool isInnerLoopIncrement(User *U) {
  137. return InnerIncrement == U;
  138. }
  139. bool isOuterLoopIncrement(User *U) {
  140. return OuterIncrement == U;
  141. }
  142. bool isInnerLoopTest(User *U) {
  143. return InnerBranch->getCondition() == U;
  144. }
  145. bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
  146. for (User *U : OuterInductionPHI->users()) {
  147. if (isOuterLoopIncrement(U))
  148. continue;
  149. auto IsValidOuterPHIUses = [&] (User *U) -> bool {
  150. LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
  151. if (!ValidOuterPHIUses.count(U)) {
  152. LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
  153. return false;
  154. }
  155. LLVM_DEBUG(dbgs() << "Use is optimisable\n");
  156. return true;
  157. };
  158. if (auto *V = dyn_cast<TruncInst>(U)) {
  159. for (auto *K : V->users()) {
  160. if (!IsValidOuterPHIUses(K))
  161. return false;
  162. }
  163. continue;
  164. }
  165. if (!IsValidOuterPHIUses(U))
  166. return false;
  167. }
  168. return true;
  169. }
  170. bool matchLinearIVUser(User *U, Value *InnerTripCount,
  171. SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
  172. LLVM_DEBUG(dbgs() << "Checking linear i*M+j expression for: "; U->dump());
  173. Value *MatchedMul = nullptr;
  174. Value *MatchedItCount = nullptr;
  175. bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI),
  176. m_Value(MatchedMul))) &&
  177. match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
  178. m_Value(MatchedItCount)));
  179. // Matches the same pattern as above, except it also looks for truncs
  180. // on the phi, which can be the result of widening the induction variables.
  181. bool IsAddTrunc =
  182. match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)),
  183. m_Value(MatchedMul))) &&
  184. match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
  185. m_Value(MatchedItCount)));
  186. if (!MatchedItCount)
  187. return false;
  188. LLVM_DEBUG(dbgs() << "Matched multiplication: "; MatchedMul->dump());
  189. LLVM_DEBUG(dbgs() << "Matched iteration count: "; MatchedItCount->dump());
  190. // The mul should not have any other uses. Widening may leave trivially dead
  191. // uses, which can be ignored.
  192. if (count_if(MatchedMul->users(), [](User *U) {
  193. return !isInstructionTriviallyDead(cast<Instruction>(U));
  194. }) > 1) {
  195. LLVM_DEBUG(dbgs() << "Multiply has more than one use\n");
  196. return false;
  197. }
  198. // Look through extends if the IV has been widened. Don't look through
  199. // extends if we already looked through a trunc.
  200. if (Widened && IsAdd &&
  201. (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
  202. assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
  203. "Unexpected type mismatch in types after widening");
  204. MatchedItCount = isa<SExtInst>(MatchedItCount)
  205. ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
  206. : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
  207. }
  208. LLVM_DEBUG(dbgs() << "Looking for inner trip count: ";
  209. InnerTripCount->dump());
  210. if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
  211. LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n");
  212. ValidOuterPHIUses.insert(MatchedMul);
  213. LinearIVUses.insert(U);
  214. return true;
  215. }
  216. LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
  217. return false;
  218. }
  219. bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
  220. Value *SExtInnerTripCount = InnerTripCount;
  221. if (Widened &&
  222. (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
  223. SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
  224. for (User *U : InnerInductionPHI->users()) {
  225. LLVM_DEBUG(dbgs() << "Checking User: "; U->dump());
  226. if (isInnerLoopIncrement(U)) {
  227. LLVM_DEBUG(dbgs() << "Use is inner loop increment, continuing\n");
  228. continue;
  229. }
  230. // After widening the IVs, a trunc instruction might have been introduced,
  231. // so look through truncs.
  232. if (isa<TruncInst>(U)) {
  233. if (!U->hasOneUse())
  234. return false;
  235. U = *U->user_begin();
  236. }
  237. // If the use is in the compare (which is also the condition of the inner
  238. // branch) then the compare has been altered by another transformation e.g
  239. // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
  240. // a constant. Ignore this use as the compare gets removed later anyway.
  241. if (isInnerLoopTest(U)) {
  242. LLVM_DEBUG(dbgs() << "Use is the inner loop test, continuing\n");
  243. continue;
  244. }
  245. if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) {
  246. LLVM_DEBUG(dbgs() << "Not a linear IV user\n");
  247. return false;
  248. }
  249. LLVM_DEBUG(dbgs() << "Linear IV users found!\n");
  250. }
  251. return true;
  252. }
  253. };
  254. } // namespace
  255. static bool
  256. setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
  257. SmallPtrSetImpl<Instruction *> &IterationInstructions) {
  258. TripCount = TC;
  259. IterationInstructions.insert(Increment);
  260. LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump());
  261. LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
  262. LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
  263. return true;
  264. }
  265. // Given the RHS of the loop latch compare instruction, verify with SCEV
  266. // that this is indeed the loop tripcount.
  267. // TODO: This used to be a straightforward check but has grown to be quite
  268. // complicated now. It is therefore worth revisiting what the additional
  269. // benefits are of this (compared to relying on canonical loops and pattern
  270. // matching).
  271. static bool verifyTripCount(Value *RHS, Loop *L,
  272. SmallPtrSetImpl<Instruction *> &IterationInstructions,
  273. PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
  274. BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
  275. const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
  276. if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
  277. LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
  278. return false;
  279. }
  280. // The Extend=false flag is used for getTripCountFromExitCount as we want
  281. // to verify and match it with the pattern matched tripcount. Please note
  282. // that overflow checks are performed in checkOverflow, but are first tried
  283. // to avoid by widening the IV.
  284. const SCEV *SCEVTripCount =
  285. SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
  286. const SCEV *SCEVRHS = SE->getSCEV(RHS);
  287. if (SCEVRHS == SCEVTripCount)
  288. return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
  289. ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
  290. if (ConstantRHS) {
  291. const SCEV *BackedgeTCExt = nullptr;
  292. if (IsWidened) {
  293. const SCEV *SCEVTripCountExt;
  294. // Find the extended backedge taken count and extended trip count using
  295. // SCEV. One of these should now match the RHS of the compare.
  296. BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
  297. SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
  298. if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
  299. LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
  300. return false;
  301. }
  302. }
  303. // If the RHS of the compare is equal to the backedge taken count we need
  304. // to add one to get the trip count.
  305. if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
  306. ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
  307. Value *NewRHS = ConstantInt::get(
  308. ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
  309. return setLoopComponents(NewRHS, TripCount, Increment,
  310. IterationInstructions);
  311. }
  312. return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
  313. }
  314. // If the RHS isn't a constant then check that the reason it doesn't match
  315. // the SCEV trip count is because the RHS is a ZExt or SExt instruction
  316. // (and take the trip count to be the RHS).
  317. if (!IsWidened) {
  318. LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
  319. return false;
  320. }
  321. auto *TripCountInst = dyn_cast<Instruction>(RHS);
  322. if (!TripCountInst) {
  323. LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
  324. return false;
  325. }
  326. if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
  327. SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
  328. LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
  329. return false;
  330. }
  331. return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
  332. }
  333. // Finds the induction variable, increment and trip count for a simple loop that
  334. // we can flatten.
  335. static bool findLoopComponents(
  336. Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,
  337. PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
  338. BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
  339. LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");
  340. if (!L->isLoopSimplifyForm()) {
  341. LLVM_DEBUG(dbgs() << "Loop is not in normal form\n");
  342. return false;
  343. }
  344. // Currently, to simplify the implementation, the Loop induction variable must
  345. // start at zero and increment with a step size of one.
  346. if (!L->isCanonical(*SE)) {
  347. LLVM_DEBUG(dbgs() << "Loop is not canonical\n");
  348. return false;
  349. }
  350. // There must be exactly one exiting block, and it must be the same at the
  351. // latch.
  352. BasicBlock *Latch = L->getLoopLatch();
  353. if (L->getExitingBlock() != Latch) {
  354. LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n");
  355. return false;
  356. }
  357. // Find the induction PHI. If there is no induction PHI, we can't do the
  358. // transformation. TODO: could other variables trigger this? Do we have to
  359. // search for the best one?
  360. InductionPHI = L->getInductionVariable(*SE);
  361. if (!InductionPHI) {
  362. LLVM_DEBUG(dbgs() << "Could not find induction PHI\n");
  363. return false;
  364. }
  365. LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump());
  366. bool ContinueOnTrue = L->contains(Latch->getTerminator()->getSuccessor(0));
  367. auto IsValidPredicate = [&](ICmpInst::Predicate Pred) {
  368. if (ContinueOnTrue)
  369. return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT;
  370. else
  371. return Pred == CmpInst::ICMP_EQ;
  372. };
  373. // Find Compare and make sure it is valid. getLatchCmpInst checks that the
  374. // back branch of the latch is conditional.
  375. ICmpInst *Compare = L->getLatchCmpInst();
  376. if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) ||
  377. Compare->hasNUsesOrMore(2)) {
  378. LLVM_DEBUG(dbgs() << "Could not find valid comparison\n");
  379. return false;
  380. }
  381. BackBranch = cast<BranchInst>(Latch->getTerminator());
  382. IterationInstructions.insert(BackBranch);
  383. LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump());
  384. IterationInstructions.insert(Compare);
  385. LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());
  386. // Find increment and trip count.
  387. // There are exactly 2 incoming values to the induction phi; one from the
  388. // pre-header and one from the latch. The incoming latch value is the
  389. // increment variable.
  390. Increment =
  391. cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch));
  392. if ((Compare->getOperand(0) != Increment || !Increment->hasNUses(2)) &&
  393. !Increment->hasNUses(1)) {
  394. LLVM_DEBUG(dbgs() << "Could not find valid increment\n");
  395. return false;
  396. }
  397. // The trip count is the RHS of the compare. If this doesn't match the trip
  398. // count computed by SCEV then this is because the trip count variable
  399. // has been widened so the types don't match, or because it is a constant and
  400. // another transformation has changed the compare (e.g. icmp ult %inc,
  401. // tripcount -> icmp ult %j, tripcount-1), or both.
  402. Value *RHS = Compare->getOperand(1);
  403. return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount,
  404. Increment, BackBranch, SE, IsWidened);
  405. }
  406. static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {
  407. // All PHIs in the inner and outer headers must either be:
  408. // - The induction PHI, which we are going to rewrite as one induction in
  409. // the new loop. This is already checked by findLoopComponents.
  410. // - An outer header PHI with all incoming values from outside the loop.
  411. // LoopSimplify guarantees we have a pre-header, so we don't need to
  412. // worry about that here.
  413. // - Pairs of PHIs in the inner and outer headers, which implement a
  414. // loop-carried dependency that will still be valid in the new loop. To
  415. // be valid, this variable must be modified only in the inner loop.
  416. // The set of PHI nodes in the outer loop header that we know will still be
  417. // valid after the transformation. These will not need to be modified (with
  418. // the exception of the induction variable), but we do need to check that
  419. // there are no unsafe PHI nodes.
  420. SmallPtrSet<PHINode *, 4> SafeOuterPHIs;
  421. SafeOuterPHIs.insert(FI.OuterInductionPHI);
  422. // Check that all PHI nodes in the inner loop header match one of the valid
  423. // patterns.
  424. for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) {
  425. // The induction PHIs break these rules, and that's OK because we treat
  426. // them specially when doing the transformation.
  427. if (&InnerPHI == FI.InnerInductionPHI)
  428. continue;
  429. if (FI.isNarrowInductionPhi(&InnerPHI))
  430. continue;
  431. // Each inner loop PHI node must have two incoming values/blocks - one
  432. // from the pre-header, and one from the latch.
  433. assert(InnerPHI.getNumIncomingValues() == 2);
  434. Value *PreHeaderValue =
  435. InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader());
  436. Value *LatchValue =
  437. InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch());
  438. // The incoming value from the outer loop must be the PHI node in the
  439. // outer loop header, with no modifications made in the top of the outer
  440. // loop.
  441. PHINode *OuterPHI = dyn_cast<PHINode>(PreHeaderValue);
  442. if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) {
  443. LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n");
  444. return false;
  445. }
  446. // The other incoming value must come from the inner loop, without any
  447. // modifications in the tail end of the outer loop. We are in LCSSA form,
  448. // so this will actually be a PHI in the inner loop's exit block, which
  449. // only uses values from inside the inner loop.
  450. PHINode *LCSSAPHI = dyn_cast<PHINode>(
  451. OuterPHI->getIncomingValueForBlock(FI.OuterLoop->getLoopLatch()));
  452. if (!LCSSAPHI) {
  453. LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n");
  454. return false;
  455. }
  456. // The value used by the LCSSA PHI must be the same one that the inner
  457. // loop's PHI uses.
  458. if (LCSSAPHI->hasConstantValue() != LatchValue) {
  459. LLVM_DEBUG(
  460. dbgs() << "LCSSA PHI incoming value does not match latch value\n");
  461. return false;
  462. }
  463. LLVM_DEBUG(dbgs() << "PHI pair is safe:\n");
  464. LLVM_DEBUG(dbgs() << " Inner: "; InnerPHI.dump());
  465. LLVM_DEBUG(dbgs() << " Outer: "; OuterPHI->dump());
  466. SafeOuterPHIs.insert(OuterPHI);
  467. FI.InnerPHIsToTransform.insert(&InnerPHI);
  468. }
  469. for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) {
  470. if (FI.isNarrowInductionPhi(&OuterPHI))
  471. continue;
  472. if (!SafeOuterPHIs.count(&OuterPHI)) {
  473. LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump());
  474. return false;
  475. }
  476. }
  477. LLVM_DEBUG(dbgs() << "checkPHIs: OK\n");
  478. return true;
  479. }
  480. static bool
  481. checkOuterLoopInsts(FlattenInfo &FI,
  482. SmallPtrSetImpl<Instruction *> &IterationInstructions,
  483. const TargetTransformInfo *TTI) {
  484. // Check for instructions in the outer but not inner loop. If any of these
  485. // have side-effects then this transformation is not legal, and if there is
  486. // a significant amount of code here which can't be optimised out that it's
  487. // not profitable (as these instructions would get executed for each
  488. // iteration of the inner loop).
  489. InstructionCost RepeatedInstrCost = 0;
  490. for (auto *B : FI.OuterLoop->getBlocks()) {
  491. if (FI.InnerLoop->contains(B))
  492. continue;
  493. for (auto &I : *B) {
  494. if (!isa<PHINode>(&I) && !I.isTerminator() &&
  495. !isSafeToSpeculativelyExecute(&I)) {
  496. LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have "
  497. "side effects: ";
  498. I.dump());
  499. return false;
  500. }
  501. // The execution count of the outer loop's iteration instructions
  502. // (increment, compare and branch) will be increased, but the
  503. // equivalent instructions will be removed from the inner loop, so
  504. // they make a net difference of zero.
  505. if (IterationInstructions.count(&I))
  506. continue;
  507. // The unconditional branch to the inner loop's header will turn into
  508. // a fall-through, so adds no cost.
  509. BranchInst *Br = dyn_cast<BranchInst>(&I);
  510. if (Br && Br->isUnconditional() &&
  511. Br->getSuccessor(0) == FI.InnerLoop->getHeader())
  512. continue;
  513. // Multiplies of the outer iteration variable and inner iteration
  514. // count will be optimised out.
  515. if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI),
  516. m_Specific(FI.InnerTripCount))))
  517. continue;
  518. InstructionCost Cost =
  519. TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
  520. LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump());
  521. RepeatedInstrCost += Cost;
  522. }
  523. }
  524. LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: "
  525. << RepeatedInstrCost << "\n");
  526. // Bail out if flattening the loops would cause instructions in the outer
  527. // loop but not in the inner loop to be executed extra times.
  528. if (RepeatedInstrCost > RepeatedInstructionThreshold) {
  529. LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: not profitable, bailing.\n");
  530. return false;
  531. }
  532. LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: OK\n");
  533. return true;
  534. }
  535. // We require all uses of both induction variables to match this pattern:
  536. //
  537. // (OuterPHI * InnerTripCount) + InnerPHI
  538. //
  539. // Any uses of the induction variables not matching that pattern would
  540. // require a div/mod to reconstruct in the flattened loop, so the
  541. // transformation wouldn't be profitable.
  542. static bool checkIVUsers(FlattenInfo &FI) {
  543. // Check that all uses of the inner loop's induction variable match the
  544. // expected pattern, recording the uses of the outer IV.
  545. SmallPtrSet<Value *, 4> ValidOuterPHIUses;
  546. if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
  547. return false;
  548. // Check that there are no uses of the outer IV other than the ones found
  549. // as part of the pattern above.
  550. if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
  551. return false;
  552. LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
  553. dbgs() << "Found " << FI.LinearIVUses.size()
  554. << " value(s) that can be replaced:\n";
  555. for (Value *V : FI.LinearIVUses) {
  556. dbgs() << " ";
  557. V->dump();
  558. });
  559. return true;
  560. }
  561. // Return an OverflowResult dependant on if overflow of the multiplication of
  562. // InnerTripCount and OuterTripCount can be assumed not to happen.
  563. static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
  564. AssumptionCache *AC) {
  565. Function *F = FI.OuterLoop->getHeader()->getParent();
  566. const DataLayout &DL = F->getParent()->getDataLayout();
  567. // For debugging/testing.
  568. if (AssumeNoOverflow)
  569. return OverflowResult::NeverOverflows;
  570. // Check if the multiply could not overflow due to known ranges of the
  571. // input values.
  572. OverflowResult OR = computeOverflowForUnsignedMul(
  573. FI.InnerTripCount, FI.OuterTripCount, DL, AC,
  574. FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
  575. if (OR != OverflowResult::MayOverflow)
  576. return OR;
  577. for (Value *V : FI.LinearIVUses) {
  578. for (Value *U : V->users()) {
  579. if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
  580. for (Value *GEPUser : U->users()) {
  581. auto *GEPUserInst = cast<Instruction>(GEPUser);
  582. if (!isa<LoadInst>(GEPUserInst) &&
  583. !(isa<StoreInst>(GEPUserInst) &&
  584. GEP == GEPUserInst->getOperand(1)))
  585. continue;
  586. if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst,
  587. FI.InnerLoop))
  588. continue;
  589. // The IV is used as the operand of a GEP which dominates the loop
  590. // latch, and the IV is at least as wide as the address space of the
  591. // GEP. In this case, the GEP would wrap around the address space
  592. // before the IV increment wraps, which would be UB.
  593. if (GEP->isInBounds() &&
  594. V->getType()->getIntegerBitWidth() >=
  595. DL.getPointerTypeSizeInBits(GEP->getType())) {
  596. LLVM_DEBUG(
  597. dbgs() << "use of linear IV would be UB if overflow occurred: ";
  598. GEP->dump());
  599. return OverflowResult::NeverOverflows;
  600. }
  601. }
  602. }
  603. }
  604. }
  605. return OverflowResult::MayOverflow;
  606. }
  607. static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
  608. ScalarEvolution *SE, AssumptionCache *AC,
  609. const TargetTransformInfo *TTI) {
  610. SmallPtrSet<Instruction *, 8> IterationInstructions;
  611. if (!findLoopComponents(FI.InnerLoop, IterationInstructions,
  612. FI.InnerInductionPHI, FI.InnerTripCount,
  613. FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened))
  614. return false;
  615. if (!findLoopComponents(FI.OuterLoop, IterationInstructions,
  616. FI.OuterInductionPHI, FI.OuterTripCount,
  617. FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened))
  618. return false;
  619. // Both of the loop trip count values must be invariant in the outer loop
  620. // (non-instructions are all inherently invariant).
  621. if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) {
  622. LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n");
  623. return false;
  624. }
  625. if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) {
  626. LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n");
  627. return false;
  628. }
  629. if (!checkPHIs(FI, TTI))
  630. return false;
  631. // FIXME: it should be possible to handle different types correctly.
  632. if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType())
  633. return false;
  634. if (!checkOuterLoopInsts(FI, IterationInstructions, TTI))
  635. return false;
  636. // Find the values in the loop that can be replaced with the linearized
  637. // induction variable, and check that there are no other uses of the inner
  638. // or outer induction variable. If there were, we could still do this
  639. // transformation, but we'd have to insert a div/mod to calculate the
  640. // original IVs, so it wouldn't be profitable.
  641. if (!checkIVUsers(FI))
  642. return false;
  643. LLVM_DEBUG(dbgs() << "CanFlattenLoopPair: OK\n");
  644. return true;
  645. }
  646. static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
  647. ScalarEvolution *SE, AssumptionCache *AC,
  648. const TargetTransformInfo *TTI, LPMUpdater *U,
  649. MemorySSAUpdater *MSSAU) {
  650. Function *F = FI.OuterLoop->getHeader()->getParent();
  651. LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");
  652. {
  653. using namespace ore;
  654. OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(),
  655. FI.InnerLoop->getHeader());
  656. OptimizationRemarkEmitter ORE(F);
  657. Remark << "Flattened into outer loop";
  658. ORE.emit(Remark);
  659. }
  660. Value *NewTripCount = BinaryOperator::CreateMul(
  661. FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
  662. FI.OuterLoop->getLoopPreheader()->getTerminator());
  663. LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
  664. NewTripCount->dump());
  665. // Fix up PHI nodes that take values from the inner loop back-edge, which
  666. // we are about to remove.
  667. FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
  668. // The old Phi will be optimised away later, but for now we can't leave
  669. // leave it in an invalid state, so are updating them too.
  670. for (PHINode *PHI : FI.InnerPHIsToTransform)
  671. PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
  672. // Modify the trip count of the outer loop to be the product of the two
  673. // trip counts.
  674. cast<User>(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount);
  675. // Replace the inner loop backedge with an unconditional branch to the exit.
  676. BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock();
  677. BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock();
  678. InnerExitingBlock->getTerminator()->eraseFromParent();
  679. BranchInst::Create(InnerExitBlock, InnerExitingBlock);
  680. // Update the DomTree and MemorySSA.
  681. DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
  682. if (MSSAU)
  683. MSSAU->removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
  684. // Replace all uses of the polynomial calculated from the two induction
  685. // variables with the one new one.
  686. IRBuilder<> Builder(FI.OuterInductionPHI->getParent()->getTerminator());
  687. for (Value *V : FI.LinearIVUses) {
  688. Value *OuterValue = FI.OuterInductionPHI;
  689. if (FI.Widened)
  690. OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
  691. "flatten.trunciv");
  692. LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: ";
  693. OuterValue->dump());
  694. V->replaceAllUsesWith(OuterValue);
  695. }
  696. // Tell LoopInfo, SCEV and the pass manager that the inner loop has been
  697. // deleted, and invalidate any outer loop information.
  698. SE->forgetLoop(FI.OuterLoop);
  699. SE->forgetBlockAndLoopDispositions();
  700. if (U)
  701. U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName());
  702. LI->erase(FI.InnerLoop);
  703. // Increment statistic value.
  704. NumFlattened++;
  705. return true;
  706. }
  707. static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
  708. ScalarEvolution *SE, AssumptionCache *AC,
  709. const TargetTransformInfo *TTI) {
  710. if (!WidenIV) {
  711. LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n");
  712. return false;
  713. }
  714. LLVM_DEBUG(dbgs() << "Try widening the IVs\n");
  715. Module *M = FI.InnerLoop->getHeader()->getParent()->getParent();
  716. auto &DL = M->getDataLayout();
  717. auto *InnerType = FI.InnerInductionPHI->getType();
  718. auto *OuterType = FI.OuterInductionPHI->getType();
  719. unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits();
  720. auto *MaxLegalType = DL.getLargestLegalIntType(M->getContext());
  721. // If both induction types are less than the maximum legal integer width,
  722. // promote both to the widest type available so we know calculating
  723. // (OuterTripCount * InnerTripCount) as the new trip count is safe.
  724. if (InnerType != OuterType ||
  725. InnerType->getScalarSizeInBits() >= MaxLegalSize ||
  726. MaxLegalType->getScalarSizeInBits() <
  727. InnerType->getScalarSizeInBits() * 2) {
  728. LLVM_DEBUG(dbgs() << "Can't widen the IV\n");
  729. return false;
  730. }
  731. SCEVExpander Rewriter(*SE, DL, "loopflatten");
  732. SmallVector<WeakTrackingVH, 4> DeadInsts;
  733. unsigned ElimExt = 0;
  734. unsigned Widened = 0;
  735. auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool {
  736. PHINode *WidePhi =
  737. createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened,
  738. true /* HasGuards */, true /* UsePostIncrementRanges */);
  739. if (!WidePhi)
  740. return false;
  741. LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump());
  742. LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIV.NarrowIV->dump());
  743. Deleted = RecursivelyDeleteDeadPHINode(WideIV.NarrowIV);
  744. return true;
  745. };
  746. bool Deleted;
  747. if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false}, Deleted))
  748. return false;
  749. // Add the narrow phi to list, so that it will be adjusted later when the
  750. // the transformation is performed.
  751. if (!Deleted)
  752. FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI);
  753. if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false}, Deleted))
  754. return false;
  755. assert(Widened && "Widened IV expected");
  756. FI.Widened = true;
  757. // Save the old/narrow induction phis, which we need to ignore in CheckPHIs.
  758. FI.NarrowInnerInductionPHI = FI.InnerInductionPHI;
  759. FI.NarrowOuterInductionPHI = FI.OuterInductionPHI;
  760. // After widening, rediscover all the loop components.
  761. return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI);
  762. }
  763. static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
  764. ScalarEvolution *SE, AssumptionCache *AC,
  765. const TargetTransformInfo *TTI, LPMUpdater *U,
  766. MemorySSAUpdater *MSSAU) {
  767. LLVM_DEBUG(
  768. dbgs() << "Loop flattening running on outer loop "
  769. << FI.OuterLoop->getHeader()->getName() << " and inner loop "
  770. << FI.InnerLoop->getHeader()->getName() << " in "
  771. << FI.OuterLoop->getHeader()->getParent()->getName() << "\n");
  772. if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI))
  773. return false;
  774. // Check if we can widen the induction variables to avoid overflow checks.
  775. bool CanFlatten = CanWidenIV(FI, DT, LI, SE, AC, TTI);
  776. // It can happen that after widening of the IV, flattening may not be
  777. // possible/happening, e.g. when it is deemed unprofitable. So bail here if
  778. // that is the case.
  779. // TODO: IV widening without performing the actual flattening transformation
  780. // is not ideal. While this codegen change should not matter much, it is an
  781. // unnecessary change which is better to avoid. It's unlikely this happens
  782. // often, because if it's unprofitibale after widening, it should be
  783. // unprofitabe before widening as checked in the first round of checks. But
  784. // 'RepeatedInstructionThreshold' is set to only 2, which can probably be
  785. // relaxed. Because this is making a code change (the IV widening, but not
  786. // the flattening), we return true here.
  787. if (FI.Widened && !CanFlatten)
  788. return true;
  789. // If we have widened and can perform the transformation, do that here.
  790. if (CanFlatten)
  791. return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
  792. // Otherwise, if we haven't widened the IV, check if the new iteration
  793. // variable might overflow. In this case, we need to version the loop, and
  794. // select the original version at runtime if the iteration space is too
  795. // large.
  796. // TODO: We currently don't version the loop.
  797. OverflowResult OR = checkOverflow(FI, DT, AC);
  798. if (OR == OverflowResult::AlwaysOverflowsHigh ||
  799. OR == OverflowResult::AlwaysOverflowsLow) {
  800. LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n");
  801. return false;
  802. } else if (OR == OverflowResult::MayOverflow) {
  803. LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");
  804. return false;
  805. }
  806. LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
  807. return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
  808. }
  809. bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
  810. AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U,
  811. MemorySSAUpdater *MSSAU) {
  812. bool Changed = false;
  813. for (Loop *InnerLoop : LN.getLoops()) {
  814. auto *OuterLoop = InnerLoop->getParentLoop();
  815. if (!OuterLoop)
  816. continue;
  817. FlattenInfo FI(OuterLoop, InnerLoop);
  818. Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
  819. }
  820. return Changed;
  821. }
  822. PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
  823. LoopStandardAnalysisResults &AR,
  824. LPMUpdater &U) {
  825. bool Changed = false;
  826. std::optional<MemorySSAUpdater> MSSAU;
  827. if (AR.MSSA) {
  828. MSSAU = MemorySSAUpdater(AR.MSSA);
  829. if (VerifyMemorySSA)
  830. AR.MSSA->verifyMemorySSA();
  831. }
  832. // The loop flattening pass requires loops to be
  833. // in simplified form, and also needs LCSSA. Running
  834. // this pass will simplify all loops that contain inner loops,
  835. // regardless of whether anything ends up being flattened.
  836. Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
  837. MSSAU ? &*MSSAU : nullptr);
  838. if (!Changed)
  839. return PreservedAnalyses::all();
  840. if (AR.MSSA && VerifyMemorySSA)
  841. AR.MSSA->verifyMemorySSA();
  842. auto PA = getLoopPassPreservedAnalyses();
  843. if (AR.MSSA)
  844. PA.preserve<MemorySSAAnalysis>();
  845. return PA;
  846. }
  847. namespace {
  848. class LoopFlattenLegacyPass : public FunctionPass {
  849. public:
  850. static char ID; // Pass ID, replacement for typeid
  851. LoopFlattenLegacyPass() : FunctionPass(ID) {
  852. initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry());
  853. }
  854. // Possibly flatten loop L into its child.
  855. bool runOnFunction(Function &F) override;
  856. void getAnalysisUsage(AnalysisUsage &AU) const override {
  857. getLoopAnalysisUsage(AU);
  858. AU.addRequired<TargetTransformInfoWrapperPass>();
  859. AU.addPreserved<TargetTransformInfoWrapperPass>();
  860. AU.addRequired<AssumptionCacheTracker>();
  861. AU.addPreserved<AssumptionCacheTracker>();
  862. AU.addPreserved<MemorySSAWrapperPass>();
  863. }
  864. };
  865. } // namespace
  866. char LoopFlattenLegacyPass::ID = 0;
  867. INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
  868. false, false)
  869. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  870. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  871. INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
  872. false, false)
  873. FunctionPass *llvm::createLoopFlattenPass() {
  874. return new LoopFlattenLegacyPass();
  875. }
  876. bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
  877. ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  878. LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  879. auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
  880. DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
  881. auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
  882. auto *TTI = &TTIP.getTTI(F);
  883. auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  884. auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>();
  885. std::optional<MemorySSAUpdater> MSSAU;
  886. if (MSSA)
  887. MSSAU = MemorySSAUpdater(&MSSA->getMSSA());
  888. bool Changed = false;
  889. for (Loop *L : *LI) {
  890. auto LN = LoopNest::getLoopNest(*L, *SE);
  891. Changed |=
  892. Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, MSSAU ? &*MSSAU : nullptr);
  893. }
  894. return Changed;
  895. }