ModelUnderTrainingRunner.h 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  15. #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  16. #include "llvm/Config/llvm-config.h"
  17. #ifdef LLVM_HAVE_TF_API
  18. #include "llvm/Analysis/MLModelRunner.h"
  19. #include "llvm/Analysis/Utils/TFUtils.h"
  20. #include "llvm/IR/LLVMContext.h"
  21. #include "llvm/IR/PassManager.h"
  22. namespace llvm {
  23. /// ModelUnderTrainingRunner - training mode implementation. It uses TF C APIs
  24. /// to dynamically load and evaluate a TF SavedModel
  25. /// (https://www.tensorflow.org/guide/saved_model). Runtime performance is
  26. /// sacrificed for ease of use while training.
  27. class ModelUnderTrainingRunner final : public MLModelRunner {
  28. public:
  29. // Disallows copy and assign.
  30. ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
  31. ModelUnderTrainingRunner &
  32. operator=(const ModelUnderTrainingRunner &) = delete;
  33. const std::vector<LoggedFeatureSpec> &outputLoggedFeatureSpecs() const {
  34. return OutputSpecs;
  35. }
  36. const Optional<TFModelEvaluator::EvaluationResult> &
  37. lastEvaluationResult() const {
  38. return LastEvaluationResult;
  39. }
  40. static bool classof(const MLModelRunner *R) {
  41. return R->getKind() == MLModelRunner::Kind::Development;
  42. }
  43. static std::unique_ptr<ModelUnderTrainingRunner>
  44. createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
  45. StringRef DecisionName,
  46. const std::vector<TensorSpec> &InputSpecs,
  47. StringRef OutputSpecsPathOverride = "");
  48. private:
  49. ModelUnderTrainingRunner(LLVMContext &Ctx, const std::string &ModelPath,
  50. const std::vector<TensorSpec> &InputSpecs,
  51. const std::vector<LoggedFeatureSpec> &OutputSpecs);
  52. std::unique_ptr<TFModelEvaluator> Evaluator;
  53. const std::vector<LoggedFeatureSpec> OutputSpecs;
  54. Optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
  55. void *evaluateUntyped() override;
  56. void *getTensorUntyped(size_t Index) override;
  57. bool isValid() const { return !!Evaluator; }
  58. };
  59. } // namespace llvm
  60. #endif // define(LLVM_HAVE_TF_API)
  61. #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  62. #ifdef __GNUC__
  63. #pragma GCC diagnostic pop
  64. #endif