TFUtils.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. //===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements utilities for interfacing with tensorflow C APIs.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Config/config.h"
  13. #if defined(LLVM_HAVE_TF_API)
  14. #include "llvm/ADT/Twine.h"
  15. #include "llvm/Analysis/Utils/TFUtils.h"
  16. #include "llvm/Support/Base64.h"
  17. #include "llvm/Support/CommandLine.h"
  18. #include "llvm/Support/Debug.h"
  19. #include "llvm/Support/JSON.h"
  20. #include "llvm/Support/ManagedStatic.h"
  21. #include "llvm/Support/MemoryBuffer.h"
  22. #include "llvm/Support/Path.h"
  23. #include "llvm/Support/raw_ostream.h"
  24. #error #include "google/protobuf/struct.pb.h"
  25. #error #include "google/protobuf/text_format.h"
  26. #error #include "tensorflow/c/c_api.h"
  27. #error #include "tensorflow/c/c_api_experimental.h"
  28. #error #include "tensorflow/core/example/example.pb.h"
  29. #include <cassert>
  30. #include <numeric>
  31. using namespace llvm;
  32. using google::protobuf::Message;
  33. using google::protobuf::TextFormat;
  34. static cl::opt<bool>
  35. ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
  36. cl::desc("Output textual (human-readable) protobuf."));
  37. namespace {
  38. using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
  39. using TFSessionOptionsPtr =
  40. std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
  41. using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
  42. struct TFInitializer {
  43. TFInitializer() {
  44. assert(!IsInitialized && "TFInitialized should be called only once");
  45. int Argc = 1;
  46. const char *Name = "";
  47. const char **NamePtr = &Name;
  48. TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
  49. IsInitialized = true;
  50. }
  51. bool IsInitialized = false;
  52. };
  53. llvm::ManagedStatic<TFInitializer> TFLibInitializer;
  54. bool ensureInitTF() { return TFLibInitializer->IsInitialized; }
  55. TFGraphPtr createTFGraph() {
  56. return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
  57. }
  58. TFStatusPtr createTFStatus() {
  59. return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
  60. }
  61. TFSessionOptionsPtr createTFSessionOptions() {
  62. return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
  63. }
  64. void serialize(const Message &SE, std::string *OutStr) {
  65. if (ProtobufTextMode) {
  66. TextFormat::PrintToString(SE, OutStr);
  67. } else {
  68. *OutStr = SE.SerializeAsString();
  69. }
  70. }
  71. } // namespace
  72. namespace llvm {
  73. class EvaluationResultImpl {
  74. public:
  75. EvaluationResultImpl(size_t OutputSize)
  76. : OutputSize(OutputSize), Output(OutputSize){};
  77. ~EvaluationResultImpl() {
  78. for (auto *P : Output)
  79. if (P)
  80. TF_DeleteTensor(P);
  81. }
  82. EvaluationResultImpl(const EvaluationResultImpl &) = delete;
  83. EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
  84. std::vector<TF_Tensor *> &getOutput() { return Output; }
  85. private:
  86. const size_t OutputSize;
  87. std::vector<TF_Tensor *> Output;
  88. };
  89. size_t TensorSpec::getElementByteSize() const {
  90. return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
  91. }
  92. TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
  93. const std::vector<int64_t> &Shape)
  94. : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
  95. ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
  96. std::multiplies<int64_t>())) {}
  97. Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
  98. const json::Value &Value) {
  99. auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
  100. std::string S;
  101. llvm::raw_string_ostream OS(S);
  102. OS << Value;
  103. Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
  104. return None;
  105. };
  106. // FIXME: accept a Path as a parameter, and use it for error reporting.
  107. json::Path::Root Root("tensor_spec");
  108. json::ObjectMapper Mapper(Value, Root);
  109. if (!Mapper)
  110. return EmitError("Value is not a dict");
  111. std::string TensorName;
  112. int TensorPort = -1;
  113. std::string TensorType;
  114. std::vector<int64_t> TensorShape;
  115. if (!Mapper.map<std::string>("name", TensorName))
  116. return EmitError("'name' property not present or not a string");
  117. if (!Mapper.map<std::string>("type", TensorType))
  118. return EmitError("'type' property not present or not a string");
  119. if (!Mapper.map<int>("port", TensorPort))
  120. return EmitError("'port' property not present or not an int");
  121. if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
  122. return EmitError("'shape' property not present or not an int array");
  123. #define PARSE_TYPE(T, E) \
  124. if (TensorType == #T) \
  125. return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
  126. TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
  127. #undef PARSE_TYPE
  128. return None;
  129. }
  130. Optional<std::vector<LoggedFeatureSpec>>
  131. loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
  132. StringRef ModelPath, StringRef SpecFileOverride) {
  133. SmallVector<char, 128> OutputSpecsPath;
  134. StringRef FileName = SpecFileOverride;
  135. if (FileName.empty()) {
  136. llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
  137. FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
  138. }
  139. auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
  140. if (!BufferOrError) {
  141. Ctx.emitError("Error opening output specs file: " + FileName + " : " +
  142. BufferOrError.getError().message());
  143. return None;
  144. }
  145. auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
  146. if (!ParsedJSONValues) {
  147. Ctx.emitError("Could not parse specs file: " + FileName);
  148. return None;
  149. }
  150. auto ValuesArray = ParsedJSONValues->getAsArray();
  151. if (!ValuesArray) {
  152. Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
  153. "logging_name:<name>} dictionaries");
  154. return None;
  155. }
  156. std::vector<LoggedFeatureSpec> Ret;
  157. for (const auto &Value : *ValuesArray)
  158. if (const auto *Obj = Value.getAsObject())
  159. if (const auto *SpecPart = Obj->get("tensor_spec"))
  160. if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
  161. if (auto LoggingName = Obj->getString("logging_name")) {
  162. if (!TensorSpec->isElementType<int64_t>() &&
  163. !TensorSpec->isElementType<int32_t>() &&
  164. !TensorSpec->isElementType<float>()) {
  165. Ctx.emitError(
  166. "Only int64, int32, and float tensors are supported. "
  167. "Found unsupported type for tensor named " +
  168. TensorSpec->name());
  169. return None;
  170. }
  171. Ret.push_back({*TensorSpec, LoggingName->str()});
  172. }
  173. if (ValuesArray->size() != Ret.size()) {
  174. Ctx.emitError(
  175. "Unable to parse output spec. It should be a json file containing an "
  176. "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
  177. "with a json object describing a TensorSpec; and a 'logging_name' key, "
  178. "which is a string to use as name when logging this tensor in the "
  179. "training log.");
  180. return None;
  181. }
  182. if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
  183. Ctx.emitError("The first output spec must describe the decision tensor, "
  184. "and must have the logging_name " +
  185. StringRef(ExpectedDecisionName));
  186. return None;
  187. }
  188. return Ret;
  189. }
  190. class TFModelEvaluatorImpl {
  191. public:
  192. TFModelEvaluatorImpl(StringRef SavedModelPath,
  193. const std::vector<TensorSpec> &InputSpecs,
  194. function_ref<TensorSpec(size_t)> GetOutputSpecs,
  195. size_t OutputSpecsSize, const char *Tags);
  196. bool isValid() const { return IsValid; }
  197. size_t OutputSize() const { return OutputFeed.size(); }
  198. void evaluate(TF_Tensor **Output, TF_Status *Status) {
  199. TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(),
  200. Input.size(), OutputFeed.data(), Output, OutputFeed.size(),
  201. nullptr, 0, nullptr, Status);
  202. }
  203. void initInput(size_t Index, TF_DataType Type,
  204. const std::vector<int64_t> &Dimensions);
  205. const std::vector<TF_Tensor *> &getInput() const { return Input; }
  206. ~TFModelEvaluatorImpl();
  207. private:
  208. /// The objects necessary for carrying out an evaluation of the SavedModel.
  209. /// They are expensive to set up, and we maintain them accross all the
  210. /// evaluations of the model.
  211. TF_Session *Session = nullptr;
  212. TFGraphPtr Graph;
  213. TFSessionOptionsPtr Options;
  214. /// The specification of the input nodes.
  215. std::vector<TF_Output> InputFeed;
  216. /// The input tensors. They must match by index of the corresponding InputFeed
  217. /// value. We set up the tensors once and just mutate theirs scalars before
  218. /// each evaluation. The input tensors keep their value after an evaluation.
  219. std::vector<TF_Tensor *> Input;
  220. /// The specification of the output nodes. When evaluating, the tensors in the
  221. /// output tensor vector must match by index the corresponding element in the
  222. /// OutputFeed.
  223. std::vector<TF_Output> OutputFeed;
  224. void invalidate() { IsValid = false; }
  225. bool IsValid = true;
  226. /// Reusable utility for ensuring we can bind the requested Name to a node in
  227. /// the SavedModel Graph.
  228. bool checkReportAndInvalidate(const TF_Output &Output,
  229. const TensorSpec &OutputSpec);
  230. };
  231. class LoggerDataImpl {
  232. const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
  233. const TensorSpec RewardSpec;
  234. const bool IncludeReward;
  235. std::vector<tensorflow::FeatureList> FeatureLists;
  236. tensorflow::FeatureList Reward;
  237. bool isSelfConsistent(const tensorflow::SequenceExample &SE,
  238. size_t NrRecords) const {
  239. bool Ret = true;
  240. for (const auto &TSpecs : LoggedFeatureSpecs) {
  241. const auto &Name = TSpecs.getLoggingName();
  242. const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
  243. if (NrRecords != static_cast<size_t>(FL.size())) {
  244. dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
  245. << NrRecords << " got " << FL.size() << "\n";
  246. Ret = false;
  247. }
  248. }
  249. if (IncludeReward && static_cast<size_t>(SE.feature_lists()
  250. .feature_list()
  251. .at(RewardSpec.name())
  252. .feature()
  253. .size()) != NrRecords) {
  254. dbgs() << "[TF-UTILS]: reward is missing records.\n";
  255. Ret = false;
  256. }
  257. return Ret;
  258. }
  259. void transferLog(tensorflow::SequenceExample &SE) {
  260. auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
  261. if (IncludeReward)
  262. (*FL)[RewardSpec.name()] = std::move(Reward);
  263. assert(FeatureLists.size() == LoggedFeatureSpecs.size());
  264. for (size_t I = 0; I < FeatureLists.size(); ++I) {
  265. const auto &LFS = LoggedFeatureSpecs[I];
  266. (*FL)[LFS.getLoggingName()] = std::move(FeatureLists[I]);
  267. }
  268. }
  269. public:
  270. LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
  271. const TensorSpec &RewardSpec, bool IncludeReward)
  272. : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
  273. IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
  274. // flush the logged info to a stream and clear the log contents.
  275. void flush(std::string *Str) {
  276. size_t NrRecords = getNrRecords();
  277. (void)NrRecords;
  278. tensorflow::SequenceExample SE;
  279. transferLog(SE);
  280. assert(isSelfConsistent(SE, NrRecords));
  281. serialize(SE, Str);
  282. }
  283. char *addNewTensor(size_t FeatureID) {
  284. const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
  285. if (Spec.isElementType<float>()) {
  286. auto *RF = FeatureLists[FeatureID]
  287. .add_feature()
  288. ->mutable_float_list()
  289. ->mutable_value();
  290. RF->Resize(Spec.getElementCount(), 0.0);
  291. return reinterpret_cast<char *>(RF->mutable_data());
  292. } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
  293. auto *RF = FeatureLists[FeatureID]
  294. .add_feature()
  295. ->mutable_int64_list()
  296. ->mutable_value();
  297. RF->Resize(Spec.getElementCount(), 0);
  298. return reinterpret_cast<char *>(RF->mutable_data());
  299. }
  300. llvm_unreachable("Unsupported tensor type.");
  301. }
  302. template <typename T> void logReward(T Value) {
  303. assert(IncludeReward);
  304. if (RewardSpec.isElementType<float>())
  305. Reward.add_feature()->mutable_float_list()->add_value(Value);
  306. else if (RewardSpec.isElementType<int32_t>() ||
  307. RewardSpec.isElementType<int64_t>())
  308. Reward.add_feature()->mutable_int64_list()->add_value(Value);
  309. else
  310. llvm_unreachable("Unsupported tensor type.");
  311. }
  312. size_t getNrRecords() const {
  313. return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
  314. }
  315. };
  316. } // namespace llvm
  317. TFModelEvaluatorImpl::TFModelEvaluatorImpl(
  318. StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
  319. function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
  320. const char *Tags = "serve")
  321. : Graph(createTFGraph()), Options(createTFSessionOptions()),
  322. InputFeed(InputSpecs.size()), Input(InputSpecs.size()),
  323. OutputFeed(OutputSpecsSize) {
  324. if (!ensureInitTF()) {
  325. errs() << "Tensorflow should have been initialized";
  326. return;
  327. }
  328. auto Status = createTFStatus();
  329. Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
  330. SavedModelPath.str().c_str(), &Tags, 1,
  331. Graph.get(), nullptr, Status.get());
  332. if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
  333. errs() << TF_Message(Status.get());
  334. invalidate();
  335. }
  336. for (size_t I = 0; I < InputSpecs.size(); ++I) {
  337. auto &InputSpec = InputSpecs[I];
  338. InputFeed[I] = {
  339. TF_GraphOperationByName(Graph.get(), (InputSpec.name()).c_str()),
  340. InputSpec.port()};
  341. if (!checkReportAndInvalidate(InputFeed[I], InputSpec))
  342. return;
  343. initInput(I, static_cast<TF_DataType>(InputSpec.typeIndex()),
  344. InputSpec.shape());
  345. }
  346. for (size_t I = 0; I < OutputSpecsSize; ++I) {
  347. auto OutputSpec = GetOutputSpecs(I);
  348. OutputFeed[I] = {
  349. TF_GraphOperationByName(Graph.get(), (OutputSpec.name()).c_str()),
  350. OutputSpec.port()};
  351. if (!checkReportAndInvalidate(OutputFeed[I], OutputSpec))
  352. return;
  353. }
  354. }
  355. TFModelEvaluator::TFModelEvaluator(
  356. StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
  357. function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
  358. const char *Tags)
  359. : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, GetOutputSpecs,
  360. OutputSpecsSize, Tags)) {
  361. if (!Impl->isValid())
  362. Impl.reset();
  363. }
  364. TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
  365. const std::vector<TensorSpec> &InputSpecs,
  366. const std::vector<TensorSpec> &OutputSpecs,
  367. const char *Tags)
  368. : TFModelEvaluator(
  369. SavedModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I]; },
  370. OutputSpecs.size(), Tags) {}
  371. TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
  372. for (auto *T : Input) {
  373. TF_DeleteTensor(T);
  374. }
  375. if (Session == nullptr)
  376. return;
  377. auto Status = createTFStatus();
  378. TF_DeleteSession(Session, Status.get());
  379. Session = nullptr;
  380. if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
  381. errs() << "Could not delete TF session";
  382. }
  383. bool TFModelEvaluatorImpl::checkReportAndInvalidate(
  384. const TF_Output &Output, const TensorSpec &OutputSpec) {
  385. if (Output.oper)
  386. return true;
  387. errs() << "Could not find TF_Output named: " + OutputSpec.name();
  388. IsValid = false;
  389. return IsValid;
  390. }
  391. Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
  392. if (!isValid())
  393. return None;
  394. std::unique_ptr<EvaluationResultImpl> Ret =
  395. std::make_unique<EvaluationResultImpl>(Impl->OutputSize());
  396. auto Status = createTFStatus();
  397. Impl->evaluate(Ret->getOutput().data(), Status.get());
  398. if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
  399. errs() << TF_Message(Status.get());
  400. Impl.reset();
  401. return None;
  402. }
  403. return EvaluationResult(std::move(Ret));
  404. }
  405. void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type,
  406. const std::vector<int64_t> &Dimensions) {
  407. int64_t TotalSize = TF_DataTypeSize(Type);
  408. for (auto &D : Dimensions)
  409. TotalSize *= D;
  410. Input[Index] =
  411. TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
  412. std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
  413. }
  414. void *TFModelEvaluator::getUntypedInput(size_t Index) {
  415. return TF_TensorData(Impl->getInput()[Index]);
  416. }
  417. TFModelEvaluator::EvaluationResult::EvaluationResult(
  418. std::unique_ptr<EvaluationResultImpl> Impl)
  419. : Impl(std::move(Impl)) {}
  420. TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
  421. : Impl(std::move(Other.Impl)) {}
  422. TFModelEvaluator::EvaluationResult &
  423. TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
  424. Impl = std::move(Other.Impl);
  425. return *this;
  426. }
  427. void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
  428. return TF_TensorData(Impl->getOutput()[Index]);
  429. }
  430. const void *
  431. TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
  432. return TF_TensorData(Impl->getOutput()[Index]);
  433. }
  434. #define TFUTILS_GETDATATYPE_IMPL(T, E) \
  435. template <> int TensorSpec::getDataType<T>() { return E; }
  436. TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
  437. #undef TFUTILS_GETDATATYPE_IMPL
  438. TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
  439. TFModelEvaluator::~TFModelEvaluator() {}
  440. Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
  441. const TensorSpec &RewardSpec, bool IncludeReward)
  442. : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
  443. IncludeReward(IncludeReward),
  444. LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
  445. IncludeReward)) {}
  446. Logger::~Logger() {}
  447. #define LOG_REWARD(NAME, TYPE) \
  448. void Logger::log##NAME##Reward(TYPE Value) { \
  449. assert(IncludeReward); \
  450. LoggerData->logReward(Value); \
  451. }
  452. LOG_REWARD(Float, float)
  453. LOG_REWARD(Int32, int32_t)
  454. LOG_REWARD(Int64, int64_t)
  455. #undef LOG_REWARD
  456. #define LOG_FINAL_REWARD(NAME, TYPE) \
  457. void Logger::log##NAME##FinalReward(TYPE Value) { \
  458. assert(RewardSpec.isElementType<TYPE>()); \
  459. for (size_t I = 1; I < LoggerData->getNrRecords(); ++I) \
  460. log##NAME##Reward(0); \
  461. log##NAME##Reward(Value); \
  462. }
  463. LOG_FINAL_REWARD(Float, float)
  464. LOG_FINAL_REWARD(Int32, int32_t)
  465. LOG_FINAL_REWARD(Int64, int64_t)
  466. #undef LOG_FINAL_REWARD
  467. void Logger::logFloatValue(size_t FeatureID, const float *Value) {
  468. assert(FeatureSpecs[FeatureID].Spec.isElementType<float>());
  469. logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
  470. }
  471. void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) {
  472. assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>());
  473. logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
  474. }
  475. void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) {
  476. assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>());
  477. logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
  478. }
  479. void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) {
  480. const auto &Spec = FeatureSpecs[FeatureID].Spec;
  481. char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID);
  482. if (Spec.isElementType<int32_t>())
  483. for (size_t I = 0; I < Spec.getElementCount(); ++I)
  484. (reinterpret_cast<int64_t *>(Buff))[I] =
  485. static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]);
  486. else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>())
  487. std::memcpy(Buff, RawData,
  488. Spec.getElementCount() * Spec.getElementByteSize());
  489. else
  490. llvm_unreachable("Unsupported tensor type");
  491. }
  492. char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
  493. return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
  494. }
  495. void Logger::flush(std::string *Str) { LoggerData->flush(Str); }
  496. void Logger::flush(raw_ostream &OS) {
  497. std::string Buff;
  498. LoggerData->flush(&Buff);
  499. OS << Buff;
  500. }
  501. void Logger::flushLogs(raw_ostream &OS,
  502. const StringMap<std::unique_ptr<Logger>> &Loggers) {
  503. google::protobuf::Struct Msg;
  504. for (const auto &NamedLogger : Loggers) {
  505. tensorflow::SequenceExample SE;
  506. const auto &Logger = NamedLogger.second;
  507. std::string Unencoded;
  508. if (Logger->LoggerData->getNrRecords() > 0)
  509. Logger->flush(&Unencoded);
  510. (*Msg.mutable_fields())[NamedLogger.first().str()]
  511. .mutable_string_value()
  512. ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded));
  513. }
  514. std::string OutStr;
  515. serialize(Msg, &OutStr);
  516. OS << OutStr;
  517. }
  518. #endif // defined(LLVM_HAVE_TF_API)