MLModelRunner.h 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H
  15. #define LLVM_ANALYSIS_MLMODELRUNNER_H
  16. #include "llvm/IR/LLVMContext.h"
  17. #include "llvm/IR/PassManager.h"
  18. namespace llvm {
  19. /// MLModelRunner interface: abstraction of a mechanism for evaluating a
  20. /// tensorflow "saved model".
  21. /// NOTE: feature indices are expected to be consistent all accross
  22. /// MLModelRunners (pertaining to the same model), and also Loggers (see
  23. /// TFUtils.h)
  24. class MLModelRunner {
  25. public:
  26. // Disallows copy and assign.
  27. MLModelRunner(const MLModelRunner &) = delete;
  28. MLModelRunner &operator=(const MLModelRunner &) = delete;
  29. virtual ~MLModelRunner() = default;
  30. template <typename T> T evaluate() {
  31. return *reinterpret_cast<T *>(evaluateUntyped());
  32. }
  33. template <typename T, typename I> T *getTensor(I FeatureID) {
  34. return reinterpret_cast<T *>(
  35. getTensorUntyped(static_cast<size_t>(FeatureID)));
  36. }
  37. template <typename T, typename I> const T *getTensor(I FeatureID) const {
  38. return reinterpret_cast<const T *>(
  39. getTensorUntyped(static_cast<size_t>(FeatureID)));
  40. }
  41. virtual void *getTensorUntyped(size_t Index) = 0;
  42. const void *getTensorUntyped(size_t Index) const {
  43. return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index);
  44. }
  45. enum class Kind : int { Unknown, Release, Development, NoOp };
  46. Kind getKind() const { return Type; }
  47. protected:
  48. MLModelRunner(LLVMContext &Ctx, Kind Type) : Ctx(Ctx), Type(Type) {
  49. assert(Type != Kind::Unknown);
  50. }
  51. virtual void *evaluateUntyped() = 0;
  52. LLVMContext &Ctx;
  53. const Kind Type;
  54. };
  55. } // namespace llvm
  56. #endif // LLVM_ANALYSIS_MLMODELRUNNER_H
  57. #ifdef __GNUC__
  58. #pragma GCC diagnostic pop
  59. #endif