ReleaseModeModelRunner.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- ReleaseModeModelRunner.h - Fast, precompiled model runner ---------===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // This file implements a model runner wrapping an AOT compiled ML model.
  15. // Only inference is supported.
  16. //
  17. //===----------------------------------------------------------------------===//
  18. #ifndef LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  19. #define LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  20. #include "llvm/Analysis/MLModelRunner.h"
  21. #include "llvm/Analysis/TensorSpec.h"
  22. #include "llvm/Support/ErrorHandling.h"
  23. #include <memory>
  24. #include <vector>
  25. namespace llvm {
  26. /// ReleaseModeModelRunner - production mode implementation of the
  27. /// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
  28. template <class TGen>
  29. class ReleaseModeModelRunner final : public MLModelRunner {
  30. public:
  31. /// FeatureNames' type should be an indexed collection of std::string, like
  32. /// std::array or std::vector, that has a size() method.
  33. template <class FType>
  34. ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec,
  35. StringRef DecisionName, StringRef FeedPrefix = "feed_",
  36. StringRef FetchPrefix = "fetch_")
  37. : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size()),
  38. CompiledModel(std::make_unique<TGen>()) {
  39. assert(CompiledModel && "The CompiledModel should be valid");
  40. for (size_t I = 0; I < InputSpec.size(); ++I) {
  41. const int Index =
  42. CompiledModel->LookupArgIndex(FeedPrefix.str() + InputSpec[I].name());
  43. void *Buffer = nullptr;
  44. if (Index >= 0)
  45. Buffer = CompiledModel->arg_data(Index);
  46. setUpBufferForTensor(I, InputSpec[I], Buffer);
  47. }
  48. ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() +
  49. DecisionName.str());
  50. assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model");
  51. }
  52. virtual ~ReleaseModeModelRunner() = default;
  53. static bool classof(const MLModelRunner *R) {
  54. return R->getKind() == MLModelRunner::Kind::Release;
  55. }
  56. private:
  57. void *evaluateUntyped() override {
  58. CompiledModel->Run();
  59. return CompiledModel->result_data(ResultIndex);
  60. }
  61. int32_t ResultIndex = -1;
  62. std::unique_ptr<TGen> CompiledModel;
  63. };
  64. /// A mock class satisfying the interface expected by ReleaseModeModelRunner for
  65. /// its `TGen` parameter. Useful to avoid conditional compilation complexity, as
  66. /// a compile-time replacement for a real AOT-ed model.
  67. class NoopSavedModelImpl final {
  68. #define NOOP_MODEL_ERRMSG \
  69. "The mock AOT-ed saved model is a compile-time stub and should not be " \
  70. "called."
  71. public:
  72. NoopSavedModelImpl() = default;
  73. int LookupArgIndex(const std::string &) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  74. int LookupResultIndex(const std::string &) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  75. void Run() { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  76. void *result_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  77. void *arg_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
  78. #undef NOOP_MODEL_ERRMSG
  79. };
  80. } // namespace llvm
  81. #endif // LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  82. #ifdef __GNUC__
  83. #pragma GCC diagnostic pop
  84. #endif