MatrixBuilder.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // This file defines the MatrixBuilder class, which is used as a convenient way
  15. // to lower matrix operations to LLVM IR.
  16. //
  17. //===----------------------------------------------------------------------===//
  18. #ifndef LLVM_IR_MATRIXBUILDER_H
  19. #define LLVM_IR_MATRIXBUILDER_H
  20. #include "llvm/IR/Constant.h"
  21. #include "llvm/IR/Constants.h"
  22. #include "llvm/IR/IRBuilder.h"
  23. #include "llvm/IR/InstrTypes.h"
  24. #include "llvm/IR/Instruction.h"
  25. #include "llvm/IR/IntrinsicInst.h"
  26. #include "llvm/IR/Type.h"
  27. #include "llvm/IR/Value.h"
  28. #include "llvm/Support/Alignment.h"
  29. namespace llvm {
  30. class Function;
  31. class Twine;
  32. class Module;
  33. template <class IRBuilderTy> class MatrixBuilder {
  34. IRBuilderTy &B;
  35. Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
  36. std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
  37. Value *RHS) {
  38. assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
  39. "One of the operands must be a matrix (embedded in a vector)");
  40. if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  41. assert(!isa<ScalableVectorType>(LHS->getType()) &&
  42. "LHS Assumed to be fixed width");
  43. RHS = B.CreateVectorSplat(
  44. cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  45. "scalar.splat");
  46. } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  47. assert(!isa<ScalableVectorType>(RHS->getType()) &&
  48. "RHS Assumed to be fixed width");
  49. LHS = B.CreateVectorSplat(
  50. cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  51. "scalar.splat");
  52. }
  53. return {LHS, RHS};
  54. }
  55. public:
  56. MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
  57. /// Create a column major, strided matrix load.
  58. /// \p DataPtr - Start address of the matrix read
  59. /// \p Rows - Number of rows in matrix (must be a constant)
  60. /// \p Columns - Number of columns in matrix (must be a constant)
  61. /// \p Stride - Space between columns
  62. CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment,
  63. Value *Stride, bool IsVolatile, unsigned Rows,
  64. unsigned Columns, const Twine &Name = "") {
  65. // Deal with the pointer
  66. PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
  67. Type *EltTy = PtrTy->getPointerElementType();
  68. auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
  69. Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
  70. B.getInt32(Columns)};
  71. Type *OverloadedTypes[] = {RetType, Stride->getType()};
  72. Function *TheFn = Intrinsic::getDeclaration(
  73. getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
  74. CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  75. Attribute AlignAttr =
  76. Attribute::getWithAlignment(Call->getContext(), Alignment);
  77. Call->addParamAttr(0, AlignAttr);
  78. return Call;
  79. }
  80. /// Create a column major, strided matrix store.
  81. /// \p Matrix - Matrix to store
  82. /// \p Ptr - Pointer to write back to
  83. /// \p Stride - Space between columns
  84. CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
  85. Value *Stride, bool IsVolatile,
  86. unsigned Rows, unsigned Columns,
  87. const Twine &Name = "") {
  88. Value *Ops[] = {Matrix, Ptr,
  89. Stride, B.getInt1(IsVolatile),
  90. B.getInt32(Rows), B.getInt32(Columns)};
  91. Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
  92. Function *TheFn = Intrinsic::getDeclaration(
  93. getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
  94. CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  95. Attribute AlignAttr =
  96. Attribute::getWithAlignment(Call->getContext(), Alignment);
  97. Call->addParamAttr(1, AlignAttr);
  98. return Call;
  99. }
  100. /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
  101. /// rows and \p Columns columns.
  102. CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
  103. unsigned Columns, const Twine &Name = "") {
  104. auto *OpType = cast<VectorType>(Matrix->getType());
  105. auto *ReturnType =
  106. FixedVectorType::get(OpType->getElementType(), Rows * Columns);
  107. Type *OverloadedTypes[] = {ReturnType};
  108. Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
  109. Function *TheFn = Intrinsic::getDeclaration(
  110. getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
  111. return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  112. }
  113. /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
  114. /// RHS.
  115. CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
  116. unsigned LHSColumns, unsigned RHSColumns,
  117. const Twine &Name = "") {
  118. auto *LHSType = cast<VectorType>(LHS->getType());
  119. auto *RHSType = cast<VectorType>(RHS->getType());
  120. auto *ReturnType =
  121. FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
  122. Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
  123. B.getInt32(RHSColumns)};
  124. Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
  125. Function *TheFn = Intrinsic::getDeclaration(
  126. getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
  127. return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
  128. }
  129. /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
  130. /// ColumnIdx).
  131. Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
  132. Value *ColumnIdx, unsigned NumRows) {
  133. return B.CreateInsertElement(
  134. Matrix, NewVal,
  135. B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
  136. ColumnIdx->getType(), NumRows)),
  137. RowIdx));
  138. }
  139. /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
  140. /// matrixes.
  141. Value *CreateAdd(Value *LHS, Value *RHS) {
  142. assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
  143. if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  144. assert(!isa<ScalableVectorType>(LHS->getType()) &&
  145. "LHS Assumed to be fixed width");
  146. RHS = B.CreateVectorSplat(
  147. cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  148. "scalar.splat");
  149. } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  150. assert(!isa<ScalableVectorType>(RHS->getType()) &&
  151. "RHS Assumed to be fixed width");
  152. LHS = B.CreateVectorSplat(
  153. cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  154. "scalar.splat");
  155. }
  156. return cast<VectorType>(LHS->getType())
  157. ->getElementType()
  158. ->isFloatingPointTy()
  159. ? B.CreateFAdd(LHS, RHS)
  160. : B.CreateAdd(LHS, RHS);
  161. }
  162. /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
  163. /// point matrixes.
  164. Value *CreateSub(Value *LHS, Value *RHS) {
  165. assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
  166. if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
  167. assert(!isa<ScalableVectorType>(LHS->getType()) &&
  168. "LHS Assumed to be fixed width");
  169. RHS = B.CreateVectorSplat(
  170. cast<VectorType>(LHS->getType())->getElementCount(), RHS,
  171. "scalar.splat");
  172. } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
  173. assert(!isa<ScalableVectorType>(RHS->getType()) &&
  174. "RHS Assumed to be fixed width");
  175. LHS = B.CreateVectorSplat(
  176. cast<VectorType>(RHS->getType())->getElementCount(), LHS,
  177. "scalar.splat");
  178. }
  179. return cast<VectorType>(LHS->getType())
  180. ->getElementType()
  181. ->isFloatingPointTy()
  182. ? B.CreateFSub(LHS, RHS)
  183. : B.CreateSub(LHS, RHS);
  184. }
  185. /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
  186. /// RHS.
  187. Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
  188. std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
  189. if (LHS->getType()->getScalarType()->isFloatingPointTy())
  190. return B.CreateFMul(LHS, RHS);
  191. return B.CreateMul(LHS, RHS);
  192. }
  193. /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
  194. /// IsUnsigned indicates whether UDiv or SDiv should be used.
  195. Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
  196. assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
  197. assert(!isa<ScalableVectorType>(LHS->getType()) &&
  198. "LHS Assumed to be fixed width");
  199. RHS =
  200. B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
  201. RHS, "scalar.splat");
  202. return cast<VectorType>(LHS->getType())
  203. ->getElementType()
  204. ->isFloatingPointTy()
  205. ? B.CreateFDiv(LHS, RHS)
  206. : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
  207. }
  208. /// Create an assumption that \p Idx is less than \p NumElements.
  209. void CreateIndexAssumption(Value *Idx, unsigned NumElements,
  210. Twine const &Name = "") {
  211. Value *NumElts =
  212. B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
  213. auto *Cmp = B.CreateICmpULT(Idx, NumElts);
  214. if (auto *ConstCond = dyn_cast<ConstantInt>(Cmp))
  215. assert(ConstCond->isOne() && "Index must be valid!");
  216. else
  217. B.CreateAssumption(Cmp);
  218. }
  219. /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
  220. /// a matrix with \p NumRows embedded in a vector.
  221. Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
  222. Twine const &Name = "") {
  223. unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
  224. ColumnIdx->getType()->getScalarSizeInBits());
  225. Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
  226. RowIdx = B.CreateZExt(RowIdx, IntTy);
  227. ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
  228. Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
  229. return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
  230. }
  231. };
  232. } // end namespace llvm
  233. #endif // LLVM_IR_MATRIXBUILDER_H
  234. #ifdef __GNUC__
  235. #pragma GCC diagnostic pop
  236. #endif