LowerMatrixIntrinsics.cpp 90 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369
  1. //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Lower matrix intrinsics to vector operations.
  10. //
  11. // TODO:
  12. // * Improve fusion:
  13. // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
  14. // transposed.
  15. // * Improve cost-modeling, e.g. choose different number of rows/columns
  16. // columns for tiles, consider cost of copies on alias.
  17. //
  18. //===----------------------------------------------------------------------===//
  19. #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
  20. #include "llvm/ADT/GraphTraits.h"
  21. #include "llvm/ADT/PostOrderIterator.h"
  22. #include "llvm/ADT/SmallVector.h"
  23. #include "llvm/Analysis/AliasAnalysis.h"
  24. #include "llvm/Analysis/DomTreeUpdater.h"
  25. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  26. #include "llvm/Analysis/TargetTransformInfo.h"
  27. #include "llvm/Analysis/ValueTracking.h"
  28. #include "llvm/Analysis/VectorUtils.h"
  29. #include "llvm/IR/CFG.h"
  30. #include "llvm/IR/DataLayout.h"
  31. #include "llvm/IR/DebugInfoMetadata.h"
  32. #include "llvm/IR/Function.h"
  33. #include "llvm/IR/IRBuilder.h"
  34. #include "llvm/IR/Instructions.h"
  35. #include "llvm/IR/IntrinsicInst.h"
  36. #include "llvm/IR/MatrixBuilder.h"
  37. #include "llvm/IR/PatternMatch.h"
  38. #include "llvm/InitializePasses.h"
  39. #include "llvm/Pass.h"
  40. #include "llvm/Support/Alignment.h"
  41. #include "llvm/Support/CommandLine.h"
  42. #include "llvm/Support/Debug.h"
  43. #include "llvm/Transforms/Scalar.h"
  44. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  45. #include "llvm/Transforms/Utils/LoopUtils.h"
  46. #include "llvm/Transforms/Utils/MatrixUtils.h"
  47. using namespace llvm;
  48. using namespace PatternMatch;
  49. #define DEBUG_TYPE "lower-matrix-intrinsics"
  50. static cl::opt<bool>
  51. FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
  52. cl::desc("Enable/disable fusing matrix instructions."));
  53. // TODO: Allow and use non-square tiles.
  54. static cl::opt<unsigned> TileSize(
  55. "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
  56. cl::desc(
  57. "Tile size for matrix instruction fusion using square-shaped tiles."));
  58. static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
  59. cl::Hidden,
  60. cl::desc("Generate loop nest for tiling."));
  61. static cl::opt<bool> ForceFusion(
  62. "force-fuse-matrix", cl::init(false), cl::Hidden,
  63. cl::desc("Force matrix instruction fusion even if not profitable."));
  64. static cl::opt<bool> AllowContractEnabled(
  65. "matrix-allow-contract", cl::init(false), cl::Hidden,
  66. cl::desc("Allow the use of FMAs if available and profitable. This may "
  67. "result in different results, due to less rounding error."));
  68. enum class MatrixLayoutTy { ColumnMajor, RowMajor };
  69. static cl::opt<MatrixLayoutTy> MatrixLayout(
  70. "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
  71. cl::desc("Sets the default matrix layout"),
  72. cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
  73. "Use column-major layout"),
  74. clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
  75. "Use row-major layout")));
  76. /// Helper function to either return Scope, if it is a subprogram or the
  77. /// attached subprogram for a local scope.
  78. static DISubprogram *getSubprogram(DIScope *Scope) {
  79. if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
  80. return Subprogram;
  81. return cast<DILocalScope>(Scope)->getSubprogram();
  82. }
  83. namespace {
  84. // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
  85. // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
  86. // assuming \p Stride elements between start two consecutive vectors.
  87. // \p Stride must be >= \p NumElements.
  88. // For column-major matrixes, the function computes the address of a column
  89. // vectors and \p NumElements must be set to the number of elements in a column
  90. // (= number of rows of the matrix). For row-major matrixes, the function
  91. // computes the address of a row vector and \p NumElements must be set to the
  92. // number of elements in a column (= number of columns of the matrix).
  93. //
  94. // Consider a 4x4 matrix in column-mjaor layout like below
  95. //
  96. // 0 1 2 3
  97. // 0 v_0_0 v_0_1 v_0_2 v_0_3
  98. // 1 v_1_0 v_1_1 v_1_2 v_1_3
  99. // 2 v_2_0 v_2_1 v_2_2 v_2_3
  100. // 3 v_3_0 v_3_1 v_3_2 v_3_3
  101. // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
  102. // we need a pointer to the first element of the submatrix as base pointer.
  103. // Then we can use computeVectorAddr to compute the addresses for the columns
  104. // of the sub-matrix.
  105. //
  106. // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
  107. // -> just returns Base
  108. // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
  109. // -> returns Base + (1 * 4)
  110. // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
  111. // -> returns Base + (2 * 4)
  112. //
  113. // The graphic below illustrates the number of elements in a column (marked
  114. // with |) and the number of skipped elements (marked with }).
  115. //
  116. // v_0_0 v_0_1 {v_0_2 {v_0_3
  117. // Base Col 1 Col 2
  118. // | | |
  119. // v_1_0 |v_1_1 |v_1_2 |v_1_3
  120. // v_2_0 |v_2_1 |v_2_2 |v_2_3
  121. // v_3_0 {v_3_1 {v_3_2 v_3_3
  122. //
  123. Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
  124. unsigned NumElements, Type *EltType,
  125. IRBuilder<> &Builder) {
  126. assert((!isa<ConstantInt>(Stride) ||
  127. cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
  128. "Stride must be >= the number of elements in the result vector.");
  129. unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
  130. // Compute the start of the vector with index VecIdx as VecIdx * Stride.
  131. Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
  132. // Get pointer to the start of the selected vector. Skip GEP creation,
  133. // if we select vector 0.
  134. if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
  135. VecStart = BasePtr;
  136. else
  137. VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
  138. // Cast elementwise vector start pointer to a pointer to a vector
  139. // (EltType x NumElements)*.
  140. auto *VecType = FixedVectorType::get(EltType, NumElements);
  141. Type *VecPtrType = PointerType::get(VecType, AS);
  142. return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
  143. }
  144. /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
  145. ///
  146. /// Currently, the lowering for each matrix intrinsic is done as follows:
  147. /// 1. Propagate the shape information from intrinsics to connected
  148. /// instructions.
  149. /// 2. Lower instructions with shape information (assuming column-major layout).
  150. /// The lowering works similarly using row-major layout.
  151. /// 2.1. Get column vectors for each argument. If we already lowered the
  152. /// definition of an argument, use the produced column vectors directly.
  153. /// If not, split the operand vector containing an embedded matrix into
  154. /// a set of column vectors,
  155. /// 2.2. Lower the instruction in terms of column major operations, which
  156. /// yields a set of column vectors containing result matrix. Note that we
  157. /// lower all instructions that have shape information. Besides the
  158. /// intrinsics, this includes stores for example.
  159. /// 2.3. Update uses of the lowered instruction. If we have shape information
  160. /// for a user, there is nothing to do, as we will look up the result
  161. /// column matrix when lowering the user. For other uses, we embed the
  162. /// result matrix in a flat vector and update the use.
  163. /// 2.4. Cache the result column matrix for the instruction we lowered
  164. /// 3. After we lowered all instructions in a function, remove the now
  165. /// obsolete instructions.
  166. ///
  167. class LowerMatrixIntrinsics {
  168. Function &Func;
  169. const DataLayout &DL;
  170. const TargetTransformInfo &TTI;
  171. AliasAnalysis *AA;
  172. DominatorTree *DT;
  173. LoopInfo *LI;
  174. OptimizationRemarkEmitter *ORE;
  175. /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
  176. struct OpInfoTy {
  177. /// Number of stores emitted to generate this matrix.
  178. unsigned NumStores = 0;
  179. /// Number of loads emitted to generate this matrix.
  180. unsigned NumLoads = 0;
  181. /// Number of compute operations emitted to generate this matrix.
  182. unsigned NumComputeOps = 0;
  183. /// Most of the time transposes can be fused with matrix multiplies or can
  184. /// be folded away via algebraic simplifications. This is the number of
  185. /// transposes that we failed to make "free" via such optimizations.
  186. unsigned NumExposedTransposes = 0;
  187. OpInfoTy &operator+=(const OpInfoTy &RHS) {
  188. NumStores += RHS.NumStores;
  189. NumLoads += RHS.NumLoads;
  190. NumComputeOps += RHS.NumComputeOps;
  191. NumExposedTransposes += RHS.NumExposedTransposes;
  192. return *this;
  193. }
  194. };
  195. /// Wrapper class representing a matrix as a set of vectors, either in row or
  196. /// column major layout. All vectors must have the same vector type.
  197. class MatrixTy {
  198. SmallVector<Value *, 16> Vectors;
  199. OpInfoTy OpInfo;
  200. bool IsColumnMajor = true;
  201. public:
  202. MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
  203. MatrixTy(ArrayRef<Value *> Vectors)
  204. : Vectors(Vectors.begin(), Vectors.end()),
  205. IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
  206. MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
  207. : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
  208. unsigned D = isColumnMajor() ? NumColumns : NumRows;
  209. for (unsigned J = 0; J < D; ++J)
  210. addVector(UndefValue::get(FixedVectorType::get(
  211. EltTy, isColumnMajor() ? NumRows : NumColumns)));
  212. }
  213. Value *getVector(unsigned i) const { return Vectors[i]; }
  214. Value *getColumn(unsigned i) const {
  215. assert(isColumnMajor() && "only supported for column-major matrixes");
  216. return Vectors[i];
  217. }
  218. Value *getRow(unsigned i) const {
  219. assert(!isColumnMajor() && "only supported for row-major matrixes");
  220. return Vectors[i];
  221. }
  222. void setVector(unsigned i, Value *V) { Vectors[i] = V; }
  223. Type *getElementType() const { return getVectorTy()->getElementType(); }
  224. unsigned getNumVectors() const {
  225. if (isColumnMajor())
  226. return getNumColumns();
  227. return getNumRows();
  228. }
  229. unsigned getNumColumns() const {
  230. if (isColumnMajor())
  231. return Vectors.size();
  232. else {
  233. assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
  234. return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
  235. }
  236. }
  237. unsigned getNumRows() const {
  238. if (isColumnMajor()) {
  239. assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
  240. return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
  241. } else
  242. return Vectors.size();
  243. }
  244. void addVector(Value *V) { Vectors.push_back(V); }
  245. VectorType *getColumnTy() {
  246. assert(isColumnMajor() && "only supported for column-major matrixes");
  247. return getVectorTy();
  248. }
  249. VectorType *getVectorTy() const {
  250. return cast<VectorType>(Vectors[0]->getType());
  251. }
  252. iterator_range<SmallVector<Value *, 8>::iterator> columns() {
  253. assert(isColumnMajor() &&
  254. "columns() only supported for column-major matrixes");
  255. return make_range(Vectors.begin(), Vectors.end());
  256. }
  257. iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
  258. return make_range(Vectors.begin(), Vectors.end());
  259. }
  260. /// Embed the vectors of the matrix into a flat vector by concatenating
  261. /// them.
  262. Value *embedInVector(IRBuilder<> &Builder) const {
  263. return Vectors.size() == 1 ? Vectors[0]
  264. : concatenateVectors(Builder, Vectors);
  265. }
  266. MatrixTy &addNumLoads(unsigned N) {
  267. OpInfo.NumLoads += N;
  268. return *this;
  269. }
  270. void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
  271. MatrixTy &addNumStores(unsigned N) {
  272. OpInfo.NumStores += N;
  273. return *this;
  274. }
  275. MatrixTy &addNumExposedTransposes(unsigned N) {
  276. OpInfo.NumExposedTransposes += N;
  277. return *this;
  278. }
  279. MatrixTy &addNumComputeOps(unsigned N) {
  280. OpInfo.NumComputeOps += N;
  281. return *this;
  282. }
  283. unsigned getNumStores() const { return OpInfo.NumStores; }
  284. unsigned getNumLoads() const { return OpInfo.NumLoads; }
  285. unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
  286. const OpInfoTy &getOpInfo() const { return OpInfo; }
  287. bool isColumnMajor() const { return IsColumnMajor; }
  288. unsigned getStride() const {
  289. if (isColumnMajor())
  290. return getNumRows();
  291. return getNumColumns();
  292. }
  293. /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
  294. /// matrix is column-major, the result vector is extracted from a column
  295. /// vector, otherwise from a row vector.
  296. Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
  297. IRBuilder<> &Builder) const {
  298. Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
  299. return Builder.CreateShuffleVector(
  300. Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
  301. "block");
  302. }
  303. };
  304. struct ShapeInfo {
  305. unsigned NumRows;
  306. unsigned NumColumns;
  307. bool IsColumnMajor;
  308. ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
  309. : NumRows(NumRows), NumColumns(NumColumns),
  310. IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
  311. ShapeInfo(Value *NumRows, Value *NumColumns)
  312. : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
  313. cast<ConstantInt>(NumColumns)->getZExtValue()) {}
  314. bool operator==(const ShapeInfo &other) {
  315. return NumRows == other.NumRows && NumColumns == other.NumColumns;
  316. }
  317. bool operator!=(const ShapeInfo &other) { return !(*this == other); }
  318. /// Returns true if shape-information is defined, meaning both dimensions
  319. /// are != 0.
  320. operator bool() const {
  321. assert(NumRows == 0 || NumColumns != 0);
  322. return NumRows != 0;
  323. }
  324. unsigned getStride() const {
  325. if (IsColumnMajor)
  326. return NumRows;
  327. return NumColumns;
  328. }
  329. unsigned getNumVectors() const {
  330. if (IsColumnMajor)
  331. return NumColumns;
  332. return NumRows;
  333. }
  334. };
  335. /// Maps instructions to their shape information. The shape information
  336. /// describes the shape to be used while lowering. This matches the shape of
  337. /// the result value of the instruction, with the only exceptions being store
  338. /// instructions and the matrix_column_major_store intrinsics. For those, the
  339. /// shape information indicates that those instructions should be lowered
  340. /// using shape information as well. A ValueMap is used so that when
  341. /// sub-passes like optimizeTransposes performs RAUW the map stays
  342. /// up-to-date.
  343. ValueMap<Value *, ShapeInfo> ShapeMap;
  344. /// List of instructions to remove. While lowering, we are not replacing all
  345. /// users of a lowered instruction, if shape information is available and
  346. /// those need to be removed after we finished lowering.
  347. SmallVector<Instruction *, 16> ToRemove;
  348. /// Map from instructions to their produced column matrix.
  349. MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
  350. private:
  351. static FastMathFlags getFastMathFlags(Instruction *Inst) {
  352. FastMathFlags FMF;
  353. if (isa<FPMathOperator>(*Inst))
  354. FMF = Inst->getFastMathFlags();
  355. FMF.setAllowContract(AllowContractEnabled || FMF.allowContract());
  356. return FMF;
  357. }
  358. public:
  359. LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
  360. AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
  361. OptimizationRemarkEmitter *ORE)
  362. : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
  363. LI(LI), ORE(ORE) {}
  364. unsigned getNumOps(Type *VT) {
  365. assert(isa<VectorType>(VT) && "Expected vector type");
  366. return getNumOps(VT->getScalarType(),
  367. cast<FixedVectorType>(VT)->getNumElements());
  368. }
  369. /// Is this the minimal version executed in the backend pipelines.
  370. bool isMinimal() const {
  371. return !DT;
  372. }
  373. /// Return the estimated number of vector ops required for an operation on
  374. /// \p VT * N.
  375. unsigned getNumOps(Type *ST, unsigned N) {
  376. return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
  377. double(TTI.getRegisterBitWidth(
  378. TargetTransformInfo::RGK_FixedWidthVector)
  379. .getFixedSize()));
  380. }
  381. /// Return the set of vectors that a matrix value is lowered to.
  382. ///
  383. /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
  384. /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
  385. /// into vectors.
  386. MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
  387. IRBuilder<> &Builder) {
  388. VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
  389. assert(VType && "MatrixVal must be a vector type");
  390. assert(cast<FixedVectorType>(VType)->getNumElements() ==
  391. SI.NumRows * SI.NumColumns &&
  392. "The vector size must match the number of matrix elements");
  393. // Check if we lowered MatrixVal using shape information. In that case,
  394. // return the existing matrix, if it matches the requested shape
  395. // information. If there is a mis-match, embed the result in a flat
  396. // vector and split it later.
  397. auto Found = Inst2ColumnMatrix.find(MatrixVal);
  398. if (Found != Inst2ColumnMatrix.end()) {
  399. MatrixTy &M = Found->second;
  400. // Return the found matrix, if its shape matches the requested shape
  401. // information
  402. if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
  403. return M;
  404. MatrixVal = M.embedInVector(Builder);
  405. }
  406. // Otherwise split MatrixVal.
  407. SmallVector<Value *, 16> SplitVecs;
  408. for (unsigned MaskStart = 0;
  409. MaskStart < cast<FixedVectorType>(VType)->getNumElements();
  410. MaskStart += SI.getStride()) {
  411. Value *V = Builder.CreateShuffleVector(
  412. MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
  413. "split");
  414. SplitVecs.push_back(V);
  415. }
  416. return {SplitVecs};
  417. }
  418. /// If \p V already has a known shape return false. Otherwise set the shape
  419. /// for instructions that support it.
  420. bool setShapeInfo(Value *V, ShapeInfo Shape) {
  421. assert(Shape && "Shape not set");
  422. if (isa<UndefValue>(V) || !supportsShapeInfo(V))
  423. return false;
  424. auto SIter = ShapeMap.find(V);
  425. if (SIter != ShapeMap.end()) {
  426. LLVM_DEBUG(dbgs() << " not overriding existing shape: "
  427. << SIter->second.NumRows << " "
  428. << SIter->second.NumColumns << " for " << *V << "\n");
  429. return false;
  430. }
  431. ShapeMap.insert({V, Shape});
  432. LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
  433. << " for " << *V << "\n");
  434. return true;
  435. }
  436. bool isUniformShape(Value *V) {
  437. Instruction *I = dyn_cast<Instruction>(V);
  438. if (!I)
  439. return true;
  440. switch (I->getOpcode()) {
  441. case Instruction::FAdd:
  442. case Instruction::FSub:
  443. case Instruction::FMul: // Scalar multiply.
  444. case Instruction::FNeg:
  445. case Instruction::Add:
  446. case Instruction::Mul:
  447. case Instruction::Sub:
  448. return true;
  449. default:
  450. return false;
  451. }
  452. }
  453. /// Returns true if shape information can be used for \p V. The supported
  454. /// instructions must match the instructions that can be lowered by this pass.
  455. bool supportsShapeInfo(Value *V) {
  456. Instruction *Inst = dyn_cast<Instruction>(V);
  457. if (!Inst)
  458. return false;
  459. IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
  460. if (II)
  461. switch (II->getIntrinsicID()) {
  462. case Intrinsic::matrix_multiply:
  463. case Intrinsic::matrix_transpose:
  464. case Intrinsic::matrix_column_major_load:
  465. case Intrinsic::matrix_column_major_store:
  466. return true;
  467. default:
  468. return false;
  469. }
  470. return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
  471. }
  472. /// Propagate the shape information of instructions to their users.
  473. /// The work list contains instructions for which we can compute the shape,
  474. /// either based on the information provided by matrix intrinsics or known
  475. /// shapes of operands.
  476. SmallVector<Instruction *, 32>
  477. propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
  478. SmallVector<Instruction *, 32> NewWorkList;
  479. // Pop an element for which we guaranteed to have at least one of the
  480. // operand shapes. Add the shape for this and then add users to the work
  481. // list.
  482. LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
  483. while (!WorkList.empty()) {
  484. Instruction *Inst = WorkList.pop_back_val();
  485. // New entry, set the value and insert operands
  486. bool Propagate = false;
  487. Value *MatrixA;
  488. Value *MatrixB;
  489. Value *M;
  490. Value *N;
  491. Value *K;
  492. if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
  493. m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
  494. m_Value(N), m_Value(K)))) {
  495. Propagate = setShapeInfo(Inst, {M, K});
  496. } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
  497. m_Value(MatrixA), m_Value(M), m_Value(N)))) {
  498. // Flip dimensions.
  499. Propagate = setShapeInfo(Inst, {N, M});
  500. } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
  501. m_Value(MatrixA), m_Value(), m_Value(),
  502. m_Value(), m_Value(M), m_Value(N)))) {
  503. Propagate = setShapeInfo(Inst, {N, M});
  504. } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
  505. m_Value(), m_Value(), m_Value(), m_Value(M),
  506. m_Value(N)))) {
  507. Propagate = setShapeInfo(Inst, {M, N});
  508. } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
  509. auto OpShape = ShapeMap.find(MatrixA);
  510. if (OpShape != ShapeMap.end())
  511. setShapeInfo(Inst, OpShape->second);
  512. continue;
  513. } else if (isUniformShape(Inst)) {
  514. // Find the first operand that has a known shape and use that.
  515. for (auto &Op : Inst->operands()) {
  516. auto OpShape = ShapeMap.find(Op.get());
  517. if (OpShape != ShapeMap.end()) {
  518. Propagate |= setShapeInfo(Inst, OpShape->second);
  519. break;
  520. }
  521. }
  522. }
  523. if (Propagate) {
  524. NewWorkList.push_back(Inst);
  525. for (auto *User : Inst->users())
  526. if (ShapeMap.count(User) == 0)
  527. WorkList.push_back(cast<Instruction>(User));
  528. }
  529. }
  530. return NewWorkList;
  531. }
  532. /// Propagate the shape to operands of instructions with shape information.
  533. /// \p Worklist contains the instruction for which we already know the shape.
  534. SmallVector<Instruction *, 32>
  535. propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
  536. SmallVector<Instruction *, 32> NewWorkList;
  537. auto pushInstruction = [](Value *V,
  538. SmallVectorImpl<Instruction *> &WorkList) {
  539. Instruction *I = dyn_cast<Instruction>(V);
  540. if (I)
  541. WorkList.push_back(I);
  542. };
  543. // Pop an element with known shape. Traverse the operands, if their shape
  544. // derives from the result shape and is unknown, add it and add them to the
  545. // worklist.
  546. LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
  547. while (!WorkList.empty()) {
  548. Value *V = WorkList.pop_back_val();
  549. size_t BeforeProcessingV = WorkList.size();
  550. if (!isa<Instruction>(V))
  551. continue;
  552. Value *MatrixA;
  553. Value *MatrixB;
  554. Value *M;
  555. Value *N;
  556. Value *K;
  557. if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
  558. m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
  559. m_Value(N), m_Value(K)))) {
  560. if (setShapeInfo(MatrixA, {M, N}))
  561. pushInstruction(MatrixA, WorkList);
  562. if (setShapeInfo(MatrixB, {N, K}))
  563. pushInstruction(MatrixB, WorkList);
  564. } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
  565. m_Value(MatrixA), m_Value(M), m_Value(N)))) {
  566. // Flip dimensions.
  567. if (setShapeInfo(MatrixA, {M, N}))
  568. pushInstruction(MatrixA, WorkList);
  569. } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
  570. m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
  571. m_Value(M), m_Value(N)))) {
  572. if (setShapeInfo(MatrixA, {M, N})) {
  573. pushInstruction(MatrixA, WorkList);
  574. }
  575. } else if (isa<LoadInst>(V) ||
  576. match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
  577. // Nothing to do, no matrix input.
  578. } else if (isa<StoreInst>(V)) {
  579. // Nothing to do. We forward-propagated to this so we would just
  580. // backward propagate to an instruction with an already known shape.
  581. } else if (isUniformShape(V)) {
  582. // Propagate to all operands.
  583. ShapeInfo Shape = ShapeMap[V];
  584. for (Use &U : cast<Instruction>(V)->operands()) {
  585. if (setShapeInfo(U.get(), Shape))
  586. pushInstruction(U.get(), WorkList);
  587. }
  588. }
  589. // After we discovered new shape info for new instructions in the
  590. // worklist, we use their users as seeds for the next round of forward
  591. // propagation.
  592. for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
  593. for (User *U : WorkList[I]->users())
  594. if (isa<Instruction>(U) && V != U)
  595. NewWorkList.push_back(cast<Instruction>(U));
  596. }
  597. return NewWorkList;
  598. }
  599. /// Try moving transposes in order to fold them away or into multiplies.
  600. void optimizeTransposes() {
  601. auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
  602. // We need to remove Old from the ShapeMap otherwise RAUW will replace it
  603. // with New. We should only add New it it supportsShapeInfo so we insert
  604. // it conditionally instead.
  605. auto S = ShapeMap.find(&Old);
  606. if (S != ShapeMap.end()) {
  607. ShapeMap.erase(S);
  608. if (supportsShapeInfo(New))
  609. ShapeMap.insert({New, S->second});
  610. }
  611. Old.replaceAllUsesWith(New);
  612. };
  613. // First sink all transposes inside matmuls, hoping that we end up with NN,
  614. // NT or TN variants.
  615. for (BasicBlock &BB : reverse(Func)) {
  616. for (auto II = BB.rbegin(); II != BB.rend();) {
  617. Instruction &I = *II;
  618. // We may remove II. By default continue on the next/prev instruction.
  619. ++II;
  620. // If we were to erase II, move again.
  621. auto EraseFromParent = [&II](Value *V) {
  622. auto *Inst = cast<Instruction>(V);
  623. if (Inst->use_empty()) {
  624. if (Inst == &*II) {
  625. ++II;
  626. }
  627. Inst->eraseFromParent();
  628. }
  629. };
  630. // If we're creating a new instruction, continue from there.
  631. Instruction *NewInst = nullptr;
  632. IRBuilder<> IB(&I);
  633. MatrixBuilder<IRBuilder<>> Builder(IB);
  634. Value *TA, *TAMA, *TAMB;
  635. ConstantInt *R, *K, *C;
  636. if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {
  637. // Transpose of a transpose is a nop
  638. Value *TATA;
  639. if (match(TA,
  640. m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
  641. ReplaceAllUsesWith(I, TATA);
  642. EraseFromParent(&I);
  643. EraseFromParent(TA);
  644. }
  645. // (A * B)^t -> B^t * A^t
  646. // RxK KxC CxK KxR
  647. else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
  648. m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
  649. m_ConstantInt(K), m_ConstantInt(C)))) {
  650. Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
  651. C->getZExtValue(),
  652. TAMB->getName() + "_t");
  653. // We are being run after shape prop, add shape for newly created
  654. // instructions so that we lower them later.
  655. setShapeInfo(T0, {C, K});
  656. Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
  657. K->getZExtValue(),
  658. TAMA->getName() + "_t");
  659. setShapeInfo(T1, {K, R});
  660. NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
  661. K->getZExtValue(),
  662. R->getZExtValue(), "mmul");
  663. ReplaceAllUsesWith(I, NewInst);
  664. EraseFromParent(&I);
  665. EraseFromParent(TA);
  666. }
  667. }
  668. // If we replaced I with a new instruction, continue from there.
  669. if (NewInst)
  670. II = std::next(BasicBlock::reverse_iterator(NewInst));
  671. }
  672. }
  673. // If we have a TT matmul, lift the transpose. We may be able to fold into
  674. // consuming multiply.
  675. for (BasicBlock &BB : Func) {
  676. for (BasicBlock::iterator II = BB.begin(); II != BB.end();) {
  677. Instruction *I = &*II;
  678. // We may remove I.
  679. ++II;
  680. Value *A, *B, *AT, *BT;
  681. ConstantInt *R, *K, *C;
  682. // A^t * B ^t -> (B * A)^t
  683. if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
  684. m_Value(A), m_Value(B), m_ConstantInt(R),
  685. m_ConstantInt(K), m_ConstantInt(C))) &&
  686. match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
  687. match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
  688. IRBuilder<> IB(&*I);
  689. MatrixBuilder<IRBuilder<>> Builder(IB);
  690. Value *M = Builder.CreateMatrixMultiply(
  691. BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
  692. setShapeInfo(M, {C, R});
  693. Instruction *NewInst = Builder.CreateMatrixTranspose(
  694. M, C->getZExtValue(), R->getZExtValue());
  695. ReplaceAllUsesWith(*I, NewInst);
  696. if (I->use_empty())
  697. I->eraseFromParent();
  698. if (A->use_empty())
  699. cast<Instruction>(A)->eraseFromParent();
  700. if (A != B && B->use_empty())
  701. cast<Instruction>(B)->eraseFromParent();
  702. }
  703. }
  704. }
  705. }
  706. bool Visit() {
  707. SmallVector<Instruction *, 32> WorkList;
  708. // Initially only the shape of matrix intrinsics is known.
  709. // Initialize the work list with ops carrying shape information.
  710. for (BasicBlock &BB : Func)
  711. for (Instruction &Inst : BB) {
  712. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
  713. if (!II)
  714. continue;
  715. switch (II->getIntrinsicID()) {
  716. case Intrinsic::matrix_multiply:
  717. case Intrinsic::matrix_transpose:
  718. case Intrinsic::matrix_column_major_load:
  719. case Intrinsic::matrix_column_major_store:
  720. WorkList.push_back(&Inst);
  721. break;
  722. default:
  723. break;
  724. }
  725. }
  726. // Avoid unnecessary work if there are no matrix intrinsics in the function.
  727. if (WorkList.empty())
  728. return false;
  729. // Propagate shapes until nothing changes any longer.
  730. while (!WorkList.empty()) {
  731. WorkList = propagateShapeForward(WorkList);
  732. WorkList = propagateShapeBackward(WorkList);
  733. }
  734. if (!isMinimal()) {
  735. optimizeTransposes();
  736. LLVM_DEBUG({
  737. dbgs() << "Dump after matrix transpose optimization:\n";
  738. Func.dump();
  739. });
  740. }
  741. bool Changed = false;
  742. SmallVector<CallInst *, 16> MaybeFusableInsts;
  743. SmallVector<Instruction *, 16> MatrixInsts;
  744. // First, collect all instructions with shape information and candidates for
  745. // fusion (currently only matrix multiplies).
  746. ReversePostOrderTraversal<Function *> RPOT(&Func);
  747. for (auto *BB : RPOT)
  748. for (Instruction &I : *BB) {
  749. if (ShapeMap.find(&I) == ShapeMap.end())
  750. continue;
  751. if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
  752. MaybeFusableInsts.push_back(cast<CallInst>(&I));
  753. MatrixInsts.push_back(&I);
  754. }
  755. // Second, try to fuse candidates.
  756. SmallPtrSet<Instruction *, 16> FusedInsts;
  757. for (CallInst *CI : MaybeFusableInsts)
  758. LowerMatrixMultiplyFused(CI, FusedInsts);
  759. Changed = !FusedInsts.empty();
  760. // Third, lower remaining instructions with shape information.
  761. for (Instruction *Inst : MatrixInsts) {
  762. if (FusedInsts.count(Inst))
  763. continue;
  764. IRBuilder<> Builder(Inst);
  765. if (CallInst *CInst = dyn_cast<CallInst>(Inst))
  766. Changed |= VisitCallInst(CInst);
  767. Value *Op1;
  768. Value *Op2;
  769. if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
  770. Changed |= VisitBinaryOperator(BinOp);
  771. if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
  772. Changed |= VisitUnaryOperator(UnOp);
  773. if (match(Inst, m_Load(m_Value(Op1))))
  774. Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
  775. else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
  776. Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
  777. }
  778. if (ORE) {
  779. RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
  780. RemarkGen.emitRemarks();
  781. }
  782. // Delete the instructions backwards, as it has a reduced likelihood of
  783. // having to update as many def-use and use-def chains.
  784. //
  785. // Because we add to ToRemove during fusion we can't guarantee that defs
  786. // are before uses. Change uses to undef temporarily as these should get
  787. // removed as well.
  788. //
  789. // For verification, we keep track of where we changed uses to undefs in
  790. // UndefedInsts and then check that we in fact remove them.
  791. SmallSet<Instruction *, 16> UndefedInsts;
  792. for (auto *Inst : reverse(ToRemove)) {
  793. for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
  794. if (auto *Undefed = dyn_cast<Instruction>(U.getUser()))
  795. UndefedInsts.insert(Undefed);
  796. U.set(UndefValue::get(Inst->getType()));
  797. }
  798. Inst->eraseFromParent();
  799. UndefedInsts.erase(Inst);
  800. }
  801. if (!UndefedInsts.empty()) {
  802. // If we didn't remove all undefed instructions, it's a hard error.
  803. dbgs() << "Undefed but present instructions:\n";
  804. for (auto *I : UndefedInsts)
  805. dbgs() << *I << "\n";
  806. llvm_unreachable("Undefed but instruction not removed");
  807. }
  808. return Changed;
  809. }
  810. /// Turns \p BasePtr into an elementwise pointer to \p EltType.
  811. Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
  812. unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
  813. Type *EltPtrType = PointerType::get(EltType, AS);
  814. return Builder.CreatePointerCast(BasePtr, EltPtrType);
  815. }
  816. /// Replace intrinsic calls
  817. bool VisitCallInst(CallInst *Inst) {
  818. if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
  819. return false;
  820. switch (Inst->getCalledFunction()->getIntrinsicID()) {
  821. case Intrinsic::matrix_multiply:
  822. LowerMultiply(Inst);
  823. break;
  824. case Intrinsic::matrix_transpose:
  825. LowerTranspose(Inst);
  826. break;
  827. case Intrinsic::matrix_column_major_load:
  828. LowerColumnMajorLoad(Inst);
  829. break;
  830. case Intrinsic::matrix_column_major_store:
  831. LowerColumnMajorStore(Inst);
  832. break;
  833. default:
  834. return false;
  835. }
  836. return true;
  837. }
  838. /// Compute the alignment for a column/row \p Idx with \p Stride between them.
  839. /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
  840. /// ConstantInt, reduce the initial alignment based on the byte offset. For
  841. /// non-ConstantInt strides, return the common alignment of the initial
  842. /// alignment and the element size in bytes.
  843. Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
  844. MaybeAlign A) const {
  845. Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
  846. if (Idx == 0)
  847. return InitialAlign;
  848. TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
  849. if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
  850. uint64_t StrideInBytes =
  851. ConstStride->getZExtValue() * ElementSizeInBits / 8;
  852. return commonAlignment(InitialAlign, Idx * StrideInBytes);
  853. }
  854. return commonAlignment(InitialAlign, ElementSizeInBits / 8);
  855. }
  856. /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
  857. /// vectors.
  858. MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
  859. bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
  860. auto *VType = cast<VectorType>(Ty);
  861. Type *EltTy = VType->getElementType();
  862. Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
  863. Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
  864. MatrixTy Result;
  865. for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
  866. Value *GEP = computeVectorAddr(
  867. EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
  868. Stride, Shape.getStride(), EltTy, Builder);
  869. Value *Vector = Builder.CreateAlignedLoad(
  870. VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
  871. IsVolatile, "col.load");
  872. Result.addVector(Vector);
  873. }
  874. return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
  875. Result.getNumVectors());
  876. }
  877. /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
  878. /// starting at \p MatrixPtr[I][J].
  879. MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
  880. ShapeInfo MatrixShape, Value *I, Value *J,
  881. ShapeInfo ResultShape, Type *EltTy,
  882. IRBuilder<> &Builder) {
  883. Value *Offset = Builder.CreateAdd(
  884. Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
  885. unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
  886. Value *EltPtr =
  887. Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
  888. Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
  889. auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
  890. ResultShape.NumColumns);
  891. Type *TilePtrTy = PointerType::get(TileTy, AS);
  892. Value *TilePtr =
  893. Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
  894. return loadMatrix(TileTy, TilePtr, Align,
  895. Builder.getInt64(MatrixShape.getStride()), IsVolatile,
  896. ResultShape, Builder);
  897. }
  898. /// Lower a load instruction with shape information.
  899. void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
  900. bool IsVolatile, ShapeInfo Shape) {
  901. IRBuilder<> Builder(Inst);
  902. finalizeLowering(Inst,
  903. loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
  904. Shape, Builder),
  905. Builder);
  906. }
  907. /// Lowers llvm.matrix.column.major.load.
  908. ///
  909. /// The intrinsic loads a matrix from memory using a stride between columns.
  910. void LowerColumnMajorLoad(CallInst *Inst) {
  911. assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
  912. "Intrinsic only supports column-major layout!");
  913. Value *Ptr = Inst->getArgOperand(0);
  914. Value *Stride = Inst->getArgOperand(1);
  915. LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
  916. cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
  917. {Inst->getArgOperand(3), Inst->getArgOperand(4)});
  918. }
  919. /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
  920. /// MatrixPtr[I][J].
  921. void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
  922. MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
  923. Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
  924. Value *Offset = Builder.CreateAdd(
  925. Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
  926. unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
  927. Value *EltPtr =
  928. Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
  929. Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
  930. auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
  931. StoreVal.getNumColumns());
  932. Type *TilePtrTy = PointerType::get(TileTy, AS);
  933. Value *TilePtr =
  934. Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
  935. storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
  936. Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
  937. }
  938. /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
  939. /// vectors.
  940. MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
  941. MaybeAlign MAlign, Value *Stride, bool IsVolatile,
  942. IRBuilder<> &Builder) {
  943. auto VType = cast<VectorType>(Ty);
  944. Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
  945. for (auto Vec : enumerate(StoreVal.vectors())) {
  946. Value *GEP = computeVectorAddr(
  947. EltPtr,
  948. Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
  949. Vec.index()),
  950. Stride, StoreVal.getStride(), VType->getElementType(), Builder);
  951. Builder.CreateAlignedStore(Vec.value(), GEP,
  952. getAlignForIndex(Vec.index(), Stride,
  953. VType->getElementType(),
  954. MAlign),
  955. IsVolatile);
  956. }
  957. return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
  958. StoreVal.getNumVectors());
  959. }
  960. /// Lower a store instruction with shape information.
  961. void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
  962. Value *Stride, bool IsVolatile, ShapeInfo Shape) {
  963. IRBuilder<> Builder(Inst);
  964. auto StoreVal = getMatrix(Matrix, Shape, Builder);
  965. finalizeLowering(Inst,
  966. storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
  967. IsVolatile, Builder),
  968. Builder);
  969. }
  970. /// Lowers llvm.matrix.column.major.store.
  971. ///
  972. /// The intrinsic store a matrix back memory using a stride between columns.
  973. void LowerColumnMajorStore(CallInst *Inst) {
  974. assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
  975. "Intrinsic only supports column-major layout!");
  976. Value *Matrix = Inst->getArgOperand(0);
  977. Value *Ptr = Inst->getArgOperand(1);
  978. Value *Stride = Inst->getArgOperand(2);
  979. LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
  980. cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
  981. {Inst->getArgOperand(4), Inst->getArgOperand(5)});
  982. }
  983. // Set elements I..I+NumElts-1 to Block
  984. Value *insertVector(Value *Col, unsigned I, Value *Block,
  985. IRBuilder<> &Builder) {
  986. // First, bring Block to the same size as Col
  987. unsigned BlockNumElts =
  988. cast<FixedVectorType>(Block->getType())->getNumElements();
  989. unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
  990. assert(NumElts >= BlockNumElts && "Too few elements for current block");
  991. Block = Builder.CreateShuffleVector(
  992. Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
  993. // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
  994. // 8, 4, 5, 6
  995. SmallVector<int, 16> Mask;
  996. unsigned i;
  997. for (i = 0; i < I; i++)
  998. Mask.push_back(i);
  999. unsigned VecNumElts =
  1000. cast<FixedVectorType>(Col->getType())->getNumElements();
  1001. for (; i < I + BlockNumElts; i++)
  1002. Mask.push_back(i - I + VecNumElts);
  1003. for (; i < VecNumElts; i++)
  1004. Mask.push_back(i);
  1005. return Builder.CreateShuffleVector(Col, Block, Mask);
  1006. }
  1007. Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
  1008. IRBuilder<> &Builder, bool AllowContraction,
  1009. unsigned &NumComputeOps) {
  1010. NumComputeOps += getNumOps(A->getType());
  1011. if (!Sum)
  1012. return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
  1013. if (UseFPOp) {
  1014. if (AllowContraction) {
  1015. // Use fmuladd for floating point operations and let the backend decide
  1016. // if that's profitable.
  1017. Function *FMulAdd = Intrinsic::getDeclaration(
  1018. Func.getParent(), Intrinsic::fmuladd, A->getType());
  1019. return Builder.CreateCall(FMulAdd, {A, B, Sum});
  1020. }
  1021. NumComputeOps += getNumOps(A->getType());
  1022. Value *Mul = Builder.CreateFMul(A, B);
  1023. return Builder.CreateFAdd(Sum, Mul);
  1024. }
  1025. NumComputeOps += getNumOps(A->getType());
  1026. Value *Mul = Builder.CreateMul(A, B);
  1027. return Builder.CreateAdd(Sum, Mul);
  1028. }
  1029. /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
  1030. /// users with shape information, there's nothing to do: they will use the
  1031. /// cached value when they are lowered. For other users, \p Matrix is
  1032. /// flattened and the uses are updated to use it. Also marks \p Inst for
  1033. /// deletion.
  1034. void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
  1035. IRBuilder<> &Builder) {
  1036. auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
  1037. (void)inserted;
  1038. assert(inserted.second && "multiple matrix lowering mapping");
  1039. ToRemove.push_back(Inst);
  1040. Value *Flattened = nullptr;
  1041. for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
  1042. if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
  1043. if (!Flattened)
  1044. Flattened = Matrix.embedInVector(Builder);
  1045. U.set(Flattened);
  1046. }
  1047. }
  1048. }
  1049. /// Compute \p Result += \p A * \p B for input matrices with left-associating
  1050. /// addition.
  1051. ///
  1052. /// We can fold a transpose into the operand that is used to extract scalars.
  1053. /// This is the first operands with row-major and the second with
  1054. /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
  1055. /// operand is transposed.
  1056. void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
  1057. const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
  1058. bool IsScalarMatrixTransposed, FastMathFlags FMF) {
  1059. const unsigned VF = std::max<unsigned>(
  1060. TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
  1061. .getFixedSize() /
  1062. Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
  1063. 1U);
  1064. unsigned R = Result.getNumRows();
  1065. unsigned C = Result.getNumColumns();
  1066. unsigned M = A.getNumColumns();
  1067. bool IsFP = Result.getElementType()->isFloatingPointTy();
  1068. assert(A.isColumnMajor() == B.isColumnMajor() &&
  1069. Result.isColumnMajor() == A.isColumnMajor() &&
  1070. "operands must agree on matrix layout");
  1071. unsigned NumComputeOps = 0;
  1072. Builder.setFastMathFlags(FMF);
  1073. if (A.isColumnMajor()) {
  1074. // Multiply columns from the first operand with scalars from the second
  1075. // operand. Then move along the K axes and accumulate the columns. With
  1076. // this the adds can be vectorized without reassociation.
  1077. for (unsigned J = 0; J < C; ++J) {
  1078. unsigned BlockSize = VF;
  1079. // If Result is zero, we don't need to accumulate in the K==0 iteration.
  1080. bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
  1081. for (unsigned I = 0; I < R; I += BlockSize) {
  1082. // Gradually lower the vectorization factor to cover the remainder.
  1083. while (I + BlockSize > R)
  1084. BlockSize /= 2;
  1085. Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
  1086. : nullptr;
  1087. for (unsigned K = 0; K < M; ++K) {
  1088. Value *L = A.extractVector(I, K, BlockSize, Builder);
  1089. Value *RH = Builder.CreateExtractElement(
  1090. B.getColumn(IsScalarMatrixTransposed ? K : J),
  1091. IsScalarMatrixTransposed ? J : K);
  1092. Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
  1093. Sum =
  1094. createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
  1095. IsFP, Builder, FMF.allowContract(), NumComputeOps);
  1096. }
  1097. Result.setVector(J,
  1098. insertVector(Result.getVector(J), I, Sum, Builder));
  1099. }
  1100. }
  1101. } else {
  1102. // Multiply rows from the second operand with scalars from the first
  1103. // operand. Then move along the K axes and accumulate the rows. With this
  1104. // the adds can be vectorized without reassociation.
  1105. for (unsigned I = 0; I < R; ++I) {
  1106. unsigned BlockSize = VF;
  1107. bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
  1108. for (unsigned J = 0; J < C; J += BlockSize) {
  1109. // Gradually lower the vectorization factor to cover the remainder.
  1110. while (J + BlockSize > C)
  1111. BlockSize /= 2;
  1112. Value *Sum = nullptr;
  1113. for (unsigned K = 0; K < M; ++K) {
  1114. Value *R = B.extractVector(K, J, BlockSize, Builder);
  1115. Value *LH = Builder.CreateExtractElement(
  1116. A.getVector(IsScalarMatrixTransposed ? K : I),
  1117. IsScalarMatrixTransposed ? I : K);
  1118. Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
  1119. Sum =
  1120. createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
  1121. IsFP, Builder, FMF.allowContract(), NumComputeOps);
  1122. }
  1123. Result.setVector(I,
  1124. insertVector(Result.getVector(I), J, Sum, Builder));
  1125. }
  1126. }
  1127. }
  1128. Result.addNumComputeOps(NumComputeOps);
  1129. }
  1130. /// Ensure that the memory in \p Load does not alias \p Store by potentially
  1131. /// copying it to a new location. This new or otherwise the original location
  1132. /// is returned.
  1133. Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
  1134. CallInst *MatMul) {
  1135. MemoryLocation StoreLoc = MemoryLocation::get(Store);
  1136. MemoryLocation LoadLoc = MemoryLocation::get(Load);
  1137. // If we can statically determine noalias we're good.
  1138. if (AA->isNoAlias(LoadLoc, StoreLoc))
  1139. return Load->getPointerOperand();
  1140. // Create code to check if the memory locations of the Load and Store
  1141. // overlap and if they do, copy Load's operand to a new buffer.
  1142. // First, create new blocks for 2n part of the check and the copy.
  1143. BasicBlock *Check0 = MatMul->getParent();
  1144. // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
  1145. // DT. Manually collect dominator tree updates, to avoid unnecessary work,
  1146. // as we adjust Check0 and Check1's branches.
  1147. SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
  1148. for (BasicBlock *Succ : successors(Check0))
  1149. DTUpdates.push_back({DT->Delete, Check0, Succ});
  1150. BasicBlock *Check1 =
  1151. SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
  1152. nullptr, "alias_cont");
  1153. BasicBlock *Copy =
  1154. SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
  1155. nullptr, "copy");
  1156. BasicBlock *Fusion =
  1157. SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
  1158. nullptr, "no_alias");
  1159. // Check if the loaded memory location begins before the end of the store
  1160. // location. If the condition holds, they might overlap, otherwise they are
  1161. // guaranteed to not overlap.
  1162. IRBuilder<> Builder(MatMul);
  1163. Check0->getTerminator()->eraseFromParent();
  1164. Builder.SetInsertPoint(Check0);
  1165. Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
  1166. Value *StoreBegin = Builder.CreatePtrToInt(
  1167. const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
  1168. Value *StoreEnd = Builder.CreateAdd(
  1169. StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
  1170. "store.end", true, true);
  1171. Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
  1172. IntPtrTy, "load.begin");
  1173. Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
  1174. Fusion);
  1175. // Check if the store begins before the end of the load location. If the
  1176. // condition holds, they alias, otherwise they are guaranteed to not
  1177. // overlap.
  1178. Check1->getTerminator()->eraseFromParent();
  1179. Builder.SetInsertPoint(Check1, Check1->begin());
  1180. Value *LoadEnd = Builder.CreateAdd(
  1181. LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
  1182. "load.end", true, true);
  1183. Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
  1184. Fusion);
  1185. // Copy load operand to new alloca.
  1186. Builder.SetInsertPoint(Copy, Copy->begin());
  1187. auto *VT = cast<FixedVectorType>(Load->getType());
  1188. // Use an array type for the alloca, to avoid potentially huge alignment
  1189. // requirements for large vector types.
  1190. auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
  1191. AllocaInst *Alloca =
  1192. Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
  1193. Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
  1194. Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
  1195. Load->getAlign(), LoadLoc.Size.getValue());
  1196. Builder.SetInsertPoint(Fusion, Fusion->begin());
  1197. PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
  1198. PHI->addIncoming(Load->getPointerOperand(), Check0);
  1199. PHI->addIncoming(Load->getPointerOperand(), Check1);
  1200. PHI->addIncoming(BC, Copy);
  1201. // Adjust DT.
  1202. DTUpdates.push_back({DT->Insert, Check0, Check1});
  1203. DTUpdates.push_back({DT->Insert, Check0, Fusion});
  1204. DTUpdates.push_back({DT->Insert, Check1, Copy});
  1205. DTUpdates.push_back({DT->Insert, Check1, Fusion});
  1206. DT->applyUpdates(DTUpdates);
  1207. return PHI;
  1208. }
  1209. bool isFusionProfitable(CallInst *MatMul) {
  1210. if (ForceFusion)
  1211. return true;
  1212. ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
  1213. ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
  1214. const unsigned R = LShape.NumRows;
  1215. const unsigned C = RShape.NumColumns;
  1216. const unsigned M = LShape.NumColumns;
  1217. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1218. const unsigned VF = std::max<unsigned>(
  1219. TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
  1220. .getFixedSize() /
  1221. EltType->getPrimitiveSizeInBits().getFixedSize(),
  1222. 1U);
  1223. // Cost model for tiling
  1224. //
  1225. // For tiling to be beneficial, we need reuse either along the R or
  1226. // the C axis. We vectorize along the R axis so that means at least
  1227. // 3 elements.
  1228. // TODO: Also consider cost of copying if operands alias.
  1229. if (R <= VF && C == 1)
  1230. return false;
  1231. // Then we need enough elements to exceed the number of vector
  1232. // registers we have. Note that this is an oversimplification since
  1233. // fusing also takes some extra loads which may exceed the number of
  1234. // reloads necessary.
  1235. unsigned Op0Regs = (R + VF - 1) / VF * M;
  1236. unsigned Op1Regs = (M + VF - 1) / VF * C;
  1237. return Op0Regs + Op1Regs >
  1238. TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));
  1239. }
  1240. MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
  1241. MatrixTy Res;
  1242. auto *ColumType = FixedVectorType::get(EltType, R);
  1243. for (unsigned I = 0; I < C; ++I)
  1244. Res.addVector(ConstantAggregateZero::get(ColumType));
  1245. return Res;
  1246. }
  1247. void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
  1248. Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
  1249. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1250. // Create the main tiling loop nest.
  1251. TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
  1252. DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
  1253. Instruction *InsertI = cast<Instruction>(MatMul);
  1254. BasicBlock *Start = InsertI->getParent();
  1255. BasicBlock *End =
  1256. SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
  1257. IRBuilder<> Builder(MatMul);
  1258. BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
  1259. Type *TileVecTy =
  1260. FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
  1261. MatrixTy TileResult;
  1262. // Insert in the inner loop header.
  1263. Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
  1264. // Create PHI nodes for the result columns to accumulate across iterations.
  1265. SmallVector<PHINode *, 4> ColumnPhis;
  1266. for (unsigned I = 0; I < TileSize; I++) {
  1267. auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
  1268. Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
  1269. TI.RowLoopHeader->getSingleSuccessor());
  1270. TileResult.addVector(Phi);
  1271. ColumnPhis.push_back(Phi);
  1272. }
  1273. // Insert in the inner loop body, which computes
  1274. // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
  1275. Builder.SetInsertPoint(InnerBody->getTerminator());
  1276. // Load tiles of the operands.
  1277. MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
  1278. {TileSize, TileSize}, EltType, Builder);
  1279. MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
  1280. {TileSize, TileSize}, EltType, Builder);
  1281. emitMatrixMultiply(TileResult, A, B, Builder, true, false,
  1282. getFastMathFlags(MatMul));
  1283. // Store result after the inner loop is done.
  1284. Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
  1285. storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
  1286. Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
  1287. TI.CurrentRow, TI.CurrentCol, EltType, Builder);
  1288. for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
  1289. ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
  1290. // Force unrolling of a few iterations of the inner loop, to make sure there
  1291. // is enough work per iteration.
  1292. // FIXME: The unroller should make this decision directly instead, but
  1293. // currently the cost-model is not up to the task.
  1294. unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
  1295. addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
  1296. "llvm.loop.unroll.count", InnerLoopUnrollCount);
  1297. }
  1298. void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
  1299. StoreInst *Store,
  1300. SmallPtrSetImpl<Instruction *> &FusedInsts) {
  1301. assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
  1302. "Tiling only supported for column-major matrixes at the moment!");
  1303. if (!isFusionProfitable(MatMul))
  1304. return;
  1305. ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
  1306. ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
  1307. const unsigned R = LShape.NumRows;
  1308. const unsigned C = RShape.NumColumns;
  1309. const unsigned M = LShape.NumColumns;
  1310. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1311. Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
  1312. Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
  1313. Value *CPtr = Store->getPointerOperand();
  1314. if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
  1315. createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
  1316. else {
  1317. IRBuilder<> Builder(Store);
  1318. for (unsigned J = 0; J < C; J += TileSize)
  1319. for (unsigned I = 0; I < R; I += TileSize) {
  1320. const unsigned TileR = std::min(R - I, unsigned(TileSize));
  1321. const unsigned TileC = std::min(C - J, unsigned(TileSize));
  1322. MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
  1323. for (unsigned K = 0; K < M; K += TileSize) {
  1324. const unsigned TileM = std::min(M - K, unsigned(TileSize));
  1325. MatrixTy A =
  1326. loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
  1327. LShape, Builder.getInt64(I), Builder.getInt64(K),
  1328. {TileR, TileM}, EltType, Builder);
  1329. MatrixTy B =
  1330. loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
  1331. RShape, Builder.getInt64(K), Builder.getInt64(J),
  1332. {TileM, TileC}, EltType, Builder);
  1333. emitMatrixMultiply(Res, A, B, Builder, true, false,
  1334. getFastMathFlags(MatMul));
  1335. }
  1336. storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
  1337. Builder.getInt64(I), Builder.getInt64(J), EltType,
  1338. Builder);
  1339. }
  1340. }
  1341. // Mark eliminated instructions as fused and remove them.
  1342. FusedInsts.insert(Store);
  1343. FusedInsts.insert(MatMul);
  1344. Store->eraseFromParent();
  1345. MatMul->eraseFromParent();
  1346. if (LoadOp0->hasNUses(0)) {
  1347. FusedInsts.insert(LoadOp0);
  1348. LoadOp0->eraseFromParent();
  1349. }
  1350. if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
  1351. FusedInsts.insert(LoadOp1);
  1352. LoadOp1->eraseFromParent();
  1353. }
  1354. }
  1355. /// Try to lower matrix multiply chains by fusing operations.
  1356. ///
  1357. /// Call finalizeLowering on lowered instructions. Instructions that are
  1358. /// completely eliminated by fusion are added to \p FusedInsts.
  1359. void LowerMatrixMultiplyFused(CallInst *MatMul,
  1360. SmallPtrSetImpl<Instruction *> &FusedInsts) {
  1361. if (!FuseMatrix || !DT)
  1362. return;
  1363. assert(AA && LI && "Analyses should be available");
  1364. Value *A = MatMul->getArgOperand(0);
  1365. Value *B = MatMul->getArgOperand(1);
  1366. // We can fold the transpose into the operand that is used to fetch scalars.
  1367. Value *T;
  1368. if (MatrixLayout == MatrixLayoutTy::ColumnMajor
  1369. ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
  1370. : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
  1371. IRBuilder<> Builder(MatMul);
  1372. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1373. ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
  1374. ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
  1375. const unsigned R = LShape.NumRows;
  1376. const unsigned M = LShape.NumColumns;
  1377. const unsigned C = RShape.NumColumns;
  1378. MatrixTy MA;
  1379. MatrixTy MB;
  1380. Value *Transpose;
  1381. if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
  1382. MA = getMatrix(A, ShapeInfo(R, M), Builder);
  1383. MB = getMatrix(T, ShapeInfo(C, M), Builder);
  1384. Transpose = B;
  1385. } else {
  1386. MA = getMatrix(T, ShapeInfo(R, M), Builder);
  1387. MB = getMatrix(B, ShapeInfo(C, M), Builder);
  1388. Transpose = A;
  1389. }
  1390. // Initialize the output
  1391. MatrixTy Result(R, C, EltType);
  1392. emitMatrixMultiply(Result, MA, MB, Builder, false, true,
  1393. getFastMathFlags(MatMul));
  1394. FusedInsts.insert(MatMul);
  1395. if (Transpose->hasOneUse()) {
  1396. FusedInsts.insert(cast<Instruction>(Transpose));
  1397. ToRemove.push_back(cast<Instruction>(Transpose));
  1398. // TODO: add a fake entry for the folded instruction so that this is
  1399. // included in the expression in the remark.
  1400. Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
  1401. }
  1402. finalizeLowering(MatMul, Result, Builder);
  1403. return;
  1404. }
  1405. if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
  1406. return;
  1407. // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
  1408. // since the single store user will be lowered as part of this.
  1409. auto *LoadOp0 = dyn_cast<LoadInst>(A);
  1410. auto *LoadOp1 = dyn_cast<LoadInst>(B);
  1411. auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
  1412. if (LoadOp0 && LoadOp1 && Store) {
  1413. // The store address must dominate the MatMul instruction, otherwise
  1414. // we create invalid IR.
  1415. SetVector<Value *> WorkList;
  1416. WorkList.insert(Store->getOperand(1));
  1417. SmallVector<Instruction *> ToHoist;
  1418. for (unsigned I = 0; I != WorkList.size(); ++I) {
  1419. Value *Current = WorkList[I];
  1420. auto *CurrI = dyn_cast<Instruction>(Current);
  1421. if (!CurrI)
  1422. continue;
  1423. if (isa<PHINode>(CurrI))
  1424. return;
  1425. if (DT->dominates(CurrI, MatMul))
  1426. continue;
  1427. if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
  1428. return;
  1429. ToHoist.push_back(CurrI);
  1430. WorkList.insert(CurrI->op_begin(), CurrI->op_end());
  1431. }
  1432. sort(ToHoist, [this](Instruction *A, Instruction *B) {
  1433. return DT->dominates(A, B);
  1434. });
  1435. for (Instruction *I : ToHoist)
  1436. I->moveBefore(MatMul);
  1437. emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
  1438. return;
  1439. }
  1440. }
  1441. /// Lowers llvm.matrix.multiply.
  1442. void LowerMultiply(CallInst *MatMul) {
  1443. IRBuilder<> Builder(MatMul);
  1444. auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
  1445. ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
  1446. ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
  1447. const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
  1448. const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
  1449. assert(Lhs.getElementType() == Rhs.getElementType() &&
  1450. "Matrix multiply argument element types do not match.");
  1451. const unsigned R = LShape.NumRows;
  1452. const unsigned C = RShape.NumColumns;
  1453. assert(LShape.NumColumns == RShape.NumRows);
  1454. // Initialize the output
  1455. MatrixTy Result(R, C, EltType);
  1456. assert(Lhs.getElementType() == Result.getElementType() &&
  1457. "Matrix multiply result element type does not match arguments.");
  1458. emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
  1459. getFastMathFlags(MatMul));
  1460. finalizeLowering(MatMul, Result, Builder);
  1461. }
  1462. /// Lowers llvm.matrix.transpose.
  1463. void LowerTranspose(CallInst *Inst) {
  1464. MatrixTy Result;
  1465. IRBuilder<> Builder(Inst);
  1466. Value *InputVal = Inst->getArgOperand(0);
  1467. VectorType *VectorTy = cast<VectorType>(InputVal->getType());
  1468. ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
  1469. MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
  1470. const unsigned NewNumVecs =
  1471. InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
  1472. const unsigned NewNumElts =
  1473. InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
  1474. for (unsigned I = 0; I < NewNumVecs; ++I) {
  1475. // Build a single result vector. First initialize it.
  1476. Value *ResultVector = UndefValue::get(
  1477. FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
  1478. // Go through the old elements and insert it into the resulting vector.
  1479. for (auto J : enumerate(InputMatrix.vectors())) {
  1480. Value *Elt = Builder.CreateExtractElement(J.value(), I);
  1481. // Row and column indices are transposed.
  1482. ResultVector =
  1483. Builder.CreateInsertElement(ResultVector, Elt, J.index());
  1484. }
  1485. Result.addVector(ResultVector);
  1486. }
  1487. // TODO: Improve estimate of operations needed for transposes. Currently we
  1488. // just count the insertelement/extractelement instructions, but do not
  1489. // account for later simplifications/combines.
  1490. finalizeLowering(
  1491. Inst,
  1492. Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
  1493. .addNumExposedTransposes(1),
  1494. Builder);
  1495. }
  1496. /// Lower load instructions, if shape information is available.
  1497. bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
  1498. auto I = ShapeMap.find(Inst);
  1499. if (I == ShapeMap.end())
  1500. return false;
  1501. LowerLoad(Inst, Ptr, Inst->getAlign(),
  1502. Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
  1503. I->second);
  1504. return true;
  1505. }
  1506. bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
  1507. IRBuilder<> &Builder) {
  1508. auto I = ShapeMap.find(StoredVal);
  1509. if (I == ShapeMap.end())
  1510. return false;
  1511. LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
  1512. Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
  1513. I->second);
  1514. return true;
  1515. }
  1516. /// Lower binary operators, if shape information is available.
  1517. bool VisitBinaryOperator(BinaryOperator *Inst) {
  1518. auto I = ShapeMap.find(Inst);
  1519. if (I == ShapeMap.end())
  1520. return false;
  1521. Value *Lhs = Inst->getOperand(0);
  1522. Value *Rhs = Inst->getOperand(1);
  1523. IRBuilder<> Builder(Inst);
  1524. ShapeInfo &Shape = I->second;
  1525. MatrixTy Result;
  1526. MatrixTy A = getMatrix(Lhs, Shape, Builder);
  1527. MatrixTy B = getMatrix(Rhs, Shape, Builder);
  1528. assert(A.isColumnMajor() == B.isColumnMajor() &&
  1529. Result.isColumnMajor() == A.isColumnMajor() &&
  1530. "operands must agree on matrix layout");
  1531. Builder.setFastMathFlags(getFastMathFlags(Inst));
  1532. // Helper to perform binary op on vectors.
  1533. auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
  1534. switch (Inst->getOpcode()) {
  1535. case Instruction::Add:
  1536. return Builder.CreateAdd(LHS, RHS);
  1537. case Instruction::Mul:
  1538. return Builder.CreateMul(LHS, RHS);
  1539. case Instruction::Sub:
  1540. return Builder.CreateSub(LHS, RHS);
  1541. case Instruction::FAdd:
  1542. return Builder.CreateFAdd(LHS, RHS);
  1543. case Instruction::FMul:
  1544. return Builder.CreateFMul(LHS, RHS);
  1545. case Instruction::FSub:
  1546. return Builder.CreateFSub(LHS, RHS);
  1547. default:
  1548. llvm_unreachable("Unsupported binary operator for matrix");
  1549. }
  1550. };
  1551. for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
  1552. Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
  1553. finalizeLowering(Inst,
  1554. Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
  1555. Result.getNumVectors()),
  1556. Builder);
  1557. return true;
  1558. }
  1559. /// Lower unary operators, if shape information is available.
  1560. bool VisitUnaryOperator(UnaryOperator *Inst) {
  1561. auto I = ShapeMap.find(Inst);
  1562. if (I == ShapeMap.end())
  1563. return false;
  1564. Value *Op = Inst->getOperand(0);
  1565. IRBuilder<> Builder(Inst);
  1566. ShapeInfo &Shape = I->second;
  1567. MatrixTy Result;
  1568. MatrixTy M = getMatrix(Op, Shape, Builder);
  1569. Builder.setFastMathFlags(getFastMathFlags(Inst));
  1570. // Helper to perform unary op on vectors.
  1571. auto BuildVectorOp = [&Builder, Inst](Value *Op) {
  1572. switch (Inst->getOpcode()) {
  1573. case Instruction::FNeg:
  1574. return Builder.CreateFNeg(Op);
  1575. default:
  1576. llvm_unreachable("Unsupported unary operator for matrix");
  1577. }
  1578. };
  1579. for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
  1580. Result.addVector(BuildVectorOp(M.getVector(I)));
  1581. finalizeLowering(Inst,
  1582. Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
  1583. Result.getNumVectors()),
  1584. Builder);
  1585. return true;
  1586. }
  1587. /// Helper to linearize a matrix expression tree into a string. Currently
  1588. /// matrix expressions are linarized by starting at an expression leaf and
  1589. /// linearizing bottom up.
  1590. struct ExprLinearizer {
  1591. unsigned LengthToBreak = 100;
  1592. std::string Str;
  1593. raw_string_ostream Stream;
  1594. unsigned LineLength = 0;
  1595. const DataLayout &DL;
  1596. /// Mapping from instructions to matrixes. It is used to identify
  1597. /// matrix instructions.
  1598. const MapVector<Value *, MatrixTy> &Inst2Matrix;
  1599. /// Mapping from values to the leaves of all expressions that the value is
  1600. /// part of.
  1601. const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
  1602. /// Set of matrix expressions in the scope of a given DISubprogram.
  1603. const SmallSetVector<Value *, 32> &ExprsInSubprogram;
  1604. /// Leaf node of the expression to linearize.
  1605. Value *Leaf;
  1606. /// Used to keep track of sub-expressions that get reused while linearizing
  1607. /// the expression. Re-used sub-expressions are marked as (reused).
  1608. SmallPtrSet<Value *, 8> ReusedExprs;
  1609. ExprLinearizer(const DataLayout &DL,
  1610. const MapVector<Value *, MatrixTy> &Inst2Matrix,
  1611. const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
  1612. const SmallSetVector<Value *, 32> &ExprsInSubprogram,
  1613. Value *Leaf)
  1614. : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
  1615. ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
  1616. void indent(unsigned N) {
  1617. LineLength += N;
  1618. for (unsigned i = 0; i < N; i++)
  1619. Stream << " ";
  1620. }
  1621. void lineBreak() {
  1622. Stream << "\n";
  1623. LineLength = 0;
  1624. }
  1625. void maybeIndent(unsigned Indent) {
  1626. if (LineLength >= LengthToBreak)
  1627. lineBreak();
  1628. if (LineLength == 0)
  1629. indent(Indent);
  1630. }
  1631. void write(StringRef S) {
  1632. LineLength += S.size();
  1633. Stream << S;
  1634. }
  1635. Value *getUnderlyingObjectThroughLoads(Value *V) {
  1636. if (Value *Ptr = getPointerOperand(V))
  1637. return getUnderlyingObjectThroughLoads(Ptr);
  1638. else if (V->getType()->isPointerTy())
  1639. return getUnderlyingObject(V);
  1640. return V;
  1641. }
  1642. /// Returns true if \p V is a matrix value in the given subprogram.
  1643. bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
  1644. /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
  1645. /// \p SS.
  1646. void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
  1647. auto M = Inst2Matrix.find(V);
  1648. if (M == Inst2Matrix.end())
  1649. SS << "unknown";
  1650. else {
  1651. SS << M->second.getNumRows();
  1652. SS << "x";
  1653. SS << M->second.getNumColumns();
  1654. }
  1655. }
  1656. /// Write the called function name. Handles calls to llvm.matrix.*
  1657. /// specially: we write the name, followed by the dimensions of the input
  1658. /// matrixes, followed by the scalar type name.
  1659. void writeFnName(CallInst *CI) {
  1660. if (!CI->getCalledFunction())
  1661. write("<no called fn>");
  1662. else {
  1663. StringRef Name = CI->getCalledFunction()->getName();
  1664. if (!Name.startswith("llvm.matrix")) {
  1665. write(Name);
  1666. return;
  1667. }
  1668. auto *II = cast<IntrinsicInst>(CI);
  1669. write(Intrinsic::getBaseName(II->getIntrinsicID())
  1670. .drop_front(StringRef("llvm.matrix.").size()));
  1671. write(".");
  1672. std::string Tmp;
  1673. raw_string_ostream SS(Tmp);
  1674. switch (II->getIntrinsicID()) {
  1675. case Intrinsic::matrix_multiply:
  1676. prettyPrintMatrixType(II->getOperand(0), SS);
  1677. SS << ".";
  1678. prettyPrintMatrixType(II->getOperand(1), SS);
  1679. SS << "." << *II->getType()->getScalarType();
  1680. break;
  1681. case Intrinsic::matrix_transpose:
  1682. prettyPrintMatrixType(II->getOperand(0), SS);
  1683. SS << "." << *II->getType()->getScalarType();
  1684. break;
  1685. case Intrinsic::matrix_column_major_load:
  1686. prettyPrintMatrixType(II, SS);
  1687. SS << "." << *II->getType()->getScalarType();
  1688. break;
  1689. case Intrinsic::matrix_column_major_store:
  1690. prettyPrintMatrixType(II->getOperand(0), SS);
  1691. SS << "." << *II->getOperand(0)->getType()->getScalarType();
  1692. break;
  1693. default:
  1694. llvm_unreachable("Unhandled case");
  1695. }
  1696. SS.flush();
  1697. write(Tmp);
  1698. }
  1699. }
  1700. unsigned getNumShapeArgs(CallInst *CI) const {
  1701. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
  1702. switch (II->getIntrinsicID()) {
  1703. case Intrinsic::matrix_multiply:
  1704. return 3;
  1705. case Intrinsic::matrix_transpose:
  1706. return 2;
  1707. case Intrinsic::matrix_column_major_load:
  1708. case Intrinsic::matrix_column_major_store:
  1709. return 3;
  1710. default:
  1711. return 0;
  1712. }
  1713. }
  1714. return 0;
  1715. }
  1716. /// Special printing for values: for pointers, we print if they refer to an
  1717. /// (function) external address or a stack address, for other values we
  1718. /// either print the constant or "scalar"/"matrix" for other values.
  1719. void write(Value *V) {
  1720. V = getUnderlyingObjectThroughLoads(V);
  1721. if (V->getType()->isPointerTy()) {
  1722. if (isa<AllocaInst>(V)) {
  1723. Stream << "stack addr";
  1724. LineLength += StringRef("stack addr").size();
  1725. } else {
  1726. Stream << "addr";
  1727. LineLength += StringRef("addr").size();
  1728. }
  1729. if (!V->getName().empty()) {
  1730. Stream << " %" << V->getName() << "";
  1731. LineLength += V->getName().size() + 2;
  1732. }
  1733. return;
  1734. }
  1735. std::string Tmp;
  1736. raw_string_ostream TmpStream(Tmp);
  1737. if (auto *CI = dyn_cast<ConstantInt>(V))
  1738. TmpStream << CI->getValue();
  1739. else if (isa<Constant>(V))
  1740. TmpStream << "constant";
  1741. else {
  1742. if (isMatrix(V))
  1743. TmpStream << "matrix";
  1744. else
  1745. TmpStream << "scalar";
  1746. }
  1747. TmpStream.flush();
  1748. Tmp = std::string(StringRef(Tmp).trim());
  1749. LineLength += Tmp.size();
  1750. Stream << Tmp;
  1751. }
  1752. /// Linearize expression \p Expr starting at an indentation of \p Indent.
  1753. /// Expressions that are re-used multiple times are prefixed with (reused)
  1754. /// at the re-used root instruction.
  1755. void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
  1756. bool ParentShared) {
  1757. auto *I = cast<Instruction>(Expr);
  1758. maybeIndent(Indent);
  1759. SmallVector<Value *, 8> Ops;
  1760. // Is Expr shared with other expression leaves?
  1761. bool ExprShared = false;
  1762. // Deal with shared subtrees. Mark them as shared, if required.
  1763. if (!ParentShared) {
  1764. auto SI = Shared.find(Expr);
  1765. assert(SI != Shared.end() && SI->second.count(Leaf));
  1766. for (Value *S : SI->second) {
  1767. if (S == Leaf)
  1768. continue;
  1769. DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
  1770. write("shared with remark at line " + std::to_string(DL.getLine()) +
  1771. " column " + std::to_string(DL.getCol()) + " (");
  1772. }
  1773. ExprShared = SI->second.size() > 1;
  1774. }
  1775. bool Reused = !ReusedExprs.insert(Expr).second;
  1776. if (Reused && !ParentReused)
  1777. write("(reused) ");
  1778. if (auto *CI = dyn_cast<CallInst>(I)) {
  1779. writeFnName(CI);
  1780. Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
  1781. } else if (isa<BitCastInst>(Expr)) {
  1782. // Special case bitcasts, which are used to materialize matrixes from
  1783. // non-matrix ops.
  1784. write("matrix");
  1785. return;
  1786. } else {
  1787. Ops.append(I->value_op_begin(), I->value_op_end());
  1788. write(std::string(I->getOpcodeName()));
  1789. }
  1790. write(std::string("("));
  1791. unsigned NumOpsToBreak = 1;
  1792. if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
  1793. NumOpsToBreak = 2;
  1794. for (Value *Op : Ops) {
  1795. if (Ops.size() > NumOpsToBreak)
  1796. lineBreak();
  1797. maybeIndent(Indent + 1);
  1798. if (isMatrix(Op))
  1799. linearizeExpr(Op, Indent + 1, Reused, ExprShared);
  1800. else
  1801. write(Op);
  1802. if (Op != Ops.back())
  1803. write(", ");
  1804. }
  1805. write(")");
  1806. }
  1807. const std::string &getResult() {
  1808. Stream.flush();
  1809. return Str;
  1810. }
  1811. };
  1812. /// Generate remarks for matrix operations in a function. To generate remarks
  1813. /// for matrix expressions, the following approach is used:
  1814. /// 1. Use the inlined-at debug information to group matrix operations to the
  1815. /// DISubprograms they are contained in.
  1816. /// 2. Collect leaves of matrix expressions (done in
  1817. /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
  1818. // mapping. Leaves are lowered matrix instructions without other matrix
  1819. // users (like stores) in the current subprogram.
  1820. /// 3. For each leaf, create a remark containing a linearizied version of the
  1821. /// matrix expression. The expression is linearized by a recursive
  1822. /// bottom-up traversal of the matrix operands, starting at a leaf. Note
  1823. /// that multiple leaves can share sub-expressions. Shared subexpressions
  1824. /// are explicitly marked as shared().
  1825. struct RemarkGenerator {
  1826. const MapVector<Value *, MatrixTy> &Inst2Matrix;
  1827. OptimizationRemarkEmitter &ORE;
  1828. Function &Func;
  1829. const DataLayout &DL;
  1830. RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
  1831. OptimizationRemarkEmitter &ORE, Function &Func)
  1832. : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
  1833. DL(Func.getParent()->getDataLayout()) {}
  1834. /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
  1835. /// instructions in Inst2Matrix returning void or without any users in
  1836. /// \p ExprsInSubprogram. Currently that should only include stores.
  1837. SmallVector<Value *, 4>
  1838. getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
  1839. SmallVector<Value *, 4> Leaves;
  1840. for (auto *Expr : ExprsInSubprogram)
  1841. if (Expr->getType()->isVoidTy() ||
  1842. !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
  1843. return ExprsInSubprogram.count(U);
  1844. }))
  1845. Leaves.push_back(Expr);
  1846. return Leaves;
  1847. }
  1848. /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
  1849. /// to all visited expressions in \p Shared. Limit the matrix operations to
  1850. /// the ones in \p ExprsInSubprogram.
  1851. void collectSharedInfo(Value *Leaf, Value *V,
  1852. const SmallSetVector<Value *, 32> &ExprsInSubprogram,
  1853. DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
  1854. if (!ExprsInSubprogram.count(V))
  1855. return;
  1856. auto I = Shared.insert({V, {}});
  1857. I.first->second.insert(Leaf);
  1858. for (Value *Op : cast<Instruction>(V)->operand_values())
  1859. collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
  1860. }
  1861. /// Calculate the number of exclusive and shared op counts for expression
  1862. /// starting at \p V. Expressions used multiple times are counted once.
  1863. /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
  1864. std::pair<OpInfoTy, OpInfoTy>
  1865. sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
  1866. const SmallSetVector<Value *, 32> &ExprsInSubprogram,
  1867. DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
  1868. if (!ExprsInSubprogram.count(Root))
  1869. return {};
  1870. // Already counted this expression. Stop.
  1871. if (!ReusedExprs.insert(Root).second)
  1872. return {};
  1873. OpInfoTy SharedCount;
  1874. OpInfoTy Count;
  1875. auto I = Shared.find(Root);
  1876. auto CM = Inst2Matrix.find(Root);
  1877. if (I->second.size() == 1)
  1878. Count = CM->second.getOpInfo();
  1879. else
  1880. SharedCount = CM->second.getOpInfo();
  1881. for (Value *Op : cast<Instruction>(Root)->operand_values()) {
  1882. auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
  1883. Count += C.first;
  1884. SharedCount += C.second;
  1885. }
  1886. return {Count, SharedCount};
  1887. }
  1888. void emitRemarks() {
  1889. if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
  1890. return;
  1891. // Map matrix operations to their containting subprograms, by traversing
  1892. // the inlinedAt chain. If the function does not have a DISubprogram, we
  1893. // only map them to the containing function.
  1894. MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
  1895. for (auto &KV : Inst2Matrix) {
  1896. if (Func.getSubprogram()) {
  1897. auto *I = cast<Instruction>(KV.first);
  1898. DILocation *Context = I->getDebugLoc();
  1899. while (Context) {
  1900. auto I =
  1901. Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
  1902. I.first->second.push_back(KV.first);
  1903. Context = DebugLoc(Context).getInlinedAt();
  1904. }
  1905. } else {
  1906. auto I = Subprog2Exprs.insert({nullptr, {}});
  1907. I.first->second.push_back(KV.first);
  1908. }
  1909. }
  1910. for (auto &KV : Subprog2Exprs) {
  1911. SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
  1912. KV.second.end());
  1913. auto Leaves = getExpressionLeaves(ExprsInSubprogram);
  1914. DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
  1915. for (Value *Leaf : Leaves)
  1916. collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
  1917. // Generate remarks for each leaf.
  1918. for (auto *L : Leaves) {
  1919. DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
  1920. DILocation *Context = cast<Instruction>(L)->getDebugLoc();
  1921. while (Context) {
  1922. if (getSubprogram(Context->getScope()) == KV.first) {
  1923. Loc = Context;
  1924. break;
  1925. }
  1926. Context = DebugLoc(Context).getInlinedAt();
  1927. }
  1928. SmallPtrSet<Value *, 8> ReusedExprs;
  1929. OpInfoTy Counts, SharedCounts;
  1930. std::tie(Counts, SharedCounts) =
  1931. sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
  1932. OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
  1933. cast<Instruction>(L)->getParent());
  1934. Rem << "Lowered with ";
  1935. Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
  1936. << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
  1937. << ore::NV("NumComputeOps", Counts.NumComputeOps)
  1938. << " compute ops, "
  1939. << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
  1940. << " exposed transposes";
  1941. if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
  1942. SharedCounts.NumComputeOps > 0) {
  1943. Rem << ",\nadditionally "
  1944. << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
  1945. << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
  1946. << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
  1947. << " compute ops"
  1948. << " are shared with other expressions";
  1949. }
  1950. Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
  1951. ORE.emit(Rem);
  1952. }
  1953. }
  1954. }
  1955. std::string
  1956. linearize(Value *L,
  1957. const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
  1958. const SmallSetVector<Value *, 32> &ExprsInSubprogram,
  1959. const DataLayout &DL) {
  1960. ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
  1961. Lin.linearizeExpr(L, 0, false, false);
  1962. return Lin.getResult();
  1963. }
  1964. };
  1965. };
  1966. } // namespace
  1967. PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
  1968. FunctionAnalysisManager &AM) {
  1969. auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  1970. OptimizationRemarkEmitter *ORE = nullptr;
  1971. AAResults *AA = nullptr;
  1972. DominatorTree *DT = nullptr;
  1973. LoopInfo *LI = nullptr;
  1974. if (!Minimal) {
  1975. ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
  1976. AA = &AM.getResult<AAManager>(F);
  1977. DT = &AM.getResult<DominatorTreeAnalysis>(F);
  1978. LI = &AM.getResult<LoopAnalysis>(F);
  1979. }
  1980. LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
  1981. if (LMT.Visit()) {
  1982. PreservedAnalyses PA;
  1983. if (!Minimal) {
  1984. PA.preserve<LoopAnalysis>();
  1985. PA.preserve<DominatorTreeAnalysis>();
  1986. }
  1987. return PA;
  1988. }
  1989. return PreservedAnalyses::all();
  1990. }
  1991. void LowerMatrixIntrinsicsPass::printPipeline(
  1992. raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
  1993. static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
  1994. OS, MapClassName2PassName);
  1995. OS << "<";
  1996. if (Minimal)
  1997. OS << "minimal";
  1998. OS << ">";
  1999. }
  2000. namespace {
  2001. class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
  2002. public:
  2003. static char ID;
  2004. LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
  2005. initializeLowerMatrixIntrinsicsLegacyPassPass(
  2006. *PassRegistry::getPassRegistry());
  2007. }
  2008. bool runOnFunction(Function &F) override {
  2009. auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  2010. auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
  2011. auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
  2012. auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  2013. auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  2014. LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
  2015. bool C = LMT.Visit();
  2016. return C;
  2017. }
  2018. void getAnalysisUsage(AnalysisUsage &AU) const override {
  2019. AU.addRequired<TargetTransformInfoWrapperPass>();
  2020. AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
  2021. AU.addRequired<AAResultsWrapperPass>();
  2022. AU.addRequired<DominatorTreeWrapperPass>();
  2023. AU.addPreserved<DominatorTreeWrapperPass>();
  2024. AU.addRequired<LoopInfoWrapperPass>();
  2025. AU.addPreserved<LoopInfoWrapperPass>();
  2026. }
  2027. };
  2028. } // namespace
  2029. static const char pass_name[] = "Lower the matrix intrinsics";
  2030. char LowerMatrixIntrinsicsLegacyPass::ID = 0;
  2031. INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
  2032. false, false)
  2033. INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
  2034. INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
  2035. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  2036. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  2037. INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
  2038. false, false)
  2039. Pass *llvm::createLowerMatrixIntrinsicsPass() {
  2040. return new LowerMatrixIntrinsicsLegacyPass();
  2041. }
  2042. namespace {
  2043. /// A lightweight version of the matrix lowering pass that only requires TTI.
  2044. /// Advanced features that require DT, AA or ORE like tiling are disabled. This
  2045. /// is used to lower matrix intrinsics if the main lowering pass is not run, for
  2046. /// example with -O0.
  2047. class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
  2048. public:
  2049. static char ID;
  2050. LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
  2051. initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(
  2052. *PassRegistry::getPassRegistry());
  2053. }
  2054. bool runOnFunction(Function &F) override {
  2055. auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  2056. LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
  2057. bool C = LMT.Visit();
  2058. return C;
  2059. }
  2060. void getAnalysisUsage(AnalysisUsage &AU) const override {
  2061. AU.addRequired<TargetTransformInfoWrapperPass>();
  2062. AU.setPreservesCFG();
  2063. }
  2064. };
  2065. } // namespace
  2066. static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
  2067. char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
  2068. INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
  2069. "lower-matrix-intrinsics-minimal", pass_name_minimal,
  2070. false, false)
  2071. INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
  2072. "lower-matrix-intrinsics-minimal", pass_name_minimal, false,
  2073. false)
  2074. Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() {
  2075. return new LowerMatrixIntrinsicsMinimalLegacyPass();
  2076. }