BranchProbabilityInfo.cpp 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321
  1. //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Loops should be simplified before this analysis.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Analysis/BranchProbabilityInfo.h"
  13. #include "llvm/ADT/PostOrderIterator.h"
  14. #include "llvm/ADT/SCCIterator.h"
  15. #include "llvm/ADT/STLExtras.h"
  16. #include "llvm/ADT/SmallVector.h"
  17. #include "llvm/Analysis/ConstantFolding.h"
  18. #include "llvm/Analysis/LoopInfo.h"
  19. #include "llvm/Analysis/PostDominators.h"
  20. #include "llvm/Analysis/TargetLibraryInfo.h"
  21. #include "llvm/IR/Attributes.h"
  22. #include "llvm/IR/BasicBlock.h"
  23. #include "llvm/IR/CFG.h"
  24. #include "llvm/IR/Constants.h"
  25. #include "llvm/IR/Dominators.h"
  26. #include "llvm/IR/Function.h"
  27. #include "llvm/IR/InstrTypes.h"
  28. #include "llvm/IR/Instruction.h"
  29. #include "llvm/IR/Instructions.h"
  30. #include "llvm/IR/LLVMContext.h"
  31. #include "llvm/IR/Metadata.h"
  32. #include "llvm/IR/PassManager.h"
  33. #include "llvm/IR/ProfDataUtils.h"
  34. #include "llvm/IR/Type.h"
  35. #include "llvm/IR/Value.h"
  36. #include "llvm/InitializePasses.h"
  37. #include "llvm/Pass.h"
  38. #include "llvm/Support/BranchProbability.h"
  39. #include "llvm/Support/Casting.h"
  40. #include "llvm/Support/CommandLine.h"
  41. #include "llvm/Support/Debug.h"
  42. #include "llvm/Support/raw_ostream.h"
  43. #include <cassert>
  44. #include <cstdint>
  45. #include <iterator>
  46. #include <map>
  47. #include <utility>
  48. using namespace llvm;
  49. #define DEBUG_TYPE "branch-prob"
  50. static cl::opt<bool> PrintBranchProb(
  51. "print-bpi", cl::init(false), cl::Hidden,
  52. cl::desc("Print the branch probability info."));
  53. cl::opt<std::string> PrintBranchProbFuncName(
  54. "print-bpi-func-name", cl::Hidden,
  55. cl::desc("The option to specify the name of the function "
  56. "whose branch probability info is printed."));
  57. INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
  58. "Branch Probability Analysis", false, true)
  59. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  60. INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
  61. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  62. INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
  63. INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
  64. "Branch Probability Analysis", false, true)
  65. BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
  66. : FunctionPass(ID) {
  67. initializeBranchProbabilityInfoWrapperPassPass(
  68. *PassRegistry::getPassRegistry());
  69. }
  70. char BranchProbabilityInfoWrapperPass::ID = 0;
  71. // Weights are for internal use only. They are used by heuristics to help to
  72. // estimate edges' probability. Example:
  73. //
  74. // Using "Loop Branch Heuristics" we predict weights of edges for the
  75. // block BB2.
  76. // ...
  77. // |
  78. // V
  79. // BB1<-+
  80. // | |
  81. // | | (Weight = 124)
  82. // V |
  83. // BB2--+
  84. // |
  85. // | (Weight = 4)
  86. // V
  87. // BB3
  88. //
  89. // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
  90. // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
  91. static const uint32_t LBH_TAKEN_WEIGHT = 124;
  92. static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
  93. /// Unreachable-terminating branch taken probability.
  94. ///
  95. /// This is the probability for a branch being taken to a block that terminates
  96. /// (eventually) in unreachable. These are predicted as unlikely as possible.
  97. /// All reachable probability will proportionally share the remaining part.
  98. static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
  99. /// Heuristics and lookup tables for non-loop branches:
  100. /// Pointer Heuristics (PH)
  101. static const uint32_t PH_TAKEN_WEIGHT = 20;
  102. static const uint32_t PH_NONTAKEN_WEIGHT = 12;
  103. static const BranchProbability
  104. PtrTakenProb(PH_TAKEN_WEIGHT, PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
  105. static const BranchProbability
  106. PtrUntakenProb(PH_NONTAKEN_WEIGHT, PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
  107. using ProbabilityList = SmallVector<BranchProbability>;
  108. using ProbabilityTable = std::map<CmpInst::Predicate, ProbabilityList>;
  109. /// Pointer comparisons:
  110. static const ProbabilityTable PointerTable{
  111. {ICmpInst::ICMP_NE, {PtrTakenProb, PtrUntakenProb}}, /// p != q -> Likely
  112. {ICmpInst::ICMP_EQ, {PtrUntakenProb, PtrTakenProb}}, /// p == q -> Unlikely
  113. };
  114. /// Zero Heuristics (ZH)
  115. static const uint32_t ZH_TAKEN_WEIGHT = 20;
  116. static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
  117. static const BranchProbability
  118. ZeroTakenProb(ZH_TAKEN_WEIGHT, ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
  119. static const BranchProbability
  120. ZeroUntakenProb(ZH_NONTAKEN_WEIGHT, ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
  121. /// Integer compares with 0:
  122. static const ProbabilityTable ICmpWithZeroTable{
  123. {CmpInst::ICMP_EQ, {ZeroUntakenProb, ZeroTakenProb}}, /// X == 0 -> Unlikely
  124. {CmpInst::ICMP_NE, {ZeroTakenProb, ZeroUntakenProb}}, /// X != 0 -> Likely
  125. {CmpInst::ICMP_SLT, {ZeroUntakenProb, ZeroTakenProb}}, /// X < 0 -> Unlikely
  126. {CmpInst::ICMP_SGT, {ZeroTakenProb, ZeroUntakenProb}}, /// X > 0 -> Likely
  127. };
  128. /// Integer compares with -1:
  129. static const ProbabilityTable ICmpWithMinusOneTable{
  130. {CmpInst::ICMP_EQ, {ZeroUntakenProb, ZeroTakenProb}}, /// X == -1 -> Unlikely
  131. {CmpInst::ICMP_NE, {ZeroTakenProb, ZeroUntakenProb}}, /// X != -1 -> Likely
  132. // InstCombine canonicalizes X >= 0 into X > -1
  133. {CmpInst::ICMP_SGT, {ZeroTakenProb, ZeroUntakenProb}}, /// X >= 0 -> Likely
  134. };
  135. /// Integer compares with 1:
  136. static const ProbabilityTable ICmpWithOneTable{
  137. // InstCombine canonicalizes X <= 0 into X < 1
  138. {CmpInst::ICMP_SLT, {ZeroUntakenProb, ZeroTakenProb}}, /// X <= 0 -> Unlikely
  139. };
  140. /// strcmp and similar functions return zero, negative, or positive, if the
  141. /// first string is equal, less, or greater than the second. We consider it
  142. /// likely that the strings are not equal, so a comparison with zero is
  143. /// probably false, but also a comparison with any other number is also
  144. /// probably false given that what exactly is returned for nonzero values is
  145. /// not specified. Any kind of comparison other than equality we know
  146. /// nothing about.
  147. static const ProbabilityTable ICmpWithLibCallTable{
  148. {CmpInst::ICMP_EQ, {ZeroUntakenProb, ZeroTakenProb}},
  149. {CmpInst::ICMP_NE, {ZeroTakenProb, ZeroUntakenProb}},
  150. };
  151. // Floating-Point Heuristics (FPH)
  152. static const uint32_t FPH_TAKEN_WEIGHT = 20;
  153. static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
  154. /// This is the probability for an ordered floating point comparison.
  155. static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
  156. /// This is the probability for an unordered floating point comparison, it means
  157. /// one or two of the operands are NaN. Usually it is used to test for an
  158. /// exceptional case, so the result is unlikely.
  159. static const uint32_t FPH_UNO_WEIGHT = 1;
  160. static const BranchProbability FPOrdTakenProb(FPH_ORD_WEIGHT,
  161. FPH_ORD_WEIGHT + FPH_UNO_WEIGHT);
  162. static const BranchProbability
  163. FPOrdUntakenProb(FPH_UNO_WEIGHT, FPH_ORD_WEIGHT + FPH_UNO_WEIGHT);
  164. static const BranchProbability
  165. FPTakenProb(FPH_TAKEN_WEIGHT, FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT);
  166. static const BranchProbability
  167. FPUntakenProb(FPH_NONTAKEN_WEIGHT, FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT);
  168. /// Floating-Point compares:
  169. static const ProbabilityTable FCmpTable{
  170. {FCmpInst::FCMP_ORD, {FPOrdTakenProb, FPOrdUntakenProb}}, /// !isnan -> Likely
  171. {FCmpInst::FCMP_UNO, {FPOrdUntakenProb, FPOrdTakenProb}}, /// isnan -> Unlikely
  172. };
  173. /// Set of dedicated "absolute" execution weights for a block. These weights are
  174. /// meaningful relative to each other and their derivatives only.
  175. enum class BlockExecWeight : std::uint32_t {
  176. /// Special weight used for cases with exact zero probability.
  177. ZERO = 0x0,
  178. /// Minimal possible non zero weight.
  179. LOWEST_NON_ZERO = 0x1,
  180. /// Weight to an 'unreachable' block.
  181. UNREACHABLE = ZERO,
  182. /// Weight to a block containing non returning call.
  183. NORETURN = LOWEST_NON_ZERO,
  184. /// Weight to 'unwind' block of an invoke instruction.
  185. UNWIND = LOWEST_NON_ZERO,
  186. /// Weight to a 'cold' block. Cold blocks are the ones containing calls marked
  187. /// with attribute 'cold'.
  188. COLD = 0xffff,
  189. /// Default weight is used in cases when there is no dedicated execution
  190. /// weight set. It is not propagated through the domination line either.
  191. DEFAULT = 0xfffff
  192. };
  193. BranchProbabilityInfo::SccInfo::SccInfo(const Function &F) {
  194. // Record SCC numbers of blocks in the CFG to identify irreducible loops.
  195. // FIXME: We could only calculate this if the CFG is known to be irreducible
  196. // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
  197. int SccNum = 0;
  198. for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
  199. ++It, ++SccNum) {
  200. // Ignore single-block SCCs since they either aren't loops or LoopInfo will
  201. // catch them.
  202. const std::vector<const BasicBlock *> &Scc = *It;
  203. if (Scc.size() == 1)
  204. continue;
  205. LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
  206. for (const auto *BB : Scc) {
  207. LLVM_DEBUG(dbgs() << " " << BB->getName());
  208. SccNums[BB] = SccNum;
  209. calculateSccBlockType(BB, SccNum);
  210. }
  211. LLVM_DEBUG(dbgs() << "\n");
  212. }
  213. }
  214. int BranchProbabilityInfo::SccInfo::getSCCNum(const BasicBlock *BB) const {
  215. auto SccIt = SccNums.find(BB);
  216. if (SccIt == SccNums.end())
  217. return -1;
  218. return SccIt->second;
  219. }
  220. void BranchProbabilityInfo::SccInfo::getSccEnterBlocks(
  221. int SccNum, SmallVectorImpl<BasicBlock *> &Enters) const {
  222. for (auto MapIt : SccBlocks[SccNum]) {
  223. const auto *BB = MapIt.first;
  224. if (isSCCHeader(BB, SccNum))
  225. for (const auto *Pred : predecessors(BB))
  226. if (getSCCNum(Pred) != SccNum)
  227. Enters.push_back(const_cast<BasicBlock *>(BB));
  228. }
  229. }
  230. void BranchProbabilityInfo::SccInfo::getSccExitBlocks(
  231. int SccNum, SmallVectorImpl<BasicBlock *> &Exits) const {
  232. for (auto MapIt : SccBlocks[SccNum]) {
  233. const auto *BB = MapIt.first;
  234. if (isSCCExitingBlock(BB, SccNum))
  235. for (const auto *Succ : successors(BB))
  236. if (getSCCNum(Succ) != SccNum)
  237. Exits.push_back(const_cast<BasicBlock *>(Succ));
  238. }
  239. }
  240. uint32_t BranchProbabilityInfo::SccInfo::getSccBlockType(const BasicBlock *BB,
  241. int SccNum) const {
  242. assert(getSCCNum(BB) == SccNum);
  243. assert(SccBlocks.size() > static_cast<unsigned>(SccNum) && "Unknown SCC");
  244. const auto &SccBlockTypes = SccBlocks[SccNum];
  245. auto It = SccBlockTypes.find(BB);
  246. if (It != SccBlockTypes.end()) {
  247. return It->second;
  248. }
  249. return Inner;
  250. }
  251. void BranchProbabilityInfo::SccInfo::calculateSccBlockType(const BasicBlock *BB,
  252. int SccNum) {
  253. assert(getSCCNum(BB) == SccNum);
  254. uint32_t BlockType = Inner;
  255. if (llvm::any_of(predecessors(BB), [&](const BasicBlock *Pred) {
  256. // Consider any block that is an entry point to the SCC as
  257. // a header.
  258. return getSCCNum(Pred) != SccNum;
  259. }))
  260. BlockType |= Header;
  261. if (llvm::any_of(successors(BB), [&](const BasicBlock *Succ) {
  262. return getSCCNum(Succ) != SccNum;
  263. }))
  264. BlockType |= Exiting;
  265. // Lazily compute the set of headers for a given SCC and cache the results
  266. // in the SccHeaderMap.
  267. if (SccBlocks.size() <= static_cast<unsigned>(SccNum))
  268. SccBlocks.resize(SccNum + 1);
  269. auto &SccBlockTypes = SccBlocks[SccNum];
  270. if (BlockType != Inner) {
  271. bool IsInserted;
  272. std::tie(std::ignore, IsInserted) =
  273. SccBlockTypes.insert(std::make_pair(BB, BlockType));
  274. assert(IsInserted && "Duplicated block in SCC");
  275. }
  276. }
  277. BranchProbabilityInfo::LoopBlock::LoopBlock(const BasicBlock *BB,
  278. const LoopInfo &LI,
  279. const SccInfo &SccI)
  280. : BB(BB) {
  281. LD.first = LI.getLoopFor(BB);
  282. if (!LD.first) {
  283. LD.second = SccI.getSCCNum(BB);
  284. }
  285. }
  286. bool BranchProbabilityInfo::isLoopEnteringEdge(const LoopEdge &Edge) const {
  287. const auto &SrcBlock = Edge.first;
  288. const auto &DstBlock = Edge.second;
  289. return (DstBlock.getLoop() &&
  290. !DstBlock.getLoop()->contains(SrcBlock.getLoop())) ||
  291. // Assume that SCCs can't be nested.
  292. (DstBlock.getSccNum() != -1 &&
  293. SrcBlock.getSccNum() != DstBlock.getSccNum());
  294. }
  295. bool BranchProbabilityInfo::isLoopExitingEdge(const LoopEdge &Edge) const {
  296. return isLoopEnteringEdge({Edge.second, Edge.first});
  297. }
  298. bool BranchProbabilityInfo::isLoopEnteringExitingEdge(
  299. const LoopEdge &Edge) const {
  300. return isLoopEnteringEdge(Edge) || isLoopExitingEdge(Edge);
  301. }
  302. bool BranchProbabilityInfo::isLoopBackEdge(const LoopEdge &Edge) const {
  303. const auto &SrcBlock = Edge.first;
  304. const auto &DstBlock = Edge.second;
  305. return SrcBlock.belongsToSameLoop(DstBlock) &&
  306. ((DstBlock.getLoop() &&
  307. DstBlock.getLoop()->getHeader() == DstBlock.getBlock()) ||
  308. (DstBlock.getSccNum() != -1 &&
  309. SccI->isSCCHeader(DstBlock.getBlock(), DstBlock.getSccNum())));
  310. }
  311. void BranchProbabilityInfo::getLoopEnterBlocks(
  312. const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Enters) const {
  313. if (LB.getLoop()) {
  314. auto *Header = LB.getLoop()->getHeader();
  315. Enters.append(pred_begin(Header), pred_end(Header));
  316. } else {
  317. assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
  318. SccI->getSccEnterBlocks(LB.getSccNum(), Enters);
  319. }
  320. }
  321. void BranchProbabilityInfo::getLoopExitBlocks(
  322. const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Exits) const {
  323. if (LB.getLoop()) {
  324. LB.getLoop()->getExitBlocks(Exits);
  325. } else {
  326. assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
  327. SccI->getSccExitBlocks(LB.getSccNum(), Exits);
  328. }
  329. }
  330. // Propagate existing explicit probabilities from either profile data or
  331. // 'expect' intrinsic processing. Examine metadata against unreachable
  332. // heuristic. The probability of the edge coming to unreachable block is
  333. // set to min of metadata and unreachable heuristic.
  334. bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
  335. const Instruction *TI = BB->getTerminator();
  336. assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
  337. if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI) ||
  338. isa<InvokeInst>(TI) || isa<CallBrInst>(TI)))
  339. return false;
  340. MDNode *WeightsNode = getValidBranchWeightMDNode(*TI);
  341. if (!WeightsNode)
  342. return false;
  343. // Check that the number of successors is manageable.
  344. assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
  345. // Build up the final weights that will be used in a temporary buffer.
  346. // Compute the sum of all weights to later decide whether they need to
  347. // be scaled to fit in 32 bits.
  348. uint64_t WeightSum = 0;
  349. SmallVector<uint32_t, 2> Weights;
  350. SmallVector<unsigned, 2> UnreachableIdxs;
  351. SmallVector<unsigned, 2> ReachableIdxs;
  352. extractBranchWeights(WeightsNode, Weights);
  353. for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
  354. WeightSum += Weights[I];
  355. const LoopBlock SrcLoopBB = getLoopBlock(BB);
  356. const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I));
  357. auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
  358. if (EstimatedWeight &&
  359. *EstimatedWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
  360. UnreachableIdxs.push_back(I);
  361. else
  362. ReachableIdxs.push_back(I);
  363. }
  364. assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
  365. // If the sum of weights does not fit in 32 bits, scale every weight down
  366. // accordingly.
  367. uint64_t ScalingFactor =
  368. (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
  369. if (ScalingFactor > 1) {
  370. WeightSum = 0;
  371. for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
  372. Weights[I] /= ScalingFactor;
  373. WeightSum += Weights[I];
  374. }
  375. }
  376. assert(WeightSum <= UINT32_MAX &&
  377. "Expected weights to scale down to 32 bits");
  378. if (WeightSum == 0 || ReachableIdxs.size() == 0) {
  379. for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
  380. Weights[I] = 1;
  381. WeightSum = TI->getNumSuccessors();
  382. }
  383. // Set the probability.
  384. SmallVector<BranchProbability, 2> BP;
  385. for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
  386. BP.push_back({ Weights[I], static_cast<uint32_t>(WeightSum) });
  387. // Examine the metadata against unreachable heuristic.
  388. // If the unreachable heuristic is more strong then we use it for this edge.
  389. if (UnreachableIdxs.size() == 0 || ReachableIdxs.size() == 0) {
  390. setEdgeProbability(BB, BP);
  391. return true;
  392. }
  393. auto UnreachableProb = UR_TAKEN_PROB;
  394. for (auto I : UnreachableIdxs)
  395. if (UnreachableProb < BP[I]) {
  396. BP[I] = UnreachableProb;
  397. }
  398. // Sum of all edge probabilities must be 1.0. If we modified the probability
  399. // of some edges then we must distribute the introduced difference over the
  400. // reachable blocks.
  401. //
  402. // Proportional distribution: the relation between probabilities of the
  403. // reachable edges is kept unchanged. That is for any reachable edges i and j:
  404. // newBP[i] / newBP[j] == oldBP[i] / oldBP[j] =>
  405. // newBP[i] / oldBP[i] == newBP[j] / oldBP[j] == K
  406. // Where K is independent of i,j.
  407. // newBP[i] == oldBP[i] * K
  408. // We need to find K.
  409. // Make sum of all reachables of the left and right parts:
  410. // sum_of_reachable(newBP) == K * sum_of_reachable(oldBP)
  411. // Sum of newBP must be equal to 1.0:
  412. // sum_of_reachable(newBP) + sum_of_unreachable(newBP) == 1.0 =>
  413. // sum_of_reachable(newBP) = 1.0 - sum_of_unreachable(newBP)
  414. // Where sum_of_unreachable(newBP) is what has been just changed.
  415. // Finally:
  416. // K == sum_of_reachable(newBP) / sum_of_reachable(oldBP) =>
  417. // K == (1.0 - sum_of_unreachable(newBP)) / sum_of_reachable(oldBP)
  418. BranchProbability NewUnreachableSum = BranchProbability::getZero();
  419. for (auto I : UnreachableIdxs)
  420. NewUnreachableSum += BP[I];
  421. BranchProbability NewReachableSum =
  422. BranchProbability::getOne() - NewUnreachableSum;
  423. BranchProbability OldReachableSum = BranchProbability::getZero();
  424. for (auto I : ReachableIdxs)
  425. OldReachableSum += BP[I];
  426. if (OldReachableSum != NewReachableSum) { // Anything to dsitribute?
  427. if (OldReachableSum.isZero()) {
  428. // If all oldBP[i] are zeroes then the proportional distribution results
  429. // in all zero probabilities and the error stays big. In this case we
  430. // evenly spread NewReachableSum over the reachable edges.
  431. BranchProbability PerEdge = NewReachableSum / ReachableIdxs.size();
  432. for (auto I : ReachableIdxs)
  433. BP[I] = PerEdge;
  434. } else {
  435. for (auto I : ReachableIdxs) {
  436. // We use uint64_t to avoid double rounding error of the following
  437. // calculation: BP[i] = BP[i] * NewReachableSum / OldReachableSum
  438. // The formula is taken from the private constructor
  439. // BranchProbability(uint32_t Numerator, uint32_t Denominator)
  440. uint64_t Mul = static_cast<uint64_t>(NewReachableSum.getNumerator()) *
  441. BP[I].getNumerator();
  442. uint32_t Div = static_cast<uint32_t>(
  443. divideNearest(Mul, OldReachableSum.getNumerator()));
  444. BP[I] = BranchProbability::getRaw(Div);
  445. }
  446. }
  447. }
  448. setEdgeProbability(BB, BP);
  449. return true;
  450. }
  451. // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
  452. // between two pointer or pointer and NULL will fail.
  453. bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
  454. const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  455. if (!BI || !BI->isConditional())
  456. return false;
  457. Value *Cond = BI->getCondition();
  458. ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
  459. if (!CI || !CI->isEquality())
  460. return false;
  461. Value *LHS = CI->getOperand(0);
  462. if (!LHS->getType()->isPointerTy())
  463. return false;
  464. assert(CI->getOperand(1)->getType()->isPointerTy());
  465. auto Search = PointerTable.find(CI->getPredicate());
  466. if (Search == PointerTable.end())
  467. return false;
  468. setEdgeProbability(BB, Search->second);
  469. return true;
  470. }
  471. // Compute the unlikely successors to the block BB in the loop L, specifically
  472. // those that are unlikely because this is a loop, and add them to the
  473. // UnlikelyBlocks set.
  474. static void
  475. computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
  476. SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
  477. // Sometimes in a loop we have a branch whose condition is made false by
  478. // taking it. This is typically something like
  479. // int n = 0;
  480. // while (...) {
  481. // if (++n >= MAX) {
  482. // n = 0;
  483. // }
  484. // }
  485. // In this sort of situation taking the branch means that at the very least it
  486. // won't be taken again in the next iteration of the loop, so we should
  487. // consider it less likely than a typical branch.
  488. //
  489. // We detect this by looking back through the graph of PHI nodes that sets the
  490. // value that the condition depends on, and seeing if we can reach a successor
  491. // block which can be determined to make the condition false.
  492. //
  493. // FIXME: We currently consider unlikely blocks to be half as likely as other
  494. // blocks, but if we consider the example above the likelyhood is actually
  495. // 1/MAX. We could therefore be more precise in how unlikely we consider
  496. // blocks to be, but it would require more careful examination of the form
  497. // of the comparison expression.
  498. const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  499. if (!BI || !BI->isConditional())
  500. return;
  501. // Check if the branch is based on an instruction compared with a constant
  502. CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
  503. if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
  504. !isa<Constant>(CI->getOperand(1)))
  505. return;
  506. // Either the instruction must be a PHI, or a chain of operations involving
  507. // constants that ends in a PHI which we can then collapse into a single value
  508. // if the PHI value is known.
  509. Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
  510. PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
  511. Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
  512. // Collect the instructions until we hit a PHI
  513. SmallVector<BinaryOperator *, 1> InstChain;
  514. while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
  515. isa<Constant>(CmpLHS->getOperand(1))) {
  516. // Stop if the chain extends outside of the loop
  517. if (!L->contains(CmpLHS))
  518. return;
  519. InstChain.push_back(cast<BinaryOperator>(CmpLHS));
  520. CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
  521. if (CmpLHS)
  522. CmpPHI = dyn_cast<PHINode>(CmpLHS);
  523. }
  524. if (!CmpPHI || !L->contains(CmpPHI))
  525. return;
  526. // Trace the phi node to find all values that come from successors of BB
  527. SmallPtrSet<PHINode*, 8> VisitedInsts;
  528. SmallVector<PHINode*, 8> WorkList;
  529. WorkList.push_back(CmpPHI);
  530. VisitedInsts.insert(CmpPHI);
  531. while (!WorkList.empty()) {
  532. PHINode *P = WorkList.pop_back_val();
  533. for (BasicBlock *B : P->blocks()) {
  534. // Skip blocks that aren't part of the loop
  535. if (!L->contains(B))
  536. continue;
  537. Value *V = P->getIncomingValueForBlock(B);
  538. // If the source is a PHI add it to the work list if we haven't
  539. // already visited it.
  540. if (PHINode *PN = dyn_cast<PHINode>(V)) {
  541. if (VisitedInsts.insert(PN).second)
  542. WorkList.push_back(PN);
  543. continue;
  544. }
  545. // If this incoming value is a constant and B is a successor of BB, then
  546. // we can constant-evaluate the compare to see if it makes the branch be
  547. // taken or not.
  548. Constant *CmpLHSConst = dyn_cast<Constant>(V);
  549. if (!CmpLHSConst || !llvm::is_contained(successors(BB), B))
  550. continue;
  551. // First collapse InstChain
  552. const DataLayout &DL = BB->getModule()->getDataLayout();
  553. for (Instruction *I : llvm::reverse(InstChain)) {
  554. CmpLHSConst = ConstantFoldBinaryOpOperands(
  555. I->getOpcode(), CmpLHSConst, cast<Constant>(I->getOperand(1)), DL);
  556. if (!CmpLHSConst)
  557. break;
  558. }
  559. if (!CmpLHSConst)
  560. continue;
  561. // Now constant-evaluate the compare
  562. Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
  563. CmpLHSConst, CmpConst, true);
  564. // If the result means we don't branch to the block then that block is
  565. // unlikely.
  566. if (Result &&
  567. ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
  568. (Result->isOneValue() && B == BI->getSuccessor(1))))
  569. UnlikelyBlocks.insert(B);
  570. }
  571. }
  572. }
  573. std::optional<uint32_t>
  574. BranchProbabilityInfo::getEstimatedBlockWeight(const BasicBlock *BB) const {
  575. auto WeightIt = EstimatedBlockWeight.find(BB);
  576. if (WeightIt == EstimatedBlockWeight.end())
  577. return std::nullopt;
  578. return WeightIt->second;
  579. }
  580. std::optional<uint32_t>
  581. BranchProbabilityInfo::getEstimatedLoopWeight(const LoopData &L) const {
  582. auto WeightIt = EstimatedLoopWeight.find(L);
  583. if (WeightIt == EstimatedLoopWeight.end())
  584. return std::nullopt;
  585. return WeightIt->second;
  586. }
  587. std::optional<uint32_t>
  588. BranchProbabilityInfo::getEstimatedEdgeWeight(const LoopEdge &Edge) const {
  589. // For edges entering a loop take weight of a loop rather than an individual
  590. // block in the loop.
  591. return isLoopEnteringEdge(Edge)
  592. ? getEstimatedLoopWeight(Edge.second.getLoopData())
  593. : getEstimatedBlockWeight(Edge.second.getBlock());
  594. }
  595. template <class IterT>
  596. std::optional<uint32_t> BranchProbabilityInfo::getMaxEstimatedEdgeWeight(
  597. const LoopBlock &SrcLoopBB, iterator_range<IterT> Successors) const {
  598. SmallVector<uint32_t, 4> Weights;
  599. std::optional<uint32_t> MaxWeight;
  600. for (const BasicBlock *DstBB : Successors) {
  601. const LoopBlock DstLoopBB = getLoopBlock(DstBB);
  602. auto Weight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
  603. if (!Weight)
  604. return std::nullopt;
  605. if (!MaxWeight || *MaxWeight < *Weight)
  606. MaxWeight = Weight;
  607. }
  608. return MaxWeight;
  609. }
  610. // Updates \p LoopBB's weight and returns true. If \p LoopBB has already
  611. // an associated weight it is unchanged and false is returned.
  612. //
  613. // Please note by the algorithm the weight is not expected to change once set
  614. // thus 'false' status is used to track visited blocks.
  615. bool BranchProbabilityInfo::updateEstimatedBlockWeight(
  616. LoopBlock &LoopBB, uint32_t BBWeight,
  617. SmallVectorImpl<BasicBlock *> &BlockWorkList,
  618. SmallVectorImpl<LoopBlock> &LoopWorkList) {
  619. BasicBlock *BB = LoopBB.getBlock();
  620. // In general, weight is assigned to a block when it has final value and
  621. // can't/shouldn't be changed. However, there are cases when a block
  622. // inherently has several (possibly "contradicting") weights. For example,
  623. // "unwind" block may also contain "cold" call. In that case the first
  624. // set weight is favored and all consequent weights are ignored.
  625. if (!EstimatedBlockWeight.insert({BB, BBWeight}).second)
  626. return false;
  627. for (BasicBlock *PredBlock : predecessors(BB)) {
  628. LoopBlock PredLoop = getLoopBlock(PredBlock);
  629. // Add affected block/loop to a working list.
  630. if (isLoopExitingEdge({PredLoop, LoopBB})) {
  631. if (!EstimatedLoopWeight.count(PredLoop.getLoopData()))
  632. LoopWorkList.push_back(PredLoop);
  633. } else if (!EstimatedBlockWeight.count(PredBlock))
  634. BlockWorkList.push_back(PredBlock);
  635. }
  636. return true;
  637. }
  638. // Starting from \p BB traverse through dominator blocks and assign \p BBWeight
  639. // to all such blocks that are post dominated by \BB. In other words to all
  640. // blocks that the one is executed if and only if another one is executed.
  641. // Importantly, we skip loops here for two reasons. First weights of blocks in
  642. // a loop should be scaled by trip count (yet possibly unknown). Second there is
  643. // no any value in doing that because that doesn't give any additional
  644. // information regarding distribution of probabilities inside the loop.
  645. // Exception is loop 'enter' and 'exit' edges that are handled in a special way
  646. // at calcEstimatedHeuristics.
  647. //
  648. // In addition, \p WorkList is populated with basic blocks if at leas one
  649. // successor has updated estimated weight.
  650. void BranchProbabilityInfo::propagateEstimatedBlockWeight(
  651. const LoopBlock &LoopBB, DominatorTree *DT, PostDominatorTree *PDT,
  652. uint32_t BBWeight, SmallVectorImpl<BasicBlock *> &BlockWorkList,
  653. SmallVectorImpl<LoopBlock> &LoopWorkList) {
  654. const BasicBlock *BB = LoopBB.getBlock();
  655. const auto *DTStartNode = DT->getNode(BB);
  656. const auto *PDTStartNode = PDT->getNode(BB);
  657. // TODO: Consider propagating weight down the domination line as well.
  658. for (const auto *DTNode = DTStartNode; DTNode != nullptr;
  659. DTNode = DTNode->getIDom()) {
  660. auto *DomBB = DTNode->getBlock();
  661. // Consider blocks which lie on one 'line'.
  662. if (!PDT->dominates(PDTStartNode, PDT->getNode(DomBB)))
  663. // If BB doesn't post dominate DomBB it will not post dominate dominators
  664. // of DomBB as well.
  665. break;
  666. LoopBlock DomLoopBB = getLoopBlock(DomBB);
  667. const LoopEdge Edge{DomLoopBB, LoopBB};
  668. // Don't propagate weight to blocks belonging to different loops.
  669. if (!isLoopEnteringExitingEdge(Edge)) {
  670. if (!updateEstimatedBlockWeight(DomLoopBB, BBWeight, BlockWorkList,
  671. LoopWorkList))
  672. // If DomBB has weight set then all it's predecessors are already
  673. // processed (since we propagate weight up to the top of IR each time).
  674. break;
  675. } else if (isLoopExitingEdge(Edge)) {
  676. LoopWorkList.push_back(DomLoopBB);
  677. }
  678. }
  679. }
  680. std::optional<uint32_t>
  681. BranchProbabilityInfo::getInitialEstimatedBlockWeight(const BasicBlock *BB) {
  682. // Returns true if \p BB has call marked with "NoReturn" attribute.
  683. auto hasNoReturn = [&](const BasicBlock *BB) {
  684. for (const auto &I : reverse(*BB))
  685. if (const CallInst *CI = dyn_cast<CallInst>(&I))
  686. if (CI->hasFnAttr(Attribute::NoReturn))
  687. return true;
  688. return false;
  689. };
  690. // Important note regarding the order of checks. They are ordered by weight
  691. // from lowest to highest. Doing that allows to avoid "unstable" results
  692. // when several conditions heuristics can be applied simultaneously.
  693. if (isa<UnreachableInst>(BB->getTerminator()) ||
  694. // If this block is terminated by a call to
  695. // @llvm.experimental.deoptimize then treat it like an unreachable
  696. // since it is expected to practically never execute.
  697. // TODO: Should we actually treat as never returning call?
  698. BB->getTerminatingDeoptimizeCall())
  699. return hasNoReturn(BB)
  700. ? static_cast<uint32_t>(BlockExecWeight::NORETURN)
  701. : static_cast<uint32_t>(BlockExecWeight::UNREACHABLE);
  702. // Check if the block is 'unwind' handler of some invoke instruction.
  703. for (const auto *Pred : predecessors(BB))
  704. if (Pred)
  705. if (const auto *II = dyn_cast<InvokeInst>(Pred->getTerminator()))
  706. if (II->getUnwindDest() == BB)
  707. return static_cast<uint32_t>(BlockExecWeight::UNWIND);
  708. // Check if the block contains 'cold' call.
  709. for (const auto &I : *BB)
  710. if (const CallInst *CI = dyn_cast<CallInst>(&I))
  711. if (CI->hasFnAttr(Attribute::Cold))
  712. return static_cast<uint32_t>(BlockExecWeight::COLD);
  713. return std::nullopt;
  714. }
  715. // Does RPO traversal over all blocks in \p F and assigns weights to
  716. // 'unreachable', 'noreturn', 'cold', 'unwind' blocks. In addition it does its
  717. // best to propagate the weight to up/down the IR.
  718. void BranchProbabilityInfo::computeEestimateBlockWeight(
  719. const Function &F, DominatorTree *DT, PostDominatorTree *PDT) {
  720. SmallVector<BasicBlock *, 8> BlockWorkList;
  721. SmallVector<LoopBlock, 8> LoopWorkList;
  722. // By doing RPO we make sure that all predecessors already have weights
  723. // calculated before visiting theirs successors.
  724. ReversePostOrderTraversal<const Function *> RPOT(&F);
  725. for (const auto *BB : RPOT)
  726. if (auto BBWeight = getInitialEstimatedBlockWeight(BB))
  727. // If we were able to find estimated weight for the block set it to this
  728. // block and propagate up the IR.
  729. propagateEstimatedBlockWeight(getLoopBlock(BB), DT, PDT, *BBWeight,
  730. BlockWorkList, LoopWorkList);
  731. // BlockWorklist/LoopWorkList contains blocks/loops with at least one
  732. // successor/exit having estimated weight. Try to propagate weight to such
  733. // blocks/loops from successors/exits.
  734. // Process loops and blocks. Order is not important.
  735. do {
  736. while (!LoopWorkList.empty()) {
  737. const LoopBlock LoopBB = LoopWorkList.pop_back_val();
  738. if (EstimatedLoopWeight.count(LoopBB.getLoopData()))
  739. continue;
  740. SmallVector<BasicBlock *, 4> Exits;
  741. getLoopExitBlocks(LoopBB, Exits);
  742. auto LoopWeight = getMaxEstimatedEdgeWeight(
  743. LoopBB, make_range(Exits.begin(), Exits.end()));
  744. if (LoopWeight) {
  745. // If we never exit the loop then we can enter it once at maximum.
  746. if (LoopWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
  747. LoopWeight = static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO);
  748. EstimatedLoopWeight.insert({LoopBB.getLoopData(), *LoopWeight});
  749. // Add all blocks entering the loop into working list.
  750. getLoopEnterBlocks(LoopBB, BlockWorkList);
  751. }
  752. }
  753. while (!BlockWorkList.empty()) {
  754. // We can reach here only if BlockWorkList is not empty.
  755. const BasicBlock *BB = BlockWorkList.pop_back_val();
  756. if (EstimatedBlockWeight.count(BB))
  757. continue;
  758. // We take maximum over all weights of successors. In other words we take
  759. // weight of "hot" path. In theory we can probably find a better function
  760. // which gives higher accuracy results (comparing to "maximum") but I
  761. // can't
  762. // think of any right now. And I doubt it will make any difference in
  763. // practice.
  764. const LoopBlock LoopBB = getLoopBlock(BB);
  765. auto MaxWeight = getMaxEstimatedEdgeWeight(LoopBB, successors(BB));
  766. if (MaxWeight)
  767. propagateEstimatedBlockWeight(LoopBB, DT, PDT, *MaxWeight,
  768. BlockWorkList, LoopWorkList);
  769. }
  770. } while (!BlockWorkList.empty() || !LoopWorkList.empty());
  771. }
  772. // Calculate edge probabilities based on block's estimated weight.
  773. // Note that gathered weights were not scaled for loops. Thus edges entering
  774. // and exiting loops requires special processing.
  775. bool BranchProbabilityInfo::calcEstimatedHeuristics(const BasicBlock *BB) {
  776. assert(BB->getTerminator()->getNumSuccessors() > 1 &&
  777. "expected more than one successor!");
  778. const LoopBlock LoopBB = getLoopBlock(BB);
  779. SmallPtrSet<const BasicBlock *, 8> UnlikelyBlocks;
  780. uint32_t TC = LBH_TAKEN_WEIGHT / LBH_NONTAKEN_WEIGHT;
  781. if (LoopBB.getLoop())
  782. computeUnlikelySuccessors(BB, LoopBB.getLoop(), UnlikelyBlocks);
  783. // Changed to 'true' if at least one successor has estimated weight.
  784. bool FoundEstimatedWeight = false;
  785. SmallVector<uint32_t, 4> SuccWeights;
  786. uint64_t TotalWeight = 0;
  787. // Go over all successors of BB and put their weights into SuccWeights.
  788. for (const BasicBlock *SuccBB : successors(BB)) {
  789. std::optional<uint32_t> Weight;
  790. const LoopBlock SuccLoopBB = getLoopBlock(SuccBB);
  791. const LoopEdge Edge{LoopBB, SuccLoopBB};
  792. Weight = getEstimatedEdgeWeight(Edge);
  793. if (isLoopExitingEdge(Edge) &&
  794. // Avoid adjustment of ZERO weight since it should remain unchanged.
  795. Weight != static_cast<uint32_t>(BlockExecWeight::ZERO)) {
  796. // Scale down loop exiting weight by trip count.
  797. Weight = std::max(
  798. static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO),
  799. Weight.value_or(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) /
  800. TC);
  801. }
  802. bool IsUnlikelyEdge = LoopBB.getLoop() && UnlikelyBlocks.contains(SuccBB);
  803. if (IsUnlikelyEdge &&
  804. // Avoid adjustment of ZERO weight since it should remain unchanged.
  805. Weight != static_cast<uint32_t>(BlockExecWeight::ZERO)) {
  806. // 'Unlikely' blocks have twice lower weight.
  807. Weight = std::max(
  808. static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO),
  809. Weight.value_or(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) / 2);
  810. }
  811. if (Weight)
  812. FoundEstimatedWeight = true;
  813. auto WeightVal =
  814. Weight.value_or(static_cast<uint32_t>(BlockExecWeight::DEFAULT));
  815. TotalWeight += WeightVal;
  816. SuccWeights.push_back(WeightVal);
  817. }
  818. // If non of blocks have estimated weight bail out.
  819. // If TotalWeight is 0 that means weight of each successor is 0 as well and
  820. // equally likely. Bail out early to not deal with devision by zero.
  821. if (!FoundEstimatedWeight || TotalWeight == 0)
  822. return false;
  823. assert(SuccWeights.size() == succ_size(BB) && "Missed successor?");
  824. const unsigned SuccCount = SuccWeights.size();
  825. // If the sum of weights does not fit in 32 bits, scale every weight down
  826. // accordingly.
  827. if (TotalWeight > UINT32_MAX) {
  828. uint64_t ScalingFactor = TotalWeight / UINT32_MAX + 1;
  829. TotalWeight = 0;
  830. for (unsigned Idx = 0; Idx < SuccCount; ++Idx) {
  831. SuccWeights[Idx] /= ScalingFactor;
  832. if (SuccWeights[Idx] == static_cast<uint32_t>(BlockExecWeight::ZERO))
  833. SuccWeights[Idx] =
  834. static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO);
  835. TotalWeight += SuccWeights[Idx];
  836. }
  837. assert(TotalWeight <= UINT32_MAX && "Total weight overflows");
  838. }
  839. // Finally set probabilities to edges according to estimated block weights.
  840. SmallVector<BranchProbability, 4> EdgeProbabilities(
  841. SuccCount, BranchProbability::getUnknown());
  842. for (unsigned Idx = 0; Idx < SuccCount; ++Idx) {
  843. EdgeProbabilities[Idx] =
  844. BranchProbability(SuccWeights[Idx], (uint32_t)TotalWeight);
  845. }
  846. setEdgeProbability(BB, EdgeProbabilities);
  847. return true;
  848. }
  849. bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
  850. const TargetLibraryInfo *TLI) {
  851. const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  852. if (!BI || !BI->isConditional())
  853. return false;
  854. Value *Cond = BI->getCondition();
  855. ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
  856. if (!CI)
  857. return false;
  858. auto GetConstantInt = [](Value *V) {
  859. if (auto *I = dyn_cast<BitCastInst>(V))
  860. return dyn_cast<ConstantInt>(I->getOperand(0));
  861. return dyn_cast<ConstantInt>(V);
  862. };
  863. Value *RHS = CI->getOperand(1);
  864. ConstantInt *CV = GetConstantInt(RHS);
  865. if (!CV)
  866. return false;
  867. // If the LHS is the result of AND'ing a value with a single bit bitmask,
  868. // we don't have information about probabilities.
  869. if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
  870. if (LHS->getOpcode() == Instruction::And)
  871. if (ConstantInt *AndRHS = GetConstantInt(LHS->getOperand(1)))
  872. if (AndRHS->getValue().isPowerOf2())
  873. return false;
  874. // Check if the LHS is the return value of a library function
  875. LibFunc Func = NumLibFuncs;
  876. if (TLI)
  877. if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
  878. if (Function *CalledFn = Call->getCalledFunction())
  879. TLI->getLibFunc(*CalledFn, Func);
  880. ProbabilityTable::const_iterator Search;
  881. if (Func == LibFunc_strcasecmp ||
  882. Func == LibFunc_strcmp ||
  883. Func == LibFunc_strncasecmp ||
  884. Func == LibFunc_strncmp ||
  885. Func == LibFunc_memcmp ||
  886. Func == LibFunc_bcmp) {
  887. Search = ICmpWithLibCallTable.find(CI->getPredicate());
  888. if (Search == ICmpWithLibCallTable.end())
  889. return false;
  890. } else if (CV->isZero()) {
  891. Search = ICmpWithZeroTable.find(CI->getPredicate());
  892. if (Search == ICmpWithZeroTable.end())
  893. return false;
  894. } else if (CV->isOne()) {
  895. Search = ICmpWithOneTable.find(CI->getPredicate());
  896. if (Search == ICmpWithOneTable.end())
  897. return false;
  898. } else if (CV->isMinusOne()) {
  899. Search = ICmpWithMinusOneTable.find(CI->getPredicate());
  900. if (Search == ICmpWithMinusOneTable.end())
  901. return false;
  902. } else {
  903. return false;
  904. }
  905. setEdgeProbability(BB, Search->second);
  906. return true;
  907. }
  908. bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
  909. const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  910. if (!BI || !BI->isConditional())
  911. return false;
  912. Value *Cond = BI->getCondition();
  913. FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
  914. if (!FCmp)
  915. return false;
  916. ProbabilityList ProbList;
  917. if (FCmp->isEquality()) {
  918. ProbList = !FCmp->isTrueWhenEqual() ?
  919. // f1 == f2 -> Unlikely
  920. ProbabilityList({FPTakenProb, FPUntakenProb}) :
  921. // f1 != f2 -> Likely
  922. ProbabilityList({FPUntakenProb, FPTakenProb});
  923. } else {
  924. auto Search = FCmpTable.find(FCmp->getPredicate());
  925. if (Search == FCmpTable.end())
  926. return false;
  927. ProbList = Search->second;
  928. }
  929. setEdgeProbability(BB, ProbList);
  930. return true;
  931. }
  932. void BranchProbabilityInfo::releaseMemory() {
  933. Probs.clear();
  934. Handles.clear();
  935. }
  936. bool BranchProbabilityInfo::invalidate(Function &, const PreservedAnalyses &PA,
  937. FunctionAnalysisManager::Invalidator &) {
  938. // Check whether the analysis, all analyses on functions, or the function's
  939. // CFG have been preserved.
  940. auto PAC = PA.getChecker<BranchProbabilityAnalysis>();
  941. return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
  942. PAC.preservedSet<CFGAnalyses>());
  943. }
  944. void BranchProbabilityInfo::print(raw_ostream &OS) const {
  945. OS << "---- Branch Probabilities ----\n";
  946. // We print the probabilities from the last function the analysis ran over,
  947. // or the function it is currently running over.
  948. assert(LastF && "Cannot print prior to running over a function");
  949. for (const auto &BI : *LastF) {
  950. for (const BasicBlock *Succ : successors(&BI))
  951. printEdgeProbability(OS << " ", &BI, Succ);
  952. }
  953. }
  954. bool BranchProbabilityInfo::
  955. isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
  956. // Hot probability is at least 4/5 = 80%
  957. // FIXME: Compare against a static "hot" BranchProbability.
  958. return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
  959. }
  960. /// Get the raw edge probability for the edge. If can't find it, return a
  961. /// default probability 1/N where N is the number of successors. Here an edge is
  962. /// specified using PredBlock and an
  963. /// index to the successors.
  964. BranchProbability
  965. BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
  966. unsigned IndexInSuccessors) const {
  967. auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
  968. assert((Probs.end() == Probs.find(std::make_pair(Src, 0))) ==
  969. (Probs.end() == I) &&
  970. "Probability for I-th successor must always be defined along with the "
  971. "probability for the first successor");
  972. if (I != Probs.end())
  973. return I->second;
  974. return {1, static_cast<uint32_t>(succ_size(Src))};
  975. }
  976. BranchProbability
  977. BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
  978. const_succ_iterator Dst) const {
  979. return getEdgeProbability(Src, Dst.getSuccessorIndex());
  980. }
  981. /// Get the raw edge probability calculated for the block pair. This returns the
  982. /// sum of all raw edge probabilities from Src to Dst.
  983. BranchProbability
  984. BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
  985. const BasicBlock *Dst) const {
  986. if (!Probs.count(std::make_pair(Src, 0)))
  987. return BranchProbability(llvm::count(successors(Src), Dst), succ_size(Src));
  988. auto Prob = BranchProbability::getZero();
  989. for (const_succ_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
  990. if (*I == Dst)
  991. Prob += Probs.find(std::make_pair(Src, I.getSuccessorIndex()))->second;
  992. return Prob;
  993. }
  994. /// Set the edge probability for all edges at once.
  995. void BranchProbabilityInfo::setEdgeProbability(
  996. const BasicBlock *Src, const SmallVectorImpl<BranchProbability> &Probs) {
  997. assert(Src->getTerminator()->getNumSuccessors() == Probs.size());
  998. eraseBlock(Src); // Erase stale data if any.
  999. if (Probs.size() == 0)
  1000. return; // Nothing to set.
  1001. Handles.insert(BasicBlockCallbackVH(Src, this));
  1002. uint64_t TotalNumerator = 0;
  1003. for (unsigned SuccIdx = 0; SuccIdx < Probs.size(); ++SuccIdx) {
  1004. this->Probs[std::make_pair(Src, SuccIdx)] = Probs[SuccIdx];
  1005. LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> " << SuccIdx
  1006. << " successor probability to " << Probs[SuccIdx]
  1007. << "\n");
  1008. TotalNumerator += Probs[SuccIdx].getNumerator();
  1009. }
  1010. // Because of rounding errors the total probability cannot be checked to be
  1011. // 1.0 exactly. That is TotalNumerator == BranchProbability::getDenominator.
  1012. // Instead, every single probability in Probs must be as accurate as possible.
  1013. // This results in error 1/denominator at most, thus the total absolute error
  1014. // should be within Probs.size / BranchProbability::getDenominator.
  1015. assert(TotalNumerator <= BranchProbability::getDenominator() + Probs.size());
  1016. assert(TotalNumerator >= BranchProbability::getDenominator() - Probs.size());
  1017. (void)TotalNumerator;
  1018. }
  1019. void BranchProbabilityInfo::copyEdgeProbabilities(BasicBlock *Src,
  1020. BasicBlock *Dst) {
  1021. eraseBlock(Dst); // Erase stale data if any.
  1022. unsigned NumSuccessors = Src->getTerminator()->getNumSuccessors();
  1023. assert(NumSuccessors == Dst->getTerminator()->getNumSuccessors());
  1024. if (NumSuccessors == 0)
  1025. return; // Nothing to set.
  1026. if (this->Probs.find(std::make_pair(Src, 0)) == this->Probs.end())
  1027. return; // No probability is set for edges from Src. Keep the same for Dst.
  1028. Handles.insert(BasicBlockCallbackVH(Dst, this));
  1029. for (unsigned SuccIdx = 0; SuccIdx < NumSuccessors; ++SuccIdx) {
  1030. auto Prob = this->Probs[std::make_pair(Src, SuccIdx)];
  1031. this->Probs[std::make_pair(Dst, SuccIdx)] = Prob;
  1032. LLVM_DEBUG(dbgs() << "set edge " << Dst->getName() << " -> " << SuccIdx
  1033. << " successor probability to " << Prob << "\n");
  1034. }
  1035. }
  1036. raw_ostream &
  1037. BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
  1038. const BasicBlock *Src,
  1039. const BasicBlock *Dst) const {
  1040. const BranchProbability Prob = getEdgeProbability(Src, Dst);
  1041. OS << "edge " << Src->getName() << " -> " << Dst->getName()
  1042. << " probability is " << Prob
  1043. << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
  1044. return OS;
  1045. }
  1046. void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
  1047. LLVM_DEBUG(dbgs() << "eraseBlock " << BB->getName() << "\n");
  1048. // Note that we cannot use successors of BB because the terminator of BB may
  1049. // have changed when eraseBlock is called as a BasicBlockCallbackVH callback.
  1050. // Instead we remove prob data for the block by iterating successors by their
  1051. // indices from 0 till the last which exists. There could not be prob data for
  1052. // a pair (BB, N) if there is no data for (BB, N-1) because the data is always
  1053. // set for all successors from 0 to M at once by the method
  1054. // setEdgeProbability().
  1055. Handles.erase(BasicBlockCallbackVH(BB, this));
  1056. for (unsigned I = 0;; ++I) {
  1057. auto MapI = Probs.find(std::make_pair(BB, I));
  1058. if (MapI == Probs.end()) {
  1059. assert(Probs.count(std::make_pair(BB, I + 1)) == 0 &&
  1060. "Must be no more successors");
  1061. return;
  1062. }
  1063. Probs.erase(MapI);
  1064. }
  1065. }
  1066. void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LoopI,
  1067. const TargetLibraryInfo *TLI,
  1068. DominatorTree *DT,
  1069. PostDominatorTree *PDT) {
  1070. LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
  1071. << " ----\n\n");
  1072. LastF = &F; // Store the last function we ran on for printing.
  1073. LI = &LoopI;
  1074. SccI = std::make_unique<SccInfo>(F);
  1075. assert(EstimatedBlockWeight.empty());
  1076. assert(EstimatedLoopWeight.empty());
  1077. std::unique_ptr<DominatorTree> DTPtr;
  1078. std::unique_ptr<PostDominatorTree> PDTPtr;
  1079. if (!DT) {
  1080. DTPtr = std::make_unique<DominatorTree>(const_cast<Function &>(F));
  1081. DT = DTPtr.get();
  1082. }
  1083. if (!PDT) {
  1084. PDTPtr = std::make_unique<PostDominatorTree>(const_cast<Function &>(F));
  1085. PDT = PDTPtr.get();
  1086. }
  1087. computeEestimateBlockWeight(F, DT, PDT);
  1088. // Walk the basic blocks in post-order so that we can build up state about
  1089. // the successors of a block iteratively.
  1090. for (const auto *BB : post_order(&F.getEntryBlock())) {
  1091. LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
  1092. << "\n");
  1093. // If there is no at least two successors, no sense to set probability.
  1094. if (BB->getTerminator()->getNumSuccessors() < 2)
  1095. continue;
  1096. if (calcMetadataWeights(BB))
  1097. continue;
  1098. if (calcEstimatedHeuristics(BB))
  1099. continue;
  1100. if (calcPointerHeuristics(BB))
  1101. continue;
  1102. if (calcZeroHeuristics(BB, TLI))
  1103. continue;
  1104. if (calcFloatingPointHeuristics(BB))
  1105. continue;
  1106. }
  1107. EstimatedLoopWeight.clear();
  1108. EstimatedBlockWeight.clear();
  1109. SccI.reset();
  1110. if (PrintBranchProb &&
  1111. (PrintBranchProbFuncName.empty() ||
  1112. F.getName().equals(PrintBranchProbFuncName))) {
  1113. print(dbgs());
  1114. }
  1115. }
  1116. void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
  1117. AnalysisUsage &AU) const {
  1118. // We require DT so it's available when LI is available. The LI updating code
  1119. // asserts that DT is also present so if we don't make sure that we have DT
  1120. // here, that assert will trigger.
  1121. AU.addRequired<DominatorTreeWrapperPass>();
  1122. AU.addRequired<LoopInfoWrapperPass>();
  1123. AU.addRequired<TargetLibraryInfoWrapperPass>();
  1124. AU.addRequired<DominatorTreeWrapperPass>();
  1125. AU.addRequired<PostDominatorTreeWrapperPass>();
  1126. AU.setPreservesAll();
  1127. }
  1128. bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
  1129. const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  1130. const TargetLibraryInfo &TLI =
  1131. getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  1132. DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  1133. PostDominatorTree &PDT =
  1134. getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
  1135. BPI.calculate(F, LI, &TLI, &DT, &PDT);
  1136. return false;
  1137. }
  1138. void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
  1139. void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
  1140. const Module *) const {
  1141. BPI.print(OS);
  1142. }
  1143. AnalysisKey BranchProbabilityAnalysis::Key;
  1144. BranchProbabilityInfo
  1145. BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
  1146. BranchProbabilityInfo BPI;
  1147. BPI.calculate(F, AM.getResult<LoopAnalysis>(F),
  1148. &AM.getResult<TargetLibraryAnalysis>(F),
  1149. &AM.getResult<DominatorTreeAnalysis>(F),
  1150. &AM.getResult<PostDominatorTreeAnalysis>(F));
  1151. return BPI;
  1152. }
  1153. PreservedAnalyses
  1154. BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
  1155. OS << "Printing analysis results of BPI for function "
  1156. << "'" << F.getName() << "':"
  1157. << "\n";
  1158. AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
  1159. return PreservedAnalyses::all();
  1160. }