ScalarEvolutionExpressions.h 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // This file defines the classes used to represent and build scalar expressions.
  15. //
  16. //===----------------------------------------------------------------------===//
  17. #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
  18. #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
  19. #include "llvm/ADT/DenseMap.h"
  20. #include "llvm/ADT/FoldingSet.h"
  21. #include "llvm/ADT/SmallPtrSet.h"
  22. #include "llvm/ADT/SmallVector.h"
  23. #include "llvm/ADT/iterator_range.h"
  24. #include "llvm/Analysis/ScalarEvolution.h"
  25. #include "llvm/IR/Constants.h"
  26. #include "llvm/IR/Value.h"
  27. #include "llvm/IR/ValueHandle.h"
  28. #include "llvm/Support/Casting.h"
  29. #include "llvm/Support/ErrorHandling.h"
  30. #include <cassert>
  31. #include <cstddef>
  32. namespace llvm {
  33. class APInt;
  34. class Constant;
  35. class ConstantRange;
  36. class Loop;
  37. class Type;
  38. enum SCEVTypes : unsigned short {
  39. // These should be ordered in terms of increasing complexity to make the
  40. // folders simpler.
  41. scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
  42. scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUMinExpr, scSMinExpr,
  43. scPtrToInt, scUnknown, scCouldNotCompute
  44. };
  45. /// This class represents a constant integer value.
  46. class SCEVConstant : public SCEV {
  47. friend class ScalarEvolution;
  48. ConstantInt *V;
  49. SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) :
  50. SCEV(ID, scConstant, 1), V(v) {}
  51. public:
  52. ConstantInt *getValue() const { return V; }
  53. const APInt &getAPInt() const { return getValue()->getValue(); }
  54. Type *getType() const { return V->getType(); }
  55. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  56. static bool classof(const SCEV *S) {
  57. return S->getSCEVType() == scConstant;
  58. }
  59. };
  60. inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
  61. APInt Size(16, 1);
  62. for (auto *Arg : Args)
  63. Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
  64. return (unsigned short)Size.getZExtValue();
  65. }
  66. /// This is the base class for unary cast operator classes.
  67. class SCEVCastExpr : public SCEV {
  68. protected:
  69. std::array<const SCEV *, 1> Operands;
  70. Type *Ty;
  71. SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
  72. Type *ty);
  73. public:
  74. const SCEV *getOperand() const { return Operands[0]; }
  75. const SCEV *getOperand(unsigned i) const {
  76. assert(i == 0 && "Operand index out of range!");
  77. return Operands[0];
  78. }
  79. using op_iterator = std::array<const SCEV *, 1>::const_iterator;
  80. using op_range = iterator_range<op_iterator>;
  81. op_range operands() const {
  82. return make_range(Operands.begin(), Operands.end());
  83. }
  84. size_t getNumOperands() const { return 1; }
  85. Type *getType() const { return Ty; }
  86. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  87. static bool classof(const SCEV *S) {
  88. return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
  89. S->getSCEVType() == scZeroExtend ||
  90. S->getSCEVType() == scSignExtend;
  91. }
  92. };
  93. /// This class represents a cast from a pointer to a pointer-sized integer
  94. /// value.
  95. class SCEVPtrToIntExpr : public SCEVCastExpr {
  96. friend class ScalarEvolution;
  97. SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
  98. public:
  99. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  100. static bool classof(const SCEV *S) {
  101. return S->getSCEVType() == scPtrToInt;
  102. }
  103. };
  104. /// This is the base class for unary integral cast operator classes.
  105. class SCEVIntegralCastExpr : public SCEVCastExpr {
  106. protected:
  107. SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
  108. const SCEV *op, Type *ty);
  109. public:
  110. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  111. static bool classof(const SCEV *S) {
  112. return S->getSCEVType() == scTruncate ||
  113. S->getSCEVType() == scZeroExtend ||
  114. S->getSCEVType() == scSignExtend;
  115. }
  116. };
  117. /// This class represents a truncation of an integer value to a
  118. /// smaller integer value.
  119. class SCEVTruncateExpr : public SCEVIntegralCastExpr {
  120. friend class ScalarEvolution;
  121. SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
  122. const SCEV *op, Type *ty);
  123. public:
  124. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  125. static bool classof(const SCEV *S) {
  126. return S->getSCEVType() == scTruncate;
  127. }
  128. };
  129. /// This class represents a zero extension of a small integer value
  130. /// to a larger integer value.
  131. class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
  132. friend class ScalarEvolution;
  133. SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
  134. const SCEV *op, Type *ty);
  135. public:
  136. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  137. static bool classof(const SCEV *S) {
  138. return S->getSCEVType() == scZeroExtend;
  139. }
  140. };
  141. /// This class represents a sign extension of a small integer value
  142. /// to a larger integer value.
  143. class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
  144. friend class ScalarEvolution;
  145. SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
  146. const SCEV *op, Type *ty);
  147. public:
  148. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  149. static bool classof(const SCEV *S) {
  150. return S->getSCEVType() == scSignExtend;
  151. }
  152. };
  153. /// This node is a base class providing common functionality for
  154. /// n'ary operators.
  155. class SCEVNAryExpr : public SCEV {
  156. protected:
  157. // Since SCEVs are immutable, ScalarEvolution allocates operand
  158. // arrays with its SCEVAllocator, so this class just needs a simple
  159. // pointer rather than a more elaborate vector-like data structure.
  160. // This also avoids the need for a non-trivial destructor.
  161. const SCEV *const *Operands;
  162. size_t NumOperands;
  163. SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
  164. const SCEV *const *O, size_t N)
  165. : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O),
  166. NumOperands(N) {}
  167. public:
  168. size_t getNumOperands() const { return NumOperands; }
  169. const SCEV *getOperand(unsigned i) const {
  170. assert(i < NumOperands && "Operand index out of range!");
  171. return Operands[i];
  172. }
  173. using op_iterator = const SCEV *const *;
  174. using op_range = iterator_range<op_iterator>;
  175. op_iterator op_begin() const { return Operands; }
  176. op_iterator op_end() const { return Operands + NumOperands; }
  177. op_range operands() const {
  178. return make_range(op_begin(), op_end());
  179. }
  180. Type *getType() const { return getOperand(0)->getType(); }
  181. NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
  182. return (NoWrapFlags)(SubclassData & Mask);
  183. }
  184. bool hasNoUnsignedWrap() const {
  185. return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
  186. }
  187. bool hasNoSignedWrap() const {
  188. return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
  189. }
  190. bool hasNoSelfWrap() const {
  191. return getNoWrapFlags(FlagNW) != FlagAnyWrap;
  192. }
  193. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  194. static bool classof(const SCEV *S) {
  195. return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
  196. S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
  197. S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
  198. S->getSCEVType() == scAddRecExpr;
  199. }
  200. };
  201. /// This node is the base class for n'ary commutative operators.
  202. class SCEVCommutativeExpr : public SCEVNAryExpr {
  203. protected:
  204. SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,
  205. enum SCEVTypes T, const SCEV *const *O, size_t N)
  206. : SCEVNAryExpr(ID, T, O, N) {}
  207. public:
  208. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  209. static bool classof(const SCEV *S) {
  210. return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
  211. S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
  212. S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
  213. }
  214. /// Set flags for a non-recurrence without clearing previously set flags.
  215. void setNoWrapFlags(NoWrapFlags Flags) {
  216. SubclassData |= Flags;
  217. }
  218. };
  219. /// This node represents an addition of some number of SCEVs.
  220. class SCEVAddExpr : public SCEVCommutativeExpr {
  221. friend class ScalarEvolution;
  222. Type *Ty;
  223. SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
  224. : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
  225. auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
  226. return Op->getType()->isPointerTy();
  227. });
  228. if (FirstPointerTypedOp != operands().end())
  229. Ty = (*FirstPointerTypedOp)->getType();
  230. else
  231. Ty = getOperand(0)->getType();
  232. }
  233. public:
  234. Type *getType() const { return Ty; }
  235. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  236. static bool classof(const SCEV *S) {
  237. return S->getSCEVType() == scAddExpr;
  238. }
  239. };
  240. /// This node represents multiplication of some number of SCEVs.
  241. class SCEVMulExpr : public SCEVCommutativeExpr {
  242. friend class ScalarEvolution;
  243. SCEVMulExpr(const FoldingSetNodeIDRef ID,
  244. const SCEV *const *O, size_t N)
  245. : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
  246. public:
  247. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  248. static bool classof(const SCEV *S) {
  249. return S->getSCEVType() == scMulExpr;
  250. }
  251. };
  252. /// This class represents a binary unsigned division operation.
  253. class SCEVUDivExpr : public SCEV {
  254. friend class ScalarEvolution;
  255. std::array<const SCEV *, 2> Operands;
  256. SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
  257. : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
  258. Operands[0] = lhs;
  259. Operands[1] = rhs;
  260. }
  261. public:
  262. const SCEV *getLHS() const { return Operands[0]; }
  263. const SCEV *getRHS() const { return Operands[1]; }
  264. size_t getNumOperands() const { return 2; }
  265. const SCEV *getOperand(unsigned i) const {
  266. assert((i == 0 || i == 1) && "Operand index out of range!");
  267. return i == 0 ? getLHS() : getRHS();
  268. }
  269. using op_iterator = std::array<const SCEV *, 2>::const_iterator;
  270. using op_range = iterator_range<op_iterator>;
  271. op_range operands() const {
  272. return make_range(Operands.begin(), Operands.end());
  273. }
  274. Type *getType() const {
  275. // In most cases the types of LHS and RHS will be the same, but in some
  276. // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
  277. // depend on the type for correctness, but handling types carefully can
  278. // avoid extra casts in the SCEVExpander. The LHS is more likely to be
  279. // a pointer type than the RHS, so use the RHS' type here.
  280. return getRHS()->getType();
  281. }
  282. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  283. static bool classof(const SCEV *S) {
  284. return S->getSCEVType() == scUDivExpr;
  285. }
  286. };
  287. /// This node represents a polynomial recurrence on the trip count
  288. /// of the specified loop. This is the primary focus of the
  289. /// ScalarEvolution framework; all the other SCEV subclasses are
  290. /// mostly just supporting infrastructure to allow SCEVAddRecExpr
  291. /// expressions to be created and analyzed.
  292. ///
  293. /// All operands of an AddRec are required to be loop invariant.
  294. ///
  295. class SCEVAddRecExpr : public SCEVNAryExpr {
  296. friend class ScalarEvolution;
  297. const Loop *L;
  298. SCEVAddRecExpr(const FoldingSetNodeIDRef ID,
  299. const SCEV *const *O, size_t N, const Loop *l)
  300. : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
  301. public:
  302. const SCEV *getStart() const { return Operands[0]; }
  303. const Loop *getLoop() const { return L; }
  304. /// Constructs and returns the recurrence indicating how much this
  305. /// expression steps by. If this is a polynomial of degree N, it
  306. /// returns a chrec of degree N-1. We cannot determine whether
  307. /// the step recurrence has self-wraparound.
  308. const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
  309. if (isAffine()) return getOperand(1);
  310. return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1,
  311. op_end()),
  312. getLoop(), FlagAnyWrap);
  313. }
  314. /// Return true if this represents an expression A + B*x where A
  315. /// and B are loop invariant values.
  316. bool isAffine() const {
  317. // We know that the start value is invariant. This expression is thus
  318. // affine iff the step is also invariant.
  319. return getNumOperands() == 2;
  320. }
  321. /// Return true if this represents an expression A + B*x + C*x^2
  322. /// where A, B and C are loop invariant values. This corresponds
  323. /// to an addrec of the form {L,+,M,+,N}
  324. bool isQuadratic() const {
  325. return getNumOperands() == 3;
  326. }
  327. /// Set flags for a recurrence without clearing any previously set flags.
  328. /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
  329. /// to make it easier to propagate flags.
  330. void setNoWrapFlags(NoWrapFlags Flags) {
  331. if (Flags & (FlagNUW | FlagNSW))
  332. Flags = ScalarEvolution::setFlags(Flags, FlagNW);
  333. SubclassData |= Flags;
  334. }
  335. /// Return the value of this chain of recurrences at the specified
  336. /// iteration number.
  337. const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
  338. /// Return the number of iterations of this loop that produce
  339. /// values in the specified constant range. Another way of
  340. /// looking at this is that it returns the first iteration number
  341. /// where the value is not in the condition, thus computing the
  342. /// exit count. If the iteration count can't be computed, an
  343. /// instance of SCEVCouldNotCompute is returned.
  344. const SCEV *getNumIterationsInRange(const ConstantRange &Range,
  345. ScalarEvolution &SE) const;
  346. /// Return an expression representing the value of this expression
  347. /// one iteration of the loop ahead.
  348. const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
  349. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  350. static bool classof(const SCEV *S) {
  351. return S->getSCEVType() == scAddRecExpr;
  352. }
  353. };
  354. /// This node is the base class min/max selections.
  355. class SCEVMinMaxExpr : public SCEVCommutativeExpr {
  356. friend class ScalarEvolution;
  357. static bool isMinMaxType(enum SCEVTypes T) {
  358. return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
  359. T == scUMinExpr;
  360. }
  361. protected:
  362. /// Note: Constructing subclasses via this constructor is allowed
  363. SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
  364. const SCEV *const *O, size_t N)
  365. : SCEVCommutativeExpr(ID, T, O, N) {
  366. assert(isMinMaxType(T));
  367. // Min and max never overflow
  368. setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
  369. }
  370. public:
  371. static bool classof(const SCEV *S) {
  372. return isMinMaxType(S->getSCEVType());
  373. }
  374. static enum SCEVTypes negate(enum SCEVTypes T) {
  375. switch (T) {
  376. case scSMaxExpr:
  377. return scSMinExpr;
  378. case scSMinExpr:
  379. return scSMaxExpr;
  380. case scUMaxExpr:
  381. return scUMinExpr;
  382. case scUMinExpr:
  383. return scUMaxExpr;
  384. default:
  385. llvm_unreachable("Not a min or max SCEV type!");
  386. }
  387. }
  388. };
  389. /// This class represents a signed maximum selection.
  390. class SCEVSMaxExpr : public SCEVMinMaxExpr {
  391. friend class ScalarEvolution;
  392. SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
  393. : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
  394. public:
  395. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  396. static bool classof(const SCEV *S) {
  397. return S->getSCEVType() == scSMaxExpr;
  398. }
  399. };
  400. /// This class represents an unsigned maximum selection.
  401. class SCEVUMaxExpr : public SCEVMinMaxExpr {
  402. friend class ScalarEvolution;
  403. SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
  404. : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
  405. public:
  406. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  407. static bool classof(const SCEV *S) {
  408. return S->getSCEVType() == scUMaxExpr;
  409. }
  410. };
  411. /// This class represents a signed minimum selection.
  412. class SCEVSMinExpr : public SCEVMinMaxExpr {
  413. friend class ScalarEvolution;
  414. SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
  415. : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
  416. public:
  417. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  418. static bool classof(const SCEV *S) {
  419. return S->getSCEVType() == scSMinExpr;
  420. }
  421. };
  422. /// This class represents an unsigned minimum selection.
  423. class SCEVUMinExpr : public SCEVMinMaxExpr {
  424. friend class ScalarEvolution;
  425. SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
  426. : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
  427. public:
  428. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  429. static bool classof(const SCEV *S) {
  430. return S->getSCEVType() == scUMinExpr;
  431. }
  432. };
  433. /// This means that we are dealing with an entirely unknown SCEV
  434. /// value, and only represent it as its LLVM Value. This is the
  435. /// "bottom" value for the analysis.
  436. class SCEVUnknown final : public SCEV, private CallbackVH {
  437. friend class ScalarEvolution;
  438. /// The parent ScalarEvolution value. This is used to update the
  439. /// parent's maps when the value associated with a SCEVUnknown is
  440. /// deleted or RAUW'd.
  441. ScalarEvolution *SE;
  442. /// The next pointer in the linked list of all SCEVUnknown
  443. /// instances owned by a ScalarEvolution.
  444. SCEVUnknown *Next;
  445. SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V,
  446. ScalarEvolution *se, SCEVUnknown *next) :
  447. SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
  448. // Implement CallbackVH.
  449. void deleted() override;
  450. void allUsesReplacedWith(Value *New) override;
  451. public:
  452. Value *getValue() const { return getValPtr(); }
  453. /// @{
  454. /// Test whether this is a special constant representing a type
  455. /// size, alignment, or field offset in a target-independent
  456. /// manner, and hasn't happened to have been folded with other
  457. /// operations into something unrecognizable. This is mainly only
  458. /// useful for pretty-printing and other situations where it isn't
  459. /// absolutely required for these to succeed.
  460. bool isSizeOf(Type *&AllocTy) const;
  461. bool isAlignOf(Type *&AllocTy) const;
  462. bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
  463. /// @}
  464. Type *getType() const { return getValPtr()->getType(); }
  465. /// Methods for support type inquiry through isa, cast, and dyn_cast:
  466. static bool classof(const SCEV *S) {
  467. return S->getSCEVType() == scUnknown;
  468. }
  469. };
  470. /// This class defines a simple visitor class that may be used for
  471. /// various SCEV analysis purposes.
  472. template<typename SC, typename RetVal=void>
  473. struct SCEVVisitor {
  474. RetVal visit(const SCEV *S) {
  475. switch (S->getSCEVType()) {
  476. case scConstant:
  477. return ((SC*)this)->visitConstant((const SCEVConstant*)S);
  478. case scPtrToInt:
  479. return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
  480. case scTruncate:
  481. return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S);
  482. case scZeroExtend:
  483. return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S);
  484. case scSignExtend:
  485. return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S);
  486. case scAddExpr:
  487. return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S);
  488. case scMulExpr:
  489. return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S);
  490. case scUDivExpr:
  491. return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S);
  492. case scAddRecExpr:
  493. return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S);
  494. case scSMaxExpr:
  495. return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S);
  496. case scUMaxExpr:
  497. return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S);
  498. case scSMinExpr:
  499. return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
  500. case scUMinExpr:
  501. return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
  502. case scUnknown:
  503. return ((SC*)this)->visitUnknown((const SCEVUnknown*)S);
  504. case scCouldNotCompute:
  505. return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S);
  506. }
  507. llvm_unreachable("Unknown SCEV kind!");
  508. }
  509. RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
  510. llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
  511. }
  512. };
  513. /// Visit all nodes in the expression tree using worklist traversal.
  514. ///
  515. /// Visitor implements:
  516. /// // return true to follow this node.
  517. /// bool follow(const SCEV *S);
  518. /// // return true to terminate the search.
  519. /// bool isDone();
  520. template<typename SV>
  521. class SCEVTraversal {
  522. SV &Visitor;
  523. SmallVector<const SCEV *, 8> Worklist;
  524. SmallPtrSet<const SCEV *, 8> Visited;
  525. void push(const SCEV *S) {
  526. if (Visited.insert(S).second && Visitor.follow(S))
  527. Worklist.push_back(S);
  528. }
  529. public:
  530. SCEVTraversal(SV& V): Visitor(V) {}
  531. void visitAll(const SCEV *Root) {
  532. push(Root);
  533. while (!Worklist.empty() && !Visitor.isDone()) {
  534. const SCEV *S = Worklist.pop_back_val();
  535. switch (S->getSCEVType()) {
  536. case scConstant:
  537. case scUnknown:
  538. continue;
  539. case scPtrToInt:
  540. case scTruncate:
  541. case scZeroExtend:
  542. case scSignExtend:
  543. push(cast<SCEVCastExpr>(S)->getOperand());
  544. continue;
  545. case scAddExpr:
  546. case scMulExpr:
  547. case scSMaxExpr:
  548. case scUMaxExpr:
  549. case scSMinExpr:
  550. case scUMinExpr:
  551. case scAddRecExpr:
  552. for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
  553. push(Op);
  554. continue;
  555. case scUDivExpr: {
  556. const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
  557. push(UDiv->getLHS());
  558. push(UDiv->getRHS());
  559. continue;
  560. }
  561. case scCouldNotCompute:
  562. llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
  563. }
  564. llvm_unreachable("Unknown SCEV kind!");
  565. }
  566. }
  567. };
  568. /// Use SCEVTraversal to visit all nodes in the given expression tree.
  569. template<typename SV>
  570. void visitAll(const SCEV *Root, SV& Visitor) {
  571. SCEVTraversal<SV> T(Visitor);
  572. T.visitAll(Root);
  573. }
  574. /// Return true if any node in \p Root satisfies the predicate \p Pred.
  575. template <typename PredTy>
  576. bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
  577. struct FindClosure {
  578. bool Found = false;
  579. PredTy Pred;
  580. FindClosure(PredTy Pred) : Pred(Pred) {}
  581. bool follow(const SCEV *S) {
  582. if (!Pred(S))
  583. return true;
  584. Found = true;
  585. return false;
  586. }
  587. bool isDone() const { return Found; }
  588. };
  589. FindClosure FC(Pred);
  590. visitAll(Root, FC);
  591. return FC.Found;
  592. }
  593. /// This visitor recursively visits a SCEV expression and re-writes it.
  594. /// The result from each visit is cached, so it will return the same
  595. /// SCEV for the same input.
  596. template<typename SC>
  597. class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
  598. protected:
  599. ScalarEvolution &SE;
  600. // Memoize the result of each visit so that we only compute once for
  601. // the same input SCEV. This is to avoid redundant computations when
  602. // a SCEV is referenced by multiple SCEVs. Without memoization, this
  603. // visit algorithm would have exponential time complexity in the worst
  604. // case, causing the compiler to hang on certain tests.
  605. DenseMap<const SCEV *, const SCEV *> RewriteResults;
  606. public:
  607. SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
  608. const SCEV *visit(const SCEV *S) {
  609. auto It = RewriteResults.find(S);
  610. if (It != RewriteResults.end())
  611. return It->second;
  612. auto* Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
  613. auto Result = RewriteResults.try_emplace(S, Visited);
  614. assert(Result.second && "Should insert a new entry");
  615. return Result.first->second;
  616. }
  617. const SCEV *visitConstant(const SCEVConstant *Constant) {
  618. return Constant;
  619. }
  620. const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
  621. const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
  622. return Operand == Expr->getOperand()
  623. ? Expr
  624. : SE.getPtrToIntExpr(Operand, Expr->getType());
  625. }
  626. const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
  627. const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
  628. return Operand == Expr->getOperand()
  629. ? Expr
  630. : SE.getTruncateExpr(Operand, Expr->getType());
  631. }
  632. const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
  633. const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
  634. return Operand == Expr->getOperand()
  635. ? Expr
  636. : SE.getZeroExtendExpr(Operand, Expr->getType());
  637. }
  638. const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
  639. const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
  640. return Operand == Expr->getOperand()
  641. ? Expr
  642. : SE.getSignExtendExpr(Operand, Expr->getType());
  643. }
  644. const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
  645. SmallVector<const SCEV *, 2> Operands;
  646. bool Changed = false;
  647. for (auto *Op : Expr->operands()) {
  648. Operands.push_back(((SC*)this)->visit(Op));
  649. Changed |= Op != Operands.back();
  650. }
  651. return !Changed ? Expr : SE.getAddExpr(Operands);
  652. }
  653. const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
  654. SmallVector<const SCEV *, 2> Operands;
  655. bool Changed = false;
  656. for (auto *Op : Expr->operands()) {
  657. Operands.push_back(((SC*)this)->visit(Op));
  658. Changed |= Op != Operands.back();
  659. }
  660. return !Changed ? Expr : SE.getMulExpr(Operands);
  661. }
  662. const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
  663. auto *LHS = ((SC *)this)->visit(Expr->getLHS());
  664. auto *RHS = ((SC *)this)->visit(Expr->getRHS());
  665. bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
  666. return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
  667. }
  668. const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
  669. SmallVector<const SCEV *, 2> Operands;
  670. bool Changed = false;
  671. for (auto *Op : Expr->operands()) {
  672. Operands.push_back(((SC*)this)->visit(Op));
  673. Changed |= Op != Operands.back();
  674. }
  675. return !Changed ? Expr
  676. : SE.getAddRecExpr(Operands, Expr->getLoop(),
  677. Expr->getNoWrapFlags());
  678. }
  679. const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
  680. SmallVector<const SCEV *, 2> Operands;
  681. bool Changed = false;
  682. for (auto *Op : Expr->operands()) {
  683. Operands.push_back(((SC *)this)->visit(Op));
  684. Changed |= Op != Operands.back();
  685. }
  686. return !Changed ? Expr : SE.getSMaxExpr(Operands);
  687. }
  688. const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
  689. SmallVector<const SCEV *, 2> Operands;
  690. bool Changed = false;
  691. for (auto *Op : Expr->operands()) {
  692. Operands.push_back(((SC*)this)->visit(Op));
  693. Changed |= Op != Operands.back();
  694. }
  695. return !Changed ? Expr : SE.getUMaxExpr(Operands);
  696. }
  697. const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
  698. SmallVector<const SCEV *, 2> Operands;
  699. bool Changed = false;
  700. for (auto *Op : Expr->operands()) {
  701. Operands.push_back(((SC *)this)->visit(Op));
  702. Changed |= Op != Operands.back();
  703. }
  704. return !Changed ? Expr : SE.getSMinExpr(Operands);
  705. }
  706. const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
  707. SmallVector<const SCEV *, 2> Operands;
  708. bool Changed = false;
  709. for (auto *Op : Expr->operands()) {
  710. Operands.push_back(((SC *)this)->visit(Op));
  711. Changed |= Op != Operands.back();
  712. }
  713. return !Changed ? Expr : SE.getUMinExpr(Operands);
  714. }
  715. const SCEV *visitUnknown(const SCEVUnknown *Expr) {
  716. return Expr;
  717. }
  718. const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
  719. return Expr;
  720. }
  721. };
  722. using ValueToValueMap = DenseMap<const Value *, Value *>;
  723. using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
  724. /// The SCEVParameterRewriter takes a scalar evolution expression and updates
  725. /// the SCEVUnknown components following the Map (Value -> SCEV).
  726. class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
  727. public:
  728. static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
  729. ValueToSCEVMapTy &Map) {
  730. SCEVParameterRewriter Rewriter(SE, Map);
  731. return Rewriter.visit(Scev);
  732. }
  733. SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
  734. : SCEVRewriteVisitor(SE), Map(M) {}
  735. const SCEV *visitUnknown(const SCEVUnknown *Expr) {
  736. auto I = Map.find(Expr->getValue());
  737. if (I == Map.end())
  738. return Expr;
  739. return I->second;
  740. }
  741. private:
  742. ValueToSCEVMapTy &Map;
  743. };
  744. using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
  745. /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
  746. /// the Map (Loop -> SCEV) to all AddRecExprs.
  747. class SCEVLoopAddRecRewriter
  748. : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
  749. public:
  750. SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
  751. : SCEVRewriteVisitor(SE), Map(M) {}
  752. static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
  753. ScalarEvolution &SE) {
  754. SCEVLoopAddRecRewriter Rewriter(SE, Map);
  755. return Rewriter.visit(Scev);
  756. }
  757. const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
  758. SmallVector<const SCEV *, 2> Operands;
  759. for (const SCEV *Op : Expr->operands())
  760. Operands.push_back(visit(Op));
  761. const Loop *L = Expr->getLoop();
  762. const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
  763. if (0 == Map.count(L))
  764. return Res;
  765. const SCEVAddRecExpr *Rec = cast<SCEVAddRecExpr>(Res);
  766. return Rec->evaluateAtIteration(Map[L], SE);
  767. }
  768. private:
  769. LoopToScevMapT &Map;
  770. };
  771. } // end namespace llvm
  772. #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
  773. #ifdef __GNUC__
  774. #pragma GCC diagnostic pop
  775. #endif