MIRSampleProfile.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file provides the implementation of the MIRSampleProfile loader, mainly
  10. // for flow sensitive SampleFDO.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/MIRSampleProfile.h"
  14. #include "llvm/ADT/DenseMap.h"
  15. #include "llvm/ADT/DenseSet.h"
  16. #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
  17. #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
  18. #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
  19. #include "llvm/CodeGen/MachineDominators.h"
  20. #include "llvm/CodeGen/MachineLoopInfo.h"
  21. #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
  22. #include "llvm/CodeGen/MachinePostDominators.h"
  23. #include "llvm/CodeGen/Passes.h"
  24. #include "llvm/IR/Function.h"
  25. #include "llvm/InitializePasses.h"
  26. #include "llvm/Support/CommandLine.h"
  27. #include "llvm/Support/Debug.h"
  28. #include "llvm/Support/raw_ostream.h"
  29. #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
  30. #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
  31. using namespace llvm;
  32. using namespace sampleprof;
  33. using namespace llvm::sampleprofutil;
  34. using ProfileCount = Function::ProfileCount;
  35. #define DEBUG_TYPE "fs-profile-loader"
  36. static cl::opt<bool> ShowFSBranchProb(
  37. "show-fs-branchprob", cl::Hidden, cl::init(false),
  38. cl::desc("Print setting flow sensitive branch probabilities"));
  39. static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
  40. "fs-profile-debug-prob-diff-threshold", cl::init(10),
  41. cl::desc("Only show debug message if the branch probility is greater than "
  42. "this value (in percentage)."));
  43. static cl::opt<unsigned> FSProfileDebugBWThreshold(
  44. "fs-profile-debug-bw-threshold", cl::init(10000),
  45. cl::desc("Only show debug message if the source branch weight is greater "
  46. " than this value."));
  47. static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
  48. cl::init(false),
  49. cl::desc("View BFI before MIR loader"));
  50. static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
  51. cl::init(false),
  52. cl::desc("View BFI after MIR loader"));
  53. char MIRProfileLoaderPass::ID = 0;
  54. INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
  55. "Load MIR Sample Profile",
  56. /* cfg = */ false, /* is_analysis = */ false)
  57. INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
  58. INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
  59. INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
  60. INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
  61. INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
  62. INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
  63. /* cfg = */ false, /* is_analysis = */ false)
  64. char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
  65. FunctionPass *llvm::createMIRProfileLoaderPass(std::string File,
  66. std::string RemappingFile,
  67. FSDiscriminatorPass P) {
  68. return new MIRProfileLoaderPass(File, RemappingFile, P);
  69. }
  70. namespace llvm {
  71. // Internal option used to control BFI display only after MBP pass.
  72. // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
  73. // -view-block-layout-with-bfi={none | fraction | integer | count}
  74. extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
  75. // Command line option to specify the name of the function for CFG dump
  76. // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
  77. extern cl::opt<std::string> ViewBlockFreqFuncName;
  78. namespace afdo_detail {
  79. template <> struct IRTraits<MachineBasicBlock> {
  80. using InstructionT = MachineInstr;
  81. using BasicBlockT = MachineBasicBlock;
  82. using FunctionT = MachineFunction;
  83. using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
  84. using LoopT = MachineLoop;
  85. using LoopInfoPtrT = MachineLoopInfo *;
  86. using DominatorTreePtrT = MachineDominatorTree *;
  87. using PostDominatorTreePtrT = MachinePostDominatorTree *;
  88. using PostDominatorTreeT = MachinePostDominatorTree;
  89. using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
  90. using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
  91. using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
  92. using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
  93. static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
  94. static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
  95. return GraphTraits<const MachineFunction *>::getEntryNode(F);
  96. }
  97. static PredRangeT getPredecessors(MachineBasicBlock *BB) {
  98. return BB->predecessors();
  99. }
  100. static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
  101. return BB->successors();
  102. }
  103. };
  104. } // namespace afdo_detail
  105. class MIRProfileLoader final
  106. : public SampleProfileLoaderBaseImpl<MachineBasicBlock> {
  107. public:
  108. void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
  109. MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
  110. MachineOptimizationRemarkEmitter *MORE) {
  111. DT = MDT;
  112. PDT = MPDT;
  113. LI = MLI;
  114. BFI = MBFI;
  115. ORE = MORE;
  116. }
  117. void setFSPass(FSDiscriminatorPass Pass) {
  118. P = Pass;
  119. LowBit = getFSPassBitBegin(P);
  120. HighBit = getFSPassBitEnd(P);
  121. assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
  122. }
  123. MIRProfileLoader(StringRef Name, StringRef RemapName)
  124. : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)) {
  125. }
  126. void setBranchProbs(MachineFunction &F);
  127. bool runOnFunction(MachineFunction &F);
  128. bool doInitialization(Module &M);
  129. bool isValid() const { return ProfileIsValid; }
  130. protected:
  131. friend class SampleCoverageTracker;
  132. /// Hold the information of the basic block frequency.
  133. MachineBlockFrequencyInfo *BFI;
  134. /// PassNum is the sequence number this pass is called, start from 1.
  135. FSDiscriminatorPass P;
  136. // LowBit in the FS discriminator used by this instance. Note the number is
  137. // 0-based. Base discrimnator use bit 0 to bit 11.
  138. unsigned LowBit;
  139. // HighwBit in the FS discriminator used by this instance. Note the number
  140. // is 0-based.
  141. unsigned HighBit;
  142. bool ProfileIsValid = true;
  143. };
  144. template <>
  145. void SampleProfileLoaderBaseImpl<
  146. MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {}
  147. void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
  148. LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
  149. for (auto &BI : F) {
  150. MachineBasicBlock *BB = &BI;
  151. if (BB->succ_size() < 2)
  152. continue;
  153. const MachineBasicBlock *EC = EquivalenceClass[BB];
  154. uint64_t BBWeight = BlockWeights[EC];
  155. uint64_t SumEdgeWeight = 0;
  156. for (MachineBasicBlock *Succ : BB->successors()) {
  157. Edge E = std::make_pair(BB, Succ);
  158. SumEdgeWeight += EdgeWeights[E];
  159. }
  160. if (BBWeight != SumEdgeWeight) {
  161. LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
  162. << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
  163. << "\n");
  164. BBWeight = SumEdgeWeight;
  165. }
  166. if (BBWeight == 0) {
  167. LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
  168. continue;
  169. }
  170. #ifndef NDEBUG
  171. uint64_t BBWeightOrig = BBWeight;
  172. #endif
  173. uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
  174. uint32_t Factor = 1;
  175. if (BBWeight > MaxWeight) {
  176. Factor = BBWeight / MaxWeight + 1;
  177. BBWeight /= Factor;
  178. LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
  179. }
  180. for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
  181. SE = BB->succ_end();
  182. SI != SE; ++SI) {
  183. MachineBasicBlock *Succ = *SI;
  184. Edge E = std::make_pair(BB, Succ);
  185. uint64_t EdgeWeight = EdgeWeights[E];
  186. EdgeWeight /= Factor;
  187. assert(BBWeight >= EdgeWeight &&
  188. "BBweight is larger than EdgeWeight -- should not happen.\n");
  189. BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
  190. BranchProbability NewProb(EdgeWeight, BBWeight);
  191. if (OldProb == NewProb)
  192. continue;
  193. BB->setSuccProbability(SI, NewProb);
  194. #ifndef NDEBUG
  195. if (!ShowFSBranchProb)
  196. continue;
  197. bool Show = false;
  198. BranchProbability Diff;
  199. if (OldProb > NewProb)
  200. Diff = OldProb - NewProb;
  201. else
  202. Diff = NewProb - OldProb;
  203. Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
  204. Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
  205. auto DIL = BB->findBranchDebugLoc();
  206. auto SuccDIL = Succ->findBranchDebugLoc();
  207. if (Show) {
  208. dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
  209. << Succ->getNumber() << "): ";
  210. if (DIL)
  211. dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
  212. << DIL->getColumn();
  213. if (SuccDIL)
  214. dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
  215. << ":" << SuccDIL->getColumn();
  216. dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb
  217. << "\n";
  218. }
  219. #endif
  220. }
  221. }
  222. }
  223. bool MIRProfileLoader::doInitialization(Module &M) {
  224. auto &Ctx = M.getContext();
  225. auto ReaderOrErr = sampleprof::SampleProfileReader::create(Filename, Ctx, P,
  226. RemappingFilename);
  227. if (std::error_code EC = ReaderOrErr.getError()) {
  228. std::string Msg = "Could not open profile: " + EC.message();
  229. Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
  230. return false;
  231. }
  232. Reader = std::move(ReaderOrErr.get());
  233. Reader->setModule(&M);
  234. ProfileIsValid = (Reader->read() == sampleprof_error::success);
  235. Reader->getSummary();
  236. return true;
  237. }
  238. bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
  239. Function &Func = MF.getFunction();
  240. clearFunctionData(false);
  241. Samples = Reader->getSamplesFor(Func);
  242. if (!Samples || Samples->empty())
  243. return false;
  244. if (getFunctionLoc(MF) == 0)
  245. return false;
  246. DenseSet<GlobalValue::GUID> InlinedGUIDs;
  247. bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
  248. // Set the new BPI, BFI.
  249. setBranchProbs(MF);
  250. return Changed;
  251. }
  252. } // namespace llvm
  253. MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName,
  254. std::string RemappingFileName,
  255. FSDiscriminatorPass P)
  256. : MachineFunctionPass(ID), ProfileFileName(FileName), P(P),
  257. MIRSampleLoader(
  258. std::make_unique<MIRProfileLoader>(FileName, RemappingFileName)) {
  259. LowBit = getFSPassBitBegin(P);
  260. HighBit = getFSPassBitEnd(P);
  261. assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
  262. }
  263. bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
  264. if (!MIRSampleLoader->isValid())
  265. return false;
  266. LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
  267. << MF.getFunction().getName() << "\n");
  268. MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
  269. MIRSampleLoader->setInitVals(
  270. &getAnalysis<MachineDominatorTree>(),
  271. &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
  272. MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
  273. MF.RenumberBlocks();
  274. if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
  275. (ViewBlockFreqFuncName.empty() ||
  276. MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
  277. MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
  278. }
  279. bool Changed = MIRSampleLoader->runOnFunction(MF);
  280. if (Changed)
  281. MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>());
  282. if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
  283. (ViewBlockFreqFuncName.empty() ||
  284. MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
  285. MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
  286. }
  287. return Changed;
  288. }
  289. bool MIRProfileLoaderPass::doInitialization(Module &M) {
  290. LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
  291. << "\n");
  292. MIRSampleLoader->setFSPass(P);
  293. return MIRSampleLoader->doInitialization(M);
  294. }
  295. void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
  296. AU.setPreservesAll();
  297. AU.addRequired<MachineBlockFrequencyInfo>();
  298. AU.addRequired<MachineDominatorTree>();
  299. AU.addRequired<MachinePostDominatorTree>();
  300. AU.addRequiredTransitive<MachineLoopInfo>();
  301. AU.addRequired<MachineOptimizationRemarkEmitterPass>();
  302. MachineFunctionPass::getAnalysisUsage(AU);
  303. }