MLModelRunner.h 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H
  15. #define LLVM_ANALYSIS_MLMODELRUNNER_H
  16. #include "llvm/Analysis/InlineModelFeatureMaps.h"
  17. #include "llvm/IR/LLVMContext.h"
  18. #include "llvm/IR/PassManager.h"
  19. namespace llvm {
  20. /// MLModelRunner interface: abstraction of a mechanism for evaluating a
  21. /// tensorflow "saved model".
  22. class MLModelRunner {
  23. public:
  24. // Disallows copy and assign.
  25. MLModelRunner(const MLModelRunner &) = delete;
  26. MLModelRunner &operator=(const MLModelRunner &) = delete;
  27. virtual ~MLModelRunner() = default;
  28. virtual bool run() = 0;
  29. virtual void setFeature(FeatureIndex Index, int64_t Value) = 0;
  30. virtual int64_t getFeature(int Index) const = 0;
  31. protected:
  32. MLModelRunner(LLVMContext &Ctx) : Ctx(Ctx) {}
  33. LLVMContext &Ctx;
  34. };
  35. } // namespace llvm
  36. #endif // LLVM_ANALYSIS_MLMODELRUNNER_H
  37. #ifdef __GNUC__
  38. #pragma GCC diagnostic pop
  39. #endif