MLModelRunner.h 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H
  15. #define LLVM_ANALYSIS_MLMODELRUNNER_H
  16. #include "llvm/Analysis/TensorSpec.h"
  17. #include "llvm/IR/PassManager.h"
  18. namespace llvm {
  19. class LLVMContext;
  20. /// MLModelRunner interface: abstraction of a mechanism for evaluating a
  21. /// tensorflow "saved model".
  22. /// NOTE: feature indices are expected to be consistent all accross
  23. /// MLModelRunners (pertaining to the same model), and also Loggers (see
  24. /// TFUtils.h)
  25. class MLModelRunner {
  26. public:
  27. // Disallows copy and assign.
  28. MLModelRunner(const MLModelRunner &) = delete;
  29. MLModelRunner &operator=(const MLModelRunner &) = delete;
  30. virtual ~MLModelRunner() = default;
  31. template <typename T> T evaluate() {
  32. return *reinterpret_cast<T *>(evaluateUntyped());
  33. }
  34. template <typename T, typename I> T *getTensor(I FeatureID) {
  35. return reinterpret_cast<T *>(
  36. getTensorUntyped(static_cast<size_t>(FeatureID)));
  37. }
  38. template <typename T, typename I> const T *getTensor(I FeatureID) const {
  39. return reinterpret_cast<const T *>(
  40. getTensorUntyped(static_cast<size_t>(FeatureID)));
  41. }
  42. void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; }
  43. const void *getTensorUntyped(size_t Index) const {
  44. return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index);
  45. }
  46. enum class Kind : int { Unknown, Release, Development, NoOp };
  47. Kind getKind() const { return Type; }
  48. protected:
  49. MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs)
  50. : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) {
  51. assert(Type != Kind::Unknown);
  52. }
  53. virtual void *evaluateUntyped() = 0;
  54. void setUpBufferForTensor(size_t Index, const TensorSpec &Spec,
  55. void *Buffer) {
  56. if (!Buffer) {
  57. OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize());
  58. Buffer = OwnedBuffers.back().data();
  59. }
  60. InputBuffers[Index] = Buffer;
  61. }
  62. LLVMContext &Ctx;
  63. const Kind Type;
  64. private:
  65. std::vector<void *> InputBuffers;
  66. std::vector<std::vector<char *>> OwnedBuffers;
  67. };
  68. } // namespace llvm
  69. #endif // LLVM_ANALYSIS_MLMODELRUNNER_H
  70. #ifdef __GNUC__
  71. #pragma GCC diagnostic pop
  72. #endif