SyntheticCountsUtils.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. //===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines utilities for propagating synthetic counts.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Analysis/SyntheticCountsUtils.h"
  13. #include "llvm/ADT/DenseSet.h"
  14. #include "llvm/ADT/SCCIterator.h"
  15. #include "llvm/Analysis/CallGraph.h"
  16. #include "llvm/IR/Function.h"
  17. #include "llvm/IR/InstIterator.h"
  18. #include "llvm/IR/Instructions.h"
  19. #include "llvm/IR/ModuleSummaryIndex.h"
  20. using namespace llvm;
  21. // Given an SCC, propagate entry counts along the edge of the SCC nodes.
  22. template <typename CallGraphType>
  23. void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
  24. const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) {
  25. DenseSet<NodeRef> SCCNodes;
  26. SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges;
  27. for (auto &Node : SCC)
  28. SCCNodes.insert(Node);
  29. // Partition the edges coming out of the SCC into those whose destination is
  30. // in the SCC and the rest.
  31. for (const auto &Node : SCCNodes) {
  32. for (auto &E : children_edges<CallGraphType>(Node)) {
  33. if (SCCNodes.count(CGT::edge_dest(E)))
  34. SCCEdges.emplace_back(Node, E);
  35. else
  36. NonSCCEdges.emplace_back(Node, E);
  37. }
  38. }
  39. // For nodes in the same SCC, update the counts in two steps:
  40. // 1. Compute the additional count for each node by propagating the counts
  41. // along all incoming edges to the node that originate from within the same
  42. // SCC and summing them up.
  43. // 2. Add the additional counts to the nodes in the SCC.
  44. // This ensures that the order of
  45. // traversal of nodes within the SCC doesn't affect the final result.
  46. DenseMap<NodeRef, Scaled64> AdditionalCounts;
  47. for (auto &E : SCCEdges) {
  48. auto OptProfCount = GetProfCount(E.first, E.second);
  49. if (!OptProfCount)
  50. continue;
  51. auto Callee = CGT::edge_dest(E.second);
  52. AdditionalCounts[Callee] += OptProfCount.getValue();
  53. }
  54. // Update the counts for the nodes in the SCC.
  55. for (auto &Entry : AdditionalCounts)
  56. AddCount(Entry.first, Entry.second);
  57. // Now update the counts for nodes outside the SCC.
  58. for (auto &E : NonSCCEdges) {
  59. auto OptProfCount = GetProfCount(E.first, E.second);
  60. if (!OptProfCount)
  61. continue;
  62. auto Callee = CGT::edge_dest(E.second);
  63. AddCount(Callee, OptProfCount.getValue());
  64. }
  65. }
  66. /// Propgate synthetic entry counts on a callgraph \p CG.
  67. ///
  68. /// This performs a reverse post-order traversal of the callgraph SCC. For each
  69. /// SCC, it first propagates the entry counts to the nodes within the SCC
  70. /// through call edges and updates them in one shot. Then the entry counts are
  71. /// propagated to nodes outside the SCC. This requires \p GraphTraits
  72. /// to have a specialization for \p CallGraphType.
  73. template <typename CallGraphType>
  74. void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
  75. GetProfCountTy GetProfCount,
  76. AddCountTy AddCount) {
  77. std::vector<SccTy> SCCs;
  78. // Collect all the SCCs.
  79. for (auto I = scc_begin(CG); !I.isAtEnd(); ++I)
  80. SCCs.push_back(*I);
  81. // The callgraph-scc needs to be visited in top-down order for propagation.
  82. // The scc iterator returns the scc in bottom-up order, so reverse the SCCs
  83. // and call propagateFromSCC.
  84. for (auto &SCC : reverse(SCCs))
  85. propagateFromSCC(SCC, GetProfCount, AddCount);
  86. }
  87. template class llvm::SyntheticCountsUtils<const CallGraph *>;
  88. template class llvm::SyntheticCountsUtils<ModuleSummaryIndex *>;