BranchProbabilityInfo.cpp 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333
  1. //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Loops should be simplified before this analysis.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Analysis/BranchProbabilityInfo.h"
  13. #include "llvm/ADT/PostOrderIterator.h"
  14. #include "llvm/ADT/SCCIterator.h"
  15. #include "llvm/ADT/STLExtras.h"
  16. #include "llvm/ADT/SmallVector.h"
  17. #include "llvm/Analysis/LoopInfo.h"
  18. #include "llvm/Analysis/PostDominators.h"
  19. #include "llvm/Analysis/TargetLibraryInfo.h"
  20. #include "llvm/IR/Attributes.h"
  21. #include "llvm/IR/BasicBlock.h"
  22. #include "llvm/IR/CFG.h"
  23. #include "llvm/IR/Constants.h"
  24. #include "llvm/IR/Dominators.h"
  25. #include "llvm/IR/Function.h"
  26. #include "llvm/IR/InstrTypes.h"
  27. #include "llvm/IR/Instruction.h"
  28. #include "llvm/IR/Instructions.h"
  29. #include "llvm/IR/LLVMContext.h"
  30. #include "llvm/IR/Metadata.h"
  31. #include "llvm/IR/PassManager.h"
  32. #include "llvm/IR/Type.h"
  33. #include "llvm/IR/Value.h"
  34. #include "llvm/InitializePasses.h"
  35. #include "llvm/Pass.h"
  36. #include "llvm/Support/BranchProbability.h"
  37. #include "llvm/Support/Casting.h"
  38. #include "llvm/Support/CommandLine.h"
  39. #include "llvm/Support/Debug.h"
  40. #include "llvm/Support/raw_ostream.h"
  41. #include <cassert>
  42. #include <cstdint>
  43. #include <iterator>
  44. #include <map>
  45. #include <utility>
  46. using namespace llvm;
  47. #define DEBUG_TYPE "branch-prob"
  48. static cl::opt<bool> PrintBranchProb(
  49. "print-bpi", cl::init(false), cl::Hidden,
  50. cl::desc("Print the branch probability info."));
  51. cl::opt<std::string> PrintBranchProbFuncName(
  52. "print-bpi-func-name", cl::Hidden,
  53. cl::desc("The option to specify the name of the function "
  54. "whose branch probability info is printed."));
  55. INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
  56. "Branch Probability Analysis", false, true)
  57. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  58. INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
  59. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  60. INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
  61. INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
  62. "Branch Probability Analysis", false, true)
  63. BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
  64. : FunctionPass(ID) {
  65. initializeBranchProbabilityInfoWrapperPassPass(
  66. *PassRegistry::getPassRegistry());
  67. }
  68. char BranchProbabilityInfoWrapperPass::ID = 0;
  69. // Weights are for internal use only. They are used by heuristics to help to
  70. // estimate edges' probability. Example:
  71. //
  72. // Using "Loop Branch Heuristics" we predict weights of edges for the
  73. // block BB2.
  74. // ...
  75. // |
  76. // V
  77. // BB1<-+
  78. // | |
  79. // | | (Weight = 124)
  80. // V |
  81. // BB2--+
  82. // |
  83. // | (Weight = 4)
  84. // V
  85. // BB3
  86. //
  87. // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
  88. // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
  89. static const uint32_t LBH_TAKEN_WEIGHT = 124;
  90. static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
  91. /// Unreachable-terminating branch taken probability.
  92. ///
  93. /// This is the probability for a branch being taken to a block that terminates
  94. /// (eventually) in unreachable. These are predicted as unlikely as possible.
  95. /// All reachable probability will proportionally share the remaining part.
  96. static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
  97. /// Heuristics and lookup tables for non-loop branches:
  98. /// Pointer Heuristics (PH)
  99. static const uint32_t PH_TAKEN_WEIGHT = 20;
  100. static const uint32_t PH_NONTAKEN_WEIGHT = 12;
  101. static const BranchProbability
  102. PtrTakenProb(PH_TAKEN_WEIGHT, PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
  103. static const BranchProbability
  104. PtrUntakenProb(PH_NONTAKEN_WEIGHT, PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
  105. using ProbabilityList = SmallVector<BranchProbability>;
  106. using ProbabilityTable = std::map<CmpInst::Predicate, ProbabilityList>;
  107. /// Pointer comparisons:
  108. static const ProbabilityTable PointerTable{
  109. {ICmpInst::ICMP_NE, {PtrTakenProb, PtrUntakenProb}}, /// p != q -> Likely
  110. {ICmpInst::ICMP_EQ, {PtrUntakenProb, PtrTakenProb}}, /// p == q -> Unlikely
  111. };
  112. /// Zero Heuristics (ZH)
  113. static const uint32_t ZH_TAKEN_WEIGHT = 20;
  114. static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
  115. static const BranchProbability
  116. ZeroTakenProb(ZH_TAKEN_WEIGHT, ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
  117. static const BranchProbability
  118. ZeroUntakenProb(ZH_NONTAKEN_WEIGHT, ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
  119. /// Integer compares with 0:
  120. static const ProbabilityTable ICmpWithZeroTable{
  121. {CmpInst::ICMP_EQ, {ZeroUntakenProb, ZeroTakenProb}}, /// X == 0 -> Unlikely
  122. {CmpInst::ICMP_NE, {ZeroTakenProb, ZeroUntakenProb}}, /// X != 0 -> Likely
  123. {CmpInst::ICMP_SLT, {ZeroUntakenProb, ZeroTakenProb}}, /// X < 0 -> Unlikely
  124. {CmpInst::ICMP_SGT, {ZeroTakenProb, ZeroUntakenProb}}, /// X > 0 -> Likely
  125. };
  126. /// Integer compares with -1:
  127. static const ProbabilityTable ICmpWithMinusOneTable{
  128. {CmpInst::ICMP_EQ, {ZeroUntakenProb, ZeroTakenProb}}, /// X == -1 -> Unlikely
  129. {CmpInst::ICMP_NE, {ZeroTakenProb, ZeroUntakenProb}}, /// X != -1 -> Likely
  130. // InstCombine canonicalizes X >= 0 into X > -1
  131. {CmpInst::ICMP_SGT, {ZeroTakenProb, ZeroUntakenProb}}, /// X >= 0 -> Likely
  132. };
  133. /// Integer compares with 1:
  134. static const ProbabilityTable ICmpWithOneTable{
  135. // InstCombine canonicalizes X <= 0 into X < 1
  136. {CmpInst::ICMP_SLT, {ZeroUntakenProb, ZeroTakenProb}}, /// X <= 0 -> Unlikely
  137. };
  138. /// strcmp and similar functions return zero, negative, or positive, if the
  139. /// first string is equal, less, or greater than the second. We consider it
  140. /// likely that the strings are not equal, so a comparison with zero is
  141. /// probably false, but also a comparison with any other number is also
  142. /// probably false given that what exactly is returned for nonzero values is
  143. /// not specified. Any kind of comparison other than equality we know
  144. /// nothing about.
  145. static const ProbabilityTable ICmpWithLibCallTable{
  146. {CmpInst::ICMP_EQ, {ZeroUntakenProb, ZeroTakenProb}},
  147. {CmpInst::ICMP_NE, {ZeroTakenProb, ZeroUntakenProb}},
  148. };
  149. // Floating-Point Heuristics (FPH)
  150. static const uint32_t FPH_TAKEN_WEIGHT = 20;
  151. static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
  152. /// This is the probability for an ordered floating point comparison.
  153. static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
  154. /// This is the probability for an unordered floating point comparison, it means
  155. /// one or two of the operands are NaN. Usually it is used to test for an
  156. /// exceptional case, so the result is unlikely.
  157. static const uint32_t FPH_UNO_WEIGHT = 1;
  158. static const BranchProbability FPOrdTakenProb(FPH_ORD_WEIGHT,
  159. FPH_ORD_WEIGHT + FPH_UNO_WEIGHT);
  160. static const BranchProbability
  161. FPOrdUntakenProb(FPH_UNO_WEIGHT, FPH_ORD_WEIGHT + FPH_UNO_WEIGHT);
  162. static const BranchProbability
  163. FPTakenProb(FPH_TAKEN_WEIGHT, FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT);
  164. static const BranchProbability
  165. FPUntakenProb(FPH_NONTAKEN_WEIGHT, FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT);
  166. /// Floating-Point compares:
  167. static const ProbabilityTable FCmpTable{
  168. {FCmpInst::FCMP_ORD, {FPOrdTakenProb, FPOrdUntakenProb}}, /// !isnan -> Likely
  169. {FCmpInst::FCMP_UNO, {FPOrdUntakenProb, FPOrdTakenProb}}, /// isnan -> Unlikely
  170. };
  171. /// Set of dedicated "absolute" execution weights for a block. These weights are
  172. /// meaningful relative to each other and their derivatives only.
  173. enum class BlockExecWeight : std::uint32_t {
  174. /// Special weight used for cases with exact zero probability.
  175. ZERO = 0x0,
  176. /// Minimal possible non zero weight.
  177. LOWEST_NON_ZERO = 0x1,
  178. /// Weight to an 'unreachable' block.
  179. UNREACHABLE = ZERO,
  180. /// Weight to a block containing non returning call.
  181. NORETURN = LOWEST_NON_ZERO,
  182. /// Weight to 'unwind' block of an invoke instruction.
  183. UNWIND = LOWEST_NON_ZERO,
  184. /// Weight to a 'cold' block. Cold blocks are the ones containing calls marked
  185. /// with attribute 'cold'.
  186. COLD = 0xffff,
  187. /// Default weight is used in cases when there is no dedicated execution
  188. /// weight set. It is not propagated through the domination line either.
  189. DEFAULT = 0xfffff
  190. };
  191. BranchProbabilityInfo::SccInfo::SccInfo(const Function &F) {
  192. // Record SCC numbers of blocks in the CFG to identify irreducible loops.
  193. // FIXME: We could only calculate this if the CFG is known to be irreducible
  194. // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
  195. int SccNum = 0;
  196. for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
  197. ++It, ++SccNum) {
  198. // Ignore single-block SCCs since they either aren't loops or LoopInfo will
  199. // catch them.
  200. const std::vector<const BasicBlock *> &Scc = *It;
  201. if (Scc.size() == 1)
  202. continue;
  203. LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
  204. for (const auto *BB : Scc) {
  205. LLVM_DEBUG(dbgs() << " " << BB->getName());
  206. SccNums[BB] = SccNum;
  207. calculateSccBlockType(BB, SccNum);
  208. }
  209. LLVM_DEBUG(dbgs() << "\n");
  210. }
  211. }
  212. int BranchProbabilityInfo::SccInfo::getSCCNum(const BasicBlock *BB) const {
  213. auto SccIt = SccNums.find(BB);
  214. if (SccIt == SccNums.end())
  215. return -1;
  216. return SccIt->second;
  217. }
  218. void BranchProbabilityInfo::SccInfo::getSccEnterBlocks(
  219. int SccNum, SmallVectorImpl<BasicBlock *> &Enters) const {
  220. for (auto MapIt : SccBlocks[SccNum]) {
  221. const auto *BB = MapIt.first;
  222. if (isSCCHeader(BB, SccNum))
  223. for (const auto *Pred : predecessors(BB))
  224. if (getSCCNum(Pred) != SccNum)
  225. Enters.push_back(const_cast<BasicBlock *>(BB));
  226. }
  227. }
  228. void BranchProbabilityInfo::SccInfo::getSccExitBlocks(
  229. int SccNum, SmallVectorImpl<BasicBlock *> &Exits) const {
  230. for (auto MapIt : SccBlocks[SccNum]) {
  231. const auto *BB = MapIt.first;
  232. if (isSCCExitingBlock(BB, SccNum))
  233. for (const auto *Succ : successors(BB))
  234. if (getSCCNum(Succ) != SccNum)
  235. Exits.push_back(const_cast<BasicBlock *>(Succ));
  236. }
  237. }
  238. uint32_t BranchProbabilityInfo::SccInfo::getSccBlockType(const BasicBlock *BB,
  239. int SccNum) const {
  240. assert(getSCCNum(BB) == SccNum);
  241. assert(SccBlocks.size() > static_cast<unsigned>(SccNum) && "Unknown SCC");
  242. const auto &SccBlockTypes = SccBlocks[SccNum];
  243. auto It = SccBlockTypes.find(BB);
  244. if (It != SccBlockTypes.end()) {
  245. return It->second;
  246. }
  247. return Inner;
  248. }
  249. void BranchProbabilityInfo::SccInfo::calculateSccBlockType(const BasicBlock *BB,
  250. int SccNum) {
  251. assert(getSCCNum(BB) == SccNum);
  252. uint32_t BlockType = Inner;
  253. if (llvm::any_of(predecessors(BB), [&](const BasicBlock *Pred) {
  254. // Consider any block that is an entry point to the SCC as
  255. // a header.
  256. return getSCCNum(Pred) != SccNum;
  257. }))
  258. BlockType |= Header;
  259. if (llvm::any_of(successors(BB), [&](const BasicBlock *Succ) {
  260. return getSCCNum(Succ) != SccNum;
  261. }))
  262. BlockType |= Exiting;
  263. // Lazily compute the set of headers for a given SCC and cache the results
  264. // in the SccHeaderMap.
  265. if (SccBlocks.size() <= static_cast<unsigned>(SccNum))
  266. SccBlocks.resize(SccNum + 1);
  267. auto &SccBlockTypes = SccBlocks[SccNum];
  268. if (BlockType != Inner) {
  269. bool IsInserted;
  270. std::tie(std::ignore, IsInserted) =
  271. SccBlockTypes.insert(std::make_pair(BB, BlockType));
  272. assert(IsInserted && "Duplicated block in SCC");
  273. }
  274. }
  275. BranchProbabilityInfo::LoopBlock::LoopBlock(const BasicBlock *BB,
  276. const LoopInfo &LI,
  277. const SccInfo &SccI)
  278. : BB(BB) {
  279. LD.first = LI.getLoopFor(BB);
  280. if (!LD.first) {
  281. LD.second = SccI.getSCCNum(BB);
  282. }
  283. }
  284. bool BranchProbabilityInfo::isLoopEnteringEdge(const LoopEdge &Edge) const {
  285. const auto &SrcBlock = Edge.first;
  286. const auto &DstBlock = Edge.second;
  287. return (DstBlock.getLoop() &&
  288. !DstBlock.getLoop()->contains(SrcBlock.getLoop())) ||
  289. // Assume that SCCs can't be nested.
  290. (DstBlock.getSccNum() != -1 &&
  291. SrcBlock.getSccNum() != DstBlock.getSccNum());
  292. }
  293. bool BranchProbabilityInfo::isLoopExitingEdge(const LoopEdge &Edge) const {
  294. return isLoopEnteringEdge({Edge.second, Edge.first});
  295. }
  296. bool BranchProbabilityInfo::isLoopEnteringExitingEdge(
  297. const LoopEdge &Edge) const {
  298. return isLoopEnteringEdge(Edge) || isLoopExitingEdge(Edge);
  299. }
  300. bool BranchProbabilityInfo::isLoopBackEdge(const LoopEdge &Edge) const {
  301. const auto &SrcBlock = Edge.first;
  302. const auto &DstBlock = Edge.second;
  303. return SrcBlock.belongsToSameLoop(DstBlock) &&
  304. ((DstBlock.getLoop() &&
  305. DstBlock.getLoop()->getHeader() == DstBlock.getBlock()) ||
  306. (DstBlock.getSccNum() != -1 &&
  307. SccI->isSCCHeader(DstBlock.getBlock(), DstBlock.getSccNum())));
  308. }
  309. void BranchProbabilityInfo::getLoopEnterBlocks(
  310. const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Enters) const {
  311. if (LB.getLoop()) {
  312. auto *Header = LB.getLoop()->getHeader();
  313. Enters.append(pred_begin(Header), pred_end(Header));
  314. } else {
  315. assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
  316. SccI->getSccEnterBlocks(LB.getSccNum(), Enters);
  317. }
  318. }
  319. void BranchProbabilityInfo::getLoopExitBlocks(
  320. const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Exits) const {
  321. if (LB.getLoop()) {
  322. LB.getLoop()->getExitBlocks(Exits);
  323. } else {
  324. assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
  325. SccI->getSccExitBlocks(LB.getSccNum(), Exits);
  326. }
  327. }
  328. // Propagate existing explicit probabilities from either profile data or
  329. // 'expect' intrinsic processing. Examine metadata against unreachable
  330. // heuristic. The probability of the edge coming to unreachable block is
  331. // set to min of metadata and unreachable heuristic.
  332. bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
  333. const Instruction *TI = BB->getTerminator();
  334. assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
  335. if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI) ||
  336. isa<InvokeInst>(TI)))
  337. return false;
  338. MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
  339. if (!WeightsNode)
  340. return false;
  341. // Check that the number of successors is manageable.
  342. assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
  343. // Ensure there are weights for all of the successors. Note that the first
  344. // operand to the metadata node is a name, not a weight.
  345. if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
  346. return false;
  347. // Build up the final weights that will be used in a temporary buffer.
  348. // Compute the sum of all weights to later decide whether they need to
  349. // be scaled to fit in 32 bits.
  350. uint64_t WeightSum = 0;
  351. SmallVector<uint32_t, 2> Weights;
  352. SmallVector<unsigned, 2> UnreachableIdxs;
  353. SmallVector<unsigned, 2> ReachableIdxs;
  354. Weights.reserve(TI->getNumSuccessors());
  355. for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
  356. ConstantInt *Weight =
  357. mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
  358. if (!Weight)
  359. return false;
  360. assert(Weight->getValue().getActiveBits() <= 32 &&
  361. "Too many bits for uint32_t");
  362. Weights.push_back(Weight->getZExtValue());
  363. WeightSum += Weights.back();
  364. const LoopBlock SrcLoopBB = getLoopBlock(BB);
  365. const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1));
  366. auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
  367. if (EstimatedWeight &&
  368. EstimatedWeight.getValue() <=
  369. static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
  370. UnreachableIdxs.push_back(I - 1);
  371. else
  372. ReachableIdxs.push_back(I - 1);
  373. }
  374. assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
  375. // If the sum of weights does not fit in 32 bits, scale every weight down
  376. // accordingly.
  377. uint64_t ScalingFactor =
  378. (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
  379. if (ScalingFactor > 1) {
  380. WeightSum = 0;
  381. for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
  382. Weights[I] /= ScalingFactor;
  383. WeightSum += Weights[I];
  384. }
  385. }
  386. assert(WeightSum <= UINT32_MAX &&
  387. "Expected weights to scale down to 32 bits");
  388. if (WeightSum == 0 || ReachableIdxs.size() == 0) {
  389. for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
  390. Weights[I] = 1;
  391. WeightSum = TI->getNumSuccessors();
  392. }
  393. // Set the probability.
  394. SmallVector<BranchProbability, 2> BP;
  395. for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
  396. BP.push_back({ Weights[I], static_cast<uint32_t>(WeightSum) });
  397. // Examine the metadata against unreachable heuristic.
  398. // If the unreachable heuristic is more strong then we use it for this edge.
  399. if (UnreachableIdxs.size() == 0 || ReachableIdxs.size() == 0) {
  400. setEdgeProbability(BB, BP);
  401. return true;
  402. }
  403. auto UnreachableProb = UR_TAKEN_PROB;
  404. for (auto I : UnreachableIdxs)
  405. if (UnreachableProb < BP[I]) {
  406. BP[I] = UnreachableProb;
  407. }
  408. // Sum of all edge probabilities must be 1.0. If we modified the probability
  409. // of some edges then we must distribute the introduced difference over the
  410. // reachable blocks.
  411. //
  412. // Proportional distribution: the relation between probabilities of the
  413. // reachable edges is kept unchanged. That is for any reachable edges i and j:
  414. // newBP[i] / newBP[j] == oldBP[i] / oldBP[j] =>
  415. // newBP[i] / oldBP[i] == newBP[j] / oldBP[j] == K
  416. // Where K is independent of i,j.
  417. // newBP[i] == oldBP[i] * K
  418. // We need to find K.
  419. // Make sum of all reachables of the left and right parts:
  420. // sum_of_reachable(newBP) == K * sum_of_reachable(oldBP)
  421. // Sum of newBP must be equal to 1.0:
  422. // sum_of_reachable(newBP) + sum_of_unreachable(newBP) == 1.0 =>
  423. // sum_of_reachable(newBP) = 1.0 - sum_of_unreachable(newBP)
  424. // Where sum_of_unreachable(newBP) is what has been just changed.
  425. // Finally:
  426. // K == sum_of_reachable(newBP) / sum_of_reachable(oldBP) =>
  427. // K == (1.0 - sum_of_unreachable(newBP)) / sum_of_reachable(oldBP)
  428. BranchProbability NewUnreachableSum = BranchProbability::getZero();
  429. for (auto I : UnreachableIdxs)
  430. NewUnreachableSum += BP[I];
  431. BranchProbability NewReachableSum =
  432. BranchProbability::getOne() - NewUnreachableSum;
  433. BranchProbability OldReachableSum = BranchProbability::getZero();
  434. for (auto I : ReachableIdxs)
  435. OldReachableSum += BP[I];
  436. if (OldReachableSum != NewReachableSum) { // Anything to dsitribute?
  437. if (OldReachableSum.isZero()) {
  438. // If all oldBP[i] are zeroes then the proportional distribution results
  439. // in all zero probabilities and the error stays big. In this case we
  440. // evenly spread NewReachableSum over the reachable edges.
  441. BranchProbability PerEdge = NewReachableSum / ReachableIdxs.size();
  442. for (auto I : ReachableIdxs)
  443. BP[I] = PerEdge;
  444. } else {
  445. for (auto I : ReachableIdxs) {
  446. // We use uint64_t to avoid double rounding error of the following
  447. // calculation: BP[i] = BP[i] * NewReachableSum / OldReachableSum
  448. // The formula is taken from the private constructor
  449. // BranchProbability(uint32_t Numerator, uint32_t Denominator)
  450. uint64_t Mul = static_cast<uint64_t>(NewReachableSum.getNumerator()) *
  451. BP[I].getNumerator();
  452. uint32_t Div = static_cast<uint32_t>(
  453. divideNearest(Mul, OldReachableSum.getNumerator()));
  454. BP[I] = BranchProbability::getRaw(Div);
  455. }
  456. }
  457. }
  458. setEdgeProbability(BB, BP);
  459. return true;
  460. }
  461. // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
  462. // between two pointer or pointer and NULL will fail.
  463. bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
  464. const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  465. if (!BI || !BI->isConditional())
  466. return false;
  467. Value *Cond = BI->getCondition();
  468. ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
  469. if (!CI || !CI->isEquality())
  470. return false;
  471. Value *LHS = CI->getOperand(0);
  472. if (!LHS->getType()->isPointerTy())
  473. return false;
  474. assert(CI->getOperand(1)->getType()->isPointerTy());
  475. auto Search = PointerTable.find(CI->getPredicate());
  476. if (Search == PointerTable.end())
  477. return false;
  478. setEdgeProbability(BB, Search->second);
  479. return true;
  480. }
  481. // Compute the unlikely successors to the block BB in the loop L, specifically
  482. // those that are unlikely because this is a loop, and add them to the
  483. // UnlikelyBlocks set.
  484. static void
  485. computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
  486. SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
  487. // Sometimes in a loop we have a branch whose condition is made false by
  488. // taking it. This is typically something like
  489. // int n = 0;
  490. // while (...) {
  491. // if (++n >= MAX) {
  492. // n = 0;
  493. // }
  494. // }
  495. // In this sort of situation taking the branch means that at the very least it
  496. // won't be taken again in the next iteration of the loop, so we should
  497. // consider it less likely than a typical branch.
  498. //
  499. // We detect this by looking back through the graph of PHI nodes that sets the
  500. // value that the condition depends on, and seeing if we can reach a successor
  501. // block which can be determined to make the condition false.
  502. //
  503. // FIXME: We currently consider unlikely blocks to be half as likely as other
  504. // blocks, but if we consider the example above the likelyhood is actually
  505. // 1/MAX. We could therefore be more precise in how unlikely we consider
  506. // blocks to be, but it would require more careful examination of the form
  507. // of the comparison expression.
  508. const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  509. if (!BI || !BI->isConditional())
  510. return;
  511. // Check if the branch is based on an instruction compared with a constant
  512. CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
  513. if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
  514. !isa<Constant>(CI->getOperand(1)))
  515. return;
  516. // Either the instruction must be a PHI, or a chain of operations involving
  517. // constants that ends in a PHI which we can then collapse into a single value
  518. // if the PHI value is known.
  519. Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
  520. PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
  521. Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
  522. // Collect the instructions until we hit a PHI
  523. SmallVector<BinaryOperator *, 1> InstChain;
  524. while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
  525. isa<Constant>(CmpLHS->getOperand(1))) {
  526. // Stop if the chain extends outside of the loop
  527. if (!L->contains(CmpLHS))
  528. return;
  529. InstChain.push_back(cast<BinaryOperator>(CmpLHS));
  530. CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
  531. if (CmpLHS)
  532. CmpPHI = dyn_cast<PHINode>(CmpLHS);
  533. }
  534. if (!CmpPHI || !L->contains(CmpPHI))
  535. return;
  536. // Trace the phi node to find all values that come from successors of BB
  537. SmallPtrSet<PHINode*, 8> VisitedInsts;
  538. SmallVector<PHINode*, 8> WorkList;
  539. WorkList.push_back(CmpPHI);
  540. VisitedInsts.insert(CmpPHI);
  541. while (!WorkList.empty()) {
  542. PHINode *P = WorkList.pop_back_val();
  543. for (BasicBlock *B : P->blocks()) {
  544. // Skip blocks that aren't part of the loop
  545. if (!L->contains(B))
  546. continue;
  547. Value *V = P->getIncomingValueForBlock(B);
  548. // If the source is a PHI add it to the work list if we haven't
  549. // already visited it.
  550. if (PHINode *PN = dyn_cast<PHINode>(V)) {
  551. if (VisitedInsts.insert(PN).second)
  552. WorkList.push_back(PN);
  553. continue;
  554. }
  555. // If this incoming value is a constant and B is a successor of BB, then
  556. // we can constant-evaluate the compare to see if it makes the branch be
  557. // taken or not.
  558. Constant *CmpLHSConst = dyn_cast<Constant>(V);
  559. if (!CmpLHSConst || !llvm::is_contained(successors(BB), B))
  560. continue;
  561. // First collapse InstChain
  562. for (Instruction *I : llvm::reverse(InstChain)) {
  563. CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
  564. cast<Constant>(I->getOperand(1)), true);
  565. if (!CmpLHSConst)
  566. break;
  567. }
  568. if (!CmpLHSConst)
  569. continue;
  570. // Now constant-evaluate the compare
  571. Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
  572. CmpLHSConst, CmpConst, true);
  573. // If the result means we don't branch to the block then that block is
  574. // unlikely.
  575. if (Result &&
  576. ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
  577. (Result->isOneValue() && B == BI->getSuccessor(1))))
  578. UnlikelyBlocks.insert(B);
  579. }
  580. }
  581. }
  582. Optional<uint32_t>
  583. BranchProbabilityInfo::getEstimatedBlockWeight(const BasicBlock *BB) const {
  584. auto WeightIt = EstimatedBlockWeight.find(BB);
  585. if (WeightIt == EstimatedBlockWeight.end())
  586. return None;
  587. return WeightIt->second;
  588. }
  589. Optional<uint32_t>
  590. BranchProbabilityInfo::getEstimatedLoopWeight(const LoopData &L) const {
  591. auto WeightIt = EstimatedLoopWeight.find(L);
  592. if (WeightIt == EstimatedLoopWeight.end())
  593. return None;
  594. return WeightIt->second;
  595. }
  596. Optional<uint32_t>
  597. BranchProbabilityInfo::getEstimatedEdgeWeight(const LoopEdge &Edge) const {
  598. // For edges entering a loop take weight of a loop rather than an individual
  599. // block in the loop.
  600. return isLoopEnteringEdge(Edge)
  601. ? getEstimatedLoopWeight(Edge.second.getLoopData())
  602. : getEstimatedBlockWeight(Edge.second.getBlock());
  603. }
  604. template <class IterT>
  605. Optional<uint32_t> BranchProbabilityInfo::getMaxEstimatedEdgeWeight(
  606. const LoopBlock &SrcLoopBB, iterator_range<IterT> Successors) const {
  607. SmallVector<uint32_t, 4> Weights;
  608. Optional<uint32_t> MaxWeight;
  609. for (const BasicBlock *DstBB : Successors) {
  610. const LoopBlock DstLoopBB = getLoopBlock(DstBB);
  611. auto Weight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
  612. if (!Weight)
  613. return None;
  614. if (!MaxWeight || MaxWeight.getValue() < Weight.getValue())
  615. MaxWeight = Weight;
  616. }
  617. return MaxWeight;
  618. }
  619. // Updates \p LoopBB's weight and returns true. If \p LoopBB has already
  620. // an associated weight it is unchanged and false is returned.
  621. //
  622. // Please note by the algorithm the weight is not expected to change once set
  623. // thus 'false' status is used to track visited blocks.
  624. bool BranchProbabilityInfo::updateEstimatedBlockWeight(
  625. LoopBlock &LoopBB, uint32_t BBWeight,
  626. SmallVectorImpl<BasicBlock *> &BlockWorkList,
  627. SmallVectorImpl<LoopBlock> &LoopWorkList) {
  628. BasicBlock *BB = LoopBB.getBlock();
  629. // In general, weight is assigned to a block when it has final value and
  630. // can't/shouldn't be changed. However, there are cases when a block
  631. // inherently has several (possibly "contradicting") weights. For example,
  632. // "unwind" block may also contain "cold" call. In that case the first
  633. // set weight is favored and all consequent weights are ignored.
  634. if (!EstimatedBlockWeight.insert({BB, BBWeight}).second)
  635. return false;
  636. for (BasicBlock *PredBlock : predecessors(BB)) {
  637. LoopBlock PredLoop = getLoopBlock(PredBlock);
  638. // Add affected block/loop to a working list.
  639. if (isLoopExitingEdge({PredLoop, LoopBB})) {
  640. if (!EstimatedLoopWeight.count(PredLoop.getLoopData()))
  641. LoopWorkList.push_back(PredLoop);
  642. } else if (!EstimatedBlockWeight.count(PredBlock))
  643. BlockWorkList.push_back(PredBlock);
  644. }
  645. return true;
  646. }
  647. // Starting from \p BB traverse through dominator blocks and assign \p BBWeight
  648. // to all such blocks that are post dominated by \BB. In other words to all
  649. // blocks that the one is executed if and only if another one is executed.
  650. // Importantly, we skip loops here for two reasons. First weights of blocks in
  651. // a loop should be scaled by trip count (yet possibly unknown). Second there is
  652. // no any value in doing that because that doesn't give any additional
  653. // information regarding distribution of probabilities inside the loop.
  654. // Exception is loop 'enter' and 'exit' edges that are handled in a special way
  655. // at calcEstimatedHeuristics.
  656. //
  657. // In addition, \p WorkList is populated with basic blocks if at leas one
  658. // successor has updated estimated weight.
  659. void BranchProbabilityInfo::propagateEstimatedBlockWeight(
  660. const LoopBlock &LoopBB, DominatorTree *DT, PostDominatorTree *PDT,
  661. uint32_t BBWeight, SmallVectorImpl<BasicBlock *> &BlockWorkList,
  662. SmallVectorImpl<LoopBlock> &LoopWorkList) {
  663. const BasicBlock *BB = LoopBB.getBlock();
  664. const auto *DTStartNode = DT->getNode(BB);
  665. const auto *PDTStartNode = PDT->getNode(BB);
  666. // TODO: Consider propagating weight down the domination line as well.
  667. for (const auto *DTNode = DTStartNode; DTNode != nullptr;
  668. DTNode = DTNode->getIDom()) {
  669. auto *DomBB = DTNode->getBlock();
  670. // Consider blocks which lie on one 'line'.
  671. if (!PDT->dominates(PDTStartNode, PDT->getNode(DomBB)))
  672. // If BB doesn't post dominate DomBB it will not post dominate dominators
  673. // of DomBB as well.
  674. break;
  675. LoopBlock DomLoopBB = getLoopBlock(DomBB);
  676. const LoopEdge Edge{DomLoopBB, LoopBB};
  677. // Don't propagate weight to blocks belonging to different loops.
  678. if (!isLoopEnteringExitingEdge(Edge)) {
  679. if (!updateEstimatedBlockWeight(DomLoopBB, BBWeight, BlockWorkList,
  680. LoopWorkList))
  681. // If DomBB has weight set then all it's predecessors are already
  682. // processed (since we propagate weight up to the top of IR each time).
  683. break;
  684. } else if (isLoopExitingEdge(Edge)) {
  685. LoopWorkList.push_back(DomLoopBB);
  686. }
  687. }
  688. }
  689. Optional<uint32_t> BranchProbabilityInfo::getInitialEstimatedBlockWeight(
  690. const BasicBlock *BB) {
  691. // Returns true if \p BB has call marked with "NoReturn" attribute.
  692. auto hasNoReturn = [&](const BasicBlock *BB) {
  693. for (const auto &I : reverse(*BB))
  694. if (const CallInst *CI = dyn_cast<CallInst>(&I))
  695. if (CI->hasFnAttr(Attribute::NoReturn))
  696. return true;
  697. return false;
  698. };
  699. // Important note regarding the order of checks. They are ordered by weight
  700. // from lowest to highest. Doing that allows to avoid "unstable" results
  701. // when several conditions heuristics can be applied simultaneously.
  702. if (isa<UnreachableInst>(BB->getTerminator()) ||
  703. // If this block is terminated by a call to
  704. // @llvm.experimental.deoptimize then treat it like an unreachable
  705. // since it is expected to practically never execute.
  706. // TODO: Should we actually treat as never returning call?
  707. BB->getTerminatingDeoptimizeCall())
  708. return hasNoReturn(BB)
  709. ? static_cast<uint32_t>(BlockExecWeight::NORETURN)
  710. : static_cast<uint32_t>(BlockExecWeight::UNREACHABLE);
  711. // Check if the block is 'unwind' handler of some invoke instruction.
  712. for (const auto *Pred : predecessors(BB))
  713. if (Pred)
  714. if (const auto *II = dyn_cast<InvokeInst>(Pred->getTerminator()))
  715. if (II->getUnwindDest() == BB)
  716. return static_cast<uint32_t>(BlockExecWeight::UNWIND);
  717. // Check if the block contains 'cold' call.
  718. for (const auto &I : *BB)
  719. if (const CallInst *CI = dyn_cast<CallInst>(&I))
  720. if (CI->hasFnAttr(Attribute::Cold))
  721. return static_cast<uint32_t>(BlockExecWeight::COLD);
  722. return None;
  723. }
  724. // Does RPO traversal over all blocks in \p F and assigns weights to
  725. // 'unreachable', 'noreturn', 'cold', 'unwind' blocks. In addition it does its
  726. // best to propagate the weight to up/down the IR.
  727. void BranchProbabilityInfo::computeEestimateBlockWeight(
  728. const Function &F, DominatorTree *DT, PostDominatorTree *PDT) {
  729. SmallVector<BasicBlock *, 8> BlockWorkList;
  730. SmallVector<LoopBlock, 8> LoopWorkList;
  731. // By doing RPO we make sure that all predecessors already have weights
  732. // calculated before visiting theirs successors.
  733. ReversePostOrderTraversal<const Function *> RPOT(&F);
  734. for (const auto *BB : RPOT)
  735. if (auto BBWeight = getInitialEstimatedBlockWeight(BB))
  736. // If we were able to find estimated weight for the block set it to this
  737. // block and propagate up the IR.
  738. propagateEstimatedBlockWeight(getLoopBlock(BB), DT, PDT,
  739. BBWeight.getValue(), BlockWorkList,
  740. LoopWorkList);
  741. // BlockWorklist/LoopWorkList contains blocks/loops with at least one
  742. // successor/exit having estimated weight. Try to propagate weight to such
  743. // blocks/loops from successors/exits.
  744. // Process loops and blocks. Order is not important.
  745. do {
  746. while (!LoopWorkList.empty()) {
  747. const LoopBlock LoopBB = LoopWorkList.pop_back_val();
  748. if (EstimatedLoopWeight.count(LoopBB.getLoopData()))
  749. continue;
  750. SmallVector<BasicBlock *, 4> Exits;
  751. getLoopExitBlocks(LoopBB, Exits);
  752. auto LoopWeight = getMaxEstimatedEdgeWeight(
  753. LoopBB, make_range(Exits.begin(), Exits.end()));
  754. if (LoopWeight) {
  755. // If we never exit the loop then we can enter it once at maximum.
  756. if (LoopWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
  757. LoopWeight = static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO);
  758. EstimatedLoopWeight.insert(
  759. {LoopBB.getLoopData(), LoopWeight.getValue()});
  760. // Add all blocks entering the loop into working list.
  761. getLoopEnterBlocks(LoopBB, BlockWorkList);
  762. }
  763. }
  764. while (!BlockWorkList.empty()) {
  765. // We can reach here only if BlockWorkList is not empty.
  766. const BasicBlock *BB = BlockWorkList.pop_back_val();
  767. if (EstimatedBlockWeight.count(BB))
  768. continue;
  769. // We take maximum over all weights of successors. In other words we take
  770. // weight of "hot" path. In theory we can probably find a better function
  771. // which gives higher accuracy results (comparing to "maximum") but I
  772. // can't
  773. // think of any right now. And I doubt it will make any difference in
  774. // practice.
  775. const LoopBlock LoopBB = getLoopBlock(BB);
  776. auto MaxWeight = getMaxEstimatedEdgeWeight(LoopBB, successors(BB));
  777. if (MaxWeight)
  778. propagateEstimatedBlockWeight(LoopBB, DT, PDT, MaxWeight.getValue(),
  779. BlockWorkList, LoopWorkList);
  780. }
  781. } while (!BlockWorkList.empty() || !LoopWorkList.empty());
  782. }
  783. // Calculate edge probabilities based on block's estimated weight.
  784. // Note that gathered weights were not scaled for loops. Thus edges entering
  785. // and exiting loops requires special processing.
  786. bool BranchProbabilityInfo::calcEstimatedHeuristics(const BasicBlock *BB) {
  787. assert(BB->getTerminator()->getNumSuccessors() > 1 &&
  788. "expected more than one successor!");
  789. const LoopBlock LoopBB = getLoopBlock(BB);
  790. SmallPtrSet<const BasicBlock *, 8> UnlikelyBlocks;
  791. uint32_t TC = LBH_TAKEN_WEIGHT / LBH_NONTAKEN_WEIGHT;
  792. if (LoopBB.getLoop())
  793. computeUnlikelySuccessors(BB, LoopBB.getLoop(), UnlikelyBlocks);
  794. // Changed to 'true' if at least one successor has estimated weight.
  795. bool FoundEstimatedWeight = false;
  796. SmallVector<uint32_t, 4> SuccWeights;
  797. uint64_t TotalWeight = 0;
  798. // Go over all successors of BB and put their weights into SuccWeights.
  799. for (const BasicBlock *SuccBB : successors(BB)) {
  800. Optional<uint32_t> Weight;
  801. const LoopBlock SuccLoopBB = getLoopBlock(SuccBB);
  802. const LoopEdge Edge{LoopBB, SuccLoopBB};
  803. Weight = getEstimatedEdgeWeight(Edge);
  804. if (isLoopExitingEdge(Edge) &&
  805. // Avoid adjustment of ZERO weight since it should remain unchanged.
  806. Weight != static_cast<uint32_t>(BlockExecWeight::ZERO)) {
  807. // Scale down loop exiting weight by trip count.
  808. Weight = std::max(
  809. static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO),
  810. Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) /
  811. TC);
  812. }
  813. bool IsUnlikelyEdge = LoopBB.getLoop() && UnlikelyBlocks.contains(SuccBB);
  814. if (IsUnlikelyEdge &&
  815. // Avoid adjustment of ZERO weight since it should remain unchanged.
  816. Weight != static_cast<uint32_t>(BlockExecWeight::ZERO)) {
  817. // 'Unlikely' blocks have twice lower weight.
  818. Weight = std::max(
  819. static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO),
  820. Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) /
  821. 2);
  822. }
  823. if (Weight)
  824. FoundEstimatedWeight = true;
  825. auto WeightVal =
  826. Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT));
  827. TotalWeight += WeightVal;
  828. SuccWeights.push_back(WeightVal);
  829. }
  830. // If non of blocks have estimated weight bail out.
  831. // If TotalWeight is 0 that means weight of each successor is 0 as well and
  832. // equally likely. Bail out early to not deal with devision by zero.
  833. if (!FoundEstimatedWeight || TotalWeight == 0)
  834. return false;
  835. assert(SuccWeights.size() == succ_size(BB) && "Missed successor?");
  836. const unsigned SuccCount = SuccWeights.size();
  837. // If the sum of weights does not fit in 32 bits, scale every weight down
  838. // accordingly.
  839. if (TotalWeight > UINT32_MAX) {
  840. uint64_t ScalingFactor = TotalWeight / UINT32_MAX + 1;
  841. TotalWeight = 0;
  842. for (unsigned Idx = 0; Idx < SuccCount; ++Idx) {
  843. SuccWeights[Idx] /= ScalingFactor;
  844. if (SuccWeights[Idx] == static_cast<uint32_t>(BlockExecWeight::ZERO))
  845. SuccWeights[Idx] =
  846. static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO);
  847. TotalWeight += SuccWeights[Idx];
  848. }
  849. assert(TotalWeight <= UINT32_MAX && "Total weight overflows");
  850. }
  851. // Finally set probabilities to edges according to estimated block weights.
  852. SmallVector<BranchProbability, 4> EdgeProbabilities(
  853. SuccCount, BranchProbability::getUnknown());
  854. for (unsigned Idx = 0; Idx < SuccCount; ++Idx) {
  855. EdgeProbabilities[Idx] =
  856. BranchProbability(SuccWeights[Idx], (uint32_t)TotalWeight);
  857. }
  858. setEdgeProbability(BB, EdgeProbabilities);
  859. return true;
  860. }
  861. bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
  862. const TargetLibraryInfo *TLI) {
  863. const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  864. if (!BI || !BI->isConditional())
  865. return false;
  866. Value *Cond = BI->getCondition();
  867. ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
  868. if (!CI)
  869. return false;
  870. auto GetConstantInt = [](Value *V) {
  871. if (auto *I = dyn_cast<BitCastInst>(V))
  872. return dyn_cast<ConstantInt>(I->getOperand(0));
  873. return dyn_cast<ConstantInt>(V);
  874. };
  875. Value *RHS = CI->getOperand(1);
  876. ConstantInt *CV = GetConstantInt(RHS);
  877. if (!CV)
  878. return false;
  879. // If the LHS is the result of AND'ing a value with a single bit bitmask,
  880. // we don't have information about probabilities.
  881. if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
  882. if (LHS->getOpcode() == Instruction::And)
  883. if (ConstantInt *AndRHS = GetConstantInt(LHS->getOperand(1)))
  884. if (AndRHS->getValue().isPowerOf2())
  885. return false;
  886. // Check if the LHS is the return value of a library function
  887. LibFunc Func = NumLibFuncs;
  888. if (TLI)
  889. if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
  890. if (Function *CalledFn = Call->getCalledFunction())
  891. TLI->getLibFunc(*CalledFn, Func);
  892. ProbabilityTable::const_iterator Search;
  893. if (Func == LibFunc_strcasecmp ||
  894. Func == LibFunc_strcmp ||
  895. Func == LibFunc_strncasecmp ||
  896. Func == LibFunc_strncmp ||
  897. Func == LibFunc_memcmp ||
  898. Func == LibFunc_bcmp) {
  899. Search = ICmpWithLibCallTable.find(CI->getPredicate());
  900. if (Search == ICmpWithLibCallTable.end())
  901. return false;
  902. } else if (CV->isZero()) {
  903. Search = ICmpWithZeroTable.find(CI->getPredicate());
  904. if (Search == ICmpWithZeroTable.end())
  905. return false;
  906. } else if (CV->isOne()) {
  907. Search = ICmpWithOneTable.find(CI->getPredicate());
  908. if (Search == ICmpWithOneTable.end())
  909. return false;
  910. } else if (CV->isMinusOne()) {
  911. Search = ICmpWithMinusOneTable.find(CI->getPredicate());
  912. if (Search == ICmpWithMinusOneTable.end())
  913. return false;
  914. } else {
  915. return false;
  916. }
  917. setEdgeProbability(BB, Search->second);
  918. return true;
  919. }
  920. bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
  921. const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  922. if (!BI || !BI->isConditional())
  923. return false;
  924. Value *Cond = BI->getCondition();
  925. FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
  926. if (!FCmp)
  927. return false;
  928. ProbabilityList ProbList;
  929. if (FCmp->isEquality()) {
  930. ProbList = !FCmp->isTrueWhenEqual() ?
  931. // f1 == f2 -> Unlikely
  932. ProbabilityList({FPTakenProb, FPUntakenProb}) :
  933. // f1 != f2 -> Likely
  934. ProbabilityList({FPUntakenProb, FPTakenProb});
  935. } else {
  936. auto Search = FCmpTable.find(FCmp->getPredicate());
  937. if (Search == FCmpTable.end())
  938. return false;
  939. ProbList = Search->second;
  940. }
  941. setEdgeProbability(BB, ProbList);
  942. return true;
  943. }
  944. void BranchProbabilityInfo::releaseMemory() {
  945. Probs.clear();
  946. Handles.clear();
  947. }
  948. bool BranchProbabilityInfo::invalidate(Function &, const PreservedAnalyses &PA,
  949. FunctionAnalysisManager::Invalidator &) {
  950. // Check whether the analysis, all analyses on functions, or the function's
  951. // CFG have been preserved.
  952. auto PAC = PA.getChecker<BranchProbabilityAnalysis>();
  953. return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
  954. PAC.preservedSet<CFGAnalyses>());
  955. }
  956. void BranchProbabilityInfo::print(raw_ostream &OS) const {
  957. OS << "---- Branch Probabilities ----\n";
  958. // We print the probabilities from the last function the analysis ran over,
  959. // or the function it is currently running over.
  960. assert(LastF && "Cannot print prior to running over a function");
  961. for (const auto &BI : *LastF) {
  962. for (const BasicBlock *Succ : successors(&BI))
  963. printEdgeProbability(OS << " ", &BI, Succ);
  964. }
  965. }
  966. bool BranchProbabilityInfo::
  967. isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
  968. // Hot probability is at least 4/5 = 80%
  969. // FIXME: Compare against a static "hot" BranchProbability.
  970. return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
  971. }
  972. /// Get the raw edge probability for the edge. If can't find it, return a
  973. /// default probability 1/N where N is the number of successors. Here an edge is
  974. /// specified using PredBlock and an
  975. /// index to the successors.
  976. BranchProbability
  977. BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
  978. unsigned IndexInSuccessors) const {
  979. auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
  980. assert((Probs.end() == Probs.find(std::make_pair(Src, 0))) ==
  981. (Probs.end() == I) &&
  982. "Probability for I-th successor must always be defined along with the "
  983. "probability for the first successor");
  984. if (I != Probs.end())
  985. return I->second;
  986. return {1, static_cast<uint32_t>(succ_size(Src))};
  987. }
  988. BranchProbability
  989. BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
  990. const_succ_iterator Dst) const {
  991. return getEdgeProbability(Src, Dst.getSuccessorIndex());
  992. }
  993. /// Get the raw edge probability calculated for the block pair. This returns the
  994. /// sum of all raw edge probabilities from Src to Dst.
  995. BranchProbability
  996. BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
  997. const BasicBlock *Dst) const {
  998. if (!Probs.count(std::make_pair(Src, 0)))
  999. return BranchProbability(llvm::count(successors(Src), Dst), succ_size(Src));
  1000. auto Prob = BranchProbability::getZero();
  1001. for (const_succ_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
  1002. if (*I == Dst)
  1003. Prob += Probs.find(std::make_pair(Src, I.getSuccessorIndex()))->second;
  1004. return Prob;
  1005. }
  1006. /// Set the edge probability for all edges at once.
  1007. void BranchProbabilityInfo::setEdgeProbability(
  1008. const BasicBlock *Src, const SmallVectorImpl<BranchProbability> &Probs) {
  1009. assert(Src->getTerminator()->getNumSuccessors() == Probs.size());
  1010. eraseBlock(Src); // Erase stale data if any.
  1011. if (Probs.size() == 0)
  1012. return; // Nothing to set.
  1013. Handles.insert(BasicBlockCallbackVH(Src, this));
  1014. uint64_t TotalNumerator = 0;
  1015. for (unsigned SuccIdx = 0; SuccIdx < Probs.size(); ++SuccIdx) {
  1016. this->Probs[std::make_pair(Src, SuccIdx)] = Probs[SuccIdx];
  1017. LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> " << SuccIdx
  1018. << " successor probability to " << Probs[SuccIdx]
  1019. << "\n");
  1020. TotalNumerator += Probs[SuccIdx].getNumerator();
  1021. }
  1022. // Because of rounding errors the total probability cannot be checked to be
  1023. // 1.0 exactly. That is TotalNumerator == BranchProbability::getDenominator.
  1024. // Instead, every single probability in Probs must be as accurate as possible.
  1025. // This results in error 1/denominator at most, thus the total absolute error
  1026. // should be within Probs.size / BranchProbability::getDenominator.
  1027. assert(TotalNumerator <= BranchProbability::getDenominator() + Probs.size());
  1028. assert(TotalNumerator >= BranchProbability::getDenominator() - Probs.size());
  1029. (void)TotalNumerator;
  1030. }
  1031. void BranchProbabilityInfo::copyEdgeProbabilities(BasicBlock *Src,
  1032. BasicBlock *Dst) {
  1033. eraseBlock(Dst); // Erase stale data if any.
  1034. unsigned NumSuccessors = Src->getTerminator()->getNumSuccessors();
  1035. assert(NumSuccessors == Dst->getTerminator()->getNumSuccessors());
  1036. if (NumSuccessors == 0)
  1037. return; // Nothing to set.
  1038. if (this->Probs.find(std::make_pair(Src, 0)) == this->Probs.end())
  1039. return; // No probability is set for edges from Src. Keep the same for Dst.
  1040. Handles.insert(BasicBlockCallbackVH(Dst, this));
  1041. for (unsigned SuccIdx = 0; SuccIdx < NumSuccessors; ++SuccIdx) {
  1042. auto Prob = this->Probs[std::make_pair(Src, SuccIdx)];
  1043. this->Probs[std::make_pair(Dst, SuccIdx)] = Prob;
  1044. LLVM_DEBUG(dbgs() << "set edge " << Dst->getName() << " -> " << SuccIdx
  1045. << " successor probability to " << Prob << "\n");
  1046. }
  1047. }
  1048. raw_ostream &
  1049. BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
  1050. const BasicBlock *Src,
  1051. const BasicBlock *Dst) const {
  1052. const BranchProbability Prob = getEdgeProbability(Src, Dst);
  1053. OS << "edge " << Src->getName() << " -> " << Dst->getName()
  1054. << " probability is " << Prob
  1055. << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
  1056. return OS;
  1057. }
  1058. void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
  1059. LLVM_DEBUG(dbgs() << "eraseBlock " << BB->getName() << "\n");
  1060. // Note that we cannot use successors of BB because the terminator of BB may
  1061. // have changed when eraseBlock is called as a BasicBlockCallbackVH callback.
  1062. // Instead we remove prob data for the block by iterating successors by their
  1063. // indices from 0 till the last which exists. There could not be prob data for
  1064. // a pair (BB, N) if there is no data for (BB, N-1) because the data is always
  1065. // set for all successors from 0 to M at once by the method
  1066. // setEdgeProbability().
  1067. Handles.erase(BasicBlockCallbackVH(BB, this));
  1068. for (unsigned I = 0;; ++I) {
  1069. auto MapI = Probs.find(std::make_pair(BB, I));
  1070. if (MapI == Probs.end()) {
  1071. assert(Probs.count(std::make_pair(BB, I + 1)) == 0 &&
  1072. "Must be no more successors");
  1073. return;
  1074. }
  1075. Probs.erase(MapI);
  1076. }
  1077. }
  1078. void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LoopI,
  1079. const TargetLibraryInfo *TLI,
  1080. DominatorTree *DT,
  1081. PostDominatorTree *PDT) {
  1082. LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
  1083. << " ----\n\n");
  1084. LastF = &F; // Store the last function we ran on for printing.
  1085. LI = &LoopI;
  1086. SccI = std::make_unique<SccInfo>(F);
  1087. assert(EstimatedBlockWeight.empty());
  1088. assert(EstimatedLoopWeight.empty());
  1089. std::unique_ptr<DominatorTree> DTPtr;
  1090. std::unique_ptr<PostDominatorTree> PDTPtr;
  1091. if (!DT) {
  1092. DTPtr = std::make_unique<DominatorTree>(const_cast<Function &>(F));
  1093. DT = DTPtr.get();
  1094. }
  1095. if (!PDT) {
  1096. PDTPtr = std::make_unique<PostDominatorTree>(const_cast<Function &>(F));
  1097. PDT = PDTPtr.get();
  1098. }
  1099. computeEestimateBlockWeight(F, DT, PDT);
  1100. // Walk the basic blocks in post-order so that we can build up state about
  1101. // the successors of a block iteratively.
  1102. for (auto BB : post_order(&F.getEntryBlock())) {
  1103. LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
  1104. << "\n");
  1105. // If there is no at least two successors, no sense to set probability.
  1106. if (BB->getTerminator()->getNumSuccessors() < 2)
  1107. continue;
  1108. if (calcMetadataWeights(BB))
  1109. continue;
  1110. if (calcEstimatedHeuristics(BB))
  1111. continue;
  1112. if (calcPointerHeuristics(BB))
  1113. continue;
  1114. if (calcZeroHeuristics(BB, TLI))
  1115. continue;
  1116. if (calcFloatingPointHeuristics(BB))
  1117. continue;
  1118. }
  1119. EstimatedLoopWeight.clear();
  1120. EstimatedBlockWeight.clear();
  1121. SccI.reset();
  1122. if (PrintBranchProb &&
  1123. (PrintBranchProbFuncName.empty() ||
  1124. F.getName().equals(PrintBranchProbFuncName))) {
  1125. print(dbgs());
  1126. }
  1127. }
  1128. void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
  1129. AnalysisUsage &AU) const {
  1130. // We require DT so it's available when LI is available. The LI updating code
  1131. // asserts that DT is also present so if we don't make sure that we have DT
  1132. // here, that assert will trigger.
  1133. AU.addRequired<DominatorTreeWrapperPass>();
  1134. AU.addRequired<LoopInfoWrapperPass>();
  1135. AU.addRequired<TargetLibraryInfoWrapperPass>();
  1136. AU.addRequired<DominatorTreeWrapperPass>();
  1137. AU.addRequired<PostDominatorTreeWrapperPass>();
  1138. AU.setPreservesAll();
  1139. }
  1140. bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
  1141. const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  1142. const TargetLibraryInfo &TLI =
  1143. getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  1144. DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  1145. PostDominatorTree &PDT =
  1146. getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
  1147. BPI.calculate(F, LI, &TLI, &DT, &PDT);
  1148. return false;
  1149. }
  1150. void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
  1151. void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
  1152. const Module *) const {
  1153. BPI.print(OS);
  1154. }
  1155. AnalysisKey BranchProbabilityAnalysis::Key;
  1156. BranchProbabilityInfo
  1157. BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
  1158. BranchProbabilityInfo BPI;
  1159. BPI.calculate(F, AM.getResult<LoopAnalysis>(F),
  1160. &AM.getResult<TargetLibraryAnalysis>(F),
  1161. &AM.getResult<DominatorTreeAnalysis>(F),
  1162. &AM.getResult<PostDominatorTreeAnalysis>(F));
  1163. return BPI;
  1164. }
  1165. PreservedAnalyses
  1166. BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
  1167. OS << "Printing analysis results of BPI for function "
  1168. << "'" << F.getName() << "':"
  1169. << "\n";
  1170. AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
  1171. return PreservedAnalyses::all();
  1172. }