CallPrinter.cpp 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
  10. // containing the call graph of a module.
  11. //
  12. // There is also a pass available to directly call dotty ('-view-callgraph').
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/Analysis/CallPrinter.h"
  16. #include "llvm/Analysis/BlockFrequencyInfo.h"
  17. #include "llvm/Analysis/BranchProbabilityInfo.h"
  18. #include "llvm/Analysis/CallGraph.h"
  19. #include "llvm/Analysis/DOTGraphTraitsPass.h"
  20. #include "llvm/Analysis/HeatUtils.h"
  21. #include "llvm/Support/CommandLine.h"
  22. #include "llvm/InitializePasses.h"
  23. #include "llvm/ADT/DenseMap.h"
  24. #include "llvm/ADT/SmallSet.h"
  25. using namespace llvm;
  26. // This option shows static (relative) call counts.
  27. // FIXME:
  28. // Need to show real counts when profile data is available
  29. static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
  30. cl::Hidden,
  31. cl::desc("Show heat colors in call-graph"));
  32. static cl::opt<bool>
  33. ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
  34. cl::desc("Show edges labeled with weights"));
  35. static cl::opt<bool>
  36. CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
  37. cl::desc("Show call-multigraph (do not remove parallel edges)"));
  38. static cl::opt<std::string> CallGraphDotFilenamePrefix(
  39. "callgraph-dot-filename-prefix", cl::Hidden,
  40. cl::desc("The prefix used for the CallGraph dot file names."));
  41. namespace llvm {
  42. class CallGraphDOTInfo {
  43. private:
  44. Module *M;
  45. CallGraph *CG;
  46. DenseMap<const Function *, uint64_t> Freq;
  47. uint64_t MaxFreq;
  48. public:
  49. std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
  50. CallGraphDOTInfo(Module *M, CallGraph *CG,
  51. function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
  52. : M(M), CG(CG), LookupBFI(LookupBFI) {
  53. MaxFreq = 0;
  54. for (Function &F : M->getFunctionList()) {
  55. uint64_t localSumFreq = 0;
  56. SmallSet<Function *, 16> Callers;
  57. for (User *U : F.users())
  58. if (isa<CallInst>(U))
  59. Callers.insert(cast<Instruction>(U)->getFunction());
  60. for (Function *Caller : Callers)
  61. localSumFreq += getNumOfCalls(*Caller, F);
  62. if (localSumFreq >= MaxFreq)
  63. MaxFreq = localSumFreq;
  64. Freq[&F] = localSumFreq;
  65. }
  66. if (!CallMultiGraph)
  67. removeParallelEdges();
  68. }
  69. Module *getModule() const { return M; }
  70. CallGraph *getCallGraph() const { return CG; }
  71. uint64_t getFreq(const Function *F) { return Freq[F]; }
  72. uint64_t getMaxFreq() { return MaxFreq; }
  73. private:
  74. void removeParallelEdges() {
  75. for (auto &I : (*CG)) {
  76. CallGraphNode *Node = I.second.get();
  77. bool FoundParallelEdge = true;
  78. while (FoundParallelEdge) {
  79. SmallSet<Function *, 16> Visited;
  80. FoundParallelEdge = false;
  81. for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
  82. if (!(Visited.insert(CI->second->getFunction())).second) {
  83. FoundParallelEdge = true;
  84. Node->removeCallEdge(CI);
  85. break;
  86. }
  87. }
  88. }
  89. }
  90. }
  91. };
  92. template <>
  93. struct GraphTraits<CallGraphDOTInfo *>
  94. : public GraphTraits<const CallGraphNode *> {
  95. static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
  96. // Start at the external node!
  97. return CGInfo->getCallGraph()->getExternalCallingNode();
  98. }
  99. typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
  100. PairTy;
  101. static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
  102. return P.second.get();
  103. }
  104. // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
  105. typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
  106. nodes_iterator;
  107. static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
  108. return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
  109. }
  110. static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
  111. return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
  112. }
  113. };
  114. template <>
  115. struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
  116. DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
  117. static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
  118. return "Call graph: " +
  119. std::string(CGInfo->getModule()->getModuleIdentifier());
  120. }
  121. static bool isNodeHidden(const CallGraphNode *Node,
  122. const CallGraphDOTInfo *CGInfo) {
  123. if (CallMultiGraph || Node->getFunction())
  124. return false;
  125. return true;
  126. }
  127. std::string getNodeLabel(const CallGraphNode *Node,
  128. CallGraphDOTInfo *CGInfo) {
  129. if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
  130. return "external caller";
  131. if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
  132. return "external callee";
  133. if (Function *Func = Node->getFunction())
  134. return std::string(Func->getName());
  135. return "external node";
  136. }
  137. static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
  138. return P.second;
  139. }
  140. // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
  141. typedef mapped_iterator<CallGraphNode::const_iterator,
  142. decltype(&CGGetValuePtr)>
  143. nodes_iterator;
  144. std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
  145. CallGraphDOTInfo *CGInfo) {
  146. if (!ShowEdgeWeight)
  147. return "";
  148. Function *Caller = Node->getFunction();
  149. if (Caller == nullptr || Caller->isDeclaration())
  150. return "";
  151. Function *Callee = (*I)->getFunction();
  152. if (Callee == nullptr)
  153. return "";
  154. uint64_t Counter = getNumOfCalls(*Caller, *Callee);
  155. double Width =
  156. 1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
  157. std::string Attrs = "label=\"" + std::to_string(Counter) +
  158. "\" penwidth=" + std::to_string(Width);
  159. return Attrs;
  160. }
  161. std::string getNodeAttributes(const CallGraphNode *Node,
  162. CallGraphDOTInfo *CGInfo) {
  163. Function *F = Node->getFunction();
  164. if (F == nullptr)
  165. return "";
  166. std::string attrs;
  167. if (ShowHeatColors) {
  168. uint64_t freq = CGInfo->getFreq(F);
  169. std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
  170. std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
  171. ? getHeatColor(0)
  172. : getHeatColor(1);
  173. attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
  174. color + "80\"";
  175. }
  176. return attrs;
  177. }
  178. };
  179. } // end llvm namespace
  180. namespace {
  181. // Viewer
  182. class CallGraphViewer : public ModulePass {
  183. public:
  184. static char ID;
  185. CallGraphViewer() : ModulePass(ID) {}
  186. void getAnalysisUsage(AnalysisUsage &AU) const override;
  187. bool runOnModule(Module &M) override;
  188. };
  189. void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
  190. ModulePass::getAnalysisUsage(AU);
  191. AU.addRequired<BlockFrequencyInfoWrapperPass>();
  192. AU.setPreservesAll();
  193. }
  194. bool CallGraphViewer::runOnModule(Module &M) {
  195. auto LookupBFI = [this](Function &F) {
  196. return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
  197. };
  198. CallGraph CG(M);
  199. CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
  200. std::string Title =
  201. DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
  202. ViewGraph(&CFGInfo, "callgraph", true, Title);
  203. return false;
  204. }
  205. // DOT Printer
  206. class CallGraphDOTPrinter : public ModulePass {
  207. public:
  208. static char ID;
  209. CallGraphDOTPrinter() : ModulePass(ID) {}
  210. void getAnalysisUsage(AnalysisUsage &AU) const override;
  211. bool runOnModule(Module &M) override;
  212. };
  213. void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
  214. ModulePass::getAnalysisUsage(AU);
  215. AU.addRequired<BlockFrequencyInfoWrapperPass>();
  216. AU.setPreservesAll();
  217. }
  218. bool CallGraphDOTPrinter::runOnModule(Module &M) {
  219. auto LookupBFI = [this](Function &F) {
  220. return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
  221. };
  222. std::string Filename;
  223. if (!CallGraphDotFilenamePrefix.empty())
  224. Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
  225. else
  226. Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
  227. errs() << "Writing '" << Filename << "'...";
  228. std::error_code EC;
  229. raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
  230. CallGraph CG(M);
  231. CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
  232. if (!EC)
  233. WriteGraph(File, &CFGInfo);
  234. else
  235. errs() << " error opening file for writing!";
  236. errs() << "\n";
  237. return false;
  238. }
  239. } // end anonymous namespace
  240. char CallGraphViewer::ID = 0;
  241. INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
  242. false)
  243. char CallGraphDOTPrinter::ID = 0;
  244. INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
  245. "Print call graph to 'dot' file", false, false)
  246. // Create methods available outside of this file, to use them
  247. // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
  248. // the link time optimization.
  249. ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
  250. ModulePass *llvm::createCallGraphDOTPrinterPass() {
  251. return new CallGraphDOTPrinter();
  252. }