#pragma once #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #endif //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H #include "llvm/Config/llvm-config.h" #ifdef LLVM_HAVE_TF_API #include "llvm/Analysis/MLModelRunner.h" #include "llvm/Analysis/Utils/TFUtils.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassManager.h" namespace llvm { /// ModelUnderTrainingRunner - training mode implementation. It uses TF C APIs /// to dynamically load and evaluate a TF SavedModel /// (https://www.tensorflow.org/guide/saved_model). Runtime performance is /// sacrificed for ease of use while training. class ModelUnderTrainingRunner final : public MLModelRunner { public: // Disallows copy and assign. ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete; ModelUnderTrainingRunner & operator=(const ModelUnderTrainingRunner &) = delete; const std::vector &outputLoggedFeatureSpecs() const { return OutputSpecs; } const Optional & lastEvaluationResult() const { return LastEvaluationResult; } static bool classof(const MLModelRunner *R) { return R->getKind() == MLModelRunner::Kind::Development; } static std::unique_ptr createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, const std::vector &InputSpecs, StringRef OutputSpecsPathOverride = ""); private: ModelUnderTrainingRunner(LLVMContext &Ctx, const std::string &ModelPath, const std::vector &InputSpecs, const std::vector &OutputSpecs); std::unique_ptr Evaluator; const std::vector OutputSpecs; Optional LastEvaluationResult; void *evaluateUntyped() override; void *getTensorUntyped(size_t Index) override; bool isValid() const { return !!Evaluator; } }; } // namespace llvm #endif // define(LLVM_HAVE_TF_API) #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H #ifdef __GNUC__ #pragma GCC diagnostic pop #endif