CallPrinter.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
  10. // containing the call graph of a module.
  11. //
  12. // There is also a pass available to directly call dotty ('-view-callgraph').
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/Analysis/CallPrinter.h"
  16. #include "llvm/ADT/DenseMap.h"
  17. #include "llvm/ADT/SmallSet.h"
  18. #include "llvm/Analysis/BlockFrequencyInfo.h"
  19. #include "llvm/Analysis/CallGraph.h"
  20. #include "llvm/Analysis/HeatUtils.h"
  21. #include "llvm/IR/Instructions.h"
  22. #include "llvm/InitializePasses.h"
  23. #include "llvm/Support/CommandLine.h"
  24. #include "llvm/Support/DOTGraphTraits.h"
  25. #include "llvm/Support/GraphWriter.h"
  26. using namespace llvm;
  27. namespace llvm {
  28. template <class GraphType> struct GraphTraits;
  29. }
  30. // This option shows static (relative) call counts.
  31. // FIXME:
  32. // Need to show real counts when profile data is available
  33. static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
  34. cl::Hidden,
  35. cl::desc("Show heat colors in call-graph"));
  36. static cl::opt<bool>
  37. ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
  38. cl::desc("Show edges labeled with weights"));
  39. static cl::opt<bool>
  40. CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
  41. cl::desc("Show call-multigraph (do not remove parallel edges)"));
  42. static cl::opt<std::string> CallGraphDotFilenamePrefix(
  43. "callgraph-dot-filename-prefix", cl::Hidden,
  44. cl::desc("The prefix used for the CallGraph dot file names."));
  45. namespace llvm {
  46. class CallGraphDOTInfo {
  47. private:
  48. Module *M;
  49. CallGraph *CG;
  50. DenseMap<const Function *, uint64_t> Freq;
  51. uint64_t MaxFreq;
  52. public:
  53. std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
  54. CallGraphDOTInfo(Module *M, CallGraph *CG,
  55. function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
  56. : M(M), CG(CG), LookupBFI(LookupBFI) {
  57. MaxFreq = 0;
  58. for (Function &F : M->getFunctionList()) {
  59. uint64_t localSumFreq = 0;
  60. SmallSet<Function *, 16> Callers;
  61. for (User *U : F.users())
  62. if (isa<CallInst>(U))
  63. Callers.insert(cast<Instruction>(U)->getFunction());
  64. for (Function *Caller : Callers)
  65. localSumFreq += getNumOfCalls(*Caller, F);
  66. if (localSumFreq >= MaxFreq)
  67. MaxFreq = localSumFreq;
  68. Freq[&F] = localSumFreq;
  69. }
  70. if (!CallMultiGraph)
  71. removeParallelEdges();
  72. }
  73. Module *getModule() const { return M; }
  74. CallGraph *getCallGraph() const { return CG; }
  75. uint64_t getFreq(const Function *F) { return Freq[F]; }
  76. uint64_t getMaxFreq() { return MaxFreq; }
  77. private:
  78. void removeParallelEdges() {
  79. for (auto &I : (*CG)) {
  80. CallGraphNode *Node = I.second.get();
  81. bool FoundParallelEdge = true;
  82. while (FoundParallelEdge) {
  83. SmallSet<Function *, 16> Visited;
  84. FoundParallelEdge = false;
  85. for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
  86. if (!(Visited.insert(CI->second->getFunction())).second) {
  87. FoundParallelEdge = true;
  88. Node->removeCallEdge(CI);
  89. break;
  90. }
  91. }
  92. }
  93. }
  94. }
  95. };
  96. template <>
  97. struct GraphTraits<CallGraphDOTInfo *>
  98. : public GraphTraits<const CallGraphNode *> {
  99. static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
  100. // Start at the external node!
  101. return CGInfo->getCallGraph()->getExternalCallingNode();
  102. }
  103. typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
  104. PairTy;
  105. static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
  106. return P.second.get();
  107. }
  108. // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
  109. typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
  110. nodes_iterator;
  111. static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
  112. return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
  113. }
  114. static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
  115. return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
  116. }
  117. };
  118. template <>
  119. struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
  120. DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
  121. static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
  122. return "Call graph: " +
  123. std::string(CGInfo->getModule()->getModuleIdentifier());
  124. }
  125. static bool isNodeHidden(const CallGraphNode *Node,
  126. const CallGraphDOTInfo *CGInfo) {
  127. if (CallMultiGraph || Node->getFunction())
  128. return false;
  129. return true;
  130. }
  131. std::string getNodeLabel(const CallGraphNode *Node,
  132. CallGraphDOTInfo *CGInfo) {
  133. if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
  134. return "external caller";
  135. if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
  136. return "external callee";
  137. if (Function *Func = Node->getFunction())
  138. return std::string(Func->getName());
  139. return "external node";
  140. }
  141. static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
  142. return P.second;
  143. }
  144. // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
  145. typedef mapped_iterator<CallGraphNode::const_iterator,
  146. decltype(&CGGetValuePtr)>
  147. nodes_iterator;
  148. std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
  149. CallGraphDOTInfo *CGInfo) {
  150. if (!ShowEdgeWeight)
  151. return "";
  152. Function *Caller = Node->getFunction();
  153. if (Caller == nullptr || Caller->isDeclaration())
  154. return "";
  155. Function *Callee = (*I)->getFunction();
  156. if (Callee == nullptr)
  157. return "";
  158. uint64_t Counter = getNumOfCalls(*Caller, *Callee);
  159. double Width =
  160. 1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
  161. std::string Attrs = "label=\"" + std::to_string(Counter) +
  162. "\" penwidth=" + std::to_string(Width);
  163. return Attrs;
  164. }
  165. std::string getNodeAttributes(const CallGraphNode *Node,
  166. CallGraphDOTInfo *CGInfo) {
  167. Function *F = Node->getFunction();
  168. if (F == nullptr)
  169. return "";
  170. std::string attrs;
  171. if (ShowHeatColors) {
  172. uint64_t freq = CGInfo->getFreq(F);
  173. std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
  174. std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
  175. ? getHeatColor(0)
  176. : getHeatColor(1);
  177. attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
  178. color + "80\"";
  179. }
  180. return attrs;
  181. }
  182. };
  183. } // end llvm namespace
  184. namespace {
  185. void doCallGraphDOTPrinting(
  186. Module &M, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
  187. std::string Filename;
  188. if (!CallGraphDotFilenamePrefix.empty())
  189. Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
  190. else
  191. Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
  192. errs() << "Writing '" << Filename << "'...";
  193. std::error_code EC;
  194. raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
  195. CallGraph CG(M);
  196. CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
  197. if (!EC)
  198. WriteGraph(File, &CFGInfo);
  199. else
  200. errs() << " error opening file for writing!";
  201. errs() << "\n";
  202. }
  203. void viewCallGraph(Module &M,
  204. function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
  205. CallGraph CG(M);
  206. CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
  207. std::string Title =
  208. DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
  209. ViewGraph(&CFGInfo, "callgraph", true, Title);
  210. }
  211. } // namespace
  212. namespace llvm {
  213. PreservedAnalyses CallGraphDOTPrinterPass::run(Module &M,
  214. ModuleAnalysisManager &AM) {
  215. FunctionAnalysisManager &FAM =
  216. AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
  217. auto LookupBFI = [&FAM](Function &F) {
  218. return &FAM.getResult<BlockFrequencyAnalysis>(F);
  219. };
  220. doCallGraphDOTPrinting(M, LookupBFI);
  221. return PreservedAnalyses::all();
  222. }
  223. PreservedAnalyses CallGraphViewerPass::run(Module &M,
  224. ModuleAnalysisManager &AM) {
  225. FunctionAnalysisManager &FAM =
  226. AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
  227. auto LookupBFI = [&FAM](Function &F) {
  228. return &FAM.getResult<BlockFrequencyAnalysis>(F);
  229. };
  230. viewCallGraph(M, LookupBFI);
  231. return PreservedAnalyses::all();
  232. }
  233. } // namespace llvm
  234. namespace {
  235. // Viewer
  236. class CallGraphViewer : public ModulePass {
  237. public:
  238. static char ID;
  239. CallGraphViewer() : ModulePass(ID) {}
  240. void getAnalysisUsage(AnalysisUsage &AU) const override;
  241. bool runOnModule(Module &M) override;
  242. };
  243. void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
  244. ModulePass::getAnalysisUsage(AU);
  245. AU.addRequired<BlockFrequencyInfoWrapperPass>();
  246. AU.setPreservesAll();
  247. }
  248. bool CallGraphViewer::runOnModule(Module &M) {
  249. auto LookupBFI = [this](Function &F) {
  250. return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
  251. };
  252. viewCallGraph(M, LookupBFI);
  253. return false;
  254. }
  255. // DOT Printer
  256. class CallGraphDOTPrinter : public ModulePass {
  257. public:
  258. static char ID;
  259. CallGraphDOTPrinter() : ModulePass(ID) {}
  260. void getAnalysisUsage(AnalysisUsage &AU) const override;
  261. bool runOnModule(Module &M) override;
  262. };
  263. void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
  264. ModulePass::getAnalysisUsage(AU);
  265. AU.addRequired<BlockFrequencyInfoWrapperPass>();
  266. AU.setPreservesAll();
  267. }
  268. bool CallGraphDOTPrinter::runOnModule(Module &M) {
  269. auto LookupBFI = [this](Function &F) {
  270. return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
  271. };
  272. doCallGraphDOTPrinting(M, LookupBFI);
  273. return false;
  274. }
  275. } // end anonymous namespace
  276. char CallGraphViewer::ID = 0;
  277. INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
  278. false)
  279. char CallGraphDOTPrinter::ID = 0;
  280. INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
  281. "Print call graph to 'dot' file", false, false)
  282. // Create methods available outside of this file, to use them
  283. // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
  284. // the link time optimization.
  285. ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
  286. ModulePass *llvm::createCallGraphDOTPrinterPass() {
  287. return new CallGraphDOTPrinter();
  288. }