InstructionCost.h 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- InstructionCost.h ----------------------------------------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. /// \file
  14. /// This file defines an InstructionCost class that is used when calculating
  15. /// the cost of an instruction, or a group of instructions. In addition to a
  16. /// numeric value representing the cost the class also contains a state that
  17. /// can be used to encode particular properties, such as a cost being invalid.
  18. /// Operations on InstructionCost implement saturation arithmetic, so that
  19. /// accumulating costs on large cost-values don't overflow.
  20. ///
  21. //===----------------------------------------------------------------------===//
  22. #ifndef LLVM_SUPPORT_INSTRUCTIONCOST_H
  23. #define LLVM_SUPPORT_INSTRUCTIONCOST_H
  24. #include "llvm/ADT/Optional.h"
  25. #include "llvm/Support/MathExtras.h"
  26. #include <limits>
  27. namespace llvm {
  28. class raw_ostream;
  29. class InstructionCost {
  30. public:
  31. using CostType = int64_t;
  32. /// CostState describes the state of a cost.
  33. enum CostState {
  34. Valid, /// < The cost value represents a valid cost, even when the
  35. /// cost-value is large.
  36. Invalid /// < Invalid indicates there is no way to represent the cost as a
  37. /// numeric value. This state exists to represent a possible issue,
  38. /// e.g. if the cost-model knows the operation cannot be expanded
  39. /// into a valid code-sequence by the code-generator. While some
  40. /// passes may assert that the calculated cost must be valid, it is
  41. /// up to individual passes how to interpret an Invalid cost. For
  42. /// example, a transformation pass could choose not to perform a
  43. /// transformation if the resulting cost would end up Invalid.
  44. /// Because some passes may assert a cost is Valid, it is not
  45. /// recommended to use Invalid costs to model 'Unknown'.
  46. /// Note that Invalid is semantically different from a (very) high,
  47. /// but valid cost, which intentionally indicates no issue, but
  48. /// rather a strong preference not to select a certain operation.
  49. };
  50. private:
  51. CostType Value = 0;
  52. CostState State = Valid;
  53. void propagateState(const InstructionCost &RHS) {
  54. if (RHS.State == Invalid)
  55. State = Invalid;
  56. }
  57. static CostType getMaxValue() { return std::numeric_limits<CostType>::max(); }
  58. static CostType getMinValue() { return std::numeric_limits<CostType>::min(); }
  59. public:
  60. // A default constructed InstructionCost is a valid zero cost
  61. InstructionCost() = default;
  62. InstructionCost(CostState) = delete;
  63. InstructionCost(CostType Val) : Value(Val), State(Valid) {}
  64. static InstructionCost getMax() { return getMaxValue(); }
  65. static InstructionCost getMin() { return getMinValue(); }
  66. static InstructionCost getInvalid(CostType Val = 0) {
  67. InstructionCost Tmp(Val);
  68. Tmp.setInvalid();
  69. return Tmp;
  70. }
  71. bool isValid() const { return State == Valid; }
  72. void setValid() { State = Valid; }
  73. void setInvalid() { State = Invalid; }
  74. CostState getState() const { return State; }
  75. /// This function is intended to be used as sparingly as possible, since the
  76. /// class provides the full range of operator support required for arithmetic
  77. /// and comparisons.
  78. Optional<CostType> getValue() const {
  79. if (isValid())
  80. return Value;
  81. return None;
  82. }
  83. /// For all of the arithmetic operators provided here any invalid state is
  84. /// perpetuated and cannot be removed. Once a cost becomes invalid it stays
  85. /// invalid, and it also inherits any invalid state from the RHS.
  86. /// Arithmetic work on the actual values is implemented with saturation,
  87. /// to avoid overflow when using more extreme cost values.
  88. InstructionCost &operator+=(const InstructionCost &RHS) {
  89. propagateState(RHS);
  90. // Saturating addition.
  91. InstructionCost::CostType Result;
  92. if (AddOverflow(Value, RHS.Value, Result))
  93. Result = RHS.Value > 0 ? getMaxValue() : getMinValue();
  94. Value = Result;
  95. return *this;
  96. }
  97. InstructionCost &operator+=(const CostType RHS) {
  98. InstructionCost RHS2(RHS);
  99. *this += RHS2;
  100. return *this;
  101. }
  102. InstructionCost &operator-=(const InstructionCost &RHS) {
  103. propagateState(RHS);
  104. // Saturating subtract.
  105. InstructionCost::CostType Result;
  106. if (SubOverflow(Value, RHS.Value, Result))
  107. Result = RHS.Value > 0 ? getMinValue() : getMaxValue();
  108. Value = Result;
  109. return *this;
  110. }
  111. InstructionCost &operator-=(const CostType RHS) {
  112. InstructionCost RHS2(RHS);
  113. *this -= RHS2;
  114. return *this;
  115. }
  116. InstructionCost &operator*=(const InstructionCost &RHS) {
  117. propagateState(RHS);
  118. // Saturating multiply.
  119. InstructionCost::CostType Result;
  120. if (MulOverflow(Value, RHS.Value, Result)) {
  121. if ((Value > 0 && RHS.Value > 0) || (Value < 0 && RHS.Value < 0))
  122. Result = getMaxValue();
  123. else
  124. Result = getMinValue();
  125. }
  126. Value = Result;
  127. return *this;
  128. }
  129. InstructionCost &operator*=(const CostType RHS) {
  130. InstructionCost RHS2(RHS);
  131. *this *= RHS2;
  132. return *this;
  133. }
  134. InstructionCost &operator/=(const InstructionCost &RHS) {
  135. propagateState(RHS);
  136. Value /= RHS.Value;
  137. return *this;
  138. }
  139. InstructionCost &operator/=(const CostType RHS) {
  140. InstructionCost RHS2(RHS);
  141. *this /= RHS2;
  142. return *this;
  143. }
  144. InstructionCost &operator++() {
  145. *this += 1;
  146. return *this;
  147. }
  148. InstructionCost operator++(int) {
  149. InstructionCost Copy = *this;
  150. ++*this;
  151. return Copy;
  152. }
  153. InstructionCost &operator--() {
  154. *this -= 1;
  155. return *this;
  156. }
  157. InstructionCost operator--(int) {
  158. InstructionCost Copy = *this;
  159. --*this;
  160. return Copy;
  161. }
  162. /// For the comparison operators we have chosen to use lexicographical
  163. /// ordering where valid costs are always considered to be less than invalid
  164. /// costs. This avoids having to add asserts to the comparison operators that
  165. /// the states are valid and users can test for validity of the cost
  166. /// explicitly.
  167. bool operator<(const InstructionCost &RHS) const {
  168. if (State != RHS.State)
  169. return State < RHS.State;
  170. return Value < RHS.Value;
  171. }
  172. // Implement in terms of operator< to ensure that the two comparisons stay in
  173. // sync
  174. bool operator==(const InstructionCost &RHS) const {
  175. return !(*this < RHS) && !(RHS < *this);
  176. }
  177. bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); }
  178. bool operator==(const CostType RHS) const {
  179. InstructionCost RHS2(RHS);
  180. return *this == RHS2;
  181. }
  182. bool operator!=(const CostType RHS) const { return !(*this == RHS); }
  183. bool operator>(const InstructionCost &RHS) const { return RHS < *this; }
  184. bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); }
  185. bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); }
  186. bool operator<(const CostType RHS) const {
  187. InstructionCost RHS2(RHS);
  188. return *this < RHS2;
  189. }
  190. bool operator>(const CostType RHS) const {
  191. InstructionCost RHS2(RHS);
  192. return *this > RHS2;
  193. }
  194. bool operator<=(const CostType RHS) const {
  195. InstructionCost RHS2(RHS);
  196. return *this <= RHS2;
  197. }
  198. bool operator>=(const CostType RHS) const {
  199. InstructionCost RHS2(RHS);
  200. return *this >= RHS2;
  201. }
  202. void print(raw_ostream &OS) const;
  203. template <class Function>
  204. auto map(const Function &F) const -> InstructionCost {
  205. if (isValid())
  206. return F(*getValue());
  207. return getInvalid();
  208. }
  209. };
  210. inline InstructionCost operator+(const InstructionCost &LHS,
  211. const InstructionCost &RHS) {
  212. InstructionCost LHS2(LHS);
  213. LHS2 += RHS;
  214. return LHS2;
  215. }
  216. inline InstructionCost operator-(const InstructionCost &LHS,
  217. const InstructionCost &RHS) {
  218. InstructionCost LHS2(LHS);
  219. LHS2 -= RHS;
  220. return LHS2;
  221. }
  222. inline InstructionCost operator*(const InstructionCost &LHS,
  223. const InstructionCost &RHS) {
  224. InstructionCost LHS2(LHS);
  225. LHS2 *= RHS;
  226. return LHS2;
  227. }
  228. inline InstructionCost operator/(const InstructionCost &LHS,
  229. const InstructionCost &RHS) {
  230. InstructionCost LHS2(LHS);
  231. LHS2 /= RHS;
  232. return LHS2;
  233. }
  234. inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) {
  235. V.print(OS);
  236. return OS;
  237. }
  238. } // namespace llvm
  239. #endif
  240. #ifdef __GNUC__
  241. #pragma GCC diagnostic pop
  242. #endif