LowerMatrixIntrinsics.cpp 95 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495
  1. //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Lower matrix intrinsics to vector operations.
  10. //
  11. // TODO:
  12. // * Improve fusion:
  13. // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
  14. // transposed.
  15. // * Improve cost-modeling, e.g. choose different number of rows/columns
  16. // columns for tiles, consider cost of copies on alias.
  17. //
  18. //===----------------------------------------------------------------------===//
  19. #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
  20. #include "llvm/ADT/PostOrderIterator.h"
  21. #include "llvm/ADT/SmallVector.h"
  22. #include "llvm/Analysis/AliasAnalysis.h"
  23. #include "llvm/Analysis/DomTreeUpdater.h"
  24. #include "llvm/Analysis/LoopInfo.h"
  25. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  26. #include "llvm/Analysis/TargetTransformInfo.h"
  27. #include "llvm/Analysis/ValueTracking.h"
  28. #include "llvm/Analysis/VectorUtils.h"
  29. #include "llvm/IR/CFG.h"
  30. #include "llvm/IR/DataLayout.h"
  31. #include "llvm/IR/DebugInfoMetadata.h"
  32. #include "llvm/IR/Function.h"
  33. #include "llvm/IR/IRBuilder.h"
  34. #include "llvm/IR/Instructions.h"
  35. #include "llvm/IR/IntrinsicInst.h"
  36. #include "llvm/IR/MatrixBuilder.h"
  37. #include "llvm/IR/PatternMatch.h"
  38. #include "llvm/InitializePasses.h"
  39. #include "llvm/Pass.h"
  40. #include "llvm/Support/Alignment.h"
  41. #include "llvm/Support/CommandLine.h"
  42. #include "llvm/Support/Debug.h"
  43. #include "llvm/Transforms/Scalar.h"
  44. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  45. #include "llvm/Transforms/Utils/LoopUtils.h"
  46. #include "llvm/Transforms/Utils/MatrixUtils.h"
  47. #include <cmath>
  48. using namespace llvm;
  49. using namespace PatternMatch;
  50. #define DEBUG_TYPE "lower-matrix-intrinsics"
  51. static cl::opt<bool>
  52. FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
  53. cl::desc("Enable/disable fusing matrix instructions."));
  54. // TODO: Allow and use non-square tiles.
  55. static cl::opt<unsigned> TileSize(
  56. "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
  57. cl::desc(
  58. "Tile size for matrix instruction fusion using square-shaped tiles."));
  59. static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
  60. cl::Hidden,
  61. cl::desc("Generate loop nest for tiling."));
  62. static cl::opt<bool> ForceFusion(
  63. "force-fuse-matrix", cl::init(false), cl::Hidden,
  64. cl::desc("Force matrix instruction fusion even if not profitable."));
  65. static cl::opt<bool> AllowContractEnabled(
  66. "matrix-allow-contract", cl::init(false), cl::Hidden,
  67. cl::desc("Allow the use of FMAs if available and profitable. This may "
  68. "result in different results, due to less rounding error."));
  69. enum class MatrixLayoutTy { ColumnMajor, RowMajor };
  70. static cl::opt<MatrixLayoutTy> MatrixLayout(
  71. "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
  72. cl::desc("Sets the default matrix layout"),
  73. cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
  74. "Use column-major layout"),
  75. clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
  76. "Use row-major layout")));
  77. static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
  78. cl::init(false));
  79. /// Helper function to either return Scope, if it is a subprogram or the
  80. /// attached subprogram for a local scope.
  81. static DISubprogram *getSubprogram(DIScope *Scope) {
  82. if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
  83. return Subprogram;
  84. return cast<DILocalScope>(Scope)->getSubprogram();
  85. }
  86. /// Erase \p V from \p BB and move \II forward to avoid invalidating
  87. /// iterators.
  88. static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
  89. BasicBlock &BB) {
  90. auto *Inst = cast<Instruction>(V);
  91. // Still used, don't erase.
  92. if (!Inst->use_empty())
  93. return;
  94. if (II != BB.rend() && Inst == &*II)
  95. ++II;
  96. Inst->eraseFromParent();
  97. }
  98. /// Return true if V is a splat of a value (which is used when multiplying a
  99. /// matrix with a scalar).
  100. static bool isSplat(Value *V) {
  101. if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
  102. return SV->isZeroEltSplat();
  103. return false;
  104. }
  105. /// Match any mul operation (fp or integer).
  106. template <typename LTy, typename RTy>
  107. auto m_AnyMul(const LTy &L, const RTy &R) {
  108. return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
  109. }
  110. /// Match any add operation (fp or integer).
  111. template <typename LTy, typename RTy>
  112. auto m_AnyAdd(const LTy &L, const RTy &R) {
  113. return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
  114. }
  115. namespace {
  116. // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
  117. // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
  118. // assuming \p Stride elements between start two consecutive vectors.
  119. // \p Stride must be >= \p NumElements.
  120. // For column-major matrixes, the function computes the address of a column
  121. // vectors and \p NumElements must be set to the number of elements in a column
  122. // (= number of rows of the matrix). For row-major matrixes, the function
  123. // computes the address of a row vector and \p NumElements must be set to the
  124. // number of elements in a column (= number of columns of the matrix).
  125. //
  126. // Consider a 4x4 matrix in column-mjaor layout like below
  127. //
  128. // 0 1 2 3
  129. // 0 v_0_0 v_0_1 v_0_2 v_0_3
  130. // 1 v_1_0 v_1_1 v_1_2 v_1_3
  131. // 2 v_2_0 v_2_1 v_2_2 v_2_3
  132. // 3 v_3_0 v_3_1 v_3_2 v_3_3
  133. // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
  134. // we need a pointer to the first element of the submatrix as base pointer.
  135. // Then we can use computeVectorAddr to compute the addresses for the columns
  136. // of the sub-matrix.
  137. //
  138. // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
  139. // -> just returns Base
  140. // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
  141. // -> returns Base + (1 * 4)
  142. // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
  143. // -> returns Base + (2 * 4)
  144. //
  145. // The graphic below illustrates the number of elements in a column (marked
  146. // with |) and the number of skipped elements (marked with }).
  147. //
  148. // v_0_0 v_0_1 {v_0_2 {v_0_3
  149. // Base Col 1 Col 2
  150. // | | |
  151. // v_1_0 |v_1_1 |v_1_2 |v_1_3
  152. // v_2_0 |v_2_1 |v_2_2 |v_2_3
  153. // v_3_0 {v_3_1 {v_3_2 v_3_3
  154. //
  155. Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
  156. unsigned NumElements, Type *EltType,
  157. IRBuilder<> &Builder) {
  158. assert((!isa<ConstantInt>(Stride) ||
  159. cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
  160. "Stride must be >= the number of elements in the result vector.");
  161. unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
  162. // Compute the start of the vector with index VecIdx as VecIdx * Stride.
  163. Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
  164. // Get pointer to the start of the selected vector. Skip GEP creation,
  165. // if we select vector 0.
  166. if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
  167. VecStart = BasePtr;
  168. else
  169. VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
  170. // Cast elementwise vector start pointer to a pointer to a vector
  171. // (EltType x NumElements)*.
  172. auto *VecType = FixedVectorType::get(EltType, NumElements);
  173. Type *VecPtrType = PointerType::get(VecType, AS);
  174. return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
  175. }
  176. /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
  177. ///
  178. /// Currently, the lowering for each matrix intrinsic is done as follows:
  179. /// 1. Propagate the shape information from intrinsics to connected
  180. /// instructions.
  181. /// 2. Lower instructions with shape information (assuming column-major layout).
  182. /// The lowering works similarly using row-major layout.
  183. /// 2.1. Get column vectors for each argument. If we already lowered the
  184. /// definition of an argument, use the produced column vectors directly.
  185. /// If not, split the operand vector containing an embedded matrix into
  186. /// a set of column vectors,
  187. /// 2.2. Lower the instruction in terms of column major operations, which
  188. /// yields a set of column vectors containing result matrix. Note that we
  189. /// lower all instructions that have shape information. Besides the
  190. /// intrinsics, this includes stores for example.
  191. /// 2.3. Update uses of the lowered instruction. If we have shape information
  192. /// for a user, there is nothing to do, as we will look up the result
  193. /// column matrix when lowering the user. For other uses, we embed the
  194. /// result matrix in a flat vector and update the use.
  195. /// 2.4. Cache the result column matrix for the instruction we lowered
  196. /// 3. After we lowered all instructions in a function, remove the now
  197. /// obsolete instructions.
  198. ///
  199. class LowerMatrixIntrinsics {
  200. Function &Func;
  201. const DataLayout &DL;
  202. const TargetTransformInfo &TTI;
  203. AliasAnalysis *AA;
  204. DominatorTree *DT;
  205. LoopInfo *LI;
  206. OptimizationRemarkEmitter *ORE;
  207. /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
  208. struct OpInfoTy {
  209. /// Number of stores emitted to generate this matrix.
  210. unsigned NumStores = 0;
  211. /// Number of loads emitted to generate this matrix.
  212. unsigned NumLoads = 0;
  213. /// Number of compute operations emitted to generate this matrix.
  214. unsigned NumComputeOps = 0;
  215. /// Most of the time transposes can be fused with matrix multiplies or can
  216. /// be folded away via algebraic simplifications. This is the number of
  217. /// transposes that we failed to make "free" via such optimizations.
  218. unsigned NumExposedTransposes = 0;
  219. OpInfoTy &operator+=(const OpInfoTy &RHS) {
  220. NumStores += RHS.NumStores;
  221. NumLoads += RHS.NumLoads;
  222. NumComputeOps += RHS.NumComputeOps;
  223. NumExposedTransposes += RHS.NumExposedTransposes;
  224. return *this;
  225. }
  226. };
  227. /// Wrapper class representing a matrix as a set of vectors, either in row or
  228. /// column major layout. All vectors must have the same vector type.
  229. class MatrixTy {
  230. SmallVector<Value *, 16> Vectors;
  231. OpInfoTy OpInfo;
  232. bool IsColumnMajor = true;
  233. public:
  234. MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
  235. MatrixTy(ArrayRef<Value *> Vectors)
  236. : Vectors(Vectors.begin(), Vectors.end()),
  237. IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
  238. MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
  239. : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
  240. unsigned D = isColumnMajor() ? NumColumns : NumRows;
  241. for (unsigned J = 0; J < D; ++J)
  242. addVector(UndefValue::get(FixedVectorType::get(
  243. EltTy, isColumnMajor() ? NumRows : NumColumns)));
  244. }
  245. Value *getVector(unsigned i) const { return Vectors[i]; }
  246. Value *getColumn(unsigned i) const {
  247. assert(isColumnMajor() && "only supported for column-major matrixes");
  248. return Vectors[i];
  249. }
  250. Value *getRow(unsigned i) const {
  251. assert(!isColumnMajor() && "only supported for row-major matrixes");
  252. return Vectors[i];
  253. }
  254. void setVector(unsigned i, Value *V) { Vectors[i] = V; }
  255. Type *getElementType() const { return getVectorTy()->getElementType(); }
  256. unsigned getNumVectors() const {
  257. if (isColumnMajor())
  258. return getNumColumns();
  259. return getNumRows();
  260. }
  261. unsigned getNumColumns() const {
  262. if (isColumnMajor())
  263. return Vectors.size();
  264. else {
  265. assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
  266. return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
  267. }
  268. }
  269. unsigned getNumRows() const {
  270. if (isColumnMajor()) {
  271. assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
  272. return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
  273. } else
  274. return Vectors.size();
  275. }
  276. void addVector(Value *V) { Vectors.push_back(V); }
  277. VectorType *getColumnTy() {
  278. assert(isColumnMajor() && "only supported for column-major matrixes");
  279. return getVectorTy();
  280. }
  281. VectorType *getVectorTy() const {
  282. return cast<VectorType>(Vectors[0]->getType());
  283. }
  284. iterator_range<SmallVector<Value *, 8>::iterator> columns() {
  285. assert(isColumnMajor() &&
  286. "columns() only supported for column-major matrixes");
  287. return make_range(Vectors.begin(), Vectors.end());
  288. }
  289. iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
  290. return make_range(Vectors.begin(), Vectors.end());
  291. }
  292. /// Embed the vectors of the matrix into a flat vector by concatenating
  293. /// them.
  294. Value *embedInVector(IRBuilder<> &Builder) const {
  295. return Vectors.size() == 1 ? Vectors[0]
  296. : concatenateVectors(Builder, Vectors);
  297. }
  298. MatrixTy &addNumLoads(unsigned N) {
  299. OpInfo.NumLoads += N;
  300. return *this;
  301. }
  302. void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
  303. MatrixTy &addNumStores(unsigned N) {
  304. OpInfo.NumStores += N;
  305. return *this;
  306. }
  307. MatrixTy &addNumExposedTransposes(unsigned N) {
  308. OpInfo.NumExposedTransposes += N;
  309. return *this;
  310. }
  311. MatrixTy &addNumComputeOps(unsigned N) {
  312. OpInfo.NumComputeOps += N;
  313. return *this;
  314. }
  315. unsigned getNumStores() const { return OpInfo.NumStores; }
  316. unsigned getNumLoads() const { return OpInfo.NumLoads; }
  317. unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
  318. const OpInfoTy &getOpInfo() const { return OpInfo; }
  319. bool isColumnMajor() const { return IsColumnMajor; }
  320. unsigned getStride() const {
  321. if (isColumnMajor())
  322. return getNumRows();
  323. return getNumColumns();
  324. }
  325. /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
  326. /// matrix is column-major, the result vector is extracted from a column
  327. /// vector, otherwise from a row vector.
  328. Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
  329. IRBuilder<> &Builder) const {
  330. Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
  331. assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
  332. NumElts &&
  333. "Extracted vector will contain poison values");
  334. return Builder.CreateShuffleVector(
  335. Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
  336. "block");
  337. }
  338. };
  339. struct ShapeInfo {
  340. unsigned NumRows;
  341. unsigned NumColumns;
  342. bool IsColumnMajor;
  343. ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
  344. : NumRows(NumRows), NumColumns(NumColumns),
  345. IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
  346. ShapeInfo(Value *NumRows, Value *NumColumns)
  347. : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
  348. cast<ConstantInt>(NumColumns)->getZExtValue()) {}
  349. bool operator==(const ShapeInfo &other) {
  350. return NumRows == other.NumRows && NumColumns == other.NumColumns;
  351. }
  352. bool operator!=(const ShapeInfo &other) { return !(*this == other); }
  353. /// Returns true if shape-information is defined, meaning both dimensions
  354. /// are != 0.
  355. operator bool() const {
  356. assert(NumRows == 0 || NumColumns != 0);
  357. return NumRows != 0;
  358. }
  359. unsigned getStride() const {
  360. if (IsColumnMajor)
  361. return NumRows;
  362. return NumColumns;
  363. }
  364. unsigned getNumVectors() const {
  365. if (IsColumnMajor)
  366. return NumColumns;
  367. return NumRows;
  368. }
  369. /// Returns the transposed shape.
  370. ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
  371. };
  372. /// Maps instructions to their shape information. The shape information
  373. /// describes the shape to be used while lowering. This matches the shape of
  374. /// the result value of the instruction, with the only exceptions being store
  375. /// instructions and the matrix_column_major_store intrinsics. For those, the
  376. /// shape information indicates that those instructions should be lowered
  377. /// using shape information as well. A ValueMap is used so that when
  378. /// sub-passes like optimizeTransposes performs RAUW the map stays
  379. /// up-to-date.
  380. ValueMap<Value *, ShapeInfo> ShapeMap;
  381. /// List of instructions to remove. While lowering, we are not replacing all
  382. /// users of a lowered instruction, if shape information is available and
  383. /// those need to be removed after we finished lowering.
  384. SmallVector<Instruction *, 16> ToRemove;
  385. /// Map from instructions to their produced column matrix.
  386. MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
  387. private:
  388. static FastMathFlags getFastMathFlags(Instruction *Inst) {
  389. FastMathFlags FMF;
  390. if (isa<FPMathOperator>(*Inst))
  391. FMF = Inst->getFastMathFlags();
  392. FMF.setAllowContract(AllowContractEnabled || FMF.allowContract());
  393. return FMF;
  394. }
  395. public:
  396. LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
  397. AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
  398. OptimizationRemarkEmitter *ORE)
  399. : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
  400. LI(LI), ORE(ORE) {}
  401. unsigned getNumOps(Type *VT) {
  402. assert(isa<VectorType>(VT) && "Expected vector type");
  403. return getNumOps(VT->getScalarType(),
  404. cast<FixedVectorType>(VT)->getNumElements());
  405. }
  406. /// Is this the minimal version executed in the backend pipelines.
  407. bool isMinimal() const {
  408. return !DT;
  409. }
  410. /// Return the estimated number of vector ops required for an operation on
  411. /// \p VT * N.
  412. unsigned getNumOps(Type *ST, unsigned N) {
  413. return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
  414. double(TTI.getRegisterBitWidth(
  415. TargetTransformInfo::RGK_FixedWidthVector)
  416. .getFixedValue()));
  417. }
  418. /// Return the set of vectors that a matrix value is lowered to.
  419. ///
  420. /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
  421. /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
  422. /// into vectors.
  423. MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
  424. IRBuilder<> &Builder) {
  425. VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
  426. assert(VType && "MatrixVal must be a vector type");
  427. assert(cast<FixedVectorType>(VType)->getNumElements() ==
  428. SI.NumRows * SI.NumColumns &&
  429. "The vector size must match the number of matrix elements");
  430. // Check if we lowered MatrixVal using shape information. In that case,
  431. // return the existing matrix, if it matches the requested shape
  432. // information. If there is a mis-match, embed the result in a flat
  433. // vector and split it later.
  434. auto Found = Inst2ColumnMatrix.find(MatrixVal);
  435. if (Found != Inst2ColumnMatrix.end()) {
  436. MatrixTy &M = Found->second;
  437. // Return the found matrix, if its shape matches the requested shape
  438. // information
  439. if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
  440. return M;
  441. MatrixVal = M.embedInVector(Builder);
  442. }
  443. // Otherwise split MatrixVal.
  444. SmallVector<Value *, 16> SplitVecs;
  445. for (unsigned MaskStart = 0;
  446. MaskStart < cast<FixedVectorType>(VType)->getNumElements();
  447. MaskStart += SI.getStride()) {
  448. Value *V = Builder.CreateShuffleVector(
  449. MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
  450. "split");
  451. SplitVecs.push_back(V);
  452. }
  453. return {SplitVecs};
  454. }
  455. /// If \p V already has a known shape return false. Otherwise set the shape
  456. /// for instructions that support it.
  457. bool setShapeInfo(Value *V, ShapeInfo Shape) {
  458. assert(Shape && "Shape not set");
  459. if (isa<UndefValue>(V) || !supportsShapeInfo(V))
  460. return false;
  461. auto SIter = ShapeMap.find(V);
  462. if (SIter != ShapeMap.end()) {
  463. LLVM_DEBUG(dbgs() << " not overriding existing shape: "
  464. << SIter->second.NumRows << " "
  465. << SIter->second.NumColumns << " for " << *V << "\n");
  466. return false;
  467. }
  468. ShapeMap.insert({V, Shape});
  469. LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
  470. << " for " << *V << "\n");
  471. return true;
  472. }
  473. bool isUniformShape(Value *V) {
  474. Instruction *I = dyn_cast<Instruction>(V);
  475. if (!I)
  476. return true;
  477. switch (I->getOpcode()) {
  478. case Instruction::FAdd:
  479. case Instruction::FSub:
  480. case Instruction::FMul: // Scalar multiply.
  481. case Instruction::FNeg:
  482. case Instruction::Add:
  483. case Instruction::Mul:
  484. case Instruction::Sub:
  485. return true;
  486. default:
  487. return false;
  488. }
  489. }
  490. /// Returns true if shape information can be used for \p V. The supported
  491. /// instructions must match the instructions that can be lowered by this pass.
  492. bool supportsShapeInfo(Value *V) {
  493. Instruction *Inst = dyn_cast<Instruction>(V);
  494. if (!Inst)
  495. return false;
  496. IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
  497. if (II)
  498. switch (II->getIntrinsicID()) {
  499. case Intrinsic::matrix_multiply:
  500. case Intrinsic::matrix_transpose:
  501. case Intrinsic::matrix_column_major_load:
  502. case Intrinsic::matrix_column_major_store:
  503. return true;
  504. default:
  505. return false;
  506. }
  507. return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
  508. }
  509. /// Propagate the shape information of instructions to their users.
  510. /// The work list contains instructions for which we can compute the shape,
  511. /// either based on the information provided by matrix intrinsics or known
  512. /// shapes of operands.
  513. SmallVector<Instruction *, 32>
  514. propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
  515. SmallVector<Instruction *, 32> NewWorkList;
  516. // Pop an element for which we guaranteed to have at least one of the
  517. // operand shapes. Add the shape for this and then add users to the work
  518. // list.
  519. LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
  520. while (!WorkList.empty()) {
  521. Instruction *Inst = WorkList.pop_back_val();
  522. // New entry, set the value and insert operands
  523. bool Propagate = false;
  524. Value *MatrixA;
  525. Value *MatrixB;
  526. Value *M;
  527. Value *N;
  528. Value *K;
  529. if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
  530. m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
  531. m_Value(N), m_Value(K)))) {
  532. Propagate = setShapeInfo(Inst, {M, K});
  533. } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
  534. m_Value(MatrixA), m_Value(M), m_Value(N)))) {
  535. // Flip dimensions.
  536. Propagate = setShapeInfo(Inst, {N, M});
  537. } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
  538. m_Value(MatrixA), m_Value(), m_Value(),
  539. m_Value(), m_Value(M), m_Value(N)))) {
  540. Propagate = setShapeInfo(Inst, {N, M});
  541. } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
  542. m_Value(), m_Value(), m_Value(), m_Value(M),
  543. m_Value(N)))) {
  544. Propagate = setShapeInfo(Inst, {M, N});
  545. } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
  546. auto OpShape = ShapeMap.find(MatrixA);
  547. if (OpShape != ShapeMap.end())
  548. setShapeInfo(Inst, OpShape->second);
  549. continue;
  550. } else if (isUniformShape(Inst)) {
  551. // Find the first operand that has a known shape and use that.
  552. for (auto &Op : Inst->operands()) {
  553. auto OpShape = ShapeMap.find(Op.get());
  554. if (OpShape != ShapeMap.end()) {
  555. Propagate |= setShapeInfo(Inst, OpShape->second);
  556. break;
  557. }
  558. }
  559. }
  560. if (Propagate) {
  561. NewWorkList.push_back(Inst);
  562. for (auto *User : Inst->users())
  563. if (ShapeMap.count(User) == 0)
  564. WorkList.push_back(cast<Instruction>(User));
  565. }
  566. }
  567. return NewWorkList;
  568. }
  569. /// Propagate the shape to operands of instructions with shape information.
  570. /// \p Worklist contains the instruction for which we already know the shape.
  571. SmallVector<Instruction *, 32>
  572. propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
  573. SmallVector<Instruction *, 32> NewWorkList;
  574. auto pushInstruction = [](Value *V,
  575. SmallVectorImpl<Instruction *> &WorkList) {
  576. Instruction *I = dyn_cast<Instruction>(V);
  577. if (I)
  578. WorkList.push_back(I);
  579. };
  580. // Pop an element with known shape. Traverse the operands, if their shape
  581. // derives from the result shape and is unknown, add it and add them to the
  582. // worklist.
  583. LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
  584. while (!WorkList.empty()) {
  585. Value *V = WorkList.pop_back_val();
  586. size_t BeforeProcessingV = WorkList.size();
  587. if (!isa<Instruction>(V))
  588. continue;
  589. Value *MatrixA;
  590. Value *MatrixB;
  591. Value *M;
  592. Value *N;
  593. Value *K;
  594. if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
  595. m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
  596. m_Value(N), m_Value(K)))) {
  597. if (setShapeInfo(MatrixA, {M, N}))
  598. pushInstruction(MatrixA, WorkList);
  599. if (setShapeInfo(MatrixB, {N, K}))
  600. pushInstruction(MatrixB, WorkList);
  601. } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
  602. m_Value(MatrixA), m_Value(M), m_Value(N)))) {
  603. // Flip dimensions.
  604. if (setShapeInfo(MatrixA, {M, N}))
  605. pushInstruction(MatrixA, WorkList);
  606. } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
  607. m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
  608. m_Value(M), m_Value(N)))) {
  609. if (setShapeInfo(MatrixA, {M, N})) {
  610. pushInstruction(MatrixA, WorkList);
  611. }
  612. } else if (isa<LoadInst>(V) ||
  613. match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
  614. // Nothing to do, no matrix input.
  615. } else if (isa<StoreInst>(V)) {
  616. // Nothing to do. We forward-propagated to this so we would just
  617. // backward propagate to an instruction with an already known shape.
  618. } else if (isUniformShape(V)) {
  619. // Propagate to all operands.
  620. ShapeInfo Shape = ShapeMap[V];
  621. for (Use &U : cast<Instruction>(V)->operands()) {
  622. if (setShapeInfo(U.get(), Shape))
  623. pushInstruction(U.get(), WorkList);
  624. }
  625. }
  626. // After we discovered new shape info for new instructions in the
  627. // worklist, we use their users as seeds for the next round of forward
  628. // propagation.
  629. for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
  630. for (User *U : WorkList[I]->users())
  631. if (isa<Instruction>(U) && V != U)
  632. NewWorkList.push_back(cast<Instruction>(U));
  633. }
  634. return NewWorkList;
  635. }
  636. /// (Op0 op Op1)^T -> Op0^T op Op1^T
  637. /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
  638. /// them on both sides of \p Operation.
  639. Instruction *distributeTransposes(
  640. Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
  641. MatrixBuilder &Builder,
  642. function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
  643. Operation) {
  644. Value *T0 = Builder.CreateMatrixTranspose(
  645. Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
  646. // We are being run after shape prop, add shape for newly created
  647. // instructions so that we lower them later.
  648. setShapeInfo(T0, Shape0.t());
  649. Value *T1 = Builder.CreateMatrixTranspose(
  650. Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
  651. setShapeInfo(T1, Shape1.t());
  652. return Operation(T0, Shape0.t(), T1, Shape1.t());
  653. }
  654. void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
  655. // We need to remove Old from the ShapeMap otherwise RAUW will replace it
  656. // with New. We should only add New it it supportsShapeInfo so we insert
  657. // it conditionally instead.
  658. auto S = ShapeMap.find(&Old);
  659. if (S != ShapeMap.end()) {
  660. ShapeMap.erase(S);
  661. if (supportsShapeInfo(New))
  662. ShapeMap.insert({New, S->second});
  663. }
  664. Old.replaceAllUsesWith(New);
  665. }
  666. /// Sink a top-level transpose inside matmuls and adds.
  667. /// This creates and erases instructions as needed, and returns the newly
  668. /// created instruction while updating the iterator to avoid invalidation. If
  669. /// this returns nullptr, no new instruction was created.
  670. Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) {
  671. BasicBlock &BB = *I.getParent();
  672. IRBuilder<> IB(&I);
  673. MatrixBuilder Builder(IB);
  674. Value *TA, *TAMA, *TAMB;
  675. ConstantInt *R, *K, *C;
  676. if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
  677. m_Value(TA), m_ConstantInt(R), m_ConstantInt(C))))
  678. return nullptr;
  679. // Transpose of a transpose is a nop
  680. Value *TATA;
  681. if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
  682. updateShapeAndReplaceAllUsesWith(I, TATA);
  683. eraseFromParentAndMove(&I, II, BB);
  684. eraseFromParentAndMove(TA, II, BB);
  685. return nullptr;
  686. }
  687. // k^T -> k
  688. if (isSplat(TA)) {
  689. updateShapeAndReplaceAllUsesWith(I, TA);
  690. eraseFromParentAndMove(&I, II, BB);
  691. return nullptr;
  692. }
  693. // (A * B)^t -> B^t * A^t
  694. // RxK KxC CxK KxR
  695. if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
  696. m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
  697. m_ConstantInt(K), m_ConstantInt(C)))) {
  698. auto NewInst = distributeTransposes(
  699. TAMB, {K, C}, TAMA, {R, K}, Builder,
  700. [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
  701. return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
  702. Shape0.NumColumns,
  703. Shape1.NumColumns, "mmul");
  704. });
  705. updateShapeAndReplaceAllUsesWith(I, NewInst);
  706. eraseFromParentAndMove(&I, II, BB);
  707. eraseFromParentAndMove(TA, II, BB);
  708. return NewInst;
  709. }
  710. // Same as above, but with a mul, which occurs when multiplied
  711. // with a scalar.
  712. // (A * k)^t -> A^t * k
  713. // R x C RxC
  714. if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
  715. (isSplat(TAMA) || isSplat(TAMB))) {
  716. IRBuilder<> LocalBuilder(&I);
  717. // We know that the transposed operand is of shape RxC.
  718. // An when multiplied with a scalar, the shape is preserved.
  719. auto NewInst = distributeTransposes(
  720. TAMA, {R, C}, TAMB, {R, C}, Builder,
  721. [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
  722. bool IsFP = I.getType()->isFPOrFPVectorTy();
  723. auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
  724. : LocalBuilder.CreateMul(T0, T1, "mmul");
  725. auto *Result = cast<Instruction>(Mul);
  726. setShapeInfo(Result, Shape0);
  727. return Result;
  728. });
  729. updateShapeAndReplaceAllUsesWith(I, NewInst);
  730. eraseFromParentAndMove(&I, II, BB);
  731. eraseFromParentAndMove(TA, II, BB);
  732. return NewInst;
  733. }
  734. // (A + B)^t -> A^t + B^t
  735. // RxC RxC CxR CxR
  736. if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
  737. IRBuilder<> LocalBuilder(&I);
  738. auto NewInst = distributeTransposes(
  739. TAMA, {R, C}, TAMB, {R, C}, Builder,
  740. [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
  741. auto *FAdd =
  742. cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
  743. setShapeInfo(FAdd, Shape0);
  744. return FAdd;
  745. });
  746. updateShapeAndReplaceAllUsesWith(I, NewInst);
  747. eraseFromParentAndMove(&I, II, BB);
  748. eraseFromParentAndMove(TA, II, BB);
  749. return NewInst;
  750. }
  751. return nullptr;
  752. }
  753. void liftTranspose(Instruction &I) {
  754. // Erase dead Instructions after lifting transposes from binops.
  755. auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
  756. if (T.use_empty())
  757. T.eraseFromParent();
  758. if (A->use_empty())
  759. cast<Instruction>(A)->eraseFromParent();
  760. if (A != B && B->use_empty())
  761. cast<Instruction>(B)->eraseFromParent();
  762. };
  763. Value *A, *B, *AT, *BT;
  764. ConstantInt *R, *K, *C;
  765. // A^t * B ^t -> (B * A)^t
  766. if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
  767. m_Value(A), m_Value(B), m_ConstantInt(R),
  768. m_ConstantInt(K), m_ConstantInt(C))) &&
  769. match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
  770. match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
  771. IRBuilder<> IB(&I);
  772. MatrixBuilder Builder(IB);
  773. Value *M = Builder.CreateMatrixMultiply(
  774. BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
  775. setShapeInfo(M, {C, R});
  776. Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
  777. R->getZExtValue());
  778. updateShapeAndReplaceAllUsesWith(I, NewInst);
  779. CleanupBinOp(I, A, B);
  780. }
  781. // A^t + B ^t -> (A + B)^t
  782. else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
  783. match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
  784. m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
  785. match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
  786. m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
  787. IRBuilder<> Builder(&I);
  788. Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
  789. setShapeInfo(Add, {C, R});
  790. MatrixBuilder MBuilder(Builder);
  791. Instruction *NewInst = MBuilder.CreateMatrixTranspose(
  792. Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
  793. updateShapeAndReplaceAllUsesWith(I, NewInst);
  794. CleanupBinOp(I, A, B);
  795. }
  796. }
  797. /// Try moving transposes in order to fold them away or into multiplies.
  798. void optimizeTransposes() {
  799. // First sink all transposes inside matmuls and adds, hoping that we end up
  800. // with NN, NT or TN variants.
  801. for (BasicBlock &BB : reverse(Func)) {
  802. for (auto II = BB.rbegin(); II != BB.rend();) {
  803. Instruction &I = *II;
  804. // We may remove II. By default continue on the next/prev instruction.
  805. ++II;
  806. if (Instruction *NewInst = sinkTranspose(I, II))
  807. II = std::next(BasicBlock::reverse_iterator(NewInst));
  808. }
  809. }
  810. // If we have a TT matmul or a TT add, lift the transpose. We may be able
  811. // to fold into consuming multiply or add.
  812. for (BasicBlock &BB : Func) {
  813. for (Instruction &I : llvm::make_early_inc_range(BB)) {
  814. liftTranspose(I);
  815. }
  816. }
  817. }
  818. bool Visit() {
  819. SmallVector<Instruction *, 32> WorkList;
  820. // Initially only the shape of matrix intrinsics is known.
  821. // Initialize the work list with ops carrying shape information.
  822. for (BasicBlock &BB : Func)
  823. for (Instruction &Inst : BB) {
  824. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
  825. if (!II)
  826. continue;
  827. switch (II->getIntrinsicID()) {
  828. case Intrinsic::matrix_multiply:
  829. case Intrinsic::matrix_transpose:
  830. case Intrinsic::matrix_column_major_load:
  831. case Intrinsic::matrix_column_major_store:
  832. WorkList.push_back(&Inst);
  833. break;
  834. default:
  835. break;
  836. }
  837. }
  838. // Avoid unnecessary work if there are no matrix intrinsics in the function.
  839. if (WorkList.empty())
  840. return false;
  841. // Propagate shapes until nothing changes any longer.
  842. while (!WorkList.empty()) {
  843. WorkList = propagateShapeForward(WorkList);
  844. WorkList = propagateShapeBackward(WorkList);
  845. }
  846. if (!isMinimal()) {
  847. optimizeTransposes();
  848. if (PrintAfterTransposeOpt) {
  849. dbgs() << "Dump after matrix transpose optimization:\n";
  850. Func.print(dbgs());
  851. }
  852. }
  853. bool Changed = false;
  854. SmallVector<CallInst *, 16> MaybeFusableInsts;
  855. SmallVector<Instruction *, 16> MatrixInsts;
  856. // First, collect all instructions with shape information and candidates for
  857. // fusion (currently only matrix multiplies).
  858. ReversePostOrderTraversal<Function *> RPOT(&Func);
  859. for (auto *BB : RPOT)
  860. for (Instruction &I : *BB) {
  861. if (ShapeMap.find(&I) == ShapeMap.end())
  862. continue;
  863. if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
  864. MaybeFusableInsts.push_back(cast<CallInst>(&I));
  865. MatrixInsts.push_back(&I);
  866. }
  867. // Second, try to fuse candidates.
  868. SmallPtrSet<Instruction *, 16> FusedInsts;
  869. for (CallInst *CI : MaybeFusableInsts)
  870. LowerMatrixMultiplyFused(CI, FusedInsts);
  871. Changed = !FusedInsts.empty();
  872. // Third, lower remaining instructions with shape information.
  873. for (Instruction *Inst : MatrixInsts) {
  874. if (FusedInsts.count(Inst))
  875. continue;
  876. IRBuilder<> Builder(Inst);
  877. if (CallInst *CInst = dyn_cast<CallInst>(Inst))
  878. Changed |= VisitCallInst(CInst);
  879. Value *Op1;
  880. Value *Op2;
  881. if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
  882. Changed |= VisitBinaryOperator(BinOp);
  883. if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
  884. Changed |= VisitUnaryOperator(UnOp);
  885. if (match(Inst, m_Load(m_Value(Op1))))
  886. Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
  887. else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
  888. Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
  889. }
  890. if (ORE) {
  891. RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
  892. RemarkGen.emitRemarks();
  893. }
  894. // Delete the instructions backwards, as it has a reduced likelihood of
  895. // having to update as many def-use and use-def chains.
  896. //
  897. // Because we add to ToRemove during fusion we can't guarantee that defs
  898. // are before uses. Change uses to poison temporarily as these should get
  899. // removed as well.
  900. //
  901. // For verification, we keep track of where we changed uses to poison in
  902. // PoisonedInsts and then check that we in fact remove them.
  903. SmallSet<Instruction *, 16> PoisonedInsts;
  904. for (auto *Inst : reverse(ToRemove)) {
  905. for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
  906. if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
  907. PoisonedInsts.insert(Poisoned);
  908. U.set(PoisonValue::get(Inst->getType()));
  909. }
  910. Inst->eraseFromParent();
  911. PoisonedInsts.erase(Inst);
  912. }
  913. if (!PoisonedInsts.empty()) {
  914. // If we didn't remove all poisoned instructions, it's a hard error.
  915. dbgs() << "Poisoned but present instructions:\n";
  916. for (auto *I : PoisonedInsts)
  917. dbgs() << *I << "\n";
  918. llvm_unreachable("Poisoned but instruction not removed");
  919. }
  920. return Changed;
  921. }
  922. /// Turns \p BasePtr into an elementwise pointer to \p EltType.
  923. Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
  924. unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
  925. Type *EltPtrType = PointerType::get(EltType, AS);
  926. return Builder.CreatePointerCast(BasePtr, EltPtrType);
  927. }
  928. /// Replace intrinsic calls
  929. bool VisitCallInst(CallInst *Inst) {
  930. if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
  931. return false;
  932. switch (Inst->getCalledFunction()->getIntrinsicID()) {
  933. case Intrinsic::matrix_multiply:
  934. LowerMultiply(Inst);
  935. break;
  936. case Intrinsic::matrix_transpose:
  937. LowerTranspose(Inst);
  938. break;
  939. case Intrinsic::matrix_column_major_load:
  940. LowerColumnMajorLoad(Inst);
  941. break;
  942. case Intrinsic::matrix_column_major_store:
  943. LowerColumnMajorStore(Inst);
  944. break;
  945. default:
  946. return false;
  947. }
  948. return true;
  949. }
  950. /// Compute the alignment for a column/row \p Idx with \p Stride between them.
  951. /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
  952. /// ConstantInt, reduce the initial alignment based on the byte offset. For
  953. /// non-ConstantInt strides, return the common alignment of the initial
  954. /// alignment and the element size in bytes.
  955. Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
  956. MaybeAlign A) const {
  957. Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
  958. if (Idx == 0)
  959. return InitialAlign;
  960. TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
  961. if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
  962. uint64_t StrideInBytes =
  963. ConstStride->getZExtValue() * ElementSizeInBits / 8;
  964. return commonAlignment(InitialAlign, Idx * StrideInBytes);
  965. }
  966. return commonAlignment(InitialAlign, ElementSizeInBits / 8);
  967. }
  968. /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
  969. /// vectors.
  970. MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
  971. bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
  972. auto *VType = cast<VectorType>(Ty);
  973. Type *EltTy = VType->getElementType();
  974. Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
  975. Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
  976. MatrixTy Result;
  977. for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
  978. Value *GEP = computeVectorAddr(
  979. EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
  980. Stride, Shape.getStride(), EltTy, Builder);
  981. Value *Vector = Builder.CreateAlignedLoad(
  982. VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
  983. IsVolatile, "col.load");
  984. Result.addVector(Vector);
  985. }
  986. return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
  987. Result.getNumVectors());
  988. }
  989. /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
  990. /// starting at \p MatrixPtr[I][J].
  991. MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
  992. ShapeInfo MatrixShape, Value *I, Value *J,
  993. ShapeInfo ResultShape, Type *EltTy,
  994. IRBuilder<> &Builder) {
  995. Value *Offset = Builder.CreateAdd(
  996. Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
  997. unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
  998. Value *EltPtr =
  999. Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
  1000. Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
  1001. auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
  1002. ResultShape.NumColumns);
  1003. Type *TilePtrTy = PointerType::get(TileTy, AS);
  1004. Value *TilePtr =
  1005. Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
  1006. return loadMatrix(TileTy, TilePtr, Align,
  1007. Builder.getInt64(MatrixShape.getStride()), IsVolatile,
  1008. ResultShape, Builder);
  1009. }
  1010. /// Lower a load instruction with shape information.
  1011. void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
  1012. bool IsVolatile, ShapeInfo Shape) {
  1013. IRBuilder<> Builder(Inst);
  1014. finalizeLowering(Inst,
  1015. loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
  1016. Shape, Builder),
  1017. Builder);
  1018. }
  1019. /// Lowers llvm.matrix.column.major.load.
  1020. ///
  1021. /// The intrinsic loads a matrix from memory using a stride between columns.
  1022. void LowerColumnMajorLoad(CallInst *Inst) {
  1023. assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
  1024. "Intrinsic only supports column-major layout!");
  1025. Value *Ptr = Inst->getArgOperand(0);
  1026. Value *Stride = Inst->getArgOperand(1);
  1027. LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
  1028. cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
  1029. {Inst->getArgOperand(3), Inst->getArgOperand(4)});
  1030. }
  1031. /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
  1032. /// MatrixPtr[I][J].
  1033. void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
  1034. MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
  1035. Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
  1036. Value *Offset = Builder.CreateAdd(
  1037. Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
  1038. unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
  1039. Value *EltPtr =
  1040. Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
  1041. Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
  1042. auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
  1043. StoreVal.getNumColumns());
  1044. Type *TilePtrTy = PointerType::get(TileTy, AS);
  1045. Value *TilePtr =
  1046. Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
  1047. storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
  1048. Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
  1049. }
  1050. /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
  1051. /// vectors.
  1052. MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
  1053. MaybeAlign MAlign, Value *Stride, bool IsVolatile,
  1054. IRBuilder<> &Builder) {
  1055. auto VType = cast<VectorType>(Ty);
  1056. Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
  1057. for (auto Vec : enumerate(StoreVal.vectors())) {
  1058. Value *GEP = computeVectorAddr(
  1059. EltPtr,
  1060. Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
  1061. Vec.index()),
  1062. Stride, StoreVal.getStride(), VType->getElementType(), Builder);
  1063. Builder.CreateAlignedStore(Vec.value(), GEP,
  1064. getAlignForIndex(Vec.index(), Stride,
  1065. VType->getElementType(),
  1066. MAlign),
  1067. IsVolatile);
  1068. }
  1069. return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
  1070. StoreVal.getNumVectors());
  1071. }
  1072. /// Lower a store instruction with shape information.
  1073. void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
  1074. Value *Stride, bool IsVolatile, ShapeInfo Shape) {
  1075. IRBuilder<> Builder(Inst);
  1076. auto StoreVal = getMatrix(Matrix, Shape, Builder);
  1077. finalizeLowering(Inst,
  1078. storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
  1079. IsVolatile, Builder),
  1080. Builder);
  1081. }
  1082. /// Lowers llvm.matrix.column.major.store.
  1083. ///
  1084. /// The intrinsic store a matrix back memory using a stride between columns.
  1085. void LowerColumnMajorStore(CallInst *Inst) {
  1086. assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
  1087. "Intrinsic only supports column-major layout!");
  1088. Value *Matrix = Inst->getArgOperand(0);
  1089. Value *Ptr = Inst->getArgOperand(1);
  1090. Value *Stride = Inst->getArgOperand(2);
  1091. LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
  1092. cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
  1093. {Inst->getArgOperand(4), Inst->getArgOperand(5)});
  1094. }
  1095. // Set elements I..I+NumElts-1 to Block
  1096. Value *insertVector(Value *Col, unsigned I, Value *Block,
  1097. IRBuilder<> &Builder) {
  1098. // First, bring Block to the same size as Col
  1099. unsigned BlockNumElts =
  1100. cast<FixedVectorType>(Block->getType())->getNumElements();
  1101. unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
  1102. assert(NumElts >= BlockNumElts && "Too few elements for current block");
  1103. Block = Builder.CreateShuffleVector(
  1104. Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
  1105. // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
  1106. // 8, 4, 5, 6
  1107. SmallVector<int, 16> Mask;
  1108. unsigned i;
  1109. for (i = 0; i < I; i++)
  1110. Mask.push_back(i);
  1111. unsigned VecNumElts =
  1112. cast<FixedVectorType>(Col->getType())->getNumElements();
  1113. for (; i < I + BlockNumElts; i++)
  1114. Mask.push_back(i - I + VecNumElts);
  1115. for (; i < VecNumElts; i++)
  1116. Mask.push_back(i);
  1117. return Builder.CreateShuffleVector(Col, Block, Mask);
  1118. }
  1119. Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
  1120. IRBuilder<> &Builder, bool AllowContraction,
  1121. unsigned &NumComputeOps) {
  1122. NumComputeOps += getNumOps(A->getType());
  1123. if (!Sum)
  1124. return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
  1125. if (UseFPOp) {
  1126. if (AllowContraction) {
  1127. // Use fmuladd for floating point operations and let the backend decide
  1128. // if that's profitable.
  1129. Function *FMulAdd = Intrinsic::getDeclaration(
  1130. Func.getParent(), Intrinsic::fmuladd, A->getType());
  1131. return Builder.CreateCall(FMulAdd, {A, B, Sum});
  1132. }
  1133. NumComputeOps += getNumOps(A->getType());
  1134. Value *Mul = Builder.CreateFMul(A, B);
  1135. return Builder.CreateFAdd(Sum, Mul);
  1136. }
  1137. NumComputeOps += getNumOps(A->getType());
  1138. Value *Mul = Builder.CreateMul(A, B);
  1139. return Builder.CreateAdd(Sum, Mul);
  1140. }
  1141. /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
  1142. /// users with shape information, there's nothing to do: they will use the
  1143. /// cached value when they are lowered. For other users, \p Matrix is
  1144. /// flattened and the uses are updated to use it. Also marks \p Inst for
  1145. /// deletion.
  1146. void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
  1147. IRBuilder<> &Builder) {
  1148. auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
  1149. (void)inserted;
  1150. assert(inserted.second && "multiple matrix lowering mapping");
  1151. ToRemove.push_back(Inst);
  1152. Value *Flattened = nullptr;
  1153. for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
  1154. if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
  1155. if (!Flattened)
  1156. Flattened = Matrix.embedInVector(Builder);
  1157. U.set(Flattened);
  1158. }
  1159. }
  1160. }
  1161. /// Compute \p Result += \p A * \p B for input matrices with left-associating
  1162. /// addition.
  1163. ///
  1164. /// We can fold a transpose into the operand that is used to extract scalars.
  1165. /// This is the first operands with row-major and the second with
  1166. /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
  1167. /// operand is transposed.
  1168. void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
  1169. const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
  1170. bool IsScalarMatrixTransposed, FastMathFlags FMF) {
  1171. const unsigned VF = std::max<unsigned>(
  1172. TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
  1173. .getFixedValue() /
  1174. Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
  1175. 1U);
  1176. unsigned R = Result.getNumRows();
  1177. unsigned C = Result.getNumColumns();
  1178. unsigned M = A.getNumColumns();
  1179. bool IsFP = Result.getElementType()->isFloatingPointTy();
  1180. assert(A.isColumnMajor() == B.isColumnMajor() &&
  1181. Result.isColumnMajor() == A.isColumnMajor() &&
  1182. "operands must agree on matrix layout");
  1183. unsigned NumComputeOps = 0;
  1184. Builder.setFastMathFlags(FMF);
  1185. if (A.isColumnMajor()) {
  1186. // Multiply columns from the first operand with scalars from the second
  1187. // operand. Then move along the K axes and accumulate the columns. With
  1188. // this the adds can be vectorized without reassociation.
  1189. for (unsigned J = 0; J < C; ++J) {
  1190. unsigned BlockSize = VF;
  1191. // If Result is zero, we don't need to accumulate in the K==0 iteration.
  1192. bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
  1193. for (unsigned I = 0; I < R; I += BlockSize) {
  1194. // Gradually lower the vectorization factor to cover the remainder.
  1195. while (I + BlockSize > R)
  1196. BlockSize /= 2;
  1197. Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
  1198. : nullptr;
  1199. for (unsigned K = 0; K < M; ++K) {
  1200. Value *L = A.extractVector(I, K, BlockSize, Builder);
  1201. Value *RH = Builder.CreateExtractElement(
  1202. B.getColumn(IsScalarMatrixTransposed ? K : J),
  1203. IsScalarMatrixTransposed ? J : K);
  1204. Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
  1205. Sum =
  1206. createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
  1207. IsFP, Builder, FMF.allowContract(), NumComputeOps);
  1208. }
  1209. Result.setVector(J,
  1210. insertVector(Result.getVector(J), I, Sum, Builder));
  1211. }
  1212. }
  1213. } else {
  1214. // Multiply rows from the second operand with scalars from the first
  1215. // operand. Then move along the K axes and accumulate the rows. With this
  1216. // the adds can be vectorized without reassociation.
  1217. for (unsigned I = 0; I < R; ++I) {
  1218. unsigned BlockSize = VF;
  1219. bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
  1220. for (unsigned J = 0; J < C; J += BlockSize) {
  1221. // Gradually lower the vectorization factor to cover the remainder.
  1222. while (J + BlockSize > C)
  1223. BlockSize /= 2;
  1224. Value *Sum = nullptr;
  1225. for (unsigned K = 0; K < M; ++K) {
  1226. Value *R = B.extractVector(K, J, BlockSize, Builder);
  1227. Value *LH = Builder.CreateExtractElement(
  1228. A.getVector(IsScalarMatrixTransposed ? K : I),
  1229. IsScalarMatrixTransposed ? I : K);
  1230. Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
  1231. Sum =
  1232. createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
  1233. IsFP, Builder, FMF.allowContract(), NumComputeOps);
  1234. }
  1235. Result.setVector(I,
  1236. insertVector(Result.getVector(I), J, Sum, Builder));
  1237. }
  1238. }
  1239. }
  1240. Result.addNumComputeOps(NumComputeOps);
  1241. }
  1242. /// Ensure that the memory in \p Load does not alias \p Store by potentially
  1243. /// copying it to a new location. This new or otherwise the original location
  1244. /// is returned.
  1245. Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
  1246. CallInst *MatMul) {
  1247. MemoryLocation StoreLoc = MemoryLocation::get(Store);
  1248. MemoryLocation LoadLoc = MemoryLocation::get(Load);
  1249. // If we can statically determine noalias we're good.
  1250. if (AA->isNoAlias(LoadLoc, StoreLoc))
  1251. return Load->getPointerOperand();
  1252. // Create code to check if the memory locations of the Load and Store
  1253. // overlap and if they do, copy Load's operand to a new buffer.
  1254. // First, create new blocks for 2n part of the check and the copy.
  1255. BasicBlock *Check0 = MatMul->getParent();
  1256. // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
  1257. // DT. Manually collect dominator tree updates, to avoid unnecessary work,
  1258. // as we adjust Check0 and Check1's branches.
  1259. SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
  1260. for (BasicBlock *Succ : successors(Check0))
  1261. DTUpdates.push_back({DT->Delete, Check0, Succ});
  1262. BasicBlock *Check1 =
  1263. SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
  1264. nullptr, "alias_cont");
  1265. BasicBlock *Copy =
  1266. SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
  1267. nullptr, "copy");
  1268. BasicBlock *Fusion =
  1269. SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
  1270. nullptr, "no_alias");
  1271. // Check if the loaded memory location begins before the end of the store
  1272. // location. If the condition holds, they might overlap, otherwise they are
  1273. // guaranteed to not overlap.
  1274. IRBuilder<> Builder(MatMul);
  1275. Check0->getTerminator()->eraseFromParent();
  1276. Builder.SetInsertPoint(Check0);
  1277. Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
  1278. Value *StoreBegin = Builder.CreatePtrToInt(
  1279. const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
  1280. Value *StoreEnd = Builder.CreateAdd(
  1281. StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
  1282. "store.end", true, true);
  1283. Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
  1284. IntPtrTy, "load.begin");
  1285. Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
  1286. Fusion);
  1287. // Check if the store begins before the end of the load location. If the
  1288. // condition holds, they alias, otherwise they are guaranteed to not
  1289. // overlap.
  1290. Check1->getTerminator()->eraseFromParent();
  1291. Builder.SetInsertPoint(Check1, Check1->begin());
  1292. Value *LoadEnd = Builder.CreateAdd(
  1293. LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
  1294. "load.end", true, true);
  1295. Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
  1296. Fusion);
  1297. // Copy load operand to new alloca.
  1298. Builder.SetInsertPoint(Copy, Copy->begin());
  1299. auto *VT = cast<FixedVectorType>(Load->getType());
  1300. // Use an array type for the alloca, to avoid potentially huge alignment
  1301. // requirements for large vector types.
  1302. auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
  1303. AllocaInst *Alloca =
  1304. Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
  1305. Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
  1306. Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
  1307. Load->getAlign(), LoadLoc.Size.getValue());
  1308. Builder.SetInsertPoint(Fusion, Fusion->begin());
  1309. PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
  1310. PHI->addIncoming(Load->getPointerOperand(), Check0);
  1311. PHI->addIncoming(Load->getPointerOperand(), Check1);
  1312. PHI->addIncoming(BC, Copy);
  1313. // Adjust DT.
  1314. DTUpdates.push_back({DT->Insert, Check0, Check1});
  1315. DTUpdates.push_back({DT->Insert, Check0, Fusion});
  1316. DTUpdates.push_back({DT->Insert, Check1, Copy});
  1317. DTUpdates.push_back({DT->Insert, Check1, Fusion});
  1318. DT->applyUpdates(DTUpdates);
  1319. return PHI;
  1320. }
  1321. bool isFusionProfitable(CallInst *MatMul) {
  1322. if (ForceFusion)
  1323. return true;
  1324. ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
  1325. ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
  1326. const unsigned R = LShape.NumRows;
  1327. const unsigned C = RShape.NumColumns;
  1328. const unsigned M = LShape.NumColumns;
  1329. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1330. const unsigned VF = std::max<unsigned>(
  1331. TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
  1332. .getFixedValue() /
  1333. EltType->getPrimitiveSizeInBits().getFixedValue(),
  1334. 1U);
  1335. // Cost model for tiling
  1336. //
  1337. // For tiling to be beneficial, we need reuse either along the R or
  1338. // the C axis. We vectorize along the R axis so that means at least
  1339. // 3 elements.
  1340. // TODO: Also consider cost of copying if operands alias.
  1341. if (R <= VF && C == 1)
  1342. return false;
  1343. // Then we need enough elements to exceed the number of vector
  1344. // registers we have. Note that this is an oversimplification since
  1345. // fusing also takes some extra loads which may exceed the number of
  1346. // reloads necessary.
  1347. unsigned Op0Regs = (R + VF - 1) / VF * M;
  1348. unsigned Op1Regs = (M + VF - 1) / VF * C;
  1349. return Op0Regs + Op1Regs >
  1350. TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));
  1351. }
  1352. MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
  1353. MatrixTy Res;
  1354. auto *ColumType = FixedVectorType::get(EltType, R);
  1355. for (unsigned I = 0; I < C; ++I)
  1356. Res.addVector(ConstantAggregateZero::get(ColumType));
  1357. return Res;
  1358. }
  1359. void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
  1360. Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
  1361. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1362. // Create the main tiling loop nest.
  1363. TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
  1364. DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
  1365. Instruction *InsertI = cast<Instruction>(MatMul);
  1366. BasicBlock *Start = InsertI->getParent();
  1367. BasicBlock *End =
  1368. SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
  1369. IRBuilder<> Builder(MatMul);
  1370. BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
  1371. Type *TileVecTy =
  1372. FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
  1373. MatrixTy TileResult;
  1374. // Insert in the inner loop header.
  1375. Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
  1376. // Create PHI nodes for the result columns to accumulate across iterations.
  1377. SmallVector<PHINode *, 4> ColumnPhis;
  1378. for (unsigned I = 0; I < TileSize; I++) {
  1379. auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
  1380. Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
  1381. TI.RowLoop.Header->getSingleSuccessor());
  1382. TileResult.addVector(Phi);
  1383. ColumnPhis.push_back(Phi);
  1384. }
  1385. // Insert in the inner loop body, which computes
  1386. // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
  1387. Builder.SetInsertPoint(InnerBody->getTerminator());
  1388. // Load tiles of the operands.
  1389. MatrixTy A =
  1390. loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
  1391. {TileSize, TileSize}, EltType, Builder);
  1392. MatrixTy B =
  1393. loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
  1394. {TileSize, TileSize}, EltType, Builder);
  1395. emitMatrixMultiply(TileResult, A, B, Builder, true, false,
  1396. getFastMathFlags(MatMul));
  1397. // Store result after the inner loop is done.
  1398. Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
  1399. storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
  1400. Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
  1401. TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
  1402. for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
  1403. ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
  1404. // Force unrolling of a few iterations of the inner loop, to make sure there
  1405. // is enough work per iteration.
  1406. // FIXME: The unroller should make this decision directly instead, but
  1407. // currently the cost-model is not up to the task.
  1408. unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
  1409. addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
  1410. "llvm.loop.unroll.count", InnerLoopUnrollCount);
  1411. }
  1412. void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
  1413. StoreInst *Store,
  1414. SmallPtrSetImpl<Instruction *> &FusedInsts) {
  1415. assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
  1416. "Tiling only supported for column-major matrixes at the moment!");
  1417. if (!isFusionProfitable(MatMul))
  1418. return;
  1419. ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
  1420. ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
  1421. const unsigned R = LShape.NumRows;
  1422. const unsigned C = RShape.NumColumns;
  1423. const unsigned M = LShape.NumColumns;
  1424. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1425. Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
  1426. Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
  1427. Value *CPtr = Store->getPointerOperand();
  1428. if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
  1429. createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
  1430. else {
  1431. IRBuilder<> Builder(Store);
  1432. for (unsigned J = 0; J < C; J += TileSize)
  1433. for (unsigned I = 0; I < R; I += TileSize) {
  1434. const unsigned TileR = std::min(R - I, unsigned(TileSize));
  1435. const unsigned TileC = std::min(C - J, unsigned(TileSize));
  1436. MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
  1437. for (unsigned K = 0; K < M; K += TileSize) {
  1438. const unsigned TileM = std::min(M - K, unsigned(TileSize));
  1439. MatrixTy A =
  1440. loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
  1441. LShape, Builder.getInt64(I), Builder.getInt64(K),
  1442. {TileR, TileM}, EltType, Builder);
  1443. MatrixTy B =
  1444. loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
  1445. RShape, Builder.getInt64(K), Builder.getInt64(J),
  1446. {TileM, TileC}, EltType, Builder);
  1447. emitMatrixMultiply(Res, A, B, Builder, true, false,
  1448. getFastMathFlags(MatMul));
  1449. }
  1450. storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
  1451. Builder.getInt64(I), Builder.getInt64(J), EltType,
  1452. Builder);
  1453. }
  1454. }
  1455. // Mark eliminated instructions as fused and remove them.
  1456. FusedInsts.insert(Store);
  1457. FusedInsts.insert(MatMul);
  1458. Store->eraseFromParent();
  1459. MatMul->eraseFromParent();
  1460. if (LoadOp0->hasNUses(0)) {
  1461. FusedInsts.insert(LoadOp0);
  1462. LoadOp0->eraseFromParent();
  1463. }
  1464. if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
  1465. FusedInsts.insert(LoadOp1);
  1466. LoadOp1->eraseFromParent();
  1467. }
  1468. }
  1469. /// Try to lower matrix multiply chains by fusing operations.
  1470. ///
  1471. /// Call finalizeLowering on lowered instructions. Instructions that are
  1472. /// completely eliminated by fusion are added to \p FusedInsts.
  1473. void LowerMatrixMultiplyFused(CallInst *MatMul,
  1474. SmallPtrSetImpl<Instruction *> &FusedInsts) {
  1475. if (!FuseMatrix || !DT)
  1476. return;
  1477. assert(AA && LI && "Analyses should be available");
  1478. Value *A = MatMul->getArgOperand(0);
  1479. Value *B = MatMul->getArgOperand(1);
  1480. // We can fold the transpose into the operand that is used to fetch scalars.
  1481. Value *T;
  1482. if (MatrixLayout == MatrixLayoutTy::ColumnMajor
  1483. ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
  1484. : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
  1485. IRBuilder<> Builder(MatMul);
  1486. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1487. ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
  1488. ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
  1489. const unsigned R = LShape.NumRows;
  1490. const unsigned M = LShape.NumColumns;
  1491. const unsigned C = RShape.NumColumns;
  1492. MatrixTy MA;
  1493. MatrixTy MB;
  1494. Value *Transpose;
  1495. if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
  1496. MA = getMatrix(A, ShapeInfo(R, M), Builder);
  1497. MB = getMatrix(T, ShapeInfo(C, M), Builder);
  1498. Transpose = B;
  1499. } else {
  1500. MA = getMatrix(T, ShapeInfo(R, M), Builder);
  1501. MB = getMatrix(B, ShapeInfo(C, M), Builder);
  1502. Transpose = A;
  1503. }
  1504. // Initialize the output
  1505. MatrixTy Result(R, C, EltType);
  1506. emitMatrixMultiply(Result, MA, MB, Builder, false, true,
  1507. getFastMathFlags(MatMul));
  1508. FusedInsts.insert(MatMul);
  1509. if (Transpose->hasOneUse()) {
  1510. FusedInsts.insert(cast<Instruction>(Transpose));
  1511. ToRemove.push_back(cast<Instruction>(Transpose));
  1512. // TODO: add a fake entry for the folded instruction so that this is
  1513. // included in the expression in the remark.
  1514. Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
  1515. }
  1516. finalizeLowering(MatMul, Result, Builder);
  1517. return;
  1518. }
  1519. if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
  1520. return;
  1521. // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
  1522. // since the single store user will be lowered as part of this.
  1523. auto *LoadOp0 = dyn_cast<LoadInst>(A);
  1524. auto *LoadOp1 = dyn_cast<LoadInst>(B);
  1525. auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
  1526. if (LoadOp0 && LoadOp1 && Store) {
  1527. // The store address must dominate the MatMul instruction, otherwise
  1528. // we create invalid IR.
  1529. SetVector<Value *> WorkList;
  1530. WorkList.insert(Store->getOperand(1));
  1531. SmallVector<Instruction *> ToHoist;
  1532. for (unsigned I = 0; I != WorkList.size(); ++I) {
  1533. Value *Current = WorkList[I];
  1534. auto *CurrI = dyn_cast<Instruction>(Current);
  1535. if (!CurrI)
  1536. continue;
  1537. if (isa<PHINode>(CurrI))
  1538. return;
  1539. if (DT->dominates(CurrI, MatMul))
  1540. continue;
  1541. if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
  1542. return;
  1543. ToHoist.push_back(CurrI);
  1544. WorkList.insert(CurrI->op_begin(), CurrI->op_end());
  1545. }
  1546. sort(ToHoist, [this](Instruction *A, Instruction *B) {
  1547. return DT->dominates(A, B);
  1548. });
  1549. for (Instruction *I : ToHoist)
  1550. I->moveBefore(MatMul);
  1551. emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
  1552. return;
  1553. }
  1554. }
  1555. /// Lowers llvm.matrix.multiply.
  1556. void LowerMultiply(CallInst *MatMul) {
  1557. IRBuilder<> Builder(MatMul);
  1558. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1559. ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
  1560. ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
  1561. const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
  1562. const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
  1563. assert(Lhs.getElementType() == Rhs.getElementType() &&
  1564. "Matrix multiply argument element types do not match.");
  1565. const unsigned R = LShape.NumRows;
  1566. const unsigned C = RShape.NumColumns;
  1567. assert(LShape.NumColumns == RShape.NumRows);
  1568. // Initialize the output
  1569. MatrixTy Result(R, C, EltType);
  1570. assert(Lhs.getElementType() == Result.getElementType() &&
  1571. "Matrix multiply result element type does not match arguments.");
  1572. emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
  1573. getFastMathFlags(MatMul));
  1574. finalizeLowering(MatMul, Result, Builder);
  1575. }
  1576. /// Lowers llvm.matrix.transpose.
  1577. void LowerTranspose(CallInst *Inst) {
  1578. MatrixTy Result;
  1579. IRBuilder<> Builder(Inst);
  1580. Value *InputVal = Inst->getArgOperand(0);
  1581. VectorType *VectorTy = cast<VectorType>(InputVal->getType());
  1582. ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
  1583. MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
  1584. const unsigned NewNumVecs =
  1585. InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
  1586. const unsigned NewNumElts =
  1587. InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
  1588. for (unsigned I = 0; I < NewNumVecs; ++I) {
  1589. // Build a single result vector. First initialize it.
  1590. Value *ResultVector = PoisonValue::get(
  1591. FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
  1592. // Go through the old elements and insert it into the resulting vector.
  1593. for (auto J : enumerate(InputMatrix.vectors())) {
  1594. Value *Elt = Builder.CreateExtractElement(J.value(), I);
  1595. // Row and column indices are transposed.
  1596. ResultVector =
  1597. Builder.CreateInsertElement(ResultVector, Elt, J.index());
  1598. }
  1599. Result.addVector(ResultVector);
  1600. }
  1601. // TODO: Improve estimate of operations needed for transposes. Currently we
  1602. // just count the insertelement/extractelement instructions, but do not
  1603. // account for later simplifications/combines.
  1604. finalizeLowering(
  1605. Inst,
  1606. Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
  1607. .addNumExposedTransposes(1),
  1608. Builder);
  1609. }
  1610. /// Lower load instructions, if shape information is available.
  1611. bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
  1612. auto I = ShapeMap.find(Inst);
  1613. if (I == ShapeMap.end())
  1614. return false;
  1615. LowerLoad(Inst, Ptr, Inst->getAlign(),
  1616. Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
  1617. I->second);
  1618. return true;
  1619. }
  1620. bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
  1621. IRBuilder<> &Builder) {
  1622. auto I = ShapeMap.find(StoredVal);
  1623. if (I == ShapeMap.end())
  1624. return false;
  1625. LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
  1626. Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
  1627. I->second);
  1628. return true;
  1629. }
  1630. /// Lower binary operators, if shape information is available.
  1631. bool VisitBinaryOperator(BinaryOperator *Inst) {
  1632. auto I = ShapeMap.find(Inst);
  1633. if (I == ShapeMap.end())
  1634. return false;
  1635. Value *Lhs = Inst->getOperand(0);
  1636. Value *Rhs = Inst->getOperand(1);
  1637. IRBuilder<> Builder(Inst);
  1638. ShapeInfo &Shape = I->second;
  1639. MatrixTy Result;
  1640. MatrixTy A = getMatrix(Lhs, Shape, Builder);
  1641. MatrixTy B = getMatrix(Rhs, Shape, Builder);
  1642. assert(A.isColumnMajor() == B.isColumnMajor() &&
  1643. Result.isColumnMajor() == A.isColumnMajor() &&
  1644. "operands must agree on matrix layout");
  1645. Builder.setFastMathFlags(getFastMathFlags(Inst));
  1646. // Helper to perform binary op on vectors.
  1647. auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
  1648. switch (Inst->getOpcode()) {
  1649. case Instruction::Add:
  1650. return Builder.CreateAdd(LHS, RHS);
  1651. case Instruction::Mul:
  1652. return Builder.CreateMul(LHS, RHS);
  1653. case Instruction::Sub:
  1654. return Builder.CreateSub(LHS, RHS);
  1655. case Instruction::FAdd:
  1656. return Builder.CreateFAdd(LHS, RHS);
  1657. case Instruction::FMul:
  1658. return Builder.CreateFMul(LHS, RHS);
  1659. case Instruction::FSub:
  1660. return Builder.CreateFSub(LHS, RHS);
  1661. default:
  1662. llvm_unreachable("Unsupported binary operator for matrix");
  1663. }
  1664. };
  1665. for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
  1666. Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
  1667. finalizeLowering(Inst,
  1668. Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
  1669. Result.getNumVectors()),
  1670. Builder);
  1671. return true;
  1672. }
  1673. /// Lower unary operators, if shape information is available.
  1674. bool VisitUnaryOperator(UnaryOperator *Inst) {
  1675. auto I = ShapeMap.find(Inst);
  1676. if (I == ShapeMap.end())
  1677. return false;
  1678. Value *Op = Inst->getOperand(0);
  1679. IRBuilder<> Builder(Inst);
  1680. ShapeInfo &Shape = I->second;
  1681. MatrixTy Result;
  1682. MatrixTy M = getMatrix(Op, Shape, Builder);
  1683. Builder.setFastMathFlags(getFastMathFlags(Inst));
  1684. // Helper to perform unary op on vectors.
  1685. auto BuildVectorOp = [&Builder, Inst](Value *Op) {
  1686. switch (Inst->getOpcode()) {
  1687. case Instruction::FNeg:
  1688. return Builder.CreateFNeg(Op);
  1689. default:
  1690. llvm_unreachable("Unsupported unary operator for matrix");
  1691. }
  1692. };
  1693. for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
  1694. Result.addVector(BuildVectorOp(M.getVector(I)));
  1695. finalizeLowering(Inst,
  1696. Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
  1697. Result.getNumVectors()),
  1698. Builder);
  1699. return true;
  1700. }
  1701. /// Helper to linearize a matrix expression tree into a string. Currently
  1702. /// matrix expressions are linarized by starting at an expression leaf and
  1703. /// linearizing bottom up.
  1704. struct ExprLinearizer {
  1705. unsigned LengthToBreak = 100;
  1706. std::string Str;
  1707. raw_string_ostream Stream;
  1708. unsigned LineLength = 0;
  1709. const DataLayout &DL;
  1710. /// Mapping from instructions to matrixes. It is used to identify
  1711. /// matrix instructions.
  1712. const MapVector<Value *, MatrixTy> &Inst2Matrix;
  1713. /// Mapping from values to the leaves of all expressions that the value is
  1714. /// part of.
  1715. const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
  1716. /// Set of matrix expressions in the scope of a given DISubprogram.
  1717. const SmallSetVector<Value *, 32> &ExprsInSubprogram;
  1718. /// Leaf node of the expression to linearize.
  1719. Value *Leaf;
  1720. /// Used to keep track of sub-expressions that get reused while linearizing
  1721. /// the expression. Re-used sub-expressions are marked as (reused).
  1722. SmallPtrSet<Value *, 8> ReusedExprs;
  1723. ExprLinearizer(const DataLayout &DL,
  1724. const MapVector<Value *, MatrixTy> &Inst2Matrix,
  1725. const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
  1726. const SmallSetVector<Value *, 32> &ExprsInSubprogram,
  1727. Value *Leaf)
  1728. : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
  1729. ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
  1730. void indent(unsigned N) {
  1731. LineLength += N;
  1732. for (unsigned i = 0; i < N; i++)
  1733. Stream << " ";
  1734. }
  1735. void lineBreak() {
  1736. Stream << "\n";
  1737. LineLength = 0;
  1738. }
  1739. void maybeIndent(unsigned Indent) {
  1740. if (LineLength >= LengthToBreak)
  1741. lineBreak();
  1742. if (LineLength == 0)
  1743. indent(Indent);
  1744. }
  1745. void write(StringRef S) {
  1746. LineLength += S.size();
  1747. Stream << S;
  1748. }
  1749. Value *getUnderlyingObjectThroughLoads(Value *V) {
  1750. if (Value *Ptr = getPointerOperand(V))
  1751. return getUnderlyingObjectThroughLoads(Ptr);
  1752. else if (V->getType()->isPointerTy())
  1753. return getUnderlyingObject(V);
  1754. return V;
  1755. }
  1756. /// Returns true if \p V is a matrix value in the given subprogram.
  1757. bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
  1758. /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
  1759. /// \p SS.
  1760. void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
  1761. auto M = Inst2Matrix.find(V);
  1762. if (M == Inst2Matrix.end())
  1763. SS << "unknown";
  1764. else {
  1765. SS << M->second.getNumRows();
  1766. SS << "x";
  1767. SS << M->second.getNumColumns();
  1768. }
  1769. }
  1770. /// Write the called function name. Handles calls to llvm.matrix.*
  1771. /// specially: we write the name, followed by the dimensions of the input
  1772. /// matrixes, followed by the scalar type name.
  1773. void writeFnName(CallInst *CI) {
  1774. if (!CI->getCalledFunction())
  1775. write("<no called fn>");
  1776. else {
  1777. StringRef Name = CI->getCalledFunction()->getName();
  1778. if (!Name.startswith("llvm.matrix")) {
  1779. write(Name);
  1780. return;
  1781. }
  1782. auto *II = cast<IntrinsicInst>(CI);
  1783. write(Intrinsic::getBaseName(II->getIntrinsicID())
  1784. .drop_front(StringRef("llvm.matrix.").size()));
  1785. write(".");
  1786. std::string Tmp;
  1787. raw_string_ostream SS(Tmp);
  1788. switch (II->getIntrinsicID()) {
  1789. case Intrinsic::matrix_multiply:
  1790. prettyPrintMatrixType(II->getOperand(0), SS);
  1791. SS << ".";
  1792. prettyPrintMatrixType(II->getOperand(1), SS);
  1793. SS << "." << *II->getType()->getScalarType();
  1794. break;
  1795. case Intrinsic::matrix_transpose:
  1796. prettyPrintMatrixType(II->getOperand(0), SS);
  1797. SS << "." << *II->getType()->getScalarType();
  1798. break;
  1799. case Intrinsic::matrix_column_major_load:
  1800. prettyPrintMatrixType(II, SS);
  1801. SS << "." << *II->getType()->getScalarType();
  1802. break;
  1803. case Intrinsic::matrix_column_major_store:
  1804. prettyPrintMatrixType(II->getOperand(0), SS);
  1805. SS << "." << *II->getOperand(0)->getType()->getScalarType();
  1806. break;
  1807. default:
  1808. llvm_unreachable("Unhandled case");
  1809. }
  1810. SS.flush();
  1811. write(Tmp);
  1812. }
  1813. }
  1814. unsigned getNumShapeArgs(CallInst *CI) const {
  1815. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
  1816. switch (II->getIntrinsicID()) {
  1817. case Intrinsic::matrix_multiply:
  1818. return 3;
  1819. case Intrinsic::matrix_transpose:
  1820. return 2;
  1821. case Intrinsic::matrix_column_major_load:
  1822. case Intrinsic::matrix_column_major_store:
  1823. return 3;
  1824. default:
  1825. return 0;
  1826. }
  1827. }
  1828. return 0;
  1829. }
  1830. /// Special printing for values: for pointers, we print if they refer to an
  1831. /// (function) external address or a stack address, for other values we
  1832. /// either print the constant or "scalar"/"matrix" for other values.
  1833. void write(Value *V) {
  1834. V = getUnderlyingObjectThroughLoads(V);
  1835. if (V->getType()->isPointerTy()) {
  1836. if (isa<AllocaInst>(V)) {
  1837. Stream << "stack addr";
  1838. LineLength += StringRef("stack addr").size();
  1839. } else {
  1840. Stream << "addr";
  1841. LineLength += StringRef("addr").size();
  1842. }
  1843. if (!V->getName().empty()) {
  1844. Stream << " %" << V->getName() << "";
  1845. LineLength += V->getName().size() + 2;
  1846. }
  1847. return;
  1848. }
  1849. std::string Tmp;
  1850. raw_string_ostream TmpStream(Tmp);
  1851. if (auto *CI = dyn_cast<ConstantInt>(V))
  1852. TmpStream << CI->getValue();
  1853. else if (isa<Constant>(V))
  1854. TmpStream << "constant";
  1855. else {
  1856. if (isMatrix(V))
  1857. TmpStream << "matrix";
  1858. else
  1859. TmpStream << "scalar";
  1860. }
  1861. TmpStream.flush();
  1862. Tmp = std::string(StringRef(Tmp).trim());
  1863. LineLength += Tmp.size();
  1864. Stream << Tmp;
  1865. }
  1866. /// Linearize expression \p Expr starting at an indentation of \p Indent.
  1867. /// Expressions that are re-used multiple times are prefixed with (reused)
  1868. /// at the re-used root instruction.
  1869. void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
  1870. bool ParentShared) {
  1871. auto *I = cast<Instruction>(Expr);
  1872. maybeIndent(Indent);
  1873. SmallVector<Value *, 8> Ops;
  1874. // Is Expr shared with other expression leaves?
  1875. bool ExprShared = false;
  1876. // Deal with shared subtrees. Mark them as shared, if required.
  1877. if (!ParentShared) {
  1878. auto SI = Shared.find(Expr);
  1879. assert(SI != Shared.end() && SI->second.count(Leaf));
  1880. for (Value *S : SI->second) {
  1881. if (S == Leaf)
  1882. continue;
  1883. DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
  1884. write("shared with remark at line " + std::to_string(DL.getLine()) +
  1885. " column " + std::to_string(DL.getCol()) + " (");
  1886. }
  1887. ExprShared = SI->second.size() > 1;
  1888. }
  1889. bool Reused = !ReusedExprs.insert(Expr).second;
  1890. if (Reused && !ParentReused)
  1891. write("(reused) ");
  1892. if (auto *CI = dyn_cast<CallInst>(I)) {
  1893. writeFnName(CI);
  1894. Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
  1895. } else if (isa<BitCastInst>(Expr)) {
  1896. // Special case bitcasts, which are used to materialize matrixes from
  1897. // non-matrix ops.
  1898. write("matrix");
  1899. return;
  1900. } else {
  1901. Ops.append(I->value_op_begin(), I->value_op_end());
  1902. write(std::string(I->getOpcodeName()));
  1903. }
  1904. write(std::string("("));
  1905. unsigned NumOpsToBreak = 1;
  1906. if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
  1907. NumOpsToBreak = 2;
  1908. for (Value *Op : Ops) {
  1909. if (Ops.size() > NumOpsToBreak)
  1910. lineBreak();
  1911. maybeIndent(Indent + 1);
  1912. if (isMatrix(Op))
  1913. linearizeExpr(Op, Indent + 1, Reused, ExprShared);
  1914. else
  1915. write(Op);
  1916. if (Op != Ops.back())
  1917. write(", ");
  1918. }
  1919. write(")");
  1920. }
  1921. const std::string &getResult() {
  1922. Stream.flush();
  1923. return Str;
  1924. }
  1925. };
  1926. /// Generate remarks for matrix operations in a function. To generate remarks
  1927. /// for matrix expressions, the following approach is used:
  1928. /// 1. Use the inlined-at debug information to group matrix operations to the
  1929. /// DISubprograms they are contained in.
  1930. /// 2. Collect leaves of matrix expressions (done in
  1931. /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
  1932. // mapping. Leaves are lowered matrix instructions without other matrix
  1933. // users (like stores) in the current subprogram.
  1934. /// 3. For each leaf, create a remark containing a linearizied version of the
  1935. /// matrix expression. The expression is linearized by a recursive
  1936. /// bottom-up traversal of the matrix operands, starting at a leaf. Note
  1937. /// that multiple leaves can share sub-expressions. Shared subexpressions
  1938. /// are explicitly marked as shared().
  1939. struct RemarkGenerator {
  1940. const MapVector<Value *, MatrixTy> &Inst2Matrix;
  1941. OptimizationRemarkEmitter &ORE;
  1942. Function &Func;
  1943. const DataLayout &DL;
  1944. RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
  1945. OptimizationRemarkEmitter &ORE, Function &Func)
  1946. : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
  1947. DL(Func.getParent()->getDataLayout()) {}
  1948. /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
  1949. /// instructions in Inst2Matrix returning void or without any users in
  1950. /// \p ExprsInSubprogram. Currently that should only include stores.
  1951. SmallVector<Value *, 4>
  1952. getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
  1953. SmallVector<Value *, 4> Leaves;
  1954. for (auto *Expr : ExprsInSubprogram)
  1955. if (Expr->getType()->isVoidTy() ||
  1956. !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
  1957. return ExprsInSubprogram.count(U);
  1958. }))
  1959. Leaves.push_back(Expr);
  1960. return Leaves;
  1961. }
  1962. /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
  1963. /// to all visited expressions in \p Shared. Limit the matrix operations to
  1964. /// the ones in \p ExprsInSubprogram.
  1965. void collectSharedInfo(Value *Leaf, Value *V,
  1966. const SmallSetVector<Value *, 32> &ExprsInSubprogram,
  1967. DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
  1968. if (!ExprsInSubprogram.count(V))
  1969. return;
  1970. auto I = Shared.insert({V, {}});
  1971. I.first->second.insert(Leaf);
  1972. for (Value *Op : cast<Instruction>(V)->operand_values())
  1973. collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
  1974. }
  1975. /// Calculate the number of exclusive and shared op counts for expression
  1976. /// starting at \p V. Expressions used multiple times are counted once.
  1977. /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
  1978. std::pair<OpInfoTy, OpInfoTy>
  1979. sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
  1980. const SmallSetVector<Value *, 32> &ExprsInSubprogram,
  1981. DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
  1982. if (!ExprsInSubprogram.count(Root))
  1983. return {};
  1984. // Already counted this expression. Stop.
  1985. if (!ReusedExprs.insert(Root).second)
  1986. return {};
  1987. OpInfoTy SharedCount;
  1988. OpInfoTy Count;
  1989. auto I = Shared.find(Root);
  1990. auto CM = Inst2Matrix.find(Root);
  1991. if (I->second.size() == 1)
  1992. Count = CM->second.getOpInfo();
  1993. else
  1994. SharedCount = CM->second.getOpInfo();
  1995. for (Value *Op : cast<Instruction>(Root)->operand_values()) {
  1996. auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
  1997. Count += C.first;
  1998. SharedCount += C.second;
  1999. }
  2000. return {Count, SharedCount};
  2001. }
  2002. void emitRemarks() {
  2003. if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
  2004. return;
  2005. // Map matrix operations to their containting subprograms, by traversing
  2006. // the inlinedAt chain. If the function does not have a DISubprogram, we
  2007. // only map them to the containing function.
  2008. MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
  2009. for (const auto &KV : Inst2Matrix) {
  2010. if (Func.getSubprogram()) {
  2011. auto *I = cast<Instruction>(KV.first);
  2012. DILocation *Context = I->getDebugLoc();
  2013. while (Context) {
  2014. auto I =
  2015. Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
  2016. I.first->second.push_back(KV.first);
  2017. Context = DebugLoc(Context).getInlinedAt();
  2018. }
  2019. } else {
  2020. auto I = Subprog2Exprs.insert({nullptr, {}});
  2021. I.first->second.push_back(KV.first);
  2022. }
  2023. }
  2024. for (auto &KV : Subprog2Exprs) {
  2025. SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
  2026. KV.second.end());
  2027. auto Leaves = getExpressionLeaves(ExprsInSubprogram);
  2028. DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
  2029. for (Value *Leaf : Leaves)
  2030. collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
  2031. // Generate remarks for each leaf.
  2032. for (auto *L : Leaves) {
  2033. DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
  2034. DILocation *Context = cast<Instruction>(L)->getDebugLoc();
  2035. while (Context) {
  2036. if (getSubprogram(Context->getScope()) == KV.first) {
  2037. Loc = Context;
  2038. break;
  2039. }
  2040. Context = DebugLoc(Context).getInlinedAt();
  2041. }
  2042. SmallPtrSet<Value *, 8> ReusedExprs;
  2043. OpInfoTy Counts, SharedCounts;
  2044. std::tie(Counts, SharedCounts) =
  2045. sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
  2046. OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
  2047. cast<Instruction>(L)->getParent());
  2048. Rem << "Lowered with ";
  2049. Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
  2050. << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
  2051. << ore::NV("NumComputeOps", Counts.NumComputeOps)
  2052. << " compute ops, "
  2053. << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
  2054. << " exposed transposes";
  2055. if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
  2056. SharedCounts.NumComputeOps > 0) {
  2057. Rem << ",\nadditionally "
  2058. << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
  2059. << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
  2060. << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
  2061. << " compute ops"
  2062. << " are shared with other expressions";
  2063. }
  2064. Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
  2065. ORE.emit(Rem);
  2066. }
  2067. }
  2068. }
  2069. std::string
  2070. linearize(Value *L,
  2071. const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
  2072. const SmallSetVector<Value *, 32> &ExprsInSubprogram,
  2073. const DataLayout &DL) {
  2074. ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
  2075. Lin.linearizeExpr(L, 0, false, false);
  2076. return Lin.getResult();
  2077. }
  2078. };
  2079. };
  2080. } // namespace
  2081. PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
  2082. FunctionAnalysisManager &AM) {
  2083. auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  2084. OptimizationRemarkEmitter *ORE = nullptr;
  2085. AAResults *AA = nullptr;
  2086. DominatorTree *DT = nullptr;
  2087. LoopInfo *LI = nullptr;
  2088. if (!Minimal) {
  2089. ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
  2090. AA = &AM.getResult<AAManager>(F);
  2091. DT = &AM.getResult<DominatorTreeAnalysis>(F);
  2092. LI = &AM.getResult<LoopAnalysis>(F);
  2093. }
  2094. LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
  2095. if (LMT.Visit()) {
  2096. PreservedAnalyses PA;
  2097. if (!Minimal) {
  2098. PA.preserve<LoopAnalysis>();
  2099. PA.preserve<DominatorTreeAnalysis>();
  2100. }
  2101. return PA;
  2102. }
  2103. return PreservedAnalyses::all();
  2104. }
  2105. void LowerMatrixIntrinsicsPass::printPipeline(
  2106. raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
  2107. static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
  2108. OS, MapClassName2PassName);
  2109. OS << "<";
  2110. if (Minimal)
  2111. OS << "minimal";
  2112. OS << ">";
  2113. }
  2114. namespace {
  2115. class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
  2116. public:
  2117. static char ID;
  2118. LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
  2119. initializeLowerMatrixIntrinsicsLegacyPassPass(
  2120. *PassRegistry::getPassRegistry());
  2121. }
  2122. bool runOnFunction(Function &F) override {
  2123. auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  2124. auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
  2125. auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
  2126. auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  2127. auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  2128. LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
  2129. bool C = LMT.Visit();
  2130. return C;
  2131. }
  2132. void getAnalysisUsage(AnalysisUsage &AU) const override {
  2133. AU.addRequired<TargetTransformInfoWrapperPass>();
  2134. AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
  2135. AU.addRequired<AAResultsWrapperPass>();
  2136. AU.addRequired<DominatorTreeWrapperPass>();
  2137. AU.addPreserved<DominatorTreeWrapperPass>();
  2138. AU.addRequired<LoopInfoWrapperPass>();
  2139. AU.addPreserved<LoopInfoWrapperPass>();
  2140. }
  2141. };
  2142. } // namespace
  2143. static const char pass_name[] = "Lower the matrix intrinsics";
  2144. char LowerMatrixIntrinsicsLegacyPass::ID = 0;
  2145. INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
  2146. false, false)
  2147. INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
  2148. INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
  2149. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  2150. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  2151. INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
  2152. false, false)
  2153. Pass *llvm::createLowerMatrixIntrinsicsPass() {
  2154. return new LowerMatrixIntrinsicsLegacyPass();
  2155. }
  2156. namespace {
  2157. /// A lightweight version of the matrix lowering pass that only requires TTI.
  2158. /// Advanced features that require DT, AA or ORE like tiling are disabled. This
  2159. /// is used to lower matrix intrinsics if the main lowering pass is not run, for
  2160. /// example with -O0.
  2161. class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
  2162. public:
  2163. static char ID;
  2164. LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
  2165. initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(
  2166. *PassRegistry::getPassRegistry());
  2167. }
  2168. bool runOnFunction(Function &F) override {
  2169. auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  2170. LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
  2171. bool C = LMT.Visit();
  2172. return C;
  2173. }
  2174. void getAnalysisUsage(AnalysisUsage &AU) const override {
  2175. AU.addRequired<TargetTransformInfoWrapperPass>();
  2176. AU.setPreservesCFG();
  2177. }
  2178. };
  2179. } // namespace
  2180. static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
  2181. char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
  2182. INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
  2183. "lower-matrix-intrinsics-minimal", pass_name_minimal,
  2184. false, false)
  2185. INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
  2186. "lower-matrix-intrinsics-minimal", pass_name_minimal, false,
  2187. false)
  2188. Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() {
  2189. return new LowerMatrixIntrinsicsMinimalLegacyPass();
  2190. }