123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- #pragma once
- #ifdef __GNUC__
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wunused-parameter"
- #endif
- //===- TensorSpec.h - type descriptor for a tensor --------------*- C++ -*-===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- #ifndef LLVM_ANALYSIS_TENSORSPEC_H
- #define LLVM_ANALYSIS_TENSORSPEC_H
- #include "llvm/Config/llvm-config.h"
- #include "llvm/ADT/StringMap.h"
- #include "llvm/IR/LLVMContext.h"
- #include "llvm/Support/JSON.h"
- #include <memory>
- #include <optional>
- #include <vector>
- namespace llvm {
- /// TensorSpec encapsulates the specification of a tensor: its dimensions, or
- /// "shape" (row-major), its type (see TensorSpec::getDataType specializations
- /// for supported types), its name and port (see "TensorFlow: Large-Scale
- /// Machine Learning on Heterogeneous Distributed Systems", section 4.2, para 2:
- /// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf)
- ///
- /// Known tensor types. The left part is the C type, the right is a name we
- /// can use to identify the type (to implement TensorSpec equality checks), and
- /// to use, if needed, when mapping to an underlying evaluator's type system.
- /// The main requirement is that the C type we use has the same size and
- /// encoding (e.g. endian-ness) as the one used by the evaluator.
- #define SUPPORTED_TENSOR_TYPES(M) \
- M(float, Float) \
- M(double, Double) \
- M(int8_t, Int8) \
- M(uint8_t, UInt8) \
- M(int16_t, Int16) \
- M(uint16_t, UInt16) \
- M(int32_t, Int32) \
- M(uint32_t, UInt32) \
- M(int64_t, Int64) \
- M(uint64_t, UInt64)
- enum class TensorType {
- Invalid,
- #define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name,
- SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS)
- #undef _TENSOR_TYPE_ENUM_MEMBERS
- Total
- };
- class TensorSpec final {
- public:
- template <typename T>
- static TensorSpec createSpec(const std::string &Name,
- const std::vector<int64_t> &Shape,
- int Port = 0) {
- return TensorSpec(Name, Port, getDataType<T>(), sizeof(T), Shape);
- }
- const std::string &name() const { return Name; }
- int port() const { return Port; }
- TensorType type() const { return Type; }
- const std::vector<int64_t> &shape() const { return Shape; }
- bool operator==(const TensorSpec &Other) const {
- return Name == Other.Name && Port == Other.Port && Type == Other.Type &&
- Shape == Other.Shape;
- }
- bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
- /// Get the number of elements in a tensor with this shape.
- size_t getElementCount() const { return ElementCount; }
- /// Get the size, in bytes, of one element.
- size_t getElementByteSize() const { return ElementSize; }
- /// Get the total size of a memory buffer needed to store the whole tensor.
- size_t getTotalTensorBufferSize() const { return ElementCount * ElementSize; }
- template <typename T> bool isElementType() const {
- return getDataType<T>() == Type;
- }
- TensorSpec(const std::string &NewName, const TensorSpec &Other)
- : TensorSpec(NewName, Other.Port, Other.Type, Other.ElementSize,
- Other.Shape) {}
- void toJSON(json::OStream &OS) const;
- private:
- TensorSpec(const std::string &Name, int Port, TensorType Type,
- size_t ElementSize, const std::vector<int64_t> &Shape);
- template <typename T> static TensorType getDataType();
- std::string Name;
- int Port = 0;
- TensorType Type = TensorType::Invalid;
- std::vector<int64_t> Shape;
- size_t ElementCount = 0;
- size_t ElementSize = 0;
- };
- /// Construct a TensorSpec from a JSON dictionary of the form:
- /// { "name": <string>,
- /// "port": <int>,
- /// "type": <string. Use LLVM's types, e.g. float, double, int64_t>,
- /// "shape": <array of ints> }
- /// For the "type" field, see the C++ primitive types used in
- /// TFUTILS_SUPPORTED_TYPES.
- std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
- const json::Value &Value);
- #define TFUTILS_GETDATATYPE_DEF(T, Name) \
- template <> TensorType TensorSpec::getDataType<T>();
- SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF)
- #undef TFUTILS_GETDATATYPE_DEF
- } // namespace llvm
- #endif // LLVM_ANALYSIS_TENSORSPEC_H
- #ifdef __GNUC__
- #pragma GCC diagnostic pop
- #endif
|