#pragma once #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #endif //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H #define LLVM_ANALYSIS_MLMODELRUNNER_H #include "llvm/Analysis/TensorSpec.h" #include "llvm/IR/PassManager.h" namespace llvm { class LLVMContext; /// MLModelRunner interface: abstraction of a mechanism for evaluating a /// tensorflow "saved model". /// NOTE: feature indices are expected to be consistent all accross /// MLModelRunners (pertaining to the same model), and also Loggers (see /// TFUtils.h) class MLModelRunner { public: // Disallows copy and assign. MLModelRunner(const MLModelRunner &) = delete; MLModelRunner &operator=(const MLModelRunner &) = delete; virtual ~MLModelRunner() = default; template T evaluate() { return *reinterpret_cast(evaluateUntyped()); } template T *getTensor(I FeatureID) { return reinterpret_cast( getTensorUntyped(static_cast(FeatureID))); } template const T *getTensor(I FeatureID) const { return reinterpret_cast( getTensorUntyped(static_cast(FeatureID))); } void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; } const void *getTensorUntyped(size_t Index) const { return (const_cast(this))->getTensorUntyped(Index); } enum class Kind : int { Unknown, Release, Development, NoOp }; Kind getKind() const { return Type; } protected: MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs) : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) { assert(Type != Kind::Unknown); } virtual void *evaluateUntyped() = 0; void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, void *Buffer) { if (!Buffer) { OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize()); Buffer = OwnedBuffers.back().data(); } InputBuffers[Index] = Buffer; } LLVMContext &Ctx; const Kind Type; private: std::vector InputBuffers; std::vector> OwnedBuffers; }; } // namespace llvm #endif // LLVM_ANALYSIS_MLMODELRUNNER_H #ifdef __GNUC__ #pragma GCC diagnostic pop #endif