CalledValuePropagation.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. //===- CalledValuePropagation.cpp - Propagate called values -----*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a transformation that attaches !callees metadata to
  10. // indirect call sites. For a given call site, the metadata, if present,
  11. // indicates the set of functions the call site could possibly target at
  12. // run-time. This metadata is added to indirect call sites when the set of
  13. // possible targets can be determined by analysis and is known to be small. The
  14. // analysis driving the transformation is similar to constant propagation and
  15. // makes uses of the generic sparse propagation solver.
  16. //
  17. //===----------------------------------------------------------------------===//
  18. #include "llvm/Transforms/IPO/CalledValuePropagation.h"
  19. #include "llvm/Analysis/SparsePropagation.h"
  20. #include "llvm/Analysis/ValueLatticeUtils.h"
  21. #include "llvm/IR/MDBuilder.h"
  22. #include "llvm/InitializePasses.h"
  23. #include "llvm/Pass.h"
  24. #include "llvm/Support/CommandLine.h"
  25. #include "llvm/Transforms/IPO.h"
  26. using namespace llvm;
  27. #define DEBUG_TYPE "called-value-propagation"
  28. /// The maximum number of functions to track per lattice value. Once the number
  29. /// of functions a call site can possibly target exceeds this threshold, it's
  30. /// lattice value becomes overdefined. The number of possible lattice values is
  31. /// bounded by Ch(F, M), where F is the number of functions in the module and M
  32. /// is MaxFunctionsPerValue. As such, this value should be kept very small. We
  33. /// likely can't do anything useful for call sites with a large number of
  34. /// possible targets, anyway.
  35. static cl::opt<unsigned> MaxFunctionsPerValue(
  36. "cvp-max-functions-per-value", cl::Hidden, cl::init(4),
  37. cl::desc("The maximum number of functions to track per lattice value"));
  38. namespace {
  39. /// To enable interprocedural analysis, we assign LLVM values to the following
  40. /// groups. The register group represents SSA registers, the return group
  41. /// represents the return values of functions, and the memory group represents
  42. /// in-memory values. An LLVM Value can technically be in more than one group.
  43. /// It's necessary to distinguish these groups so we can, for example, track a
  44. /// global variable separately from the value stored at its location.
  45. enum class IPOGrouping { Register, Return, Memory };
  46. /// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
  47. using CVPLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
  48. /// The lattice value type used by our custom lattice function. It holds the
  49. /// lattice state, and a set of functions.
  50. class CVPLatticeVal {
  51. public:
  52. /// The states of the lattice values. Only the FunctionSet state is
  53. /// interesting. It indicates the set of functions to which an LLVM value may
  54. /// refer.
  55. enum CVPLatticeStateTy { Undefined, FunctionSet, Overdefined, Untracked };
  56. /// Comparator for sorting the functions set. We want to keep the order
  57. /// deterministic for testing, etc.
  58. struct Compare {
  59. bool operator()(const Function *LHS, const Function *RHS) const {
  60. return LHS->getName() < RHS->getName();
  61. }
  62. };
  63. CVPLatticeVal() : LatticeState(Undefined) {}
  64. CVPLatticeVal(CVPLatticeStateTy LatticeState) : LatticeState(LatticeState) {}
  65. CVPLatticeVal(std::vector<Function *> &&Functions)
  66. : LatticeState(FunctionSet), Functions(std::move(Functions)) {
  67. assert(llvm::is_sorted(this->Functions, Compare()));
  68. }
  69. /// Get a reference to the functions held by this lattice value. The number
  70. /// of functions will be zero for states other than FunctionSet.
  71. const std::vector<Function *> &getFunctions() const {
  72. return Functions;
  73. }
  74. /// Returns true if the lattice value is in the FunctionSet state.
  75. bool isFunctionSet() const { return LatticeState == FunctionSet; }
  76. bool operator==(const CVPLatticeVal &RHS) const {
  77. return LatticeState == RHS.LatticeState && Functions == RHS.Functions;
  78. }
  79. bool operator!=(const CVPLatticeVal &RHS) const {
  80. return LatticeState != RHS.LatticeState || Functions != RHS.Functions;
  81. }
  82. private:
  83. /// Holds the state this lattice value is in.
  84. CVPLatticeStateTy LatticeState;
  85. /// Holds functions indicating the possible targets of call sites. This set
  86. /// is empty for lattice values in the undefined, overdefined, and untracked
  87. /// states. The maximum size of the set is controlled by
  88. /// MaxFunctionsPerValue. Since most LLVM values are expected to be in
  89. /// uninteresting states (i.e., overdefined), CVPLatticeVal objects should be
  90. /// small and efficiently copyable.
  91. // FIXME: This could be a TinyPtrVector and/or merge with LatticeState.
  92. std::vector<Function *> Functions;
  93. };
  94. /// The custom lattice function used by the generic sparse propagation solver.
  95. /// It handles merging lattice values and computing new lattice values for
  96. /// constants, arguments, values returned from trackable functions, and values
  97. /// located in trackable global variables. It also computes the lattice values
  98. /// that change as a result of executing instructions.
  99. class CVPLatticeFunc
  100. : public AbstractLatticeFunction<CVPLatticeKey, CVPLatticeVal> {
  101. public:
  102. CVPLatticeFunc()
  103. : AbstractLatticeFunction(CVPLatticeVal(CVPLatticeVal::Undefined),
  104. CVPLatticeVal(CVPLatticeVal::Overdefined),
  105. CVPLatticeVal(CVPLatticeVal::Untracked)) {}
  106. /// Compute and return a CVPLatticeVal for the given CVPLatticeKey.
  107. CVPLatticeVal ComputeLatticeVal(CVPLatticeKey Key) override {
  108. switch (Key.getInt()) {
  109. case IPOGrouping::Register:
  110. if (isa<Instruction>(Key.getPointer())) {
  111. return getUndefVal();
  112. } else if (auto *A = dyn_cast<Argument>(Key.getPointer())) {
  113. if (canTrackArgumentsInterprocedurally(A->getParent()))
  114. return getUndefVal();
  115. } else if (auto *C = dyn_cast<Constant>(Key.getPointer())) {
  116. return computeConstant(C);
  117. }
  118. return getOverdefinedVal();
  119. case IPOGrouping::Memory:
  120. case IPOGrouping::Return:
  121. if (auto *GV = dyn_cast<GlobalVariable>(Key.getPointer())) {
  122. if (canTrackGlobalVariableInterprocedurally(GV))
  123. return computeConstant(GV->getInitializer());
  124. } else if (auto *F = cast<Function>(Key.getPointer()))
  125. if (canTrackReturnsInterprocedurally(F))
  126. return getUndefVal();
  127. }
  128. return getOverdefinedVal();
  129. }
  130. /// Merge the two given lattice values. The interesting cases are merging two
  131. /// FunctionSet values and a FunctionSet value with an Undefined value. For
  132. /// these cases, we simply union the function sets. If the size of the union
  133. /// is greater than the maximum functions we track, the merged value is
  134. /// overdefined.
  135. CVPLatticeVal MergeValues(CVPLatticeVal X, CVPLatticeVal Y) override {
  136. if (X == getOverdefinedVal() || Y == getOverdefinedVal())
  137. return getOverdefinedVal();
  138. if (X == getUndefVal() && Y == getUndefVal())
  139. return getUndefVal();
  140. std::vector<Function *> Union;
  141. std::set_union(X.getFunctions().begin(), X.getFunctions().end(),
  142. Y.getFunctions().begin(), Y.getFunctions().end(),
  143. std::back_inserter(Union), CVPLatticeVal::Compare{});
  144. if (Union.size() > MaxFunctionsPerValue)
  145. return getOverdefinedVal();
  146. return CVPLatticeVal(std::move(Union));
  147. }
  148. /// Compute the lattice values that change as a result of executing the given
  149. /// instruction. The changed values are stored in \p ChangedValues. We handle
  150. /// just a few kinds of instructions since we're only propagating values that
  151. /// can be called.
  152. void ComputeInstructionState(
  153. Instruction &I, DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
  154. SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) override {
  155. switch (I.getOpcode()) {
  156. case Instruction::Call:
  157. case Instruction::Invoke:
  158. return visitCallBase(cast<CallBase>(I), ChangedValues, SS);
  159. case Instruction::Load:
  160. return visitLoad(*cast<LoadInst>(&I), ChangedValues, SS);
  161. case Instruction::Ret:
  162. return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
  163. case Instruction::Select:
  164. return visitSelect(*cast<SelectInst>(&I), ChangedValues, SS);
  165. case Instruction::Store:
  166. return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
  167. default:
  168. return visitInst(I, ChangedValues, SS);
  169. }
  170. }
  171. /// Print the given CVPLatticeVal to the specified stream.
  172. void PrintLatticeVal(CVPLatticeVal LV, raw_ostream &OS) override {
  173. if (LV == getUndefVal())
  174. OS << "Undefined ";
  175. else if (LV == getOverdefinedVal())
  176. OS << "Overdefined";
  177. else if (LV == getUntrackedVal())
  178. OS << "Untracked ";
  179. else
  180. OS << "FunctionSet";
  181. }
  182. /// Print the given CVPLatticeKey to the specified stream.
  183. void PrintLatticeKey(CVPLatticeKey Key, raw_ostream &OS) override {
  184. if (Key.getInt() == IPOGrouping::Register)
  185. OS << "<reg> ";
  186. else if (Key.getInt() == IPOGrouping::Memory)
  187. OS << "<mem> ";
  188. else if (Key.getInt() == IPOGrouping::Return)
  189. OS << "<ret> ";
  190. if (isa<Function>(Key.getPointer()))
  191. OS << Key.getPointer()->getName();
  192. else
  193. OS << *Key.getPointer();
  194. }
  195. /// We collect a set of indirect calls when visiting call sites. This method
  196. /// returns a reference to that set.
  197. SmallPtrSetImpl<CallBase *> &getIndirectCalls() { return IndirectCalls; }
  198. private:
  199. /// Holds the indirect calls we encounter during the analysis. We will attach
  200. /// metadata to these calls after the analysis indicating the functions the
  201. /// calls can possibly target.
  202. SmallPtrSet<CallBase *, 32> IndirectCalls;
  203. /// Compute a new lattice value for the given constant. The constant, after
  204. /// stripping any pointer casts, should be a Function. We ignore null
  205. /// pointers as an optimization, since calling these values is undefined
  206. /// behavior.
  207. CVPLatticeVal computeConstant(Constant *C) {
  208. if (isa<ConstantPointerNull>(C))
  209. return CVPLatticeVal(CVPLatticeVal::FunctionSet);
  210. if (auto *F = dyn_cast<Function>(C->stripPointerCasts()))
  211. return CVPLatticeVal({F});
  212. return getOverdefinedVal();
  213. }
  214. /// Handle return instructions. The function's return state is the merge of
  215. /// the returned value state and the function's return state.
  216. void visitReturn(ReturnInst &I,
  217. DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
  218. SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
  219. Function *F = I.getParent()->getParent();
  220. if (F->getReturnType()->isVoidTy())
  221. return;
  222. auto RegI = CVPLatticeKey(I.getReturnValue(), IPOGrouping::Register);
  223. auto RetF = CVPLatticeKey(F, IPOGrouping::Return);
  224. ChangedValues[RetF] =
  225. MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
  226. }
  227. /// Handle call sites. The state of a called function's formal arguments is
  228. /// the merge of the argument state with the call sites corresponding actual
  229. /// argument state. The call site state is the merge of the call site state
  230. /// with the returned value state of the called function.
  231. void visitCallBase(CallBase &CB,
  232. DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
  233. SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
  234. Function *F = CB.getCalledFunction();
  235. auto RegI = CVPLatticeKey(&CB, IPOGrouping::Register);
  236. // If this is an indirect call, save it so we can quickly revisit it when
  237. // attaching metadata.
  238. if (!F)
  239. IndirectCalls.insert(&CB);
  240. // If we can't track the function's return values, there's nothing to do.
  241. if (!F || !canTrackReturnsInterprocedurally(F)) {
  242. // Void return, No need to create and update CVPLattice state as no one
  243. // can use it.
  244. if (CB.getType()->isVoidTy())
  245. return;
  246. ChangedValues[RegI] = getOverdefinedVal();
  247. return;
  248. }
  249. // Inform the solver that the called function is executable, and perform
  250. // the merges for the arguments and return value.
  251. SS.MarkBlockExecutable(&F->front());
  252. auto RetF = CVPLatticeKey(F, IPOGrouping::Return);
  253. for (Argument &A : F->args()) {
  254. auto RegFormal = CVPLatticeKey(&A, IPOGrouping::Register);
  255. auto RegActual =
  256. CVPLatticeKey(CB.getArgOperand(A.getArgNo()), IPOGrouping::Register);
  257. ChangedValues[RegFormal] =
  258. MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
  259. }
  260. // Void return, No need to create and update CVPLattice state as no one can
  261. // use it.
  262. if (CB.getType()->isVoidTy())
  263. return;
  264. ChangedValues[RegI] =
  265. MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
  266. }
  267. /// Handle select instructions. The select instruction state is the merge the
  268. /// true and false value states.
  269. void visitSelect(SelectInst &I,
  270. DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
  271. SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
  272. auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
  273. auto RegT = CVPLatticeKey(I.getTrueValue(), IPOGrouping::Register);
  274. auto RegF = CVPLatticeKey(I.getFalseValue(), IPOGrouping::Register);
  275. ChangedValues[RegI] =
  276. MergeValues(SS.getValueState(RegT), SS.getValueState(RegF));
  277. }
  278. /// Handle load instructions. If the pointer operand of the load is a global
  279. /// variable, we attempt to track the value. The loaded value state is the
  280. /// merge of the loaded value state with the global variable state.
  281. void visitLoad(LoadInst &I,
  282. DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
  283. SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
  284. auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
  285. if (auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand())) {
  286. auto MemGV = CVPLatticeKey(GV, IPOGrouping::Memory);
  287. ChangedValues[RegI] =
  288. MergeValues(SS.getValueState(RegI), SS.getValueState(MemGV));
  289. } else {
  290. ChangedValues[RegI] = getOverdefinedVal();
  291. }
  292. }
  293. /// Handle store instructions. If the pointer operand of the store is a
  294. /// global variable, we attempt to track the value. The global variable state
  295. /// is the merge of the stored value state with the global variable state.
  296. void visitStore(StoreInst &I,
  297. DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
  298. SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
  299. auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
  300. if (!GV)
  301. return;
  302. auto RegI = CVPLatticeKey(I.getValueOperand(), IPOGrouping::Register);
  303. auto MemGV = CVPLatticeKey(GV, IPOGrouping::Memory);
  304. ChangedValues[MemGV] =
  305. MergeValues(SS.getValueState(RegI), SS.getValueState(MemGV));
  306. }
  307. /// Handle all other instructions. All other instructions are marked
  308. /// overdefined.
  309. void visitInst(Instruction &I,
  310. DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
  311. SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
  312. // Simply bail if this instruction has no user.
  313. if (I.use_empty())
  314. return;
  315. auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
  316. ChangedValues[RegI] = getOverdefinedVal();
  317. }
  318. };
  319. } // namespace
  320. namespace llvm {
  321. /// A specialization of LatticeKeyInfo for CVPLatticeKeys. The generic solver
  322. /// must translate between LatticeKeys and LLVM Values when adding Values to
  323. /// its work list and inspecting the state of control-flow related values.
  324. template <> struct LatticeKeyInfo<CVPLatticeKey> {
  325. static inline Value *getValueFromLatticeKey(CVPLatticeKey Key) {
  326. return Key.getPointer();
  327. }
  328. static inline CVPLatticeKey getLatticeKeyFromValue(Value *V) {
  329. return CVPLatticeKey(V, IPOGrouping::Register);
  330. }
  331. };
  332. } // namespace llvm
  333. static bool runCVP(Module &M) {
  334. // Our custom lattice function and generic sparse propagation solver.
  335. CVPLatticeFunc Lattice;
  336. SparseSolver<CVPLatticeKey, CVPLatticeVal> Solver(&Lattice);
  337. // For each function in the module, if we can't track its arguments, let the
  338. // generic solver assume it is executable.
  339. for (Function &F : M)
  340. if (!F.isDeclaration() && !canTrackArgumentsInterprocedurally(&F))
  341. Solver.MarkBlockExecutable(&F.front());
  342. // Solver our custom lattice. In doing so, we will also build a set of
  343. // indirect call sites.
  344. Solver.Solve();
  345. // Attach metadata to the indirect call sites that were collected indicating
  346. // the set of functions they can possibly target.
  347. bool Changed = false;
  348. MDBuilder MDB(M.getContext());
  349. for (CallBase *C : Lattice.getIndirectCalls()) {
  350. auto RegI = CVPLatticeKey(C->getCalledOperand(), IPOGrouping::Register);
  351. CVPLatticeVal LV = Solver.getExistingValueState(RegI);
  352. if (!LV.isFunctionSet() || LV.getFunctions().empty())
  353. continue;
  354. MDNode *Callees = MDB.createCallees(LV.getFunctions());
  355. C->setMetadata(LLVMContext::MD_callees, Callees);
  356. Changed = true;
  357. }
  358. return Changed;
  359. }
  360. PreservedAnalyses CalledValuePropagationPass::run(Module &M,
  361. ModuleAnalysisManager &) {
  362. runCVP(M);
  363. return PreservedAnalyses::all();
  364. }
  365. namespace {
  366. class CalledValuePropagationLegacyPass : public ModulePass {
  367. public:
  368. static char ID;
  369. void getAnalysisUsage(AnalysisUsage &AU) const override {
  370. AU.setPreservesAll();
  371. }
  372. CalledValuePropagationLegacyPass() : ModulePass(ID) {
  373. initializeCalledValuePropagationLegacyPassPass(
  374. *PassRegistry::getPassRegistry());
  375. }
  376. bool runOnModule(Module &M) override {
  377. if (skipModule(M))
  378. return false;
  379. return runCVP(M);
  380. }
  381. };
  382. } // namespace
  383. char CalledValuePropagationLegacyPass::ID = 0;
  384. INITIALIZE_PASS(CalledValuePropagationLegacyPass, "called-value-propagation",
  385. "Called Value Propagation", false, false)
  386. ModulePass *llvm::createCalledValuePropagationPass() {
  387. return new CalledValuePropagationLegacyPass();
  388. }