ScalarEvolutionDivision.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines the class that knows how to divide SCEV's.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Analysis/ScalarEvolutionDivision.h"
  13. #include "llvm/ADT/APInt.h"
  14. #include "llvm/ADT/DenseMap.h"
  15. #include "llvm/ADT/SmallVector.h"
  16. #include "llvm/Analysis/ScalarEvolution.h"
  17. #include "llvm/Support/Casting.h"
  18. #include <cassert>
  19. #include <cstdint>
  20. namespace llvm {
  21. class Type;
  22. }
  23. using namespace llvm;
  24. namespace {
  25. static inline int sizeOfSCEV(const SCEV *S) {
  26. struct FindSCEVSize {
  27. int Size = 0;
  28. FindSCEVSize() = default;
  29. bool follow(const SCEV *S) {
  30. ++Size;
  31. // Keep looking at all operands of S.
  32. return true;
  33. }
  34. bool isDone() const { return false; }
  35. };
  36. FindSCEVSize F;
  37. SCEVTraversal<FindSCEVSize> ST(F);
  38. ST.visitAll(S);
  39. return F.Size;
  40. }
  41. } // namespace
  42. // Computes the Quotient and Remainder of the division of Numerator by
  43. // Denominator.
  44. void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
  45. const SCEV *Denominator, const SCEV **Quotient,
  46. const SCEV **Remainder) {
  47. assert(Numerator && Denominator && "Uninitialized SCEV");
  48. SCEVDivision D(SE, Numerator, Denominator);
  49. // Check for the trivial case here to avoid having to check for it in the
  50. // rest of the code.
  51. if (Numerator == Denominator) {
  52. *Quotient = D.One;
  53. *Remainder = D.Zero;
  54. return;
  55. }
  56. if (Numerator->isZero()) {
  57. *Quotient = D.Zero;
  58. *Remainder = D.Zero;
  59. return;
  60. }
  61. // A simple case when N/1. The quotient is N.
  62. if (Denominator->isOne()) {
  63. *Quotient = Numerator;
  64. *Remainder = D.Zero;
  65. return;
  66. }
  67. // Split the Denominator when it is a product.
  68. if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
  69. const SCEV *Q, *R;
  70. *Quotient = Numerator;
  71. for (const SCEV *Op : T->operands()) {
  72. divide(SE, *Quotient, Op, &Q, &R);
  73. *Quotient = Q;
  74. // Bail out when the Numerator is not divisible by one of the terms of
  75. // the Denominator.
  76. if (!R->isZero()) {
  77. *Quotient = D.Zero;
  78. *Remainder = Numerator;
  79. return;
  80. }
  81. }
  82. *Remainder = D.Zero;
  83. return;
  84. }
  85. D.visit(Numerator);
  86. *Quotient = D.Quotient;
  87. *Remainder = D.Remainder;
  88. }
  89. void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
  90. if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
  91. APInt NumeratorVal = Numerator->getAPInt();
  92. APInt DenominatorVal = D->getAPInt();
  93. uint32_t NumeratorBW = NumeratorVal.getBitWidth();
  94. uint32_t DenominatorBW = DenominatorVal.getBitWidth();
  95. if (NumeratorBW > DenominatorBW)
  96. DenominatorVal = DenominatorVal.sext(NumeratorBW);
  97. else if (NumeratorBW < DenominatorBW)
  98. NumeratorVal = NumeratorVal.sext(DenominatorBW);
  99. APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
  100. APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
  101. APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
  102. Quotient = SE.getConstant(QuotientVal);
  103. Remainder = SE.getConstant(RemainderVal);
  104. return;
  105. }
  106. }
  107. void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
  108. const SCEV *StartQ, *StartR, *StepQ, *StepR;
  109. if (!Numerator->isAffine())
  110. return cannotDivide(Numerator);
  111. divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
  112. divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
  113. // Bail out if the types do not match.
  114. Type *Ty = Denominator->getType();
  115. if (Ty != StartQ->getType() || Ty != StartR->getType() ||
  116. Ty != StepQ->getType() || Ty != StepR->getType())
  117. return cannotDivide(Numerator);
  118. Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
  119. Numerator->getNoWrapFlags());
  120. Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
  121. Numerator->getNoWrapFlags());
  122. }
  123. void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
  124. SmallVector<const SCEV *, 2> Qs, Rs;
  125. Type *Ty = Denominator->getType();
  126. for (const SCEV *Op : Numerator->operands()) {
  127. const SCEV *Q, *R;
  128. divide(SE, Op, Denominator, &Q, &R);
  129. // Bail out if types do not match.
  130. if (Ty != Q->getType() || Ty != R->getType())
  131. return cannotDivide(Numerator);
  132. Qs.push_back(Q);
  133. Rs.push_back(R);
  134. }
  135. if (Qs.size() == 1) {
  136. Quotient = Qs[0];
  137. Remainder = Rs[0];
  138. return;
  139. }
  140. Quotient = SE.getAddExpr(Qs);
  141. Remainder = SE.getAddExpr(Rs);
  142. }
  143. void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
  144. SmallVector<const SCEV *, 2> Qs;
  145. Type *Ty = Denominator->getType();
  146. bool FoundDenominatorTerm = false;
  147. for (const SCEV *Op : Numerator->operands()) {
  148. // Bail out if types do not match.
  149. if (Ty != Op->getType())
  150. return cannotDivide(Numerator);
  151. if (FoundDenominatorTerm) {
  152. Qs.push_back(Op);
  153. continue;
  154. }
  155. // Check whether Denominator divides one of the product operands.
  156. const SCEV *Q, *R;
  157. divide(SE, Op, Denominator, &Q, &R);
  158. if (!R->isZero()) {
  159. Qs.push_back(Op);
  160. continue;
  161. }
  162. // Bail out if types do not match.
  163. if (Ty != Q->getType())
  164. return cannotDivide(Numerator);
  165. FoundDenominatorTerm = true;
  166. Qs.push_back(Q);
  167. }
  168. if (FoundDenominatorTerm) {
  169. Remainder = Zero;
  170. if (Qs.size() == 1)
  171. Quotient = Qs[0];
  172. else
  173. Quotient = SE.getMulExpr(Qs);
  174. return;
  175. }
  176. if (!isa<SCEVUnknown>(Denominator))
  177. return cannotDivide(Numerator);
  178. // The Remainder is obtained by replacing Denominator by 0 in Numerator.
  179. ValueToSCEVMapTy RewriteMap;
  180. RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
  181. Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
  182. if (Remainder->isZero()) {
  183. // The Quotient is obtained by replacing Denominator by 1 in Numerator.
  184. RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
  185. Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
  186. return;
  187. }
  188. // Quotient is (Numerator - Remainder) divided by Denominator.
  189. const SCEV *Q, *R;
  190. const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
  191. // This SCEV does not seem to simplify: fail the division here.
  192. if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
  193. return cannotDivide(Numerator);
  194. divide(SE, Diff, Denominator, &Q, &R);
  195. if (R != Zero)
  196. return cannotDivide(Numerator);
  197. Quotient = Q;
  198. }
  199. SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
  200. const SCEV *Denominator)
  201. : SE(S), Denominator(Denominator) {
  202. Zero = SE.getZero(Denominator->getType());
  203. One = SE.getOne(Denominator->getType());
  204. // We generally do not know how to divide Expr by Denominator. We initialize
  205. // the division to a "cannot divide" state to simplify the rest of the code.
  206. cannotDivide(Numerator);
  207. }
  208. // Convenience function for giving up on the division. We set the quotient to
  209. // be equal to zero and the remainder to be equal to the numerator.
  210. void SCEVDivision::cannotDivide(const SCEV *Numerator) {
  211. Quotient = Zero;
  212. Remainder = Numerator;
  213. }