MLInlineAdvisor.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
  14. #define LLVM_ANALYSIS_MLINLINEADVISOR_H
  15. #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
  16. #include "llvm/Analysis/InlineAdvisor.h"
  17. #include "llvm/Analysis/LazyCallGraph.h"
  18. #include "llvm/Analysis/MLModelRunner.h"
  19. #include "llvm/IR/PassManager.h"
  20. #include <deque>
  21. #include <map>
  22. #include <memory>
  23. #include <optional>
  24. namespace llvm {
  25. class DiagnosticInfoOptimizationBase;
  26. class Module;
  27. class MLInlineAdvice;
  28. class MLInlineAdvisor : public InlineAdvisor {
  29. public:
  30. MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
  31. std::unique_ptr<MLModelRunner> ModelRunner);
  32. virtual ~MLInlineAdvisor() = default;
  33. void onPassEntry(LazyCallGraph::SCC *SCC) override;
  34. void onPassExit(LazyCallGraph::SCC *SCC) override;
  35. int64_t getIRSize(Function &F) const {
  36. return getCachedFPI(F).TotalInstructionCount;
  37. }
  38. void onSuccessfulInlining(const MLInlineAdvice &Advice,
  39. bool CalleeWasDeleted);
  40. bool isForcedToStop() const { return ForceStop; }
  41. int64_t getLocalCalls(Function &F);
  42. const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
  43. FunctionPropertiesInfo &getCachedFPI(Function &) const;
  44. protected:
  45. std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
  46. std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
  47. bool Advice) override;
  48. virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
  49. virtual std::unique_ptr<MLInlineAdvice>
  50. getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
  51. // Get the initial 'level' of the function, or 0 if the function has been
  52. // introduced afterwards.
  53. // TODO: should we keep this updated?
  54. unsigned getInitialFunctionLevel(const Function &F) const;
  55. std::unique_ptr<MLModelRunner> ModelRunner;
  56. private:
  57. int64_t getModuleIRSize() const;
  58. std::unique_ptr<InlineAdvice>
  59. getSkipAdviceIfUnreachableCallsite(CallBase &CB);
  60. void print(raw_ostream &OS) const override;
  61. // Using std::map to benefit from its iterator / reference non-invalidating
  62. // semantics, which make it easy to use `getCachedFPI` results from multiple
  63. // calls without needing to copy to avoid invalidation effects.
  64. mutable std::map<const Function *, FunctionPropertiesInfo> FPICache;
  65. LazyCallGraph &CG;
  66. int64_t NodeCount = 0;
  67. int64_t EdgeCount = 0;
  68. int64_t EdgesOfLastSeenNodes = 0;
  69. std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
  70. const int32_t InitialIRSize = 0;
  71. int32_t CurrentIRSize = 0;
  72. llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC;
  73. DenseSet<const LazyCallGraph::Node *> AllNodes;
  74. bool ForceStop = false;
  75. };
  76. /// InlineAdvice that tracks changes post inlining. For that reason, it only
  77. /// overrides the "successful inlining" extension points.
  78. class MLInlineAdvice : public InlineAdvice {
  79. public:
  80. MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
  81. OptimizationRemarkEmitter &ORE, bool Recommendation);
  82. virtual ~MLInlineAdvice() = default;
  83. void recordInliningImpl() override;
  84. void recordInliningWithCalleeDeletedImpl() override;
  85. void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
  86. void recordUnattemptedInliningImpl() override;
  87. Function *getCaller() const { return Caller; }
  88. Function *getCallee() const { return Callee; }
  89. const int64_t CallerIRSize;
  90. const int64_t CalleeIRSize;
  91. const int64_t CallerAndCalleeEdges;
  92. void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const;
  93. private:
  94. void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
  95. MLInlineAdvisor *getAdvisor() const {
  96. return static_cast<MLInlineAdvisor *>(Advisor);
  97. };
  98. // Make a copy of the FPI of the caller right before inlining. If inlining
  99. // fails, we can just update the cache with that value.
  100. const FunctionPropertiesInfo PreInlineCallerFPI;
  101. std::optional<FunctionPropertiesUpdater> FPU;
  102. };
  103. } // namespace llvm
  104. #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H
  105. #ifdef __GNUC__
  106. #pragma GCC diagnostic pop
  107. #endif