TensorSpec.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Implementation file for the abstraction of a tensor type, and JSON loading
  10. // utils.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/Config/config.h"
  14. #include "llvm/ADT/Twine.h"
  15. #include "llvm/Analysis/TensorSpec.h"
  16. #include "llvm/Support/CommandLine.h"
  17. #include "llvm/Support/Debug.h"
  18. #include "llvm/Support/JSON.h"
  19. #include "llvm/Support/ManagedStatic.h"
  20. #include "llvm/Support/raw_ostream.h"
  21. #include <array>
  22. #include <cassert>
  23. #include <numeric>
  24. using namespace llvm;
  25. namespace llvm {
  26. #define TFUTILS_GETDATATYPE_IMPL(T, E) \
  27. template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
  28. SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
  29. #undef TFUTILS_GETDATATYPE_IMPL
  30. static std::array<std::string, static_cast<size_t>(TensorType::Total)>
  31. TensorTypeNames{"INVALID",
  32. #define TFUTILS_GETNAME_IMPL(T, _) #T,
  33. SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
  34. #undef TFUTILS_GETNAME_IMPL
  35. };
  36. StringRef toString(TensorType TT) {
  37. return TensorTypeNames[static_cast<size_t>(TT)];
  38. }
  39. void TensorSpec::toJSON(json::OStream &OS) const {
  40. OS.object([&]() {
  41. OS.attribute("name", name());
  42. OS.attribute("type", toString(type()));
  43. OS.attribute("port", port());
  44. OS.attributeArray("shape", [&]() {
  45. for (size_t D : shape())
  46. OS.value(static_cast<int64_t>(D));
  47. });
  48. });
  49. }
  50. TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
  51. size_t ElementSize, const std::vector<int64_t> &Shape)
  52. : Name(Name), Port(Port), Type(Type), Shape(Shape),
  53. ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
  54. std::multiplies<int64_t>())),
  55. ElementSize(ElementSize) {}
  56. std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
  57. const json::Value &Value) {
  58. auto EmitError =
  59. [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
  60. std::string S;
  61. llvm::raw_string_ostream OS(S);
  62. OS << Value;
  63. Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
  64. return std::nullopt;
  65. };
  66. // FIXME: accept a Path as a parameter, and use it for error reporting.
  67. json::Path::Root Root("tensor_spec");
  68. json::ObjectMapper Mapper(Value, Root);
  69. if (!Mapper)
  70. return EmitError("Value is not a dict");
  71. std::string TensorName;
  72. int TensorPort = -1;
  73. std::string TensorType;
  74. std::vector<int64_t> TensorShape;
  75. if (!Mapper.map<std::string>("name", TensorName))
  76. return EmitError("'name' property not present or not a string");
  77. if (!Mapper.map<std::string>("type", TensorType))
  78. return EmitError("'type' property not present or not a string");
  79. if (!Mapper.map<int>("port", TensorPort))
  80. return EmitError("'port' property not present or not an int");
  81. if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
  82. return EmitError("'shape' property not present or not an int array");
  83. #define PARSE_TYPE(T, E) \
  84. if (TensorType == #T) \
  85. return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
  86. SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
  87. #undef PARSE_TYPE
  88. return std::nullopt;
  89. }
  90. } // namespace llvm