DurationRewriter.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. //===--- DurationRewriter.cpp - clang-tidy --------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include <cmath>
  9. #include <optional>
  10. #include "DurationRewriter.h"
  11. #include "clang/Tooling/FixIt.h"
  12. #include "llvm/ADT/IndexedMap.h"
  13. using namespace clang::ast_matchers;
  14. namespace clang::tidy::abseil {
  15. struct DurationScale2IndexFunctor {
  16. using argument_type = DurationScale;
  17. unsigned operator()(DurationScale Scale) const {
  18. return static_cast<unsigned>(Scale);
  19. }
  20. };
  21. /// Returns an integer if the fractional part of a `FloatingLiteral` is `0`.
  22. static std::optional<llvm::APSInt>
  23. truncateIfIntegral(const FloatingLiteral &FloatLiteral) {
  24. double Value = FloatLiteral.getValueAsApproximateDouble();
  25. if (std::fmod(Value, 1) == 0) {
  26. if (Value >= static_cast<double>(1u << 31))
  27. return std::nullopt;
  28. return llvm::APSInt::get(static_cast<int64_t>(Value));
  29. }
  30. return std::nullopt;
  31. }
  32. const std::pair<llvm::StringRef, llvm::StringRef> &
  33. getDurationInverseForScale(DurationScale Scale) {
  34. static const llvm::IndexedMap<std::pair<llvm::StringRef, llvm::StringRef>,
  35. DurationScale2IndexFunctor>
  36. InverseMap = []() {
  37. // TODO: Revisit the immediately invoked lambda technique when
  38. // IndexedMap gets an initializer list constructor.
  39. llvm::IndexedMap<std::pair<llvm::StringRef, llvm::StringRef>,
  40. DurationScale2IndexFunctor>
  41. InverseMap;
  42. InverseMap.resize(6);
  43. InverseMap[DurationScale::Hours] =
  44. std::make_pair("::absl::ToDoubleHours", "::absl::ToInt64Hours");
  45. InverseMap[DurationScale::Minutes] =
  46. std::make_pair("::absl::ToDoubleMinutes", "::absl::ToInt64Minutes");
  47. InverseMap[DurationScale::Seconds] =
  48. std::make_pair("::absl::ToDoubleSeconds", "::absl::ToInt64Seconds");
  49. InverseMap[DurationScale::Milliseconds] = std::make_pair(
  50. "::absl::ToDoubleMilliseconds", "::absl::ToInt64Milliseconds");
  51. InverseMap[DurationScale::Microseconds] = std::make_pair(
  52. "::absl::ToDoubleMicroseconds", "::absl::ToInt64Microseconds");
  53. InverseMap[DurationScale::Nanoseconds] = std::make_pair(
  54. "::absl::ToDoubleNanoseconds", "::absl::ToInt64Nanoseconds");
  55. return InverseMap;
  56. }();
  57. return InverseMap[Scale];
  58. }
  59. /// If `Node` is a call to the inverse of `Scale`, return that inverse's
  60. /// argument, otherwise std::nullopt.
  61. static std::optional<std::string>
  62. rewriteInverseDurationCall(const MatchFinder::MatchResult &Result,
  63. DurationScale Scale, const Expr &Node) {
  64. const std::pair<llvm::StringRef, llvm::StringRef> &InverseFunctions =
  65. getDurationInverseForScale(Scale);
  66. if (const auto *MaybeCallArg = selectFirst<const Expr>(
  67. "e",
  68. match(callExpr(callee(functionDecl(hasAnyName(
  69. InverseFunctions.first, InverseFunctions.second))),
  70. hasArgument(0, expr().bind("e"))),
  71. Node, *Result.Context))) {
  72. return tooling::fixit::getText(*MaybeCallArg, *Result.Context).str();
  73. }
  74. return std::nullopt;
  75. }
  76. /// If `Node` is a call to the inverse of `Scale`, return that inverse's
  77. /// argument, otherwise std::nullopt.
  78. static std::optional<std::string>
  79. rewriteInverseTimeCall(const MatchFinder::MatchResult &Result,
  80. DurationScale Scale, const Expr &Node) {
  81. llvm::StringRef InverseFunction = getTimeInverseForScale(Scale);
  82. if (const auto *MaybeCallArg = selectFirst<const Expr>(
  83. "e", match(callExpr(callee(functionDecl(hasName(InverseFunction))),
  84. hasArgument(0, expr().bind("e"))),
  85. Node, *Result.Context))) {
  86. return tooling::fixit::getText(*MaybeCallArg, *Result.Context).str();
  87. }
  88. return std::nullopt;
  89. }
  90. /// Returns the factory function name for a given `Scale`.
  91. llvm::StringRef getDurationFactoryForScale(DurationScale Scale) {
  92. switch (Scale) {
  93. case DurationScale::Hours:
  94. return "absl::Hours";
  95. case DurationScale::Minutes:
  96. return "absl::Minutes";
  97. case DurationScale::Seconds:
  98. return "absl::Seconds";
  99. case DurationScale::Milliseconds:
  100. return "absl::Milliseconds";
  101. case DurationScale::Microseconds:
  102. return "absl::Microseconds";
  103. case DurationScale::Nanoseconds:
  104. return "absl::Nanoseconds";
  105. }
  106. llvm_unreachable("unknown scaling factor");
  107. }
  108. llvm::StringRef getTimeFactoryForScale(DurationScale Scale) {
  109. switch (Scale) {
  110. case DurationScale::Hours:
  111. return "absl::FromUnixHours";
  112. case DurationScale::Minutes:
  113. return "absl::FromUnixMinutes";
  114. case DurationScale::Seconds:
  115. return "absl::FromUnixSeconds";
  116. case DurationScale::Milliseconds:
  117. return "absl::FromUnixMillis";
  118. case DurationScale::Microseconds:
  119. return "absl::FromUnixMicros";
  120. case DurationScale::Nanoseconds:
  121. return "absl::FromUnixNanos";
  122. }
  123. llvm_unreachable("unknown scaling factor");
  124. }
  125. /// Returns the Time factory function name for a given `Scale`.
  126. llvm::StringRef getTimeInverseForScale(DurationScale Scale) {
  127. switch (Scale) {
  128. case DurationScale::Hours:
  129. return "absl::ToUnixHours";
  130. case DurationScale::Minutes:
  131. return "absl::ToUnixMinutes";
  132. case DurationScale::Seconds:
  133. return "absl::ToUnixSeconds";
  134. case DurationScale::Milliseconds:
  135. return "absl::ToUnixMillis";
  136. case DurationScale::Microseconds:
  137. return "absl::ToUnixMicros";
  138. case DurationScale::Nanoseconds:
  139. return "absl::ToUnixNanos";
  140. }
  141. llvm_unreachable("unknown scaling factor");
  142. }
  143. /// Returns `true` if `Node` is a value which evaluates to a literal `0`.
  144. bool isLiteralZero(const MatchFinder::MatchResult &Result, const Expr &Node) {
  145. auto ZeroMatcher =
  146. anyOf(integerLiteral(equals(0)), floatLiteral(equals(0.0)));
  147. // Check to see if we're using a zero directly.
  148. if (selectFirst<const clang::Expr>(
  149. "val", match(expr(ignoringImpCasts(ZeroMatcher)).bind("val"), Node,
  150. *Result.Context)) != nullptr)
  151. return true;
  152. // Now check to see if we're using a functional cast with a scalar
  153. // initializer expression, e.g. `int{0}`.
  154. if (selectFirst<const clang::Expr>(
  155. "val", match(cxxFunctionalCastExpr(
  156. hasDestinationType(
  157. anyOf(isInteger(), realFloatingPointType())),
  158. hasSourceExpression(initListExpr(
  159. hasInit(0, ignoringParenImpCasts(ZeroMatcher)))))
  160. .bind("val"),
  161. Node, *Result.Context)) != nullptr)
  162. return true;
  163. return false;
  164. }
  165. std::optional<std::string>
  166. stripFloatCast(const ast_matchers::MatchFinder::MatchResult &Result,
  167. const Expr &Node) {
  168. if (const Expr *MaybeCastArg = selectFirst<const Expr>(
  169. "cast_arg",
  170. match(expr(anyOf(cxxStaticCastExpr(
  171. hasDestinationType(realFloatingPointType()),
  172. hasSourceExpression(expr().bind("cast_arg"))),
  173. cStyleCastExpr(
  174. hasDestinationType(realFloatingPointType()),
  175. hasSourceExpression(expr().bind("cast_arg"))),
  176. cxxFunctionalCastExpr(
  177. hasDestinationType(realFloatingPointType()),
  178. hasSourceExpression(expr().bind("cast_arg"))))),
  179. Node, *Result.Context)))
  180. return tooling::fixit::getText(*MaybeCastArg, *Result.Context).str();
  181. return std::nullopt;
  182. }
  183. std::optional<std::string>
  184. stripFloatLiteralFraction(const MatchFinder::MatchResult &Result,
  185. const Expr &Node) {
  186. if (const auto *LitFloat = llvm::dyn_cast<FloatingLiteral>(&Node))
  187. // Attempt to simplify a `Duration` factory call with a literal argument.
  188. if (std::optional<llvm::APSInt> IntValue = truncateIfIntegral(*LitFloat))
  189. return toString(*IntValue, /*radix=*/10);
  190. return std::nullopt;
  191. }
  192. std::string simplifyDurationFactoryArg(const MatchFinder::MatchResult &Result,
  193. const Expr &Node) {
  194. // Check for an explicit cast to `float` or `double`.
  195. if (std::optional<std::string> MaybeArg = stripFloatCast(Result, Node))
  196. return *MaybeArg;
  197. // Check for floats without fractional components.
  198. if (std::optional<std::string> MaybeArg =
  199. stripFloatLiteralFraction(Result, Node))
  200. return *MaybeArg;
  201. // We couldn't simplify any further, so return the argument text.
  202. return tooling::fixit::getText(Node, *Result.Context).str();
  203. }
  204. std::optional<DurationScale> getScaleForDurationInverse(llvm::StringRef Name) {
  205. static const llvm::StringMap<DurationScale> ScaleMap(
  206. {{"ToDoubleHours", DurationScale::Hours},
  207. {"ToInt64Hours", DurationScale::Hours},
  208. {"ToDoubleMinutes", DurationScale::Minutes},
  209. {"ToInt64Minutes", DurationScale::Minutes},
  210. {"ToDoubleSeconds", DurationScale::Seconds},
  211. {"ToInt64Seconds", DurationScale::Seconds},
  212. {"ToDoubleMilliseconds", DurationScale::Milliseconds},
  213. {"ToInt64Milliseconds", DurationScale::Milliseconds},
  214. {"ToDoubleMicroseconds", DurationScale::Microseconds},
  215. {"ToInt64Microseconds", DurationScale::Microseconds},
  216. {"ToDoubleNanoseconds", DurationScale::Nanoseconds},
  217. {"ToInt64Nanoseconds", DurationScale::Nanoseconds}});
  218. auto ScaleIter = ScaleMap.find(std::string(Name));
  219. if (ScaleIter == ScaleMap.end())
  220. return std::nullopt;
  221. return ScaleIter->second;
  222. }
  223. std::optional<DurationScale> getScaleForTimeInverse(llvm::StringRef Name) {
  224. static const llvm::StringMap<DurationScale> ScaleMap(
  225. {{"ToUnixHours", DurationScale::Hours},
  226. {"ToUnixMinutes", DurationScale::Minutes},
  227. {"ToUnixSeconds", DurationScale::Seconds},
  228. {"ToUnixMillis", DurationScale::Milliseconds},
  229. {"ToUnixMicros", DurationScale::Microseconds},
  230. {"ToUnixNanos", DurationScale::Nanoseconds}});
  231. auto ScaleIter = ScaleMap.find(std::string(Name));
  232. if (ScaleIter == ScaleMap.end())
  233. return std::nullopt;
  234. return ScaleIter->second;
  235. }
  236. std::string rewriteExprFromNumberToDuration(
  237. const ast_matchers::MatchFinder::MatchResult &Result, DurationScale Scale,
  238. const Expr *Node) {
  239. const Expr &RootNode = *Node->IgnoreParenImpCasts();
  240. // First check to see if we can undo a complementary function call.
  241. if (std::optional<std::string> MaybeRewrite =
  242. rewriteInverseDurationCall(Result, Scale, RootNode))
  243. return *MaybeRewrite;
  244. if (isLiteralZero(Result, RootNode))
  245. return std::string("absl::ZeroDuration()");
  246. return (llvm::Twine(getDurationFactoryForScale(Scale)) + "(" +
  247. simplifyDurationFactoryArg(Result, RootNode) + ")")
  248. .str();
  249. }
  250. std::string rewriteExprFromNumberToTime(
  251. const ast_matchers::MatchFinder::MatchResult &Result, DurationScale Scale,
  252. const Expr *Node) {
  253. const Expr &RootNode = *Node->IgnoreParenImpCasts();
  254. // First check to see if we can undo a complementary function call.
  255. if (std::optional<std::string> MaybeRewrite =
  256. rewriteInverseTimeCall(Result, Scale, RootNode))
  257. return *MaybeRewrite;
  258. if (isLiteralZero(Result, RootNode))
  259. return std::string("absl::UnixEpoch()");
  260. return (llvm::Twine(getTimeFactoryForScale(Scale)) + "(" +
  261. tooling::fixit::getText(RootNode, *Result.Context) + ")")
  262. .str();
  263. }
  264. bool isInMacro(const MatchFinder::MatchResult &Result, const Expr *E) {
  265. if (!E->getBeginLoc().isMacroID())
  266. return false;
  267. SourceLocation Loc = E->getBeginLoc();
  268. // We want to get closer towards the initial macro typed into the source only
  269. // if the location is being expanded as a macro argument.
  270. while (Result.SourceManager->isMacroArgExpansion(Loc)) {
  271. // We are calling getImmediateMacroCallerLoc, but note it is essentially
  272. // equivalent to calling getImmediateSpellingLoc in this context according
  273. // to Clang implementation. We are not calling getImmediateSpellingLoc
  274. // because Clang comment says it "should not generally be used by clients."
  275. Loc = Result.SourceManager->getImmediateMacroCallerLoc(Loc);
  276. }
  277. return Loc.isMacroID();
  278. }
  279. } // namespace clang::tidy::abseil