TrainingLogger.h 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- TrainingLogger.h - mlgo feature/reward logging ----------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // The design goals of the logger are:
  15. // - no dependencies that llvm doesn't already have.
  16. // - support streaming, so that we don't need to buffer data during compilation
  17. // - 0-decoding tensor values. Tensor values are potentially very large buffers
  18. // of scalars. Because of their potentially large size, avoiding
  19. // serialization/deserialization overhead is preferred.
  20. //
  21. // The simple logger produces an output of the form (each line item on its line)
  22. // - header: a json object describing the data that will follow.
  23. // - context: e.g. function name, for regalloc, or "default" for module-wide
  24. // optimizations like the inliner. This is the context to which the subsequent
  25. // data corresponds.
  26. // - observation number.
  27. // - tensor values - raw bytes of the tensors, in the order given in the header.
  28. // The values are in succession, i.e. no separator is found between successive
  29. // tensor values. At the end, there is a new line character.
  30. // - [score] - this is optional, and is present if it was present in the header.
  31. // Currently, for final rewards, we output "0" scores after each observation,
  32. // except for the last one.
  33. // <repeat>
  34. // The file should be read as binary, but the reason we use newlines is mostly
  35. // ease of debugging: the log can be opened in a text editor and, while tensor
  36. // values are inscrutable, at least the sequence of data can be easily observed.
  37. // Of course, the buffer of tensor values could contain '\n' bytes. A reader
  38. // should use the header information to know how much data to read for the
  39. // tensor values, and not use line information for that.
  40. //
  41. // An example reader, used for test, is available at
  42. // Analysis/models/log_reader.py
  43. //
  44. // Example:
  45. // {"features":[list of TensorSpecs], "score":<a tensor spec>}
  46. // {"context": "aFunction"}
  47. // {"observation": 0}
  48. // <bytes>
  49. // {"outcome": 0}
  50. // <bytes for the tensor corresponding to the "score" spec in the header>
  51. // {"observation": 1}
  52. // ...
  53. // {"context": "anotherFunction"}
  54. // {"observation": 0}
  55. // ...
  56. //
  57. #ifndef LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H
  58. #define LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H
  59. #include "llvm/Config/llvm-config.h"
  60. #include "llvm/ADT/StringMap.h"
  61. #include "llvm/Analysis/TensorSpec.h"
  62. #include "llvm/IR/LLVMContext.h"
  63. #include "llvm/Support/JSON.h"
  64. #include <memory>
  65. #include <optional>
  66. #include <vector>
  67. namespace llvm {
  68. /// Logging utility - given an ordered specification of features, and assuming
  69. /// a scalar reward, allow logging feature values and rewards.
  70. /// The assumption is that, for an event to be logged (i.e. a set of feature
  71. /// values and a reward), the user calls the log* API for each feature exactly
  72. /// once, providing the index matching the position in the feature spec list
  73. /// provided at construction. The example assumes the first feature's element
  74. /// type is float, the second is int64, and the reward is float:
  75. ///
  76. /// event 0:
  77. /// logFloatValue(0, ...)
  78. /// logInt64Value(1, ...)
  79. /// ...
  80. /// logFloatReward(...)
  81. /// event 1:
  82. /// logFloatValue(0, ...)
  83. /// logInt64Value(1, ...)
  84. /// ...
  85. /// logFloatReward(...)
  86. ///
  87. /// At the end, call print to generate the log.
  88. /// Alternatively, don't call logReward at the end of each event, just
  89. /// log{Float|Int32|Int64}FinalReward at the end.
  90. class Logger final {
  91. std::unique_ptr<raw_ostream> OS;
  92. const std::vector<TensorSpec> FeatureSpecs;
  93. const TensorSpec RewardSpec;
  94. const bool IncludeReward;
  95. StringMap<size_t> ObservationIDs;
  96. std::string CurrentContext;
  97. void writeHeader();
  98. void writeTensor(const TensorSpec &Spec, const char *RawData) {
  99. OS->write(RawData, Spec.getTotalTensorBufferSize());
  100. }
  101. void logRewardImpl(const char *RawData);
  102. public:
  103. /// Construct a Logger. If IncludeReward is false, then logReward or
  104. /// logFinalReward shouldn't be called, and the reward feature won't be
  105. /// printed out.
  106. /// NOTE: the FeatureSpecs are expected to be in the same order (i.e. have
  107. /// corresponding indices) with any MLModelRunner implementations
  108. /// corresponding to the model being trained/logged.
  109. Logger(std::unique_ptr<raw_ostream> OS,
  110. const std::vector<TensorSpec> &FeatureSpecs,
  111. const TensorSpec &RewardSpec, bool IncludeReward);
  112. void switchContext(StringRef Name);
  113. void startObservation();
  114. void endObservation();
  115. const std::string &currentContext() const { return CurrentContext; }
  116. bool hasObservationInProgress() const {
  117. return ObservationIDs.find(CurrentContext) != ObservationIDs.end();
  118. }
  119. template <typename T> void logReward(T Value) {
  120. logRewardImpl(reinterpret_cast<const char *>(&Value));
  121. }
  122. void logTensorValue(size_t FeatureID, const char *RawData) {
  123. writeTensor(FeatureSpecs[FeatureID], RawData);
  124. }
  125. };
  126. } // namespace llvm
  127. #endif // LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H
  128. #ifdef __GNUC__
  129. #pragma GCC diagnostic pop
  130. #endif