ModelUnderTrainingRunner.h 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  15. #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  16. #include "llvm/ADT/STLExtras.h"
  17. #include "llvm/ADT/iterator_range.h"
  18. #include "llvm/Analysis/TensorSpec.h"
  19. #include "llvm/Config/llvm-config.h"
  20. #ifdef LLVM_HAVE_TFLITE
  21. #include "llvm/Analysis/MLModelRunner.h"
  22. #include "llvm/Analysis/Utils/TFUtils.h"
  23. #include "llvm/IR/LLVMContext.h"
  24. #include "llvm/IR/PassManager.h"
  25. namespace llvm {
  26. /// ModelUnderTrainingRunner - training mode implementation. It uses TF C APIs
  27. /// to dynamically load and evaluate a TF SavedModel
  28. /// (https://www.tensorflow.org/guide/saved_model). Runtime performance is
  29. /// sacrificed for ease of use while training.
  30. class ModelUnderTrainingRunner final : public MLModelRunner {
  31. public:
  32. // Disallows copy and assign.
  33. ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
  34. ModelUnderTrainingRunner &
  35. operator=(const ModelUnderTrainingRunner &) = delete;
  36. const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const {
  37. return ExtraOutputsForLogging;
  38. }
  39. const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const {
  40. return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1);
  41. }
  42. const std::optional<TFModelEvaluator::EvaluationResult> &
  43. lastEvaluationResult() const {
  44. return LastEvaluationResult;
  45. }
  46. static bool classof(const MLModelRunner *R) {
  47. return R->getKind() == MLModelRunner::Kind::Development;
  48. }
  49. static std::unique_ptr<ModelUnderTrainingRunner>
  50. createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
  51. StringRef DecisionName,
  52. const std::vector<TensorSpec> &InputSpecs,
  53. StringRef OutputSpecsPathOverride = "");
  54. ModelUnderTrainingRunner(
  55. LLVMContext &Ctx, const std::string &ModelPath,
  56. const std::vector<TensorSpec> &InputSpecs,
  57. const std::vector<TensorSpec> &OutputSpecs,
  58. const std::vector<TensorSpec> &ExtraOutputsForLogging = {});
  59. bool isValid() const { return !!Evaluator; }
  60. private:
  61. std::unique_ptr<TFModelEvaluator> Evaluator;
  62. const std::vector<TensorSpec> OutputSpecs;
  63. const std::vector<TensorSpec> ExtraOutputsForLogging;
  64. std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
  65. void *evaluateUntyped() override;
  66. };
  67. } // namespace llvm
  68. #endif // define(LLVM_HAVE_TFLITE)
  69. #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
  70. #ifdef __GNUC__
  71. #pragma GCC diagnostic pop
  72. #endif