ScalarEvolutionNormalization.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements utilities for working with "normalized" expressions.
  10. // See the comments at the top of ScalarEvolutionNormalization.h for details.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/Analysis/ScalarEvolutionNormalization.h"
  14. #include "llvm/Analysis/LoopInfo.h"
  15. #include "llvm/Analysis/ScalarEvolution.h"
  16. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  17. using namespace llvm;
  18. /// TransformKind - Different types of transformations that
  19. /// TransformForPostIncUse can do.
  20. enum TransformKind {
  21. /// Normalize - Normalize according to the given loops.
  22. Normalize,
  23. /// Denormalize - Perform the inverse transform on the expression with the
  24. /// given loop set.
  25. Denormalize
  26. };
  27. namespace {
  28. struct NormalizeDenormalizeRewriter
  29. : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
  30. const TransformKind Kind;
  31. // NB! Pred is a function_ref. Storing it here is okay only because
  32. // we're careful about the lifetime of NormalizeDenormalizeRewriter.
  33. const NormalizePredTy Pred;
  34. NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
  35. ScalarEvolution &SE)
  36. : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
  37. Pred(Pred) {}
  38. const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
  39. };
  40. } // namespace
  41. const SCEV *
  42. NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
  43. SmallVector<const SCEV *, 8> Operands;
  44. transform(AR->operands(), std::back_inserter(Operands),
  45. [&](const SCEV *Op) { return visit(Op); });
  46. if (!Pred(AR))
  47. return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
  48. // Normalization and denormalization are fancy names for decrementing and
  49. // incrementing a SCEV expression with respect to a set of loops. Since
  50. // Pred(AR) has returned true, we know we need to normalize or denormalize AR
  51. // with respect to its loop.
  52. if (Kind == Denormalize) {
  53. // Denormalization / "partial increment" is essentially the same as \c
  54. // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the
  55. // symmetry with Normalization clear.
  56. for (int i = 0, e = Operands.size() - 1; i < e; i++)
  57. Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
  58. } else {
  59. assert(Kind == Normalize && "Only two possibilities!");
  60. // Normalization / "partial decrement" is a bit more subtle. Since
  61. // incrementing a SCEV expression (in general) changes the step of the SCEV
  62. // expression as well, we cannot use the step of the current expression.
  63. // Instead, we have to use the step of the very expression we're trying to
  64. // compute!
  65. //
  66. // We solve the issue by recursively building up the result, starting from
  67. // the "least significant" operand in the add recurrence:
  68. //
  69. // Base case:
  70. // Single operand add recurrence. It's its own normalization.
  71. //
  72. // N-operand case:
  73. // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
  74. //
  75. // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
  76. // normalization by induction. We subtract the normalized step
  77. // recurrence from S_{N-1} to get the normalization of S.
  78. for (int i = Operands.size() - 2; i >= 0; i--)
  79. Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
  80. }
  81. return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
  82. }
  83. const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
  84. const PostIncLoopSet &Loops,
  85. ScalarEvolution &SE) {
  86. auto Pred = [&](const SCEVAddRecExpr *AR) {
  87. return Loops.count(AR->getLoop());
  88. };
  89. return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
  90. }
  91. const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
  92. ScalarEvolution &SE) {
  93. return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
  94. }
  95. const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
  96. const PostIncLoopSet &Loops,
  97. ScalarEvolution &SE) {
  98. auto Pred = [&](const SCEVAddRecExpr *AR) {
  99. return Loops.count(AR->getLoop());
  100. };
  101. return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
  102. }