LowerExpectIntrinsic.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. //===- LowerExpectIntrinsic.cpp - Lower expect intrinsic ------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass lowers the 'expect' intrinsic to LLVM metadata.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h"
  13. #include "llvm/ADT/SmallVector.h"
  14. #include "llvm/ADT/Statistic.h"
  15. #include "llvm/ADT/iterator_range.h"
  16. #include "llvm/IR/BasicBlock.h"
  17. #include "llvm/IR/Constants.h"
  18. #include "llvm/IR/Function.h"
  19. #include "llvm/IR/Instructions.h"
  20. #include "llvm/IR/Intrinsics.h"
  21. #include "llvm/IR/LLVMContext.h"
  22. #include "llvm/IR/MDBuilder.h"
  23. #include "llvm/IR/Metadata.h"
  24. #include "llvm/InitializePasses.h"
  25. #include "llvm/Pass.h"
  26. #include "llvm/Support/CommandLine.h"
  27. #include "llvm/Support/Debug.h"
  28. #include "llvm/Transforms/Scalar.h"
  29. using namespace llvm;
  30. #define DEBUG_TYPE "lower-expect-intrinsic"
  31. STATISTIC(ExpectIntrinsicsHandled,
  32. "Number of 'expect' intrinsic instructions handled");
  33. // These default values are chosen to represent an extremely skewed outcome for
  34. // a condition, but they leave some room for interpretation by later passes.
  35. //
  36. // If the documentation for __builtin_expect() was made explicit that it should
  37. // only be used in extreme cases, we could make this ratio higher. As it stands,
  38. // programmers may be using __builtin_expect() / llvm.expect to annotate that a
  39. // branch is likely or unlikely to be taken.
  40. // WARNING: these values are internal implementation detail of the pass.
  41. // They should not be exposed to the outside of the pass, front-end codegen
  42. // should emit @llvm.expect intrinsics instead of using these weights directly.
  43. // Transforms should use TargetTransformInfo's getPredictableBranchThreshold().
  44. static cl::opt<uint32_t> LikelyBranchWeight(
  45. "likely-branch-weight", cl::Hidden, cl::init(2000),
  46. cl::desc("Weight of the branch likely to be taken (default = 2000)"));
  47. static cl::opt<uint32_t> UnlikelyBranchWeight(
  48. "unlikely-branch-weight", cl::Hidden, cl::init(1),
  49. cl::desc("Weight of the branch unlikely to be taken (default = 1)"));
  50. static std::tuple<uint32_t, uint32_t>
  51. getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) {
  52. if (IntrinsicID == Intrinsic::expect) {
  53. // __builtin_expect
  54. return std::make_tuple(LikelyBranchWeight.getValue(),
  55. UnlikelyBranchWeight.getValue());
  56. } else {
  57. // __builtin_expect_with_probability
  58. assert(CI->getNumOperands() >= 3 &&
  59. "expect with probability must have 3 arguments");
  60. auto *Confidence = cast<ConstantFP>(CI->getArgOperand(2));
  61. double TrueProb = Confidence->getValueAPF().convertToDouble();
  62. assert((TrueProb >= 0.0 && TrueProb <= 1.0) &&
  63. "probability value must be in the range [0.0, 1.0]");
  64. double FalseProb = (1.0 - TrueProb) / (BranchCount - 1);
  65. uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0);
  66. uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0);
  67. return std::make_tuple(LikelyBW, UnlikelyBW);
  68. }
  69. }
  70. static bool handleSwitchExpect(SwitchInst &SI) {
  71. CallInst *CI = dyn_cast<CallInst>(SI.getCondition());
  72. if (!CI)
  73. return false;
  74. Function *Fn = CI->getCalledFunction();
  75. if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
  76. Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
  77. return false;
  78. Value *ArgValue = CI->getArgOperand(0);
  79. ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
  80. if (!ExpectedValue)
  81. return false;
  82. SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);
  83. unsigned n = SI.getNumCases(); // +1 for default case.
  84. uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
  85. std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
  86. getBranchWeight(Fn->getIntrinsicID(), CI, n + 1);
  87. SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);
  88. uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1;
  89. Weights[Index] = LikelyBranchWeightVal;
  90. SI.setCondition(ArgValue);
  91. SI.setMetadata(LLVMContext::MD_prof,
  92. MDBuilder(CI->getContext()).createBranchWeights(Weights));
  93. return true;
  94. }
  95. /// Handler for PHINodes that define the value argument to an
  96. /// @llvm.expect call.
  97. ///
  98. /// If the operand of the phi has a constant value and it 'contradicts'
  99. /// with the expected value of phi def, then the corresponding incoming
  100. /// edge of the phi is unlikely to be taken. Using that information,
  101. /// the branch probability info for the originating branch can be inferred.
  102. static void handlePhiDef(CallInst *Expect) {
  103. Value &Arg = *Expect->getArgOperand(0);
  104. ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(Expect->getArgOperand(1));
  105. if (!ExpectedValue)
  106. return;
  107. const APInt &ExpectedPhiValue = ExpectedValue->getValue();
  108. // Walk up in backward a list of instructions that
  109. // have 'copy' semantics by 'stripping' the copies
  110. // until a PHI node or an instruction of unknown kind
  111. // is reached. Negation via xor is also handled.
  112. //
  113. // C = PHI(...);
  114. // B = C;
  115. // A = B;
  116. // D = __builtin_expect(A, 0);
  117. //
  118. Value *V = &Arg;
  119. SmallVector<Instruction *, 4> Operations;
  120. while (!isa<PHINode>(V)) {
  121. if (ZExtInst *ZExt = dyn_cast<ZExtInst>(V)) {
  122. V = ZExt->getOperand(0);
  123. Operations.push_back(ZExt);
  124. continue;
  125. }
  126. if (SExtInst *SExt = dyn_cast<SExtInst>(V)) {
  127. V = SExt->getOperand(0);
  128. Operations.push_back(SExt);
  129. continue;
  130. }
  131. BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V);
  132. if (!BinOp || BinOp->getOpcode() != Instruction::Xor)
  133. return;
  134. ConstantInt *CInt = dyn_cast<ConstantInt>(BinOp->getOperand(1));
  135. if (!CInt)
  136. return;
  137. V = BinOp->getOperand(0);
  138. Operations.push_back(BinOp);
  139. }
  140. // Executes the recorded operations on input 'Value'.
  141. auto ApplyOperations = [&](const APInt &Value) {
  142. APInt Result = Value;
  143. for (auto Op : llvm::reverse(Operations)) {
  144. switch (Op->getOpcode()) {
  145. case Instruction::Xor:
  146. Result ^= cast<ConstantInt>(Op->getOperand(1))->getValue();
  147. break;
  148. case Instruction::ZExt:
  149. Result = Result.zext(Op->getType()->getIntegerBitWidth());
  150. break;
  151. case Instruction::SExt:
  152. Result = Result.sext(Op->getType()->getIntegerBitWidth());
  153. break;
  154. default:
  155. llvm_unreachable("Unexpected operation");
  156. }
  157. }
  158. return Result;
  159. };
  160. auto *PhiDef = cast<PHINode>(V);
  161. // Get the first dominating conditional branch of the operand
  162. // i's incoming block.
  163. auto GetDomConditional = [&](unsigned i) -> BranchInst * {
  164. BasicBlock *BB = PhiDef->getIncomingBlock(i);
  165. BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  166. if (BI && BI->isConditional())
  167. return BI;
  168. BB = BB->getSinglePredecessor();
  169. if (!BB)
  170. return nullptr;
  171. BI = dyn_cast<BranchInst>(BB->getTerminator());
  172. if (!BI || BI->isUnconditional())
  173. return nullptr;
  174. return BI;
  175. };
  176. // Now walk through all Phi operands to find phi oprerands with values
  177. // conflicting with the expected phi output value. Any such operand
  178. // indicates the incoming edge to that operand is unlikely.
  179. for (unsigned i = 0, e = PhiDef->getNumIncomingValues(); i != e; ++i) {
  180. Value *PhiOpnd = PhiDef->getIncomingValue(i);
  181. ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd);
  182. if (!CI)
  183. continue;
  184. // Not an interesting case when IsUnlikely is false -- we can not infer
  185. // anything useful when the operand value matches the expected phi
  186. // output.
  187. if (ExpectedPhiValue == ApplyOperations(CI->getValue()))
  188. continue;
  189. BranchInst *BI = GetDomConditional(i);
  190. if (!BI)
  191. continue;
  192. MDBuilder MDB(PhiDef->getContext());
  193. // There are two situations in which an operand of the PhiDef comes
  194. // from a given successor of a branch instruction BI.
  195. // 1) When the incoming block of the operand is the successor block;
  196. // 2) When the incoming block is BI's enclosing block and the
  197. // successor is the PhiDef's enclosing block.
  198. //
  199. // Returns true if the operand which comes from OpndIncomingBB
  200. // comes from outgoing edge of BI that leads to Succ block.
  201. auto *OpndIncomingBB = PhiDef->getIncomingBlock(i);
  202. auto IsOpndComingFromSuccessor = [&](BasicBlock *Succ) {
  203. if (OpndIncomingBB == Succ)
  204. // If this successor is the incoming block for this
  205. // Phi operand, then this successor does lead to the Phi.
  206. return true;
  207. if (OpndIncomingBB == BI->getParent() && Succ == PhiDef->getParent())
  208. // Otherwise, if the edge is directly from the branch
  209. // to the Phi, this successor is the one feeding this
  210. // Phi operand.
  211. return true;
  212. return false;
  213. };
  214. uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
  215. std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight(
  216. Expect->getCalledFunction()->getIntrinsicID(), Expect, 2);
  217. if (IsOpndComingFromSuccessor(BI->getSuccessor(1)))
  218. BI->setMetadata(LLVMContext::MD_prof,
  219. MDB.createBranchWeights(LikelyBranchWeightVal,
  220. UnlikelyBranchWeightVal));
  221. else if (IsOpndComingFromSuccessor(BI->getSuccessor(0)))
  222. BI->setMetadata(LLVMContext::MD_prof,
  223. MDB.createBranchWeights(UnlikelyBranchWeightVal,
  224. LikelyBranchWeightVal));
  225. }
  226. }
  227. // Handle both BranchInst and SelectInst.
  228. template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
  229. // Handle non-optimized IR code like:
  230. // %expval = call i64 @llvm.expect.i64(i64 %conv1, i64 1)
  231. // %tobool = icmp ne i64 %expval, 0
  232. // br i1 %tobool, label %if.then, label %if.end
  233. //
  234. // Or the following simpler case:
  235. // %expval = call i1 @llvm.expect.i1(i1 %cmp, i1 1)
  236. // br i1 %expval, label %if.then, label %if.end
  237. CallInst *CI;
  238. ICmpInst *CmpI = dyn_cast<ICmpInst>(BSI.getCondition());
  239. CmpInst::Predicate Predicate;
  240. ConstantInt *CmpConstOperand = nullptr;
  241. if (!CmpI) {
  242. CI = dyn_cast<CallInst>(BSI.getCondition());
  243. Predicate = CmpInst::ICMP_NE;
  244. } else {
  245. Predicate = CmpI->getPredicate();
  246. if (Predicate != CmpInst::ICMP_NE && Predicate != CmpInst::ICMP_EQ)
  247. return false;
  248. CmpConstOperand = dyn_cast<ConstantInt>(CmpI->getOperand(1));
  249. if (!CmpConstOperand)
  250. return false;
  251. CI = dyn_cast<CallInst>(CmpI->getOperand(0));
  252. }
  253. if (!CI)
  254. return false;
  255. uint64_t ValueComparedTo = 0;
  256. if (CmpConstOperand) {
  257. if (CmpConstOperand->getBitWidth() > 64)
  258. return false;
  259. ValueComparedTo = CmpConstOperand->getZExtValue();
  260. }
  261. Function *Fn = CI->getCalledFunction();
  262. if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
  263. Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
  264. return false;
  265. Value *ArgValue = CI->getArgOperand(0);
  266. ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
  267. if (!ExpectedValue)
  268. return false;
  269. MDBuilder MDB(CI->getContext());
  270. MDNode *Node;
  271. uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
  272. std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
  273. getBranchWeight(Fn->getIntrinsicID(), CI, 2);
  274. if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==
  275. (Predicate == CmpInst::ICMP_EQ)) {
  276. Node =
  277. MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
  278. } else {
  279. Node =
  280. MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal);
  281. }
  282. if (CmpI)
  283. CmpI->setOperand(0, ArgValue);
  284. else
  285. BSI.setCondition(ArgValue);
  286. BSI.setMetadata(LLVMContext::MD_prof, Node);
  287. return true;
  288. }
  289. static bool handleBranchExpect(BranchInst &BI) {
  290. if (BI.isUnconditional())
  291. return false;
  292. return handleBrSelExpect<BranchInst>(BI);
  293. }
  294. static bool lowerExpectIntrinsic(Function &F) {
  295. bool Changed = false;
  296. for (BasicBlock &BB : F) {
  297. // Create "block_weights" metadata.
  298. if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) {
  299. if (handleBranchExpect(*BI))
  300. ExpectIntrinsicsHandled++;
  301. } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator())) {
  302. if (handleSwitchExpect(*SI))
  303. ExpectIntrinsicsHandled++;
  304. }
  305. // Remove llvm.expect intrinsics. Iterate backwards in order
  306. // to process select instructions before the intrinsic gets
  307. // removed.
  308. for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(BB))) {
  309. CallInst *CI = dyn_cast<CallInst>(&Inst);
  310. if (!CI) {
  311. if (SelectInst *SI = dyn_cast<SelectInst>(&Inst)) {
  312. if (handleBrSelExpect(*SI))
  313. ExpectIntrinsicsHandled++;
  314. }
  315. continue;
  316. }
  317. Function *Fn = CI->getCalledFunction();
  318. if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect ||
  319. Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) {
  320. // Before erasing the llvm.expect, walk backward to find
  321. // phi that define llvm.expect's first arg, and
  322. // infer branch probability:
  323. handlePhiDef(CI);
  324. Value *Exp = CI->getArgOperand(0);
  325. CI->replaceAllUsesWith(Exp);
  326. CI->eraseFromParent();
  327. Changed = true;
  328. }
  329. }
  330. }
  331. return Changed;
  332. }
  333. PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F,
  334. FunctionAnalysisManager &) {
  335. if (lowerExpectIntrinsic(F))
  336. return PreservedAnalyses::none();
  337. return PreservedAnalyses::all();
  338. }
  339. namespace {
  340. /// Legacy pass for lowering expect intrinsics out of the IR.
  341. ///
  342. /// When this pass is run over a function it uses expect intrinsics which feed
  343. /// branches and switches to provide branch weight metadata for those
  344. /// terminators. It then removes the expect intrinsics from the IR so the rest
  345. /// of the optimizer can ignore them.
  346. class LowerExpectIntrinsic : public FunctionPass {
  347. public:
  348. static char ID;
  349. LowerExpectIntrinsic() : FunctionPass(ID) {
  350. initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry());
  351. }
  352. bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); }
  353. };
  354. }
  355. char LowerExpectIntrinsic::ID = 0;
  356. INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect",
  357. "Lower 'expect' Intrinsics", false, false)
  358. FunctionPass *llvm::createLowerExpectIntrinsicPass() {
  359. return new LowerExpectIntrinsic();
  360. }