ModelUnderTrainingRunner.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
  10. // happens off a model that's provided from the command line and is interpreted.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/ADT/STLExtras.h"
  14. #include "llvm/Config/config.h"
  15. #if defined(LLVM_HAVE_TFLITE)
  16. #include "llvm/Analysis/ModelUnderTrainingRunner.h"
  17. #include "llvm/Support/MemoryBuffer.h"
  18. #include "llvm/Support/Path.h"
  19. #include <optional>
  20. using namespace llvm;
  21. namespace {
  22. struct LoggedFeatureSpec {
  23. TensorSpec Spec;
  24. std::optional<std::string> LoggingName;
  25. };
  26. std::optional<std::vector<LoggedFeatureSpec>>
  27. loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
  28. StringRef ModelPath, StringRef SpecFileOverride) {
  29. SmallVector<char, 128> OutputSpecsPath;
  30. StringRef FileName = SpecFileOverride;
  31. if (FileName.empty()) {
  32. llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
  33. FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
  34. }
  35. auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
  36. if (!BufferOrError) {
  37. Ctx.emitError("Error opening output specs file: " + FileName + " : " +
  38. BufferOrError.getError().message());
  39. return std::nullopt;
  40. }
  41. auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
  42. if (!ParsedJSONValues) {
  43. Ctx.emitError("Could not parse specs file: " + FileName);
  44. return std::nullopt;
  45. }
  46. auto ValuesArray = ParsedJSONValues->getAsArray();
  47. if (!ValuesArray) {
  48. Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
  49. "logging_name:<name>} dictionaries");
  50. return std::nullopt;
  51. }
  52. std::vector<LoggedFeatureSpec> Ret;
  53. for (const auto &Value : *ValuesArray)
  54. if (const auto *Obj = Value.getAsObject())
  55. if (const auto *SpecPart = Obj->get("tensor_spec"))
  56. if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
  57. if (auto LoggingName = Obj->getString("logging_name")) {
  58. if (!TensorSpec->isElementType<int64_t>() &&
  59. !TensorSpec->isElementType<int32_t>() &&
  60. !TensorSpec->isElementType<float>()) {
  61. Ctx.emitError(
  62. "Only int64, int32, and float tensors are supported. "
  63. "Found unsupported type for tensor named " +
  64. TensorSpec->name());
  65. return std::nullopt;
  66. }
  67. Ret.push_back({*TensorSpec, LoggingName->str()});
  68. }
  69. if (ValuesArray->size() != Ret.size()) {
  70. Ctx.emitError(
  71. "Unable to parse output spec. It should be a json file containing an "
  72. "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
  73. "with a json object describing a TensorSpec; and a 'logging_name' key, "
  74. "which is a string to use as name when logging this tensor in the "
  75. "training log.");
  76. return std::nullopt;
  77. }
  78. if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
  79. Ctx.emitError("The first output spec must describe the decision tensor, "
  80. "and must have the logging_name " +
  81. StringRef(ExpectedDecisionName));
  82. return std::nullopt;
  83. }
  84. return Ret;
  85. }
  86. } // namespace
  87. ModelUnderTrainingRunner::ModelUnderTrainingRunner(
  88. LLVMContext &Ctx, const std::string &ModelPath,
  89. const std::vector<TensorSpec> &InputSpecs,
  90. const std::vector<TensorSpec> &OutputSpecs,
  91. const std::vector<TensorSpec> &ExtraOutputsForLogging)
  92. : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
  93. OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) {
  94. Evaluator =
  95. std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs);
  96. if (!Evaluator || !Evaluator->isValid()) {
  97. Ctx.emitError("Failed to create saved model evaluator");
  98. Evaluator.reset();
  99. return;
  100. }
  101. for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
  102. setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
  103. }
  104. }
  105. void *ModelUnderTrainingRunner::evaluateUntyped() {
  106. LastEvaluationResult = Evaluator->evaluate();
  107. if (!LastEvaluationResult.has_value()) {
  108. Ctx.emitError("Error evaluating model.");
  109. return nullptr;
  110. }
  111. return LastEvaluationResult->getUntypedTensorValue(0);
  112. }
  113. std::unique_ptr<ModelUnderTrainingRunner>
  114. ModelUnderTrainingRunner::createAndEnsureValid(
  115. LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
  116. const std::vector<TensorSpec> &InputSpecs,
  117. StringRef OutputSpecsPathOverride) {
  118. if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
  119. OutputSpecsPathOverride)) {
  120. std::unique_ptr<ModelUnderTrainingRunner> MUTR;
  121. std::vector<TensorSpec> OutputSpecs;
  122. std::vector<TensorSpec> ExtraOutputsForLogging;
  123. append_range(OutputSpecs,
  124. map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) {
  125. return LFS.Spec;
  126. }));
  127. append_range(ExtraOutputsForLogging,
  128. map_range(drop_begin(*MaybeOutputSpecs),
  129. [](const LoggedFeatureSpec &LFS) {
  130. return TensorSpec(LFS.LoggingName
  131. ? *LFS.LoggingName
  132. : LFS.Spec.name(),
  133. LFS.Spec);
  134. }));
  135. MUTR.reset(new ModelUnderTrainingRunner(
  136. Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging));
  137. if (MUTR && MUTR->isValid())
  138. return MUTR;
  139. Ctx.emitError("Could not load or create model evaluator.");
  140. return nullptr;
  141. }
  142. Ctx.emitError("Could not load the policy model from the provided path");
  143. return nullptr;
  144. }
  145. #endif // defined(LLVM_HAVE_TFLITE)