MLInlineAdvisor.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
  14. #define LLVM_ANALYSIS_MLINLINEADVISOR_H
  15. #include "llvm/Analysis/CallGraph.h"
  16. #include "llvm/Analysis/InlineAdvisor.h"
  17. #include "llvm/Analysis/MLModelRunner.h"
  18. #include "llvm/IR/PassManager.h"
  19. #include <memory>
  20. #include <unordered_map>
  21. namespace llvm {
  22. class Module;
  23. class MLInlineAdvice;
  24. class MLInlineAdvisor : public InlineAdvisor {
  25. public:
  26. MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
  27. std::unique_ptr<MLModelRunner> ModelRunner);
  28. CallGraph *callGraph() const { return CG.get(); }
  29. virtual ~MLInlineAdvisor() = default;
  30. void onPassEntry() override;
  31. int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); }
  32. void onSuccessfulInlining(const MLInlineAdvice &Advice,
  33. bool CalleeWasDeleted);
  34. bool isForcedToStop() const { return ForceStop; }
  35. int64_t getLocalCalls(Function &F);
  36. const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
  37. protected:
  38. std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
  39. std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
  40. bool Advice) override;
  41. virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
  42. virtual std::unique_ptr<MLInlineAdvice>
  43. getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
  44. std::unique_ptr<MLModelRunner> ModelRunner;
  45. private:
  46. int64_t getModuleIRSize() const;
  47. std::unique_ptr<CallGraph> CG;
  48. int64_t NodeCount = 0;
  49. int64_t EdgeCount = 0;
  50. std::map<const Function *, unsigned> FunctionLevels;
  51. const int32_t InitialIRSize = 0;
  52. int32_t CurrentIRSize = 0;
  53. bool ForceStop = false;
  54. };
  55. /// InlineAdvice that tracks changes post inlining. For that reason, it only
  56. /// overrides the "successful inlining" extension points.
  57. class MLInlineAdvice : public InlineAdvice {
  58. public:
  59. MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
  60. OptimizationRemarkEmitter &ORE, bool Recommendation)
  61. : InlineAdvice(Advisor, CB, ORE, Recommendation),
  62. CallerIRSize(Advisor->isForcedToStop() ? 0
  63. : Advisor->getIRSize(*Caller)),
  64. CalleeIRSize(Advisor->isForcedToStop() ? 0
  65. : Advisor->getIRSize(*Callee)),
  66. CallerAndCalleeEdges(Advisor->isForcedToStop()
  67. ? 0
  68. : (Advisor->getLocalCalls(*Caller) +
  69. Advisor->getLocalCalls(*Callee))) {}
  70. virtual ~MLInlineAdvice() = default;
  71. void recordInliningImpl() override;
  72. void recordInliningWithCalleeDeletedImpl() override;
  73. void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
  74. void recordUnattemptedInliningImpl() override;
  75. Function *getCaller() const { return Caller; }
  76. Function *getCallee() const { return Callee; }
  77. const int64_t CallerIRSize;
  78. const int64_t CalleeIRSize;
  79. const int64_t CallerAndCalleeEdges;
  80. private:
  81. void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
  82. MLInlineAdvisor *getAdvisor() const {
  83. return static_cast<MLInlineAdvisor *>(Advisor);
  84. };
  85. };
  86. } // namespace llvm
  87. #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H
  88. #ifdef __GNUC__
  89. #pragma GCC diagnostic pop
  90. #endif