ScalarEvolutionDivision.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines the class that knows how to divide SCEV's.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Analysis/ScalarEvolutionDivision.h"
  13. #include "llvm/ADT/APInt.h"
  14. #include "llvm/ADT/DenseMap.h"
  15. #include "llvm/ADT/SmallVector.h"
  16. #include "llvm/Analysis/ScalarEvolution.h"
  17. #include "llvm/IR/Constants.h"
  18. #include "llvm/Support/Casting.h"
  19. #include "llvm/Support/ErrorHandling.h"
  20. #include <cassert>
  21. #include <cstdint>
  22. namespace llvm {
  23. class Type;
  24. }
  25. using namespace llvm;
  26. namespace {
  27. static inline int sizeOfSCEV(const SCEV *S) {
  28. struct FindSCEVSize {
  29. int Size = 0;
  30. FindSCEVSize() = default;
  31. bool follow(const SCEV *S) {
  32. ++Size;
  33. // Keep looking at all operands of S.
  34. return true;
  35. }
  36. bool isDone() const { return false; }
  37. };
  38. FindSCEVSize F;
  39. SCEVTraversal<FindSCEVSize> ST(F);
  40. ST.visitAll(S);
  41. return F.Size;
  42. }
  43. } // namespace
  44. // Computes the Quotient and Remainder of the division of Numerator by
  45. // Denominator.
  46. void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
  47. const SCEV *Denominator, const SCEV **Quotient,
  48. const SCEV **Remainder) {
  49. assert(Numerator && Denominator && "Uninitialized SCEV");
  50. SCEVDivision D(SE, Numerator, Denominator);
  51. // Check for the trivial case here to avoid having to check for it in the
  52. // rest of the code.
  53. if (Numerator == Denominator) {
  54. *Quotient = D.One;
  55. *Remainder = D.Zero;
  56. return;
  57. }
  58. if (Numerator->isZero()) {
  59. *Quotient = D.Zero;
  60. *Remainder = D.Zero;
  61. return;
  62. }
  63. // A simple case when N/1. The quotient is N.
  64. if (Denominator->isOne()) {
  65. *Quotient = Numerator;
  66. *Remainder = D.Zero;
  67. return;
  68. }
  69. // Split the Denominator when it is a product.
  70. if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
  71. const SCEV *Q, *R;
  72. *Quotient = Numerator;
  73. for (const SCEV *Op : T->operands()) {
  74. divide(SE, *Quotient, Op, &Q, &R);
  75. *Quotient = Q;
  76. // Bail out when the Numerator is not divisible by one of the terms of
  77. // the Denominator.
  78. if (!R->isZero()) {
  79. *Quotient = D.Zero;
  80. *Remainder = Numerator;
  81. return;
  82. }
  83. }
  84. *Remainder = D.Zero;
  85. return;
  86. }
  87. D.visit(Numerator);
  88. *Quotient = D.Quotient;
  89. *Remainder = D.Remainder;
  90. }
  91. void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
  92. if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
  93. APInt NumeratorVal = Numerator->getAPInt();
  94. APInt DenominatorVal = D->getAPInt();
  95. uint32_t NumeratorBW = NumeratorVal.getBitWidth();
  96. uint32_t DenominatorBW = DenominatorVal.getBitWidth();
  97. if (NumeratorBW > DenominatorBW)
  98. DenominatorVal = DenominatorVal.sext(NumeratorBW);
  99. else if (NumeratorBW < DenominatorBW)
  100. NumeratorVal = NumeratorVal.sext(DenominatorBW);
  101. APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
  102. APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
  103. APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
  104. Quotient = SE.getConstant(QuotientVal);
  105. Remainder = SE.getConstant(RemainderVal);
  106. return;
  107. }
  108. }
  109. void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
  110. const SCEV *StartQ, *StartR, *StepQ, *StepR;
  111. if (!Numerator->isAffine())
  112. return cannotDivide(Numerator);
  113. divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
  114. divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
  115. // Bail out if the types do not match.
  116. Type *Ty = Denominator->getType();
  117. if (Ty != StartQ->getType() || Ty != StartR->getType() ||
  118. Ty != StepQ->getType() || Ty != StepR->getType())
  119. return cannotDivide(Numerator);
  120. Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
  121. Numerator->getNoWrapFlags());
  122. Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
  123. Numerator->getNoWrapFlags());
  124. }
  125. void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
  126. SmallVector<const SCEV *, 2> Qs, Rs;
  127. Type *Ty = Denominator->getType();
  128. for (const SCEV *Op : Numerator->operands()) {
  129. const SCEV *Q, *R;
  130. divide(SE, Op, Denominator, &Q, &R);
  131. // Bail out if types do not match.
  132. if (Ty != Q->getType() || Ty != R->getType())
  133. return cannotDivide(Numerator);
  134. Qs.push_back(Q);
  135. Rs.push_back(R);
  136. }
  137. if (Qs.size() == 1) {
  138. Quotient = Qs[0];
  139. Remainder = Rs[0];
  140. return;
  141. }
  142. Quotient = SE.getAddExpr(Qs);
  143. Remainder = SE.getAddExpr(Rs);
  144. }
  145. void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
  146. SmallVector<const SCEV *, 2> Qs;
  147. Type *Ty = Denominator->getType();
  148. bool FoundDenominatorTerm = false;
  149. for (const SCEV *Op : Numerator->operands()) {
  150. // Bail out if types do not match.
  151. if (Ty != Op->getType())
  152. return cannotDivide(Numerator);
  153. if (FoundDenominatorTerm) {
  154. Qs.push_back(Op);
  155. continue;
  156. }
  157. // Check whether Denominator divides one of the product operands.
  158. const SCEV *Q, *R;
  159. divide(SE, Op, Denominator, &Q, &R);
  160. if (!R->isZero()) {
  161. Qs.push_back(Op);
  162. continue;
  163. }
  164. // Bail out if types do not match.
  165. if (Ty != Q->getType())
  166. return cannotDivide(Numerator);
  167. FoundDenominatorTerm = true;
  168. Qs.push_back(Q);
  169. }
  170. if (FoundDenominatorTerm) {
  171. Remainder = Zero;
  172. if (Qs.size() == 1)
  173. Quotient = Qs[0];
  174. else
  175. Quotient = SE.getMulExpr(Qs);
  176. return;
  177. }
  178. if (!isa<SCEVUnknown>(Denominator))
  179. return cannotDivide(Numerator);
  180. // The Remainder is obtained by replacing Denominator by 0 in Numerator.
  181. ValueToSCEVMapTy RewriteMap;
  182. RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
  183. Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
  184. if (Remainder->isZero()) {
  185. // The Quotient is obtained by replacing Denominator by 1 in Numerator.
  186. RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
  187. Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
  188. return;
  189. }
  190. // Quotient is (Numerator - Remainder) divided by Denominator.
  191. const SCEV *Q, *R;
  192. const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
  193. // This SCEV does not seem to simplify: fail the division here.
  194. if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
  195. return cannotDivide(Numerator);
  196. divide(SE, Diff, Denominator, &Q, &R);
  197. if (R != Zero)
  198. return cannotDivide(Numerator);
  199. Quotient = Q;
  200. }
  201. SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
  202. const SCEV *Denominator)
  203. : SE(S), Denominator(Denominator) {
  204. Zero = SE.getZero(Denominator->getType());
  205. One = SE.getOne(Denominator->getType());
  206. // We generally do not know how to divide Expr by Denominator. We initialize
  207. // the division to a "cannot divide" state to simplify the rest of the code.
  208. cannotDivide(Numerator);
  209. }
  210. // Convenience function for giving up on the division. We set the quotient to
  211. // be equal to zero and the remainder to be equal to the numerator.
  212. void SCEVDivision::cannotDivide(const SCEV *Numerator) {
  213. Quotient = Zero;
  214. Remainder = Numerator;
  215. }