ReleaseModeModelRunner.h 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- ReleaseModeModelRunner.h - Fast, precompiled model runner ---------===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // This file implements a model runner wrapping an AOT compiled ML model.
  15. // Only inference is supported.
  16. //
  17. //===----------------------------------------------------------------------===//
  18. #ifndef LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  19. #define LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  20. #include "llvm/Analysis/MLModelRunner.h"
  21. #include <memory>
  22. #include <vector>
  23. using namespace llvm;
  24. namespace llvm {
  25. /// ReleaseModeModelRunner - production mode implementation of the
  26. /// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
  27. template <class TGen>
  28. class ReleaseModeModelRunner final : public MLModelRunner {
  29. public:
  30. /// FeatureNames' type should be an indexed collection of std::string, like
  31. /// std::array or std::vector, that has a size() method.
  32. template <class FType>
  33. ReleaseModeModelRunner(LLVMContext &Ctx, const FType &FeatureNames,
  34. StringRef DecisionName, StringRef FeedPrefix = "feed_",
  35. StringRef FetchPrefix = "fetch_")
  36. : MLModelRunner(Ctx, MLModelRunner::Kind::Release),
  37. CompiledModel(std::make_unique<TGen>()) {
  38. assert(CompiledModel && "The CompiledModel should be valid");
  39. const size_t FeatureCount = FeatureNames.size();
  40. FeatureIndices.resize(FeatureCount);
  41. for (size_t I = 0; I < FeatureCount; ++I) {
  42. const int Index =
  43. CompiledModel->LookupArgIndex(FeedPrefix.str() + FeatureNames[I]);
  44. assert(Index >= 0 && "Cannot find Feature in inlining model");
  45. FeatureIndices[I] = Index;
  46. }
  47. ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() +
  48. DecisionName.str());
  49. assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model");
  50. }
  51. virtual ~ReleaseModeModelRunner() = default;
  52. static bool classof(const MLModelRunner *R) {
  53. return R->getKind() == MLModelRunner::Kind::Release;
  54. }
  55. private:
  56. void *evaluateUntyped() override {
  57. CompiledModel->Run();
  58. return CompiledModel->result_data(ResultIndex);
  59. }
  60. void *getTensorUntyped(size_t Index) override {
  61. return reinterpret_cast<char *>(
  62. CompiledModel->arg_data(FeatureIndices[Index]));
  63. }
  64. std::vector<int32_t> FeatureIndices;
  65. int32_t ResultIndex = -1;
  66. std::unique_ptr<TGen> CompiledModel;
  67. };
  68. } // namespace llvm
  69. #endif // LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
  70. #ifdef __GNUC__
  71. #pragma GCC diagnostic pop
  72. #endif