MatrixBuilder.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // This file defines the MatrixBuilder class, which is used as a convenient way
  15. // to lower matrix operations to LLVM IR.
  16. //
  17. //===----------------------------------------------------------------------===//
  18. #ifndef LLVM_IR_MATRIXBUILDER_H
  19. #define LLVM_IR_MATRIXBUILDER_H
  20. #include "llvm/IR/Constant.h"
  21. #include "llvm/IR/Constants.h"
  22. #include "llvm/IR/IRBuilder.h"
  23. #include "llvm/IR/InstrTypes.h"
  24. #include "llvm/IR/Instruction.h"
  25. #include "llvm/IR/IntrinsicInst.h"
  26. #include "llvm/IR/Type.h"
  27. #include "llvm/IR/Value.h"
  28. #include "llvm/Support/Alignment.h"
  29. namespace llvm {
  30. class Function;
  31. class Twine;
  32. class Module;
  33. class MatrixBuilder {
  34. IRBuilderBase &B;
  35. Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
  36. std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
  37. Value *RHS) {
  38. assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
  39. "One of the operands must be a matrix (embedded in a vector)");
  40. if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  41. assert(!isa<ScalableVectorType>(LHS->getType()) &&
  42. "LHS Assumed to be fixed width");
  43. RHS = B.CreateVectorSplat(
  44. cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  45. "scalar.splat");
  46. } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  47. assert(!isa<ScalableVectorType>(RHS->getType()) &&
  48. "RHS Assumed to be fixed width");
  49. LHS = B.CreateVectorSplat(
  50. cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  51. "scalar.splat");
  52. }
  53. return {LHS, RHS};
  54. }
  55. public:
  56. MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
  57. /// Create a column major, strided matrix load.
  58. /// \p EltTy - Matrix element type
  59. /// \p DataPtr - Start address of the matrix read
  60. /// \p Rows - Number of rows in matrix (must be a constant)
  61. /// \p Columns - Number of columns in matrix (must be a constant)
  62. /// \p Stride - Space between columns
  63. CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
  64. Value *Stride, bool IsVolatile, unsigned Rows,
  65. unsigned Columns, const Twine &Name = "") {
  66. auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
  67. Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
  68. B.getInt32(Columns)};
  69. Type *OverloadedTypes[] = {RetType, Stride->getType()};
  70. Function *TheFn = Intrinsic::getDeclaration(
  71. getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
  72. CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  73. Attribute AlignAttr =
  74. Attribute::getWithAlignment(Call->getContext(), Alignment);
  75. Call->addParamAttr(0, AlignAttr);
  76. return Call;
  77. }
  78. /// Create a column major, strided matrix store.
  79. /// \p Matrix - Matrix to store
  80. /// \p Ptr - Pointer to write back to
  81. /// \p Stride - Space between columns
  82. CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
  83. Value *Stride, bool IsVolatile,
  84. unsigned Rows, unsigned Columns,
  85. const Twine &Name = "") {
  86. Value *Ops[] = {Matrix, Ptr,
  87. Stride, B.getInt1(IsVolatile),
  88. B.getInt32(Rows), B.getInt32(Columns)};
  89. Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
  90. Function *TheFn = Intrinsic::getDeclaration(
  91. getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
  92. CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  93. Attribute AlignAttr =
  94. Attribute::getWithAlignment(Call->getContext(), Alignment);
  95. Call->addParamAttr(1, AlignAttr);
  96. return Call;
  97. }
  98. /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
  99. /// rows and \p Columns columns.
  100. CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
  101. unsigned Columns, const Twine &Name = "") {
  102. auto *OpType = cast<VectorType>(Matrix->getType());
  103. auto *ReturnType =
  104. FixedVectorType::get(OpType->getElementType(), Rows * Columns);
  105. Type *OverloadedTypes[] = {ReturnType};
  106. Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
  107. Function *TheFn = Intrinsic::getDeclaration(
  108. getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
  109. return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  110. }
  111. /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
  112. /// RHS.
  113. CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
  114. unsigned LHSColumns, unsigned RHSColumns,
  115. const Twine &Name = "") {
  116. auto *LHSType = cast<VectorType>(LHS->getType());
  117. auto *RHSType = cast<VectorType>(RHS->getType());
  118. auto *ReturnType =
  119. FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
  120. Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
  121. B.getInt32(RHSColumns)};
  122. Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
  123. Function *TheFn = Intrinsic::getDeclaration(
  124. getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
  125. return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  126. }
  127. /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
  128. /// ColumnIdx).
  129. Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
  130. Value *ColumnIdx, unsigned NumRows) {
  131. return B.CreateInsertElement(
  132. Matrix, NewVal,
  133. B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
  134. ColumnIdx->getType(), NumRows)),
  135. RowIdx));
  136. }
  137. /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
  138. /// matrixes.
  139. Value *CreateAdd(Value *LHS, Value *RHS) {
  140. assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
  141. if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  142. assert(!isa<ScalableVectorType>(LHS->getType()) &&
  143. "LHS Assumed to be fixed width");
  144. RHS = B.CreateVectorSplat(
  145. cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  146. "scalar.splat");
  147. } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  148. assert(!isa<ScalableVectorType>(RHS->getType()) &&
  149. "RHS Assumed to be fixed width");
  150. LHS = B.CreateVectorSplat(
  151. cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  152. "scalar.splat");
  153. }
  154. return cast<VectorType>(LHS->getType())
  155. ->getElementType()
  156. ->isFloatingPointTy()
  157. ? B.CreateFAdd(LHS, RHS)
  158. : B.CreateAdd(LHS, RHS);
  159. }
  160. /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
  161. /// point matrixes.
  162. Value *CreateSub(Value *LHS, Value *RHS) {
  163. assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
  164. if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  165. assert(!isa<ScalableVectorType>(LHS->getType()) &&
  166. "LHS Assumed to be fixed width");
  167. RHS = B.CreateVectorSplat(
  168. cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  169. "scalar.splat");
  170. } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  171. assert(!isa<ScalableVectorType>(RHS->getType()) &&
  172. "RHS Assumed to be fixed width");
  173. LHS = B.CreateVectorSplat(
  174. cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  175. "scalar.splat");
  176. }
  177. return cast<VectorType>(LHS->getType())
  178. ->getElementType()
  179. ->isFloatingPointTy()
  180. ? B.CreateFSub(LHS, RHS)
  181. : B.CreateSub(LHS, RHS);
  182. }
  183. /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
  184. /// RHS.
  185. Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
  186. std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
  187. if (LHS->getType()->getScalarType()->isFloatingPointTy())
  188. return B.CreateFMul(LHS, RHS);
  189. return B.CreateMul(LHS, RHS);
  190. }
  191. /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
  192. /// IsUnsigned indicates whether UDiv or SDiv should be used.
  193. Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
  194. assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
  195. assert(!isa<ScalableVectorType>(LHS->getType()) &&
  196. "LHS Assumed to be fixed width");
  197. RHS =
  198. B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
  199. RHS, "scalar.splat");
  200. return cast<VectorType>(LHS->getType())
  201. ->getElementType()
  202. ->isFloatingPointTy()
  203. ? B.CreateFDiv(LHS, RHS)
  204. : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
  205. }
  206. /// Create an assumption that \p Idx is less than \p NumElements.
  207. void CreateIndexAssumption(Value *Idx, unsigned NumElements,
  208. Twine const &Name = "") {
  209. Value *NumElts =
  210. B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
  211. auto *Cmp = B.CreateICmpULT(Idx, NumElts);
  212. if (isa<ConstantInt>(Cmp))
  213. assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
  214. else
  215. B.CreateAssumption(Cmp);
  216. }
  217. /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
  218. /// a matrix with \p NumRows embedded in a vector.
  219. Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
  220. Twine const &Name = "") {
  221. unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
  222. ColumnIdx->getType()->getScalarSizeInBits());
  223. Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
  224. RowIdx = B.CreateZExt(RowIdx, IntTy);
  225. ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
  226. Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
  227. return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
  228. }
  229. };
  230. } // end namespace llvm
  231. #endif // LLVM_IR_MATRIXBUILDER_H
  232. #ifdef __GNUC__
  233. #pragma GCC diagnostic pop
  234. #endif