MIRSampleProfile.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file provides the implementation of the MIRSampleProfile loader, mainly
  10. // for flow sensitive SampleFDO.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/MIRSampleProfile.h"
  14. #include "llvm/ADT/DenseMap.h"
  15. #include "llvm/ADT/DenseSet.h"
  16. #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
  17. #include "llvm/IR/Function.h"
  18. #include "llvm/Support/CommandLine.h"
  19. #include "llvm/Support/Debug.h"
  20. #include "llvm/Support/raw_ostream.h"
  21. #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
  22. #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
  23. using namespace llvm;
  24. using namespace sampleprof;
  25. using namespace llvm::sampleprofutil;
  26. using ProfileCount = Function::ProfileCount;
  27. #define DEBUG_TYPE "fs-profile-loader"
  28. static cl::opt<bool> ShowFSBranchProb(
  29. "show-fs-branchprob", cl::Hidden, cl::init(false),
  30. cl::desc("Print setting flow sensitive branch probabilities"));
  31. static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
  32. "fs-profile-debug-prob-diff-threshold", cl::init(10),
  33. cl::desc("Only show debug message if the branch probility is greater than "
  34. "this value (in percentage)."));
  35. static cl::opt<unsigned> FSProfileDebugBWThreshold(
  36. "fs-profile-debug-bw-threshold", cl::init(10000),
  37. cl::desc("Only show debug message if the source branch weight is greater "
  38. " than this value."));
  39. static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
  40. cl::init(false),
  41. cl::desc("View BFI before MIR loader"));
  42. static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
  43. cl::init(false),
  44. cl::desc("View BFI after MIR loader"));
  45. char MIRProfileLoaderPass::ID = 0;
  46. INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
  47. "Load MIR Sample Profile",
  48. /* cfg = */ false, /* is_analysis = */ false)
  49. INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
  50. INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
  51. INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
  52. INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
  53. INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
  54. INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
  55. /* cfg = */ false, /* is_analysis = */ false)
  56. char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
  57. FunctionPass *llvm::createMIRProfileLoaderPass(std::string File,
  58. std::string RemappingFile,
  59. FSDiscriminatorPass P) {
  60. return new MIRProfileLoaderPass(File, RemappingFile, P);
  61. }
  62. namespace llvm {
  63. // Internal option used to control BFI display only after MBP pass.
  64. // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
  65. // -view-block-layout-with-bfi={none | fraction | integer | count}
  66. extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
  67. // Command line option to specify the name of the function for CFG dump
  68. // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
  69. extern cl::opt<std::string> ViewBlockFreqFuncName;
  70. namespace afdo_detail {
  71. template <> struct IRTraits<MachineBasicBlock> {
  72. using InstructionT = MachineInstr;
  73. using BasicBlockT = MachineBasicBlock;
  74. using FunctionT = MachineFunction;
  75. using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
  76. using LoopT = MachineLoop;
  77. using LoopInfoPtrT = MachineLoopInfo *;
  78. using DominatorTreePtrT = MachineDominatorTree *;
  79. using PostDominatorTreePtrT = MachinePostDominatorTree *;
  80. using PostDominatorTreeT = MachinePostDominatorTree;
  81. using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
  82. using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
  83. using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
  84. using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
  85. static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
  86. static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
  87. return GraphTraits<const MachineFunction *>::getEntryNode(F);
  88. }
  89. static PredRangeT getPredecessors(MachineBasicBlock *BB) {
  90. return BB->predecessors();
  91. }
  92. static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
  93. return BB->successors();
  94. }
  95. };
  96. } // namespace afdo_detail
  97. class MIRProfileLoader final
  98. : public SampleProfileLoaderBaseImpl<MachineBasicBlock> {
  99. public:
  100. void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
  101. MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
  102. MachineOptimizationRemarkEmitter *MORE) {
  103. DT = MDT;
  104. PDT = MPDT;
  105. LI = MLI;
  106. BFI = MBFI;
  107. ORE = MORE;
  108. }
  109. void setFSPass(FSDiscriminatorPass Pass) {
  110. P = Pass;
  111. LowBit = getFSPassBitBegin(P);
  112. HighBit = getFSPassBitEnd(P);
  113. assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
  114. }
  115. MIRProfileLoader(StringRef Name, StringRef RemapName)
  116. : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)) {
  117. }
  118. void setBranchProbs(MachineFunction &F);
  119. bool runOnFunction(MachineFunction &F);
  120. bool doInitialization(Module &M);
  121. bool isValid() const { return ProfileIsValid; }
  122. protected:
  123. friend class SampleCoverageTracker;
  124. /// Hold the information of the basic block frequency.
  125. MachineBlockFrequencyInfo *BFI;
  126. /// PassNum is the sequence number this pass is called, start from 1.
  127. FSDiscriminatorPass P;
  128. // LowBit in the FS discriminator used by this instance. Note the number is
  129. // 0-based. Base discrimnator use bit 0 to bit 11.
  130. unsigned LowBit;
  131. // HighwBit in the FS discriminator used by this instance. Note the number
  132. // is 0-based.
  133. unsigned HighBit;
  134. bool ProfileIsValid = true;
  135. };
  136. template <>
  137. void SampleProfileLoaderBaseImpl<
  138. MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {}
  139. void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
  140. LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
  141. for (auto &BI : F) {
  142. MachineBasicBlock *BB = &BI;
  143. if (BB->succ_size() < 2)
  144. continue;
  145. const MachineBasicBlock *EC = EquivalenceClass[BB];
  146. uint64_t BBWeight = BlockWeights[EC];
  147. uint64_t SumEdgeWeight = 0;
  148. for (MachineBasicBlock *Succ : BB->successors()) {
  149. Edge E = std::make_pair(BB, Succ);
  150. SumEdgeWeight += EdgeWeights[E];
  151. }
  152. if (BBWeight != SumEdgeWeight) {
  153. LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
  154. << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
  155. << "\n");
  156. BBWeight = SumEdgeWeight;
  157. }
  158. if (BBWeight == 0) {
  159. LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
  160. continue;
  161. }
  162. #ifndef NDEBUG
  163. uint64_t BBWeightOrig = BBWeight;
  164. #endif
  165. uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
  166. uint32_t Factor = 1;
  167. if (BBWeight > MaxWeight) {
  168. Factor = BBWeight / MaxWeight + 1;
  169. BBWeight /= Factor;
  170. LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
  171. }
  172. for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
  173. SE = BB->succ_end();
  174. SI != SE; ++SI) {
  175. MachineBasicBlock *Succ = *SI;
  176. Edge E = std::make_pair(BB, Succ);
  177. uint64_t EdgeWeight = EdgeWeights[E];
  178. EdgeWeight /= Factor;
  179. assert(BBWeight >= EdgeWeight &&
  180. "BBweight is larger than EdgeWeight -- should not happen.\n");
  181. BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
  182. BranchProbability NewProb(EdgeWeight, BBWeight);
  183. if (OldProb == NewProb)
  184. continue;
  185. BB->setSuccProbability(SI, NewProb);
  186. #ifndef NDEBUG
  187. if (!ShowFSBranchProb)
  188. continue;
  189. bool Show = false;
  190. BranchProbability Diff;
  191. if (OldProb > NewProb)
  192. Diff = OldProb - NewProb;
  193. else
  194. Diff = NewProb - OldProb;
  195. Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
  196. Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
  197. auto DIL = BB->findBranchDebugLoc();
  198. auto SuccDIL = Succ->findBranchDebugLoc();
  199. if (Show) {
  200. dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
  201. << Succ->getNumber() << "): ";
  202. if (DIL)
  203. dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
  204. << DIL->getColumn();
  205. if (SuccDIL)
  206. dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
  207. << ":" << SuccDIL->getColumn();
  208. dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb
  209. << "\n";
  210. }
  211. #endif
  212. }
  213. }
  214. }
  215. bool MIRProfileLoader::doInitialization(Module &M) {
  216. auto &Ctx = M.getContext();
  217. auto ReaderOrErr = sampleprof::SampleProfileReader::create(Filename, Ctx, P,
  218. RemappingFilename);
  219. if (std::error_code EC = ReaderOrErr.getError()) {
  220. std::string Msg = "Could not open profile: " + EC.message();
  221. Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
  222. return false;
  223. }
  224. Reader = std::move(ReaderOrErr.get());
  225. Reader->setModule(&M);
  226. ProfileIsValid = (Reader->read() == sampleprof_error::success);
  227. Reader->getSummary();
  228. return true;
  229. }
  230. bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
  231. Function &Func = MF.getFunction();
  232. clearFunctionData(false);
  233. Samples = Reader->getSamplesFor(Func);
  234. if (!Samples || Samples->empty())
  235. return false;
  236. if (getFunctionLoc(MF) == 0)
  237. return false;
  238. DenseSet<GlobalValue::GUID> InlinedGUIDs;
  239. bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
  240. // Set the new BPI, BFI.
  241. setBranchProbs(MF);
  242. return Changed;
  243. }
  244. } // namespace llvm
  245. MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName,
  246. std::string RemappingFileName,
  247. FSDiscriminatorPass P)
  248. : MachineFunctionPass(ID), ProfileFileName(FileName), P(P),
  249. MIRSampleLoader(
  250. std::make_unique<MIRProfileLoader>(FileName, RemappingFileName)) {
  251. LowBit = getFSPassBitBegin(P);
  252. HighBit = getFSPassBitEnd(P);
  253. assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
  254. }
  255. bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
  256. if (!MIRSampleLoader->isValid())
  257. return false;
  258. LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
  259. << MF.getFunction().getName() << "\n");
  260. MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
  261. MIRSampleLoader->setInitVals(
  262. &getAnalysis<MachineDominatorTree>(),
  263. &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
  264. MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
  265. MF.RenumberBlocks();
  266. if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
  267. (ViewBlockFreqFuncName.empty() ||
  268. MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
  269. MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
  270. }
  271. bool Changed = MIRSampleLoader->runOnFunction(MF);
  272. if (Changed)
  273. MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>());
  274. if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
  275. (ViewBlockFreqFuncName.empty() ||
  276. MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
  277. MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
  278. }
  279. return Changed;
  280. }
  281. bool MIRProfileLoaderPass::doInitialization(Module &M) {
  282. LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
  283. << "\n");
  284. MIRSampleLoader->setFSPass(P);
  285. return MIRSampleLoader->doInitialization(M);
  286. }
  287. void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
  288. AU.setPreservesAll();
  289. AU.addRequired<MachineBlockFrequencyInfo>();
  290. AU.addRequired<MachineDominatorTree>();
  291. AU.addRequired<MachinePostDominatorTree>();
  292. AU.addRequiredTransitive<MachineLoopInfo>();
  293. AU.addRequired<MachineOptimizationRemarkEmitterPass>();
  294. MachineFunctionPass::getAnalysisUsage(AU);
  295. }