BranchProbability.cpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. //===-------------- lib/Support/BranchProbability.cpp -----------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements Branch Probability class.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Support/BranchProbability.h"
  13. #include "llvm/Config/llvm-config.h"
  14. #include "llvm/Support/Debug.h"
  15. #include "llvm/Support/Format.h"
  16. #include "llvm/Support/raw_ostream.h"
  17. #include <cassert>
  18. #include <cmath>
  19. using namespace llvm;
  20. constexpr uint32_t BranchProbability::D;
  21. raw_ostream &BranchProbability::print(raw_ostream &OS) const {
  22. if (isUnknown())
  23. return OS << "?%";
  24. // Get a percentage rounded to two decimal digits. This avoids
  25. // implementation-defined rounding inside printf.
  26. double Percent = rint(((double)N / D) * 100.0 * 100.0) / 100.0;
  27. return OS << format("0x%08" PRIx32 " / 0x%08" PRIx32 " = %.2f%%", N, D,
  28. Percent);
  29. }
  30. #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
  31. LLVM_DUMP_METHOD void BranchProbability::dump() const { print(dbgs()) << '\n'; }
  32. #endif
  33. BranchProbability::BranchProbability(uint32_t Numerator, uint32_t Denominator) {
  34. assert(Denominator > 0 && "Denominator cannot be 0!");
  35. assert(Numerator <= Denominator && "Probability cannot be bigger than 1!");
  36. if (Denominator == D)
  37. N = Numerator;
  38. else {
  39. uint64_t Prob64 =
  40. (Numerator * static_cast<uint64_t>(D) + Denominator / 2) / Denominator;
  41. N = static_cast<uint32_t>(Prob64);
  42. }
  43. }
  44. BranchProbability
  45. BranchProbability::getBranchProbability(uint64_t Numerator,
  46. uint64_t Denominator) {
  47. assert(Numerator <= Denominator && "Probability cannot be bigger than 1!");
  48. // Scale down Denominator to fit in a 32-bit integer.
  49. int Scale = 0;
  50. while (Denominator > UINT32_MAX) {
  51. Denominator >>= 1;
  52. Scale++;
  53. }
  54. return BranchProbability(Numerator >> Scale, Denominator);
  55. }
  56. // If ConstD is not zero, then replace D by ConstD so that division and modulo
  57. // operations by D can be optimized, in case this function is not inlined by the
  58. // compiler.
  59. template <uint32_t ConstD>
  60. static uint64_t scale(uint64_t Num, uint32_t N, uint32_t D) {
  61. if (ConstD > 0)
  62. D = ConstD;
  63. assert(D && "divide by 0");
  64. // Fast path for multiplying by 1.0.
  65. if (!Num || D == N)
  66. return Num;
  67. // Split Num into upper and lower parts to multiply, then recombine.
  68. uint64_t ProductHigh = (Num >> 32) * N;
  69. uint64_t ProductLow = (Num & UINT32_MAX) * N;
  70. // Split into 32-bit digits.
  71. uint32_t Upper32 = ProductHigh >> 32;
  72. uint32_t Lower32 = ProductLow & UINT32_MAX;
  73. uint32_t Mid32Partial = ProductHigh & UINT32_MAX;
  74. uint32_t Mid32 = Mid32Partial + (ProductLow >> 32);
  75. // Carry.
  76. Upper32 += Mid32 < Mid32Partial;
  77. uint64_t Rem = (uint64_t(Upper32) << 32) | Mid32;
  78. uint64_t UpperQ = Rem / D;
  79. // Check for overflow.
  80. if (UpperQ > UINT32_MAX)
  81. return UINT64_MAX;
  82. Rem = ((Rem % D) << 32) | Lower32;
  83. uint64_t LowerQ = Rem / D;
  84. uint64_t Q = (UpperQ << 32) + LowerQ;
  85. // Check for overflow.
  86. return Q < LowerQ ? UINT64_MAX : Q;
  87. }
  88. uint64_t BranchProbability::scale(uint64_t Num) const {
  89. return ::scale<D>(Num, N, D);
  90. }
  91. uint64_t BranchProbability::scaleByInverse(uint64_t Num) const {
  92. return ::scale<0>(Num, D, N);
  93. }