LowerExpectIntrinsic.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. //===- LowerExpectIntrinsic.cpp - Lower expect intrinsic ------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass lowers the 'expect' intrinsic to LLVM metadata.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h"
  13. #include "llvm/ADT/SmallVector.h"
  14. #include "llvm/ADT/Statistic.h"
  15. #include "llvm/ADT/iterator_range.h"
  16. #include "llvm/IR/BasicBlock.h"
  17. #include "llvm/IR/Constants.h"
  18. #include "llvm/IR/Function.h"
  19. #include "llvm/IR/Instructions.h"
  20. #include "llvm/IR/Intrinsics.h"
  21. #include "llvm/IR/LLVMContext.h"
  22. #include "llvm/IR/MDBuilder.h"
  23. #include "llvm/InitializePasses.h"
  24. #include "llvm/Pass.h"
  25. #include "llvm/Support/CommandLine.h"
  26. #include "llvm/Transforms/Scalar.h"
  27. #include "llvm/Transforms/Utils/MisExpect.h"
  28. #include <cmath>
  29. using namespace llvm;
  30. #define DEBUG_TYPE "lower-expect-intrinsic"
  31. STATISTIC(ExpectIntrinsicsHandled,
  32. "Number of 'expect' intrinsic instructions handled");
  33. // These default values are chosen to represent an extremely skewed outcome for
  34. // a condition, but they leave some room for interpretation by later passes.
  35. //
  36. // If the documentation for __builtin_expect() was made explicit that it should
  37. // only be used in extreme cases, we could make this ratio higher. As it stands,
  38. // programmers may be using __builtin_expect() / llvm.expect to annotate that a
  39. // branch is likely or unlikely to be taken.
  40. // WARNING: these values are internal implementation detail of the pass.
  41. // They should not be exposed to the outside of the pass, front-end codegen
  42. // should emit @llvm.expect intrinsics instead of using these weights directly.
  43. // Transforms should use TargetTransformInfo's getPredictableBranchThreshold().
  44. static cl::opt<uint32_t> LikelyBranchWeight(
  45. "likely-branch-weight", cl::Hidden, cl::init(2000),
  46. cl::desc("Weight of the branch likely to be taken (default = 2000)"));
  47. static cl::opt<uint32_t> UnlikelyBranchWeight(
  48. "unlikely-branch-weight", cl::Hidden, cl::init(1),
  49. cl::desc("Weight of the branch unlikely to be taken (default = 1)"));
  50. static std::tuple<uint32_t, uint32_t>
  51. getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) {
  52. if (IntrinsicID == Intrinsic::expect) {
  53. // __builtin_expect
  54. return std::make_tuple(LikelyBranchWeight.getValue(),
  55. UnlikelyBranchWeight.getValue());
  56. } else {
  57. // __builtin_expect_with_probability
  58. assert(CI->getNumOperands() >= 3 &&
  59. "expect with probability must have 3 arguments");
  60. auto *Confidence = cast<ConstantFP>(CI->getArgOperand(2));
  61. double TrueProb = Confidence->getValueAPF().convertToDouble();
  62. assert((TrueProb >= 0.0 && TrueProb <= 1.0) &&
  63. "probability value must be in the range [0.0, 1.0]");
  64. double FalseProb = (1.0 - TrueProb) / (BranchCount - 1);
  65. uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0);
  66. uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0);
  67. return std::make_tuple(LikelyBW, UnlikelyBW);
  68. }
  69. }
  70. static bool handleSwitchExpect(SwitchInst &SI) {
  71. CallInst *CI = dyn_cast<CallInst>(SI.getCondition());
  72. if (!CI)
  73. return false;
  74. Function *Fn = CI->getCalledFunction();
  75. if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
  76. Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
  77. return false;
  78. Value *ArgValue = CI->getArgOperand(0);
  79. ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
  80. if (!ExpectedValue)
  81. return false;
  82. SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);
  83. unsigned n = SI.getNumCases(); // +1 for default case.
  84. uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
  85. std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
  86. getBranchWeight(Fn->getIntrinsicID(), CI, n + 1);
  87. SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);
  88. uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1;
  89. Weights[Index] = LikelyBranchWeightVal;
  90. misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
  91. SI.setCondition(ArgValue);
  92. SI.setMetadata(LLVMContext::MD_prof,
  93. MDBuilder(CI->getContext()).createBranchWeights(Weights));
  94. return true;
  95. }
  96. /// Handler for PHINodes that define the value argument to an
  97. /// @llvm.expect call.
  98. ///
  99. /// If the operand of the phi has a constant value and it 'contradicts'
  100. /// with the expected value of phi def, then the corresponding incoming
  101. /// edge of the phi is unlikely to be taken. Using that information,
  102. /// the branch probability info for the originating branch can be inferred.
  103. static void handlePhiDef(CallInst *Expect) {
  104. Value &Arg = *Expect->getArgOperand(0);
  105. ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(Expect->getArgOperand(1));
  106. if (!ExpectedValue)
  107. return;
  108. const APInt &ExpectedPhiValue = ExpectedValue->getValue();
  109. bool ExpectedValueIsLikely = true;
  110. Function *Fn = Expect->getCalledFunction();
  111. // If the function is expect_with_probability, then we need to take the
  112. // probability into consideration. For example, in
  113. // expect.with.probability.i64(i64 %a, i64 1, double 0.0), the
  114. // "ExpectedValue" 1 is unlikely. This affects probability propagation later.
  115. if (Fn->getIntrinsicID() == Intrinsic::expect_with_probability) {
  116. auto *Confidence = cast<ConstantFP>(Expect->getArgOperand(2));
  117. double TrueProb = Confidence->getValueAPF().convertToDouble();
  118. ExpectedValueIsLikely = (TrueProb > 0.5);
  119. }
  120. // Walk up in backward a list of instructions that
  121. // have 'copy' semantics by 'stripping' the copies
  122. // until a PHI node or an instruction of unknown kind
  123. // is reached. Negation via xor is also handled.
  124. //
  125. // C = PHI(...);
  126. // B = C;
  127. // A = B;
  128. // D = __builtin_expect(A, 0);
  129. //
  130. Value *V = &Arg;
  131. SmallVector<Instruction *, 4> Operations;
  132. while (!isa<PHINode>(V)) {
  133. if (ZExtInst *ZExt = dyn_cast<ZExtInst>(V)) {
  134. V = ZExt->getOperand(0);
  135. Operations.push_back(ZExt);
  136. continue;
  137. }
  138. if (SExtInst *SExt = dyn_cast<SExtInst>(V)) {
  139. V = SExt->getOperand(0);
  140. Operations.push_back(SExt);
  141. continue;
  142. }
  143. BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V);
  144. if (!BinOp || BinOp->getOpcode() != Instruction::Xor)
  145. return;
  146. ConstantInt *CInt = dyn_cast<ConstantInt>(BinOp->getOperand(1));
  147. if (!CInt)
  148. return;
  149. V = BinOp->getOperand(0);
  150. Operations.push_back(BinOp);
  151. }
  152. // Executes the recorded operations on input 'Value'.
  153. auto ApplyOperations = [&](const APInt &Value) {
  154. APInt Result = Value;
  155. for (auto *Op : llvm::reverse(Operations)) {
  156. switch (Op->getOpcode()) {
  157. case Instruction::Xor:
  158. Result ^= cast<ConstantInt>(Op->getOperand(1))->getValue();
  159. break;
  160. case Instruction::ZExt:
  161. Result = Result.zext(Op->getType()->getIntegerBitWidth());
  162. break;
  163. case Instruction::SExt:
  164. Result = Result.sext(Op->getType()->getIntegerBitWidth());
  165. break;
  166. default:
  167. llvm_unreachable("Unexpected operation");
  168. }
  169. }
  170. return Result;
  171. };
  172. auto *PhiDef = cast<PHINode>(V);
  173. // Get the first dominating conditional branch of the operand
  174. // i's incoming block.
  175. auto GetDomConditional = [&](unsigned i) -> BranchInst * {
  176. BasicBlock *BB = PhiDef->getIncomingBlock(i);
  177. BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  178. if (BI && BI->isConditional())
  179. return BI;
  180. BB = BB->getSinglePredecessor();
  181. if (!BB)
  182. return nullptr;
  183. BI = dyn_cast<BranchInst>(BB->getTerminator());
  184. if (!BI || BI->isUnconditional())
  185. return nullptr;
  186. return BI;
  187. };
  188. // Now walk through all Phi operands to find phi oprerands with values
  189. // conflicting with the expected phi output value. Any such operand
  190. // indicates the incoming edge to that operand is unlikely.
  191. for (unsigned i = 0, e = PhiDef->getNumIncomingValues(); i != e; ++i) {
  192. Value *PhiOpnd = PhiDef->getIncomingValue(i);
  193. ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd);
  194. if (!CI)
  195. continue;
  196. // Not an interesting case when IsUnlikely is false -- we can not infer
  197. // anything useful when:
  198. // (1) We expect some phi output and the operand value matches it, or
  199. // (2) We don't expect some phi output (i.e. the "ExpectedValue" has low
  200. // probability) and the operand value doesn't match that.
  201. const APInt &CurrentPhiValue = ApplyOperations(CI->getValue());
  202. if (ExpectedValueIsLikely == (ExpectedPhiValue == CurrentPhiValue))
  203. continue;
  204. BranchInst *BI = GetDomConditional(i);
  205. if (!BI)
  206. continue;
  207. MDBuilder MDB(PhiDef->getContext());
  208. // There are two situations in which an operand of the PhiDef comes
  209. // from a given successor of a branch instruction BI.
  210. // 1) When the incoming block of the operand is the successor block;
  211. // 2) When the incoming block is BI's enclosing block and the
  212. // successor is the PhiDef's enclosing block.
  213. //
  214. // Returns true if the operand which comes from OpndIncomingBB
  215. // comes from outgoing edge of BI that leads to Succ block.
  216. auto *OpndIncomingBB = PhiDef->getIncomingBlock(i);
  217. auto IsOpndComingFromSuccessor = [&](BasicBlock *Succ) {
  218. if (OpndIncomingBB == Succ)
  219. // If this successor is the incoming block for this
  220. // Phi operand, then this successor does lead to the Phi.
  221. return true;
  222. if (OpndIncomingBB == BI->getParent() && Succ == PhiDef->getParent())
  223. // Otherwise, if the edge is directly from the branch
  224. // to the Phi, this successor is the one feeding this
  225. // Phi operand.
  226. return true;
  227. return false;
  228. };
  229. uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
  230. std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight(
  231. Expect->getCalledFunction()->getIntrinsicID(), Expect, 2);
  232. if (!ExpectedValueIsLikely)
  233. std::swap(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
  234. if (IsOpndComingFromSuccessor(BI->getSuccessor(1)))
  235. BI->setMetadata(LLVMContext::MD_prof,
  236. MDB.createBranchWeights(LikelyBranchWeightVal,
  237. UnlikelyBranchWeightVal));
  238. else if (IsOpndComingFromSuccessor(BI->getSuccessor(0)))
  239. BI->setMetadata(LLVMContext::MD_prof,
  240. MDB.createBranchWeights(UnlikelyBranchWeightVal,
  241. LikelyBranchWeightVal));
  242. }
  243. }
  244. // Handle both BranchInst and SelectInst.
  245. template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
  246. // Handle non-optimized IR code like:
  247. // %expval = call i64 @llvm.expect.i64(i64 %conv1, i64 1)
  248. // %tobool = icmp ne i64 %expval, 0
  249. // br i1 %tobool, label %if.then, label %if.end
  250. //
  251. // Or the following simpler case:
  252. // %expval = call i1 @llvm.expect.i1(i1 %cmp, i1 1)
  253. // br i1 %expval, label %if.then, label %if.end
  254. CallInst *CI;
  255. ICmpInst *CmpI = dyn_cast<ICmpInst>(BSI.getCondition());
  256. CmpInst::Predicate Predicate;
  257. ConstantInt *CmpConstOperand = nullptr;
  258. if (!CmpI) {
  259. CI = dyn_cast<CallInst>(BSI.getCondition());
  260. Predicate = CmpInst::ICMP_NE;
  261. } else {
  262. Predicate = CmpI->getPredicate();
  263. if (Predicate != CmpInst::ICMP_NE && Predicate != CmpInst::ICMP_EQ)
  264. return false;
  265. CmpConstOperand = dyn_cast<ConstantInt>(CmpI->getOperand(1));
  266. if (!CmpConstOperand)
  267. return false;
  268. CI = dyn_cast<CallInst>(CmpI->getOperand(0));
  269. }
  270. if (!CI)
  271. return false;
  272. uint64_t ValueComparedTo = 0;
  273. if (CmpConstOperand) {
  274. if (CmpConstOperand->getBitWidth() > 64)
  275. return false;
  276. ValueComparedTo = CmpConstOperand->getZExtValue();
  277. }
  278. Function *Fn = CI->getCalledFunction();
  279. if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
  280. Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
  281. return false;
  282. Value *ArgValue = CI->getArgOperand(0);
  283. ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
  284. if (!ExpectedValue)
  285. return false;
  286. MDBuilder MDB(CI->getContext());
  287. MDNode *Node;
  288. uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
  289. std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
  290. getBranchWeight(Fn->getIntrinsicID(), CI, 2);
  291. SmallVector<uint32_t, 4> ExpectedWeights;
  292. if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==
  293. (Predicate == CmpInst::ICMP_EQ)) {
  294. Node =
  295. MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
  296. ExpectedWeights = {LikelyBranchWeightVal, UnlikelyBranchWeightVal};
  297. } else {
  298. Node =
  299. MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal);
  300. ExpectedWeights = {UnlikelyBranchWeightVal, LikelyBranchWeightVal};
  301. }
  302. if (CmpI)
  303. CmpI->setOperand(0, ArgValue);
  304. else
  305. BSI.setCondition(ArgValue);
  306. misexpect::checkFrontendInstrumentation(BSI, ExpectedWeights);
  307. BSI.setMetadata(LLVMContext::MD_prof, Node);
  308. return true;
  309. }
  310. static bool handleBranchExpect(BranchInst &BI) {
  311. if (BI.isUnconditional())
  312. return false;
  313. return handleBrSelExpect<BranchInst>(BI);
  314. }
  315. static bool lowerExpectIntrinsic(Function &F) {
  316. bool Changed = false;
  317. for (BasicBlock &BB : F) {
  318. // Create "block_weights" metadata.
  319. if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) {
  320. if (handleBranchExpect(*BI))
  321. ExpectIntrinsicsHandled++;
  322. } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator())) {
  323. if (handleSwitchExpect(*SI))
  324. ExpectIntrinsicsHandled++;
  325. }
  326. // Remove llvm.expect intrinsics. Iterate backwards in order
  327. // to process select instructions before the intrinsic gets
  328. // removed.
  329. for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(BB))) {
  330. CallInst *CI = dyn_cast<CallInst>(&Inst);
  331. if (!CI) {
  332. if (SelectInst *SI = dyn_cast<SelectInst>(&Inst)) {
  333. if (handleBrSelExpect(*SI))
  334. ExpectIntrinsicsHandled++;
  335. }
  336. continue;
  337. }
  338. Function *Fn = CI->getCalledFunction();
  339. if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect ||
  340. Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) {
  341. // Before erasing the llvm.expect, walk backward to find
  342. // phi that define llvm.expect's first arg, and
  343. // infer branch probability:
  344. handlePhiDef(CI);
  345. Value *Exp = CI->getArgOperand(0);
  346. CI->replaceAllUsesWith(Exp);
  347. CI->eraseFromParent();
  348. Changed = true;
  349. }
  350. }
  351. }
  352. return Changed;
  353. }
  354. PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F,
  355. FunctionAnalysisManager &) {
  356. if (lowerExpectIntrinsic(F))
  357. return PreservedAnalyses::none();
  358. return PreservedAnalyses::all();
  359. }
  360. namespace {
  361. /// Legacy pass for lowering expect intrinsics out of the IR.
  362. ///
  363. /// When this pass is run over a function it uses expect intrinsics which feed
  364. /// branches and switches to provide branch weight metadata for those
  365. /// terminators. It then removes the expect intrinsics from the IR so the rest
  366. /// of the optimizer can ignore them.
  367. class LowerExpectIntrinsic : public FunctionPass {
  368. public:
  369. static char ID;
  370. LowerExpectIntrinsic() : FunctionPass(ID) {
  371. initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry());
  372. }
  373. bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); }
  374. };
  375. } // namespace
  376. char LowerExpectIntrinsic::ID = 0;
  377. INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect",
  378. "Lower 'expect' Intrinsics", false, false)
  379. FunctionPass *llvm::createLowerExpectIntrinsicPass() {
  380. return new LowerExpectIntrinsic();
  381. }