Math.h 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #ifndef LLVM_CODEGEN_PBQP_MATH_H
  14. #define LLVM_CODEGEN_PBQP_MATH_H
  15. #include "llvm/ADT/Hashing.h"
  16. #include "llvm/ADT/STLExtras.h"
  17. #include <algorithm>
  18. #include <cassert>
  19. #include <functional>
  20. #include <memory>
  21. namespace llvm {
  22. namespace PBQP {
  23. using PBQPNum = float;
  24. /// PBQP Vector class.
  25. class Vector {
  26. friend hash_code hash_value(const Vector &);
  27. public:
  28. /// Construct a PBQP vector of the given size.
  29. explicit Vector(unsigned Length)
  30. : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) {}
  31. /// Construct a PBQP vector with initializer.
  32. Vector(unsigned Length, PBQPNum InitVal)
  33. : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) {
  34. std::fill(Data.get(), Data.get() + Length, InitVal);
  35. }
  36. /// Copy construct a PBQP vector.
  37. Vector(const Vector &V)
  38. : Length(V.Length), Data(std::make_unique<PBQPNum []>(Length)) {
  39. std::copy(V.Data.get(), V.Data.get() + Length, Data.get());
  40. }
  41. /// Move construct a PBQP vector.
  42. Vector(Vector &&V)
  43. : Length(V.Length), Data(std::move(V.Data)) {
  44. V.Length = 0;
  45. }
  46. /// Comparison operator.
  47. bool operator==(const Vector &V) const {
  48. assert(Length != 0 && Data && "Invalid vector");
  49. if (Length != V.Length)
  50. return false;
  51. return std::equal(Data.get(), Data.get() + Length, V.Data.get());
  52. }
  53. /// Return the length of the vector
  54. unsigned getLength() const {
  55. assert(Length != 0 && Data && "Invalid vector");
  56. return Length;
  57. }
  58. /// Element access.
  59. PBQPNum& operator[](unsigned Index) {
  60. assert(Length != 0 && Data && "Invalid vector");
  61. assert(Index < Length && "Vector element access out of bounds.");
  62. return Data[Index];
  63. }
  64. /// Const element access.
  65. const PBQPNum& operator[](unsigned Index) const {
  66. assert(Length != 0 && Data && "Invalid vector");
  67. assert(Index < Length && "Vector element access out of bounds.");
  68. return Data[Index];
  69. }
  70. /// Add another vector to this one.
  71. Vector& operator+=(const Vector &V) {
  72. assert(Length != 0 && Data && "Invalid vector");
  73. assert(Length == V.Length && "Vector length mismatch.");
  74. std::transform(Data.get(), Data.get() + Length, V.Data.get(), Data.get(),
  75. std::plus<PBQPNum>());
  76. return *this;
  77. }
  78. /// Returns the index of the minimum value in this vector
  79. unsigned minIndex() const {
  80. assert(Length != 0 && Data && "Invalid vector");
  81. return std::min_element(Data.get(), Data.get() + Length) - Data.get();
  82. }
  83. private:
  84. unsigned Length;
  85. std::unique_ptr<PBQPNum []> Data;
  86. };
  87. /// Return a hash_value for the given vector.
  88. inline hash_code hash_value(const Vector &V) {
  89. unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data.get());
  90. unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data.get() + V.Length);
  91. return hash_combine(V.Length, hash_combine_range(VBegin, VEnd));
  92. }
  93. /// Output a textual representation of the given vector on the given
  94. /// output stream.
  95. template <typename OStream>
  96. OStream& operator<<(OStream &OS, const Vector &V) {
  97. assert((V.getLength() != 0) && "Zero-length vector badness.");
  98. OS << "[ " << V[0];
  99. for (unsigned i = 1; i < V.getLength(); ++i)
  100. OS << ", " << V[i];
  101. OS << " ]";
  102. return OS;
  103. }
  104. /// PBQP Matrix class
  105. class Matrix {
  106. private:
  107. friend hash_code hash_value(const Matrix &);
  108. public:
  109. /// Construct a PBQP Matrix with the given dimensions.
  110. Matrix(unsigned Rows, unsigned Cols) :
  111. Rows(Rows), Cols(Cols), Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
  112. }
  113. /// Construct a PBQP Matrix with the given dimensions and initial
  114. /// value.
  115. Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal)
  116. : Rows(Rows), Cols(Cols),
  117. Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
  118. std::fill(Data.get(), Data.get() + (Rows * Cols), InitVal);
  119. }
  120. /// Copy construct a PBQP matrix.
  121. Matrix(const Matrix &M)
  122. : Rows(M.Rows), Cols(M.Cols),
  123. Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
  124. std::copy(M.Data.get(), M.Data.get() + (Rows * Cols), Data.get());
  125. }
  126. /// Move construct a PBQP matrix.
  127. Matrix(Matrix &&M)
  128. : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) {
  129. M.Rows = M.Cols = 0;
  130. }
  131. /// Comparison operator.
  132. bool operator==(const Matrix &M) const {
  133. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  134. if (Rows != M.Rows || Cols != M.Cols)
  135. return false;
  136. return std::equal(Data.get(), Data.get() + (Rows * Cols), M.Data.get());
  137. }
  138. /// Return the number of rows in this matrix.
  139. unsigned getRows() const {
  140. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  141. return Rows;
  142. }
  143. /// Return the number of cols in this matrix.
  144. unsigned getCols() const {
  145. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  146. return Cols;
  147. }
  148. /// Matrix element access.
  149. PBQPNum* operator[](unsigned R) {
  150. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  151. assert(R < Rows && "Row out of bounds.");
  152. return Data.get() + (R * Cols);
  153. }
  154. /// Matrix element access.
  155. const PBQPNum* operator[](unsigned R) const {
  156. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  157. assert(R < Rows && "Row out of bounds.");
  158. return Data.get() + (R * Cols);
  159. }
  160. /// Returns the given row as a vector.
  161. Vector getRowAsVector(unsigned R) const {
  162. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  163. Vector V(Cols);
  164. for (unsigned C = 0; C < Cols; ++C)
  165. V[C] = (*this)[R][C];
  166. return V;
  167. }
  168. /// Returns the given column as a vector.
  169. Vector getColAsVector(unsigned C) const {
  170. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  171. Vector V(Rows);
  172. for (unsigned R = 0; R < Rows; ++R)
  173. V[R] = (*this)[R][C];
  174. return V;
  175. }
  176. /// Matrix transpose.
  177. Matrix transpose() const {
  178. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  179. Matrix M(Cols, Rows);
  180. for (unsigned r = 0; r < Rows; ++r)
  181. for (unsigned c = 0; c < Cols; ++c)
  182. M[c][r] = (*this)[r][c];
  183. return M;
  184. }
  185. /// Add the given matrix to this one.
  186. Matrix& operator+=(const Matrix &M) {
  187. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  188. assert(Rows == M.Rows && Cols == M.Cols &&
  189. "Matrix dimensions mismatch.");
  190. std::transform(Data.get(), Data.get() + (Rows * Cols), M.Data.get(),
  191. Data.get(), std::plus<PBQPNum>());
  192. return *this;
  193. }
  194. Matrix operator+(const Matrix &M) {
  195. assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
  196. Matrix Tmp(*this);
  197. Tmp += M;
  198. return Tmp;
  199. }
  200. private:
  201. unsigned Rows, Cols;
  202. std::unique_ptr<PBQPNum []> Data;
  203. };
  204. /// Return a hash_code for the given matrix.
  205. inline hash_code hash_value(const Matrix &M) {
  206. unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get());
  207. unsigned *MEnd =
  208. reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols));
  209. return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd));
  210. }
  211. /// Output a textual representation of the given matrix on the given
  212. /// output stream.
  213. template <typename OStream>
  214. OStream& operator<<(OStream &OS, const Matrix &M) {
  215. assert((M.getRows() != 0) && "Zero-row matrix badness.");
  216. for (unsigned i = 0; i < M.getRows(); ++i)
  217. OS << M.getRowAsVector(i) << "\n";
  218. return OS;
  219. }
  220. template <typename Metadata>
  221. class MDVector : public Vector {
  222. public:
  223. MDVector(const Vector &v) : Vector(v), md(*this) {}
  224. MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { }
  225. const Metadata& getMetadata() const { return md; }
  226. private:
  227. Metadata md;
  228. };
  229. template <typename Metadata>
  230. inline hash_code hash_value(const MDVector<Metadata> &V) {
  231. return hash_value(static_cast<const Vector&>(V));
  232. }
  233. template <typename Metadata>
  234. class MDMatrix : public Matrix {
  235. public:
  236. MDMatrix(const Matrix &m) : Matrix(m), md(*this) {}
  237. MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { }
  238. const Metadata& getMetadata() const { return md; }
  239. private:
  240. Metadata md;
  241. };
  242. template <typename Metadata>
  243. inline hash_code hash_value(const MDMatrix<Metadata> &M) {
  244. return hash_value(static_cast<const Matrix&>(M));
  245. }
  246. } // end namespace PBQP
  247. } // end namespace llvm
  248. #endif // LLVM_CODEGEN_PBQP_MATH_H
  249. #ifdef __GNUC__
  250. #pragma GCC diagnostic pop
  251. #endif