DurationFactoryScaleCheck.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. //===--- DurationFactoryScaleCheck.cpp - clang-tidy -----------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "DurationFactoryScaleCheck.h"
  9. #include "DurationRewriter.h"
  10. #include "clang/AST/ASTContext.h"
  11. #include "clang/ASTMatchers/ASTMatchFinder.h"
  12. #include "clang/Tooling/FixIt.h"
  13. #include <optional>
  14. using namespace clang::ast_matchers;
  15. namespace clang::tidy::abseil {
  16. // Given the name of a duration factory function, return the appropriate
  17. // `DurationScale` for that factory. If no factory can be found for
  18. // `FactoryName`, return `std::nullopt`.
  19. static std::optional<DurationScale>
  20. getScaleForFactory(llvm::StringRef FactoryName) {
  21. return llvm::StringSwitch<std::optional<DurationScale>>(FactoryName)
  22. .Case("Nanoseconds", DurationScale::Nanoseconds)
  23. .Case("Microseconds", DurationScale::Microseconds)
  24. .Case("Milliseconds", DurationScale::Milliseconds)
  25. .Case("Seconds", DurationScale::Seconds)
  26. .Case("Minutes", DurationScale::Minutes)
  27. .Case("Hours", DurationScale::Hours)
  28. .Default(std::nullopt);
  29. }
  30. // Given either an integer or float literal, return its value.
  31. // One and only one of `IntLit` and `FloatLit` should be provided.
  32. static double getValue(const IntegerLiteral *IntLit,
  33. const FloatingLiteral *FloatLit) {
  34. if (IntLit)
  35. return IntLit->getValue().getLimitedValue();
  36. assert(FloatLit != nullptr && "Neither IntLit nor FloatLit set");
  37. return FloatLit->getValueAsApproximateDouble();
  38. }
  39. // Given the scale of a duration and a `Multiplier`, determine if `Multiplier`
  40. // would produce a new scale. If so, return a tuple containing the new scale
  41. // and a suitable Multiplier for that scale, otherwise `std::nullopt`.
  42. static std::optional<std::tuple<DurationScale, double>>
  43. getNewScaleSingleStep(DurationScale OldScale, double Multiplier) {
  44. switch (OldScale) {
  45. case DurationScale::Hours:
  46. if (Multiplier <= 1.0 / 60.0)
  47. return std::make_tuple(DurationScale::Minutes, Multiplier * 60.0);
  48. break;
  49. case DurationScale::Minutes:
  50. if (Multiplier >= 60.0)
  51. return std::make_tuple(DurationScale::Hours, Multiplier / 60.0);
  52. if (Multiplier <= 1.0 / 60.0)
  53. return std::make_tuple(DurationScale::Seconds, Multiplier * 60.0);
  54. break;
  55. case DurationScale::Seconds:
  56. if (Multiplier >= 60.0)
  57. return std::make_tuple(DurationScale::Minutes, Multiplier / 60.0);
  58. if (Multiplier <= 1e-3)
  59. return std::make_tuple(DurationScale::Milliseconds, Multiplier * 1e3);
  60. break;
  61. case DurationScale::Milliseconds:
  62. if (Multiplier >= 1e3)
  63. return std::make_tuple(DurationScale::Seconds, Multiplier / 1e3);
  64. if (Multiplier <= 1e-3)
  65. return std::make_tuple(DurationScale::Microseconds, Multiplier * 1e3);
  66. break;
  67. case DurationScale::Microseconds:
  68. if (Multiplier >= 1e3)
  69. return std::make_tuple(DurationScale::Milliseconds, Multiplier / 1e3);
  70. if (Multiplier <= 1e-3)
  71. return std::make_tuple(DurationScale::Nanoseconds, Multiplier * 1e-3);
  72. break;
  73. case DurationScale::Nanoseconds:
  74. if (Multiplier >= 1e3)
  75. return std::make_tuple(DurationScale::Microseconds, Multiplier / 1e3);
  76. break;
  77. }
  78. return std::nullopt;
  79. }
  80. // Given the scale of a duration and a `Multiplier`, determine if `Multiplier`
  81. // would produce a new scale. If so, return it, otherwise `std::nullopt`.
  82. static std::optional<DurationScale> getNewScale(DurationScale OldScale,
  83. double Multiplier) {
  84. while (Multiplier != 1.0) {
  85. std::optional<std::tuple<DurationScale, double>> Result =
  86. getNewScaleSingleStep(OldScale, Multiplier);
  87. if (!Result)
  88. break;
  89. if (std::get<1>(*Result) == 1.0)
  90. return std::get<0>(*Result);
  91. Multiplier = std::get<1>(*Result);
  92. OldScale = std::get<0>(*Result);
  93. }
  94. return std::nullopt;
  95. }
  96. void DurationFactoryScaleCheck::registerMatchers(MatchFinder *Finder) {
  97. Finder->addMatcher(
  98. callExpr(
  99. callee(functionDecl(DurationFactoryFunction()).bind("call_decl")),
  100. hasArgument(
  101. 0,
  102. ignoringImpCasts(anyOf(
  103. cxxFunctionalCastExpr(
  104. hasDestinationType(
  105. anyOf(isInteger(), realFloatingPointType())),
  106. hasSourceExpression(initListExpr())),
  107. integerLiteral(equals(0)), floatLiteral(equals(0.0)),
  108. binaryOperator(hasOperatorName("*"),
  109. hasEitherOperand(ignoringImpCasts(
  110. anyOf(integerLiteral(), floatLiteral()))))
  111. .bind("mult_binop"),
  112. binaryOperator(hasOperatorName("/"), hasRHS(floatLiteral()))
  113. .bind("div_binop")))))
  114. .bind("call"),
  115. this);
  116. }
  117. void DurationFactoryScaleCheck::check(const MatchFinder::MatchResult &Result) {
  118. const auto *Call = Result.Nodes.getNodeAs<CallExpr>("call");
  119. // Don't try to replace things inside of macro definitions.
  120. if (Call->getExprLoc().isMacroID())
  121. return;
  122. const Expr *Arg = Call->getArg(0)->IgnoreParenImpCasts();
  123. // Arguments which are macros are ignored.
  124. if (Arg->getBeginLoc().isMacroID())
  125. return;
  126. // We first handle the cases of literal zero (both float and integer).
  127. if (isLiteralZero(Result, *Arg)) {
  128. diag(Call->getBeginLoc(),
  129. "use ZeroDuration() for zero-length time intervals")
  130. << FixItHint::CreateReplacement(Call->getSourceRange(),
  131. "absl::ZeroDuration()");
  132. return;
  133. }
  134. const auto *CallDecl = Result.Nodes.getNodeAs<FunctionDecl>("call_decl");
  135. std::optional<DurationScale> MaybeScale =
  136. getScaleForFactory(CallDecl->getName());
  137. if (!MaybeScale)
  138. return;
  139. DurationScale Scale = *MaybeScale;
  140. const Expr *Remainder;
  141. std::optional<DurationScale> NewScale;
  142. // We next handle the cases of multiplication and division.
  143. if (const auto *MultBinOp =
  144. Result.Nodes.getNodeAs<BinaryOperator>("mult_binop")) {
  145. // For multiplication, we need to look at both operands, and consider the
  146. // cases where a user is multiplying by something such as 1e-3.
  147. // First check the LHS
  148. const auto *IntLit = llvm::dyn_cast<IntegerLiteral>(MultBinOp->getLHS());
  149. const auto *FloatLit = llvm::dyn_cast<FloatingLiteral>(MultBinOp->getLHS());
  150. if (IntLit || FloatLit) {
  151. NewScale = getNewScale(Scale, getValue(IntLit, FloatLit));
  152. if (NewScale)
  153. Remainder = MultBinOp->getRHS();
  154. }
  155. // If we weren't able to scale based on the LHS, check the RHS
  156. if (!NewScale) {
  157. IntLit = llvm::dyn_cast<IntegerLiteral>(MultBinOp->getRHS());
  158. FloatLit = llvm::dyn_cast<FloatingLiteral>(MultBinOp->getRHS());
  159. if (IntLit || FloatLit) {
  160. NewScale = getNewScale(Scale, getValue(IntLit, FloatLit));
  161. if (NewScale)
  162. Remainder = MultBinOp->getLHS();
  163. }
  164. }
  165. } else if (const auto *DivBinOp =
  166. Result.Nodes.getNodeAs<BinaryOperator>("div_binop")) {
  167. // We next handle division.
  168. // For division, we only check the RHS.
  169. const auto *FloatLit = llvm::cast<FloatingLiteral>(DivBinOp->getRHS());
  170. std::optional<DurationScale> NewScale =
  171. getNewScale(Scale, 1.0 / FloatLit->getValueAsApproximateDouble());
  172. if (NewScale) {
  173. const Expr *Remainder = DivBinOp->getLHS();
  174. // We've found an appropriate scaling factor and the new scale, so output
  175. // the relevant fix.
  176. diag(Call->getBeginLoc(), "internal duration scaling can be removed")
  177. << FixItHint::CreateReplacement(
  178. Call->getSourceRange(),
  179. (llvm::Twine(getDurationFactoryForScale(*NewScale)) + "(" +
  180. tooling::fixit::getText(*Remainder, *Result.Context) + ")")
  181. .str());
  182. }
  183. }
  184. if (NewScale) {
  185. assert(Remainder && "No remainder found");
  186. // We've found an appropriate scaling factor and the new scale, so output
  187. // the relevant fix.
  188. diag(Call->getBeginLoc(), "internal duration scaling can be removed")
  189. << FixItHint::CreateReplacement(
  190. Call->getSourceRange(),
  191. (llvm::Twine(getDurationFactoryForScale(*NewScale)) + "(" +
  192. tooling::fixit::getText(*Remainder, *Result.Context) + ")")
  193. .str());
  194. }
  195. }
  196. } // namespace clang::tidy::abseil