123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- // This file defines the class that knows how to divide SCEV's.
- //
- //===----------------------------------------------------------------------===//
- #include "llvm/Analysis/ScalarEvolutionDivision.h"
- #include "llvm/ADT/APInt.h"
- #include "llvm/ADT/DenseMap.h"
- #include "llvm/ADT/SmallVector.h"
- #include "llvm/Analysis/ScalarEvolution.h"
- #include "llvm/Support/Casting.h"
- #include <cassert>
- #include <cstdint>
- namespace llvm {
- class Type;
- }
- using namespace llvm;
- namespace {
- static inline int sizeOfSCEV(const SCEV *S) {
- struct FindSCEVSize {
- int Size = 0;
- FindSCEVSize() = default;
- bool follow(const SCEV *S) {
- ++Size;
- // Keep looking at all operands of S.
- return true;
- }
- bool isDone() const { return false; }
- };
- FindSCEVSize F;
- SCEVTraversal<FindSCEVSize> ST(F);
- ST.visitAll(S);
- return F.Size;
- }
- } // namespace
- // Computes the Quotient and Remainder of the division of Numerator by
- // Denominator.
- void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
- const SCEV *Denominator, const SCEV **Quotient,
- const SCEV **Remainder) {
- assert(Numerator && Denominator && "Uninitialized SCEV");
- SCEVDivision D(SE, Numerator, Denominator);
- // Check for the trivial case here to avoid having to check for it in the
- // rest of the code.
- if (Numerator == Denominator) {
- *Quotient = D.One;
- *Remainder = D.Zero;
- return;
- }
- if (Numerator->isZero()) {
- *Quotient = D.Zero;
- *Remainder = D.Zero;
- return;
- }
- // A simple case when N/1. The quotient is N.
- if (Denominator->isOne()) {
- *Quotient = Numerator;
- *Remainder = D.Zero;
- return;
- }
- // Split the Denominator when it is a product.
- if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
- const SCEV *Q, *R;
- *Quotient = Numerator;
- for (const SCEV *Op : T->operands()) {
- divide(SE, *Quotient, Op, &Q, &R);
- *Quotient = Q;
- // Bail out when the Numerator is not divisible by one of the terms of
- // the Denominator.
- if (!R->isZero()) {
- *Quotient = D.Zero;
- *Remainder = Numerator;
- return;
- }
- }
- *Remainder = D.Zero;
- return;
- }
- D.visit(Numerator);
- *Quotient = D.Quotient;
- *Remainder = D.Remainder;
- }
- void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
- if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
- APInt NumeratorVal = Numerator->getAPInt();
- APInt DenominatorVal = D->getAPInt();
- uint32_t NumeratorBW = NumeratorVal.getBitWidth();
- uint32_t DenominatorBW = DenominatorVal.getBitWidth();
- if (NumeratorBW > DenominatorBW)
- DenominatorVal = DenominatorVal.sext(NumeratorBW);
- else if (NumeratorBW < DenominatorBW)
- NumeratorVal = NumeratorVal.sext(DenominatorBW);
- APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
- APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
- APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
- Quotient = SE.getConstant(QuotientVal);
- Remainder = SE.getConstant(RemainderVal);
- return;
- }
- }
- void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
- const SCEV *StartQ, *StartR, *StepQ, *StepR;
- if (!Numerator->isAffine())
- return cannotDivide(Numerator);
- divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
- divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
- // Bail out if the types do not match.
- Type *Ty = Denominator->getType();
- if (Ty != StartQ->getType() || Ty != StartR->getType() ||
- Ty != StepQ->getType() || Ty != StepR->getType())
- return cannotDivide(Numerator);
- Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
- Numerator->getNoWrapFlags());
- Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
- Numerator->getNoWrapFlags());
- }
- void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
- SmallVector<const SCEV *, 2> Qs, Rs;
- Type *Ty = Denominator->getType();
- for (const SCEV *Op : Numerator->operands()) {
- const SCEV *Q, *R;
- divide(SE, Op, Denominator, &Q, &R);
- // Bail out if types do not match.
- if (Ty != Q->getType() || Ty != R->getType())
- return cannotDivide(Numerator);
- Qs.push_back(Q);
- Rs.push_back(R);
- }
- if (Qs.size() == 1) {
- Quotient = Qs[0];
- Remainder = Rs[0];
- return;
- }
- Quotient = SE.getAddExpr(Qs);
- Remainder = SE.getAddExpr(Rs);
- }
- void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
- SmallVector<const SCEV *, 2> Qs;
- Type *Ty = Denominator->getType();
- bool FoundDenominatorTerm = false;
- for (const SCEV *Op : Numerator->operands()) {
- // Bail out if types do not match.
- if (Ty != Op->getType())
- return cannotDivide(Numerator);
- if (FoundDenominatorTerm) {
- Qs.push_back(Op);
- continue;
- }
- // Check whether Denominator divides one of the product operands.
- const SCEV *Q, *R;
- divide(SE, Op, Denominator, &Q, &R);
- if (!R->isZero()) {
- Qs.push_back(Op);
- continue;
- }
- // Bail out if types do not match.
- if (Ty != Q->getType())
- return cannotDivide(Numerator);
- FoundDenominatorTerm = true;
- Qs.push_back(Q);
- }
- if (FoundDenominatorTerm) {
- Remainder = Zero;
- if (Qs.size() == 1)
- Quotient = Qs[0];
- else
- Quotient = SE.getMulExpr(Qs);
- return;
- }
- if (!isa<SCEVUnknown>(Denominator))
- return cannotDivide(Numerator);
- // The Remainder is obtained by replacing Denominator by 0 in Numerator.
- ValueToSCEVMapTy RewriteMap;
- RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
- Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
- if (Remainder->isZero()) {
- // The Quotient is obtained by replacing Denominator by 1 in Numerator.
- RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
- Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
- return;
- }
- // Quotient is (Numerator - Remainder) divided by Denominator.
- const SCEV *Q, *R;
- const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
- // This SCEV does not seem to simplify: fail the division here.
- if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
- return cannotDivide(Numerator);
- divide(SE, Diff, Denominator, &Q, &R);
- if (R != Zero)
- return cannotDivide(Numerator);
- Quotient = Q;
- }
- SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
- const SCEV *Denominator)
- : SE(S), Denominator(Denominator) {
- Zero = SE.getZero(Denominator->getType());
- One = SE.getOne(Denominator->getType());
- // We generally do not know how to divide Expr by Denominator. We initialize
- // the division to a "cannot divide" state to simplify the rest of the code.
- cannotDivide(Numerator);
- }
- // Convenience function for giving up on the division. We set the quotient to
- // be equal to zero and the remainder to be equal to the numerator.
- void SCEVDivision::cannotDivide(const SCEV *Numerator) {
- Quotient = Zero;
- Remainder = Numerator;
- }
|