12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495 |
- //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- // Lower matrix intrinsics to vector operations.
- //
- // TODO:
- // * Improve fusion:
- // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
- // transposed.
- // * Improve cost-modeling, e.g. choose different number of rows/columns
- // columns for tiles, consider cost of copies on alias.
- //
- //===----------------------------------------------------------------------===//
- #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
- #include "llvm/ADT/PostOrderIterator.h"
- #include "llvm/ADT/SmallVector.h"
- #include "llvm/Analysis/AliasAnalysis.h"
- #include "llvm/Analysis/DomTreeUpdater.h"
- #include "llvm/Analysis/LoopInfo.h"
- #include "llvm/Analysis/OptimizationRemarkEmitter.h"
- #include "llvm/Analysis/TargetTransformInfo.h"
- #include "llvm/Analysis/ValueTracking.h"
- #include "llvm/Analysis/VectorUtils.h"
- #include "llvm/IR/CFG.h"
- #include "llvm/IR/DataLayout.h"
- #include "llvm/IR/DebugInfoMetadata.h"
- #include "llvm/IR/Function.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/IR/IntrinsicInst.h"
- #include "llvm/IR/MatrixBuilder.h"
- #include "llvm/IR/PatternMatch.h"
- #include "llvm/InitializePasses.h"
- #include "llvm/Pass.h"
- #include "llvm/Support/Alignment.h"
- #include "llvm/Support/CommandLine.h"
- #include "llvm/Support/Debug.h"
- #include "llvm/Transforms/Scalar.h"
- #include "llvm/Transforms/Utils/BasicBlockUtils.h"
- #include "llvm/Transforms/Utils/LoopUtils.h"
- #include "llvm/Transforms/Utils/MatrixUtils.h"
- #include <cmath>
- using namespace llvm;
- using namespace PatternMatch;
- #define DEBUG_TYPE "lower-matrix-intrinsics"
- static cl::opt<bool>
- FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
- cl::desc("Enable/disable fusing matrix instructions."));
- // TODO: Allow and use non-square tiles.
- static cl::opt<unsigned> TileSize(
- "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
- cl::desc(
- "Tile size for matrix instruction fusion using square-shaped tiles."));
- static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
- cl::Hidden,
- cl::desc("Generate loop nest for tiling."));
- static cl::opt<bool> ForceFusion(
- "force-fuse-matrix", cl::init(false), cl::Hidden,
- cl::desc("Force matrix instruction fusion even if not profitable."));
- static cl::opt<bool> AllowContractEnabled(
- "matrix-allow-contract", cl::init(false), cl::Hidden,
- cl::desc("Allow the use of FMAs if available and profitable. This may "
- "result in different results, due to less rounding error."));
- enum class MatrixLayoutTy { ColumnMajor, RowMajor };
- static cl::opt<MatrixLayoutTy> MatrixLayout(
- "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
- cl::desc("Sets the default matrix layout"),
- cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
- "Use column-major layout"),
- clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
- "Use row-major layout")));
- static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
- cl::init(false));
- /// Helper function to either return Scope, if it is a subprogram or the
- /// attached subprogram for a local scope.
- static DISubprogram *getSubprogram(DIScope *Scope) {
- if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
- return Subprogram;
- return cast<DILocalScope>(Scope)->getSubprogram();
- }
- /// Erase \p V from \p BB and move \II forward to avoid invalidating
- /// iterators.
- static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
- BasicBlock &BB) {
- auto *Inst = cast<Instruction>(V);
- // Still used, don't erase.
- if (!Inst->use_empty())
- return;
- if (II != BB.rend() && Inst == &*II)
- ++II;
- Inst->eraseFromParent();
- }
- /// Return true if V is a splat of a value (which is used when multiplying a
- /// matrix with a scalar).
- static bool isSplat(Value *V) {
- if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
- return SV->isZeroEltSplat();
- return false;
- }
- /// Match any mul operation (fp or integer).
- template <typename LTy, typename RTy>
- auto m_AnyMul(const LTy &L, const RTy &R) {
- return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
- }
- /// Match any add operation (fp or integer).
- template <typename LTy, typename RTy>
- auto m_AnyAdd(const LTy &L, const RTy &R) {
- return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
- }
- namespace {
- // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
- // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
- // assuming \p Stride elements between start two consecutive vectors.
- // \p Stride must be >= \p NumElements.
- // For column-major matrixes, the function computes the address of a column
- // vectors and \p NumElements must be set to the number of elements in a column
- // (= number of rows of the matrix). For row-major matrixes, the function
- // computes the address of a row vector and \p NumElements must be set to the
- // number of elements in a column (= number of columns of the matrix).
- //
- // Consider a 4x4 matrix in column-mjaor layout like below
- //
- // 0 1 2 3
- // 0 v_0_0 v_0_1 v_0_2 v_0_3
- // 1 v_1_0 v_1_1 v_1_2 v_1_3
- // 2 v_2_0 v_2_1 v_2_2 v_2_3
- // 3 v_3_0 v_3_1 v_3_2 v_3_3
- // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
- // we need a pointer to the first element of the submatrix as base pointer.
- // Then we can use computeVectorAddr to compute the addresses for the columns
- // of the sub-matrix.
- //
- // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
- // -> just returns Base
- // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
- // -> returns Base + (1 * 4)
- // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
- // -> returns Base + (2 * 4)
- //
- // The graphic below illustrates the number of elements in a column (marked
- // with |) and the number of skipped elements (marked with }).
- //
- // v_0_0 v_0_1 {v_0_2 {v_0_3
- // Base Col 1 Col 2
- // | | |
- // v_1_0 |v_1_1 |v_1_2 |v_1_3
- // v_2_0 |v_2_1 |v_2_2 |v_2_3
- // v_3_0 {v_3_1 {v_3_2 v_3_3
- //
- Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
- unsigned NumElements, Type *EltType,
- IRBuilder<> &Builder) {
- assert((!isa<ConstantInt>(Stride) ||
- cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
- "Stride must be >= the number of elements in the result vector.");
- unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
- // Compute the start of the vector with index VecIdx as VecIdx * Stride.
- Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
- // Get pointer to the start of the selected vector. Skip GEP creation,
- // if we select vector 0.
- if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
- VecStart = BasePtr;
- else
- VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
- // Cast elementwise vector start pointer to a pointer to a vector
- // (EltType x NumElements)*.
- auto *VecType = FixedVectorType::get(EltType, NumElements);
- Type *VecPtrType = PointerType::get(VecType, AS);
- return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
- }
- /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
- ///
- /// Currently, the lowering for each matrix intrinsic is done as follows:
- /// 1. Propagate the shape information from intrinsics to connected
- /// instructions.
- /// 2. Lower instructions with shape information (assuming column-major layout).
- /// The lowering works similarly using row-major layout.
- /// 2.1. Get column vectors for each argument. If we already lowered the
- /// definition of an argument, use the produced column vectors directly.
- /// If not, split the operand vector containing an embedded matrix into
- /// a set of column vectors,
- /// 2.2. Lower the instruction in terms of column major operations, which
- /// yields a set of column vectors containing result matrix. Note that we
- /// lower all instructions that have shape information. Besides the
- /// intrinsics, this includes stores for example.
- /// 2.3. Update uses of the lowered instruction. If we have shape information
- /// for a user, there is nothing to do, as we will look up the result
- /// column matrix when lowering the user. For other uses, we embed the
- /// result matrix in a flat vector and update the use.
- /// 2.4. Cache the result column matrix for the instruction we lowered
- /// 3. After we lowered all instructions in a function, remove the now
- /// obsolete instructions.
- ///
- class LowerMatrixIntrinsics {
- Function &Func;
- const DataLayout &DL;
- const TargetTransformInfo &TTI;
- AliasAnalysis *AA;
- DominatorTree *DT;
- LoopInfo *LI;
- OptimizationRemarkEmitter *ORE;
- /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
- struct OpInfoTy {
- /// Number of stores emitted to generate this matrix.
- unsigned NumStores = 0;
- /// Number of loads emitted to generate this matrix.
- unsigned NumLoads = 0;
- /// Number of compute operations emitted to generate this matrix.
- unsigned NumComputeOps = 0;
- /// Most of the time transposes can be fused with matrix multiplies or can
- /// be folded away via algebraic simplifications. This is the number of
- /// transposes that we failed to make "free" via such optimizations.
- unsigned NumExposedTransposes = 0;
- OpInfoTy &operator+=(const OpInfoTy &RHS) {
- NumStores += RHS.NumStores;
- NumLoads += RHS.NumLoads;
- NumComputeOps += RHS.NumComputeOps;
- NumExposedTransposes += RHS.NumExposedTransposes;
- return *this;
- }
- };
- /// Wrapper class representing a matrix as a set of vectors, either in row or
- /// column major layout. All vectors must have the same vector type.
- class MatrixTy {
- SmallVector<Value *, 16> Vectors;
- OpInfoTy OpInfo;
- bool IsColumnMajor = true;
- public:
- MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
- MatrixTy(ArrayRef<Value *> Vectors)
- : Vectors(Vectors.begin(), Vectors.end()),
- IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
- MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
- : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
- unsigned D = isColumnMajor() ? NumColumns : NumRows;
- for (unsigned J = 0; J < D; ++J)
- addVector(UndefValue::get(FixedVectorType::get(
- EltTy, isColumnMajor() ? NumRows : NumColumns)));
- }
- Value *getVector(unsigned i) const { return Vectors[i]; }
- Value *getColumn(unsigned i) const {
- assert(isColumnMajor() && "only supported for column-major matrixes");
- return Vectors[i];
- }
- Value *getRow(unsigned i) const {
- assert(!isColumnMajor() && "only supported for row-major matrixes");
- return Vectors[i];
- }
- void setVector(unsigned i, Value *V) { Vectors[i] = V; }
- Type *getElementType() const { return getVectorTy()->getElementType(); }
- unsigned getNumVectors() const {
- if (isColumnMajor())
- return getNumColumns();
- return getNumRows();
- }
- unsigned getNumColumns() const {
- if (isColumnMajor())
- return Vectors.size();
- else {
- assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
- return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
- }
- }
- unsigned getNumRows() const {
- if (isColumnMajor()) {
- assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
- return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
- } else
- return Vectors.size();
- }
- void addVector(Value *V) { Vectors.push_back(V); }
- VectorType *getColumnTy() {
- assert(isColumnMajor() && "only supported for column-major matrixes");
- return getVectorTy();
- }
- VectorType *getVectorTy() const {
- return cast<VectorType>(Vectors[0]->getType());
- }
- iterator_range<SmallVector<Value *, 8>::iterator> columns() {
- assert(isColumnMajor() &&
- "columns() only supported for column-major matrixes");
- return make_range(Vectors.begin(), Vectors.end());
- }
- iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
- return make_range(Vectors.begin(), Vectors.end());
- }
- /// Embed the vectors of the matrix into a flat vector by concatenating
- /// them.
- Value *embedInVector(IRBuilder<> &Builder) const {
- return Vectors.size() == 1 ? Vectors[0]
- : concatenateVectors(Builder, Vectors);
- }
- MatrixTy &addNumLoads(unsigned N) {
- OpInfo.NumLoads += N;
- return *this;
- }
- void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
- MatrixTy &addNumStores(unsigned N) {
- OpInfo.NumStores += N;
- return *this;
- }
- MatrixTy &addNumExposedTransposes(unsigned N) {
- OpInfo.NumExposedTransposes += N;
- return *this;
- }
- MatrixTy &addNumComputeOps(unsigned N) {
- OpInfo.NumComputeOps += N;
- return *this;
- }
- unsigned getNumStores() const { return OpInfo.NumStores; }
- unsigned getNumLoads() const { return OpInfo.NumLoads; }
- unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
- const OpInfoTy &getOpInfo() const { return OpInfo; }
- bool isColumnMajor() const { return IsColumnMajor; }
- unsigned getStride() const {
- if (isColumnMajor())
- return getNumRows();
- return getNumColumns();
- }
- /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
- /// matrix is column-major, the result vector is extracted from a column
- /// vector, otherwise from a row vector.
- Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
- IRBuilder<> &Builder) const {
- Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
- assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
- NumElts &&
- "Extracted vector will contain poison values");
- return Builder.CreateShuffleVector(
- Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
- "block");
- }
- };
- struct ShapeInfo {
- unsigned NumRows;
- unsigned NumColumns;
- bool IsColumnMajor;
- ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
- : NumRows(NumRows), NumColumns(NumColumns),
- IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
- ShapeInfo(Value *NumRows, Value *NumColumns)
- : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
- cast<ConstantInt>(NumColumns)->getZExtValue()) {}
- bool operator==(const ShapeInfo &other) {
- return NumRows == other.NumRows && NumColumns == other.NumColumns;
- }
- bool operator!=(const ShapeInfo &other) { return !(*this == other); }
- /// Returns true if shape-information is defined, meaning both dimensions
- /// are != 0.
- operator bool() const {
- assert(NumRows == 0 || NumColumns != 0);
- return NumRows != 0;
- }
- unsigned getStride() const {
- if (IsColumnMajor)
- return NumRows;
- return NumColumns;
- }
- unsigned getNumVectors() const {
- if (IsColumnMajor)
- return NumColumns;
- return NumRows;
- }
- /// Returns the transposed shape.
- ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
- };
- /// Maps instructions to their shape information. The shape information
- /// describes the shape to be used while lowering. This matches the shape of
- /// the result value of the instruction, with the only exceptions being store
- /// instructions and the matrix_column_major_store intrinsics. For those, the
- /// shape information indicates that those instructions should be lowered
- /// using shape information as well. A ValueMap is used so that when
- /// sub-passes like optimizeTransposes performs RAUW the map stays
- /// up-to-date.
- ValueMap<Value *, ShapeInfo> ShapeMap;
- /// List of instructions to remove. While lowering, we are not replacing all
- /// users of a lowered instruction, if shape information is available and
- /// those need to be removed after we finished lowering.
- SmallVector<Instruction *, 16> ToRemove;
- /// Map from instructions to their produced column matrix.
- MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
- private:
- static FastMathFlags getFastMathFlags(Instruction *Inst) {
- FastMathFlags FMF;
- if (isa<FPMathOperator>(*Inst))
- FMF = Inst->getFastMathFlags();
- FMF.setAllowContract(AllowContractEnabled || FMF.allowContract());
- return FMF;
- }
- public:
- LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
- AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
- OptimizationRemarkEmitter *ORE)
- : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
- LI(LI), ORE(ORE) {}
- unsigned getNumOps(Type *VT) {
- assert(isa<VectorType>(VT) && "Expected vector type");
- return getNumOps(VT->getScalarType(),
- cast<FixedVectorType>(VT)->getNumElements());
- }
- /// Is this the minimal version executed in the backend pipelines.
- bool isMinimal() const {
- return !DT;
- }
- /// Return the estimated number of vector ops required for an operation on
- /// \p VT * N.
- unsigned getNumOps(Type *ST, unsigned N) {
- return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
- double(TTI.getRegisterBitWidth(
- TargetTransformInfo::RGK_FixedWidthVector)
- .getFixedValue()));
- }
- /// Return the set of vectors that a matrix value is lowered to.
- ///
- /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
- /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
- /// into vectors.
- MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
- IRBuilder<> &Builder) {
- VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
- assert(VType && "MatrixVal must be a vector type");
- assert(cast<FixedVectorType>(VType)->getNumElements() ==
- SI.NumRows * SI.NumColumns &&
- "The vector size must match the number of matrix elements");
- // Check if we lowered MatrixVal using shape information. In that case,
- // return the existing matrix, if it matches the requested shape
- // information. If there is a mis-match, embed the result in a flat
- // vector and split it later.
- auto Found = Inst2ColumnMatrix.find(MatrixVal);
- if (Found != Inst2ColumnMatrix.end()) {
- MatrixTy &M = Found->second;
- // Return the found matrix, if its shape matches the requested shape
- // information
- if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
- return M;
- MatrixVal = M.embedInVector(Builder);
- }
- // Otherwise split MatrixVal.
- SmallVector<Value *, 16> SplitVecs;
- for (unsigned MaskStart = 0;
- MaskStart < cast<FixedVectorType>(VType)->getNumElements();
- MaskStart += SI.getStride()) {
- Value *V = Builder.CreateShuffleVector(
- MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
- "split");
- SplitVecs.push_back(V);
- }
- return {SplitVecs};
- }
- /// If \p V already has a known shape return false. Otherwise set the shape
- /// for instructions that support it.
- bool setShapeInfo(Value *V, ShapeInfo Shape) {
- assert(Shape && "Shape not set");
- if (isa<UndefValue>(V) || !supportsShapeInfo(V))
- return false;
- auto SIter = ShapeMap.find(V);
- if (SIter != ShapeMap.end()) {
- LLVM_DEBUG(dbgs() << " not overriding existing shape: "
- << SIter->second.NumRows << " "
- << SIter->second.NumColumns << " for " << *V << "\n");
- return false;
- }
- ShapeMap.insert({V, Shape});
- LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
- << " for " << *V << "\n");
- return true;
- }
- bool isUniformShape(Value *V) {
- Instruction *I = dyn_cast<Instruction>(V);
- if (!I)
- return true;
- switch (I->getOpcode()) {
- case Instruction::FAdd:
- case Instruction::FSub:
- case Instruction::FMul: // Scalar multiply.
- case Instruction::FNeg:
- case Instruction::Add:
- case Instruction::Mul:
- case Instruction::Sub:
- return true;
- default:
- return false;
- }
- }
- /// Returns true if shape information can be used for \p V. The supported
- /// instructions must match the instructions that can be lowered by this pass.
- bool supportsShapeInfo(Value *V) {
- Instruction *Inst = dyn_cast<Instruction>(V);
- if (!Inst)
- return false;
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
- if (II)
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_column_major_load:
- case Intrinsic::matrix_column_major_store:
- return true;
- default:
- return false;
- }
- return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
- }
- /// Propagate the shape information of instructions to their users.
- /// The work list contains instructions for which we can compute the shape,
- /// either based on the information provided by matrix intrinsics or known
- /// shapes of operands.
- SmallVector<Instruction *, 32>
- propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
- SmallVector<Instruction *, 32> NewWorkList;
- // Pop an element for which we guaranteed to have at least one of the
- // operand shapes. Add the shape for this and then add users to the work
- // list.
- LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
- while (!WorkList.empty()) {
- Instruction *Inst = WorkList.pop_back_val();
- // New entry, set the value and insert operands
- bool Propagate = false;
- Value *MatrixA;
- Value *MatrixB;
- Value *M;
- Value *N;
- Value *K;
- if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
- m_Value(N), m_Value(K)))) {
- Propagate = setShapeInfo(Inst, {M, K});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(MatrixA), m_Value(M), m_Value(N)))) {
- // Flip dimensions.
- Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
- m_Value(MatrixA), m_Value(), m_Value(),
- m_Value(), m_Value(M), m_Value(N)))) {
- Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
- m_Value(), m_Value(), m_Value(), m_Value(M),
- m_Value(N)))) {
- Propagate = setShapeInfo(Inst, {M, N});
- } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
- auto OpShape = ShapeMap.find(MatrixA);
- if (OpShape != ShapeMap.end())
- setShapeInfo(Inst, OpShape->second);
- continue;
- } else if (isUniformShape(Inst)) {
- // Find the first operand that has a known shape and use that.
- for (auto &Op : Inst->operands()) {
- auto OpShape = ShapeMap.find(Op.get());
- if (OpShape != ShapeMap.end()) {
- Propagate |= setShapeInfo(Inst, OpShape->second);
- break;
- }
- }
- }
- if (Propagate) {
- NewWorkList.push_back(Inst);
- for (auto *User : Inst->users())
- if (ShapeMap.count(User) == 0)
- WorkList.push_back(cast<Instruction>(User));
- }
- }
- return NewWorkList;
- }
- /// Propagate the shape to operands of instructions with shape information.
- /// \p Worklist contains the instruction for which we already know the shape.
- SmallVector<Instruction *, 32>
- propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
- SmallVector<Instruction *, 32> NewWorkList;
- auto pushInstruction = [](Value *V,
- SmallVectorImpl<Instruction *> &WorkList) {
- Instruction *I = dyn_cast<Instruction>(V);
- if (I)
- WorkList.push_back(I);
- };
- // Pop an element with known shape. Traverse the operands, if their shape
- // derives from the result shape and is unknown, add it and add them to the
- // worklist.
- LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
- while (!WorkList.empty()) {
- Value *V = WorkList.pop_back_val();
- size_t BeforeProcessingV = WorkList.size();
- if (!isa<Instruction>(V))
- continue;
- Value *MatrixA;
- Value *MatrixB;
- Value *M;
- Value *N;
- Value *K;
- if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
- m_Value(N), m_Value(K)))) {
- if (setShapeInfo(MatrixA, {M, N}))
- pushInstruction(MatrixA, WorkList);
- if (setShapeInfo(MatrixB, {N, K}))
- pushInstruction(MatrixB, WorkList);
- } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(MatrixA), m_Value(M), m_Value(N)))) {
- // Flip dimensions.
- if (setShapeInfo(MatrixA, {M, N}))
- pushInstruction(MatrixA, WorkList);
- } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
- m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
- m_Value(M), m_Value(N)))) {
- if (setShapeInfo(MatrixA, {M, N})) {
- pushInstruction(MatrixA, WorkList);
- }
- } else if (isa<LoadInst>(V) ||
- match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
- // Nothing to do, no matrix input.
- } else if (isa<StoreInst>(V)) {
- // Nothing to do. We forward-propagated to this so we would just
- // backward propagate to an instruction with an already known shape.
- } else if (isUniformShape(V)) {
- // Propagate to all operands.
- ShapeInfo Shape = ShapeMap[V];
- for (Use &U : cast<Instruction>(V)->operands()) {
- if (setShapeInfo(U.get(), Shape))
- pushInstruction(U.get(), WorkList);
- }
- }
- // After we discovered new shape info for new instructions in the
- // worklist, we use their users as seeds for the next round of forward
- // propagation.
- for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
- for (User *U : WorkList[I]->users())
- if (isa<Instruction>(U) && V != U)
- NewWorkList.push_back(cast<Instruction>(U));
- }
- return NewWorkList;
- }
- /// (Op0 op Op1)^T -> Op0^T op Op1^T
- /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
- /// them on both sides of \p Operation.
- Instruction *distributeTransposes(
- Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
- MatrixBuilder &Builder,
- function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
- Operation) {
- Value *T0 = Builder.CreateMatrixTranspose(
- Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
- // We are being run after shape prop, add shape for newly created
- // instructions so that we lower them later.
- setShapeInfo(T0, Shape0.t());
- Value *T1 = Builder.CreateMatrixTranspose(
- Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
- setShapeInfo(T1, Shape1.t());
- return Operation(T0, Shape0.t(), T1, Shape1.t());
- }
- void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
- // We need to remove Old from the ShapeMap otherwise RAUW will replace it
- // with New. We should only add New it it supportsShapeInfo so we insert
- // it conditionally instead.
- auto S = ShapeMap.find(&Old);
- if (S != ShapeMap.end()) {
- ShapeMap.erase(S);
- if (supportsShapeInfo(New))
- ShapeMap.insert({New, S->second});
- }
- Old.replaceAllUsesWith(New);
- }
- /// Sink a top-level transpose inside matmuls and adds.
- /// This creates and erases instructions as needed, and returns the newly
- /// created instruction while updating the iterator to avoid invalidation. If
- /// this returns nullptr, no new instruction was created.
- Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) {
- BasicBlock &BB = *I.getParent();
- IRBuilder<> IB(&I);
- MatrixBuilder Builder(IB);
- Value *TA, *TAMA, *TAMB;
- ConstantInt *R, *K, *C;
- if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(TA), m_ConstantInt(R), m_ConstantInt(C))))
- return nullptr;
- // Transpose of a transpose is a nop
- Value *TATA;
- if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
- updateShapeAndReplaceAllUsesWith(I, TATA);
- eraseFromParentAndMove(&I, II, BB);
- eraseFromParentAndMove(TA, II, BB);
- return nullptr;
- }
- // k^T -> k
- if (isSplat(TA)) {
- updateShapeAndReplaceAllUsesWith(I, TA);
- eraseFromParentAndMove(&I, II, BB);
- return nullptr;
- }
- // (A * B)^t -> B^t * A^t
- // RxK KxC CxK KxR
- if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
- m_ConstantInt(K), m_ConstantInt(C)))) {
- auto NewInst = distributeTransposes(
- TAMB, {K, C}, TAMA, {R, K}, Builder,
- [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
- return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
- Shape0.NumColumns,
- Shape1.NumColumns, "mmul");
- });
- updateShapeAndReplaceAllUsesWith(I, NewInst);
- eraseFromParentAndMove(&I, II, BB);
- eraseFromParentAndMove(TA, II, BB);
- return NewInst;
- }
- // Same as above, but with a mul, which occurs when multiplied
- // with a scalar.
- // (A * k)^t -> A^t * k
- // R x C RxC
- if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
- (isSplat(TAMA) || isSplat(TAMB))) {
- IRBuilder<> LocalBuilder(&I);
- // We know that the transposed operand is of shape RxC.
- // An when multiplied with a scalar, the shape is preserved.
- auto NewInst = distributeTransposes(
- TAMA, {R, C}, TAMB, {R, C}, Builder,
- [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
- bool IsFP = I.getType()->isFPOrFPVectorTy();
- auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
- : LocalBuilder.CreateMul(T0, T1, "mmul");
- auto *Result = cast<Instruction>(Mul);
- setShapeInfo(Result, Shape0);
- return Result;
- });
- updateShapeAndReplaceAllUsesWith(I, NewInst);
- eraseFromParentAndMove(&I, II, BB);
- eraseFromParentAndMove(TA, II, BB);
- return NewInst;
- }
- // (A + B)^t -> A^t + B^t
- // RxC RxC CxR CxR
- if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
- IRBuilder<> LocalBuilder(&I);
- auto NewInst = distributeTransposes(
- TAMA, {R, C}, TAMB, {R, C}, Builder,
- [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
- auto *FAdd =
- cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
- setShapeInfo(FAdd, Shape0);
- return FAdd;
- });
- updateShapeAndReplaceAllUsesWith(I, NewInst);
- eraseFromParentAndMove(&I, II, BB);
- eraseFromParentAndMove(TA, II, BB);
- return NewInst;
- }
- return nullptr;
- }
- void liftTranspose(Instruction &I) {
- // Erase dead Instructions after lifting transposes from binops.
- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
- if (T.use_empty())
- T.eraseFromParent();
- if (A->use_empty())
- cast<Instruction>(A)->eraseFromParent();
- if (A != B && B->use_empty())
- cast<Instruction>(B)->eraseFromParent();
- };
- Value *A, *B, *AT, *BT;
- ConstantInt *R, *K, *C;
- // A^t * B ^t -> (B * A)^t
- if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(A), m_Value(B), m_ConstantInt(R),
- m_ConstantInt(K), m_ConstantInt(C))) &&
- match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
- match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
- IRBuilder<> IB(&I);
- MatrixBuilder Builder(IB);
- Value *M = Builder.CreateMatrixMultiply(
- BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
- setShapeInfo(M, {C, R});
- Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
- R->getZExtValue());
- updateShapeAndReplaceAllUsesWith(I, NewInst);
- CleanupBinOp(I, A, B);
- }
- // A^t + B ^t -> (A + B)^t
- else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
- match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
- match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
- IRBuilder<> Builder(&I);
- Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
- setShapeInfo(Add, {C, R});
- MatrixBuilder MBuilder(Builder);
- Instruction *NewInst = MBuilder.CreateMatrixTranspose(
- Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
- updateShapeAndReplaceAllUsesWith(I, NewInst);
- CleanupBinOp(I, A, B);
- }
- }
- /// Try moving transposes in order to fold them away or into multiplies.
- void optimizeTransposes() {
- // First sink all transposes inside matmuls and adds, hoping that we end up
- // with NN, NT or TN variants.
- for (BasicBlock &BB : reverse(Func)) {
- for (auto II = BB.rbegin(); II != BB.rend();) {
- Instruction &I = *II;
- // We may remove II. By default continue on the next/prev instruction.
- ++II;
- if (Instruction *NewInst = sinkTranspose(I, II))
- II = std::next(BasicBlock::reverse_iterator(NewInst));
- }
- }
- // If we have a TT matmul or a TT add, lift the transpose. We may be able
- // to fold into consuming multiply or add.
- for (BasicBlock &BB : Func) {
- for (Instruction &I : llvm::make_early_inc_range(BB)) {
- liftTranspose(I);
- }
- }
- }
- bool Visit() {
- SmallVector<Instruction *, 32> WorkList;
- // Initially only the shape of matrix intrinsics is known.
- // Initialize the work list with ops carrying shape information.
- for (BasicBlock &BB : Func)
- for (Instruction &Inst : BB) {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
- if (!II)
- continue;
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_column_major_load:
- case Intrinsic::matrix_column_major_store:
- WorkList.push_back(&Inst);
- break;
- default:
- break;
- }
- }
- // Avoid unnecessary work if there are no matrix intrinsics in the function.
- if (WorkList.empty())
- return false;
- // Propagate shapes until nothing changes any longer.
- while (!WorkList.empty()) {
- WorkList = propagateShapeForward(WorkList);
- WorkList = propagateShapeBackward(WorkList);
- }
- if (!isMinimal()) {
- optimizeTransposes();
- if (PrintAfterTransposeOpt) {
- dbgs() << "Dump after matrix transpose optimization:\n";
- Func.print(dbgs());
- }
- }
- bool Changed = false;
- SmallVector<CallInst *, 16> MaybeFusableInsts;
- SmallVector<Instruction *, 16> MatrixInsts;
- // First, collect all instructions with shape information and candidates for
- // fusion (currently only matrix multiplies).
- ReversePostOrderTraversal<Function *> RPOT(&Func);
- for (auto *BB : RPOT)
- for (Instruction &I : *BB) {
- if (ShapeMap.find(&I) == ShapeMap.end())
- continue;
- if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
- MaybeFusableInsts.push_back(cast<CallInst>(&I));
- MatrixInsts.push_back(&I);
- }
- // Second, try to fuse candidates.
- SmallPtrSet<Instruction *, 16> FusedInsts;
- for (CallInst *CI : MaybeFusableInsts)
- LowerMatrixMultiplyFused(CI, FusedInsts);
- Changed = !FusedInsts.empty();
- // Third, lower remaining instructions with shape information.
- for (Instruction *Inst : MatrixInsts) {
- if (FusedInsts.count(Inst))
- continue;
- IRBuilder<> Builder(Inst);
- if (CallInst *CInst = dyn_cast<CallInst>(Inst))
- Changed |= VisitCallInst(CInst);
- Value *Op1;
- Value *Op2;
- if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
- Changed |= VisitBinaryOperator(BinOp);
- if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
- Changed |= VisitUnaryOperator(UnOp);
- if (match(Inst, m_Load(m_Value(Op1))))
- Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
- else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
- }
- if (ORE) {
- RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
- RemarkGen.emitRemarks();
- }
- // Delete the instructions backwards, as it has a reduced likelihood of
- // having to update as many def-use and use-def chains.
- //
- // Because we add to ToRemove during fusion we can't guarantee that defs
- // are before uses. Change uses to poison temporarily as these should get
- // removed as well.
- //
- // For verification, we keep track of where we changed uses to poison in
- // PoisonedInsts and then check that we in fact remove them.
- SmallSet<Instruction *, 16> PoisonedInsts;
- for (auto *Inst : reverse(ToRemove)) {
- for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
- if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
- PoisonedInsts.insert(Poisoned);
- U.set(PoisonValue::get(Inst->getType()));
- }
- Inst->eraseFromParent();
- PoisonedInsts.erase(Inst);
- }
- if (!PoisonedInsts.empty()) {
- // If we didn't remove all poisoned instructions, it's a hard error.
- dbgs() << "Poisoned but present instructions:\n";
- for (auto *I : PoisonedInsts)
- dbgs() << *I << "\n";
- llvm_unreachable("Poisoned but instruction not removed");
- }
- return Changed;
- }
- /// Turns \p BasePtr into an elementwise pointer to \p EltType.
- Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
- unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
- Type *EltPtrType = PointerType::get(EltType, AS);
- return Builder.CreatePointerCast(BasePtr, EltPtrType);
- }
- /// Replace intrinsic calls
- bool VisitCallInst(CallInst *Inst) {
- if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
- return false;
- switch (Inst->getCalledFunction()->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- LowerMultiply(Inst);
- break;
- case Intrinsic::matrix_transpose:
- LowerTranspose(Inst);
- break;
- case Intrinsic::matrix_column_major_load:
- LowerColumnMajorLoad(Inst);
- break;
- case Intrinsic::matrix_column_major_store:
- LowerColumnMajorStore(Inst);
- break;
- default:
- return false;
- }
- return true;
- }
- /// Compute the alignment for a column/row \p Idx with \p Stride between them.
- /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
- /// ConstantInt, reduce the initial alignment based on the byte offset. For
- /// non-ConstantInt strides, return the common alignment of the initial
- /// alignment and the element size in bytes.
- Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
- MaybeAlign A) const {
- Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
- if (Idx == 0)
- return InitialAlign;
- TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
- if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
- uint64_t StrideInBytes =
- ConstStride->getZExtValue() * ElementSizeInBits / 8;
- return commonAlignment(InitialAlign, Idx * StrideInBytes);
- }
- return commonAlignment(InitialAlign, ElementSizeInBits / 8);
- }
- /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
- /// vectors.
- MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
- bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
- auto *VType = cast<VectorType>(Ty);
- Type *EltTy = VType->getElementType();
- Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
- Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
- MatrixTy Result;
- for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
- Value *GEP = computeVectorAddr(
- EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
- Stride, Shape.getStride(), EltTy, Builder);
- Value *Vector = Builder.CreateAlignedLoad(
- VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
- IsVolatile, "col.load");
- Result.addVector(Vector);
- }
- return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors());
- }
- /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
- /// starting at \p MatrixPtr[I][J].
- MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
- ShapeInfo MatrixShape, Value *I, Value *J,
- ShapeInfo ResultShape, Type *EltTy,
- IRBuilder<> &Builder) {
- Value *Offset = Builder.CreateAdd(
- Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
- unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
- Value *EltPtr =
- Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
- Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
- auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
- ResultShape.NumColumns);
- Type *TilePtrTy = PointerType::get(TileTy, AS);
- Value *TilePtr =
- Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
- return loadMatrix(TileTy, TilePtr, Align,
- Builder.getInt64(MatrixShape.getStride()), IsVolatile,
- ResultShape, Builder);
- }
- /// Lower a load instruction with shape information.
- void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
- bool IsVolatile, ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- finalizeLowering(Inst,
- loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
- Shape, Builder),
- Builder);
- }
- /// Lowers llvm.matrix.column.major.load.
- ///
- /// The intrinsic loads a matrix from memory using a stride between columns.
- void LowerColumnMajorLoad(CallInst *Inst) {
- assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
- "Intrinsic only supports column-major layout!");
- Value *Ptr = Inst->getArgOperand(0);
- Value *Stride = Inst->getArgOperand(1);
- LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
- cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
- {Inst->getArgOperand(3), Inst->getArgOperand(4)});
- }
- /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
- /// MatrixPtr[I][J].
- void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
- MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
- Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
- Value *Offset = Builder.CreateAdd(
- Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
- unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
- Value *EltPtr =
- Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
- Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
- auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
- StoreVal.getNumColumns());
- Type *TilePtrTy = PointerType::get(TileTy, AS);
- Value *TilePtr =
- Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
- storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
- Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
- }
- /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
- /// vectors.
- MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
- MaybeAlign MAlign, Value *Stride, bool IsVolatile,
- IRBuilder<> &Builder) {
- auto VType = cast<VectorType>(Ty);
- Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
- for (auto Vec : enumerate(StoreVal.vectors())) {
- Value *GEP = computeVectorAddr(
- EltPtr,
- Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
- Vec.index()),
- Stride, StoreVal.getStride(), VType->getElementType(), Builder);
- Builder.CreateAlignedStore(Vec.value(), GEP,
- getAlignForIndex(Vec.index(), Stride,
- VType->getElementType(),
- MAlign),
- IsVolatile);
- }
- return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
- StoreVal.getNumVectors());
- }
- /// Lower a store instruction with shape information.
- void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
- Value *Stride, bool IsVolatile, ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- auto StoreVal = getMatrix(Matrix, Shape, Builder);
- finalizeLowering(Inst,
- storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
- IsVolatile, Builder),
- Builder);
- }
- /// Lowers llvm.matrix.column.major.store.
- ///
- /// The intrinsic store a matrix back memory using a stride between columns.
- void LowerColumnMajorStore(CallInst *Inst) {
- assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
- "Intrinsic only supports column-major layout!");
- Value *Matrix = Inst->getArgOperand(0);
- Value *Ptr = Inst->getArgOperand(1);
- Value *Stride = Inst->getArgOperand(2);
- LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
- cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
- {Inst->getArgOperand(4), Inst->getArgOperand(5)});
- }
- // Set elements I..I+NumElts-1 to Block
- Value *insertVector(Value *Col, unsigned I, Value *Block,
- IRBuilder<> &Builder) {
- // First, bring Block to the same size as Col
- unsigned BlockNumElts =
- cast<FixedVectorType>(Block->getType())->getNumElements();
- unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
- assert(NumElts >= BlockNumElts && "Too few elements for current block");
- Block = Builder.CreateShuffleVector(
- Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
- // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
- // 8, 4, 5, 6
- SmallVector<int, 16> Mask;
- unsigned i;
- for (i = 0; i < I; i++)
- Mask.push_back(i);
- unsigned VecNumElts =
- cast<FixedVectorType>(Col->getType())->getNumElements();
- for (; i < I + BlockNumElts; i++)
- Mask.push_back(i - I + VecNumElts);
- for (; i < VecNumElts; i++)
- Mask.push_back(i);
- return Builder.CreateShuffleVector(Col, Block, Mask);
- }
- Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
- IRBuilder<> &Builder, bool AllowContraction,
- unsigned &NumComputeOps) {
- NumComputeOps += getNumOps(A->getType());
- if (!Sum)
- return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
- if (UseFPOp) {
- if (AllowContraction) {
- // Use fmuladd for floating point operations and let the backend decide
- // if that's profitable.
- Function *FMulAdd = Intrinsic::getDeclaration(
- Func.getParent(), Intrinsic::fmuladd, A->getType());
- return Builder.CreateCall(FMulAdd, {A, B, Sum});
- }
- NumComputeOps += getNumOps(A->getType());
- Value *Mul = Builder.CreateFMul(A, B);
- return Builder.CreateFAdd(Sum, Mul);
- }
- NumComputeOps += getNumOps(A->getType());
- Value *Mul = Builder.CreateMul(A, B);
- return Builder.CreateAdd(Sum, Mul);
- }
- /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
- /// users with shape information, there's nothing to do: they will use the
- /// cached value when they are lowered. For other users, \p Matrix is
- /// flattened and the uses are updated to use it. Also marks \p Inst for
- /// deletion.
- void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
- IRBuilder<> &Builder) {
- auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
- (void)inserted;
- assert(inserted.second && "multiple matrix lowering mapping");
- ToRemove.push_back(Inst);
- Value *Flattened = nullptr;
- for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
- if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
- if (!Flattened)
- Flattened = Matrix.embedInVector(Builder);
- U.set(Flattened);
- }
- }
- }
- /// Compute \p Result += \p A * \p B for input matrices with left-associating
- /// addition.
- ///
- /// We can fold a transpose into the operand that is used to extract scalars.
- /// This is the first operands with row-major and the second with
- /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
- /// operand is transposed.
- void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
- const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
- bool IsScalarMatrixTransposed, FastMathFlags FMF) {
- const unsigned VF = std::max<unsigned>(
- TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
- .getFixedValue() /
- Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
- 1U);
- unsigned R = Result.getNumRows();
- unsigned C = Result.getNumColumns();
- unsigned M = A.getNumColumns();
- bool IsFP = Result.getElementType()->isFloatingPointTy();
- assert(A.isColumnMajor() == B.isColumnMajor() &&
- Result.isColumnMajor() == A.isColumnMajor() &&
- "operands must agree on matrix layout");
- unsigned NumComputeOps = 0;
- Builder.setFastMathFlags(FMF);
- if (A.isColumnMajor()) {
- // Multiply columns from the first operand with scalars from the second
- // operand. Then move along the K axes and accumulate the columns. With
- // this the adds can be vectorized without reassociation.
- for (unsigned J = 0; J < C; ++J) {
- unsigned BlockSize = VF;
- // If Result is zero, we don't need to accumulate in the K==0 iteration.
- bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
- for (unsigned I = 0; I < R; I += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (I + BlockSize > R)
- BlockSize /= 2;
- Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
- : nullptr;
- for (unsigned K = 0; K < M; ++K) {
- Value *L = A.extractVector(I, K, BlockSize, Builder);
- Value *RH = Builder.CreateExtractElement(
- B.getColumn(IsScalarMatrixTransposed ? K : J),
- IsScalarMatrixTransposed ? J : K);
- Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
- Sum =
- createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
- IsFP, Builder, FMF.allowContract(), NumComputeOps);
- }
- Result.setVector(J,
- insertVector(Result.getVector(J), I, Sum, Builder));
- }
- }
- } else {
- // Multiply rows from the second operand with scalars from the first
- // operand. Then move along the K axes and accumulate the rows. With this
- // the adds can be vectorized without reassociation.
- for (unsigned I = 0; I < R; ++I) {
- unsigned BlockSize = VF;
- bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
- for (unsigned J = 0; J < C; J += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (J + BlockSize > C)
- BlockSize /= 2;
- Value *Sum = nullptr;
- for (unsigned K = 0; K < M; ++K) {
- Value *R = B.extractVector(K, J, BlockSize, Builder);
- Value *LH = Builder.CreateExtractElement(
- A.getVector(IsScalarMatrixTransposed ? K : I),
- IsScalarMatrixTransposed ? I : K);
- Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
- Sum =
- createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
- IsFP, Builder, FMF.allowContract(), NumComputeOps);
- }
- Result.setVector(I,
- insertVector(Result.getVector(I), J, Sum, Builder));
- }
- }
- }
- Result.addNumComputeOps(NumComputeOps);
- }
- /// Ensure that the memory in \p Load does not alias \p Store by potentially
- /// copying it to a new location. This new or otherwise the original location
- /// is returned.
- Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
- CallInst *MatMul) {
- MemoryLocation StoreLoc = MemoryLocation::get(Store);
- MemoryLocation LoadLoc = MemoryLocation::get(Load);
- // If we can statically determine noalias we're good.
- if (AA->isNoAlias(LoadLoc, StoreLoc))
- return Load->getPointerOperand();
- // Create code to check if the memory locations of the Load and Store
- // overlap and if they do, copy Load's operand to a new buffer.
- // First, create new blocks for 2n part of the check and the copy.
- BasicBlock *Check0 = MatMul->getParent();
- // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
- // DT. Manually collect dominator tree updates, to avoid unnecessary work,
- // as we adjust Check0 and Check1's branches.
- SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
- for (BasicBlock *Succ : successors(Check0))
- DTUpdates.push_back({DT->Delete, Check0, Succ});
- BasicBlock *Check1 =
- SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
- nullptr, "alias_cont");
- BasicBlock *Copy =
- SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
- nullptr, "copy");
- BasicBlock *Fusion =
- SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
- nullptr, "no_alias");
- // Check if the loaded memory location begins before the end of the store
- // location. If the condition holds, they might overlap, otherwise they are
- // guaranteed to not overlap.
- IRBuilder<> Builder(MatMul);
- Check0->getTerminator()->eraseFromParent();
- Builder.SetInsertPoint(Check0);
- Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
- Value *StoreBegin = Builder.CreatePtrToInt(
- const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
- Value *StoreEnd = Builder.CreateAdd(
- StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
- "store.end", true, true);
- Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
- IntPtrTy, "load.begin");
- Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
- Fusion);
- // Check if the store begins before the end of the load location. If the
- // condition holds, they alias, otherwise they are guaranteed to not
- // overlap.
- Check1->getTerminator()->eraseFromParent();
- Builder.SetInsertPoint(Check1, Check1->begin());
- Value *LoadEnd = Builder.CreateAdd(
- LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
- "load.end", true, true);
- Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
- Fusion);
- // Copy load operand to new alloca.
- Builder.SetInsertPoint(Copy, Copy->begin());
- auto *VT = cast<FixedVectorType>(Load->getType());
- // Use an array type for the alloca, to avoid potentially huge alignment
- // requirements for large vector types.
- auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
- AllocaInst *Alloca =
- Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
- Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
- Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
- Load->getAlign(), LoadLoc.Size.getValue());
- Builder.SetInsertPoint(Fusion, Fusion->begin());
- PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
- PHI->addIncoming(Load->getPointerOperand(), Check0);
- PHI->addIncoming(Load->getPointerOperand(), Check1);
- PHI->addIncoming(BC, Copy);
- // Adjust DT.
- DTUpdates.push_back({DT->Insert, Check0, Check1});
- DTUpdates.push_back({DT->Insert, Check0, Fusion});
- DTUpdates.push_back({DT->Insert, Check1, Copy});
- DTUpdates.push_back({DT->Insert, Check1, Fusion});
- DT->applyUpdates(DTUpdates);
- return PHI;
- }
- bool isFusionProfitable(CallInst *MatMul) {
- if (ForceFusion)
- return true;
- ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
- ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
- const unsigned R = LShape.NumRows;
- const unsigned C = RShape.NumColumns;
- const unsigned M = LShape.NumColumns;
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
- const unsigned VF = std::max<unsigned>(
- TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
- .getFixedValue() /
- EltType->getPrimitiveSizeInBits().getFixedValue(),
- 1U);
- // Cost model for tiling
- //
- // For tiling to be beneficial, we need reuse either along the R or
- // the C axis. We vectorize along the R axis so that means at least
- // 3 elements.
- // TODO: Also consider cost of copying if operands alias.
- if (R <= VF && C == 1)
- return false;
- // Then we need enough elements to exceed the number of vector
- // registers we have. Note that this is an oversimplification since
- // fusing also takes some extra loads which may exceed the number of
- // reloads necessary.
- unsigned Op0Regs = (R + VF - 1) / VF * M;
- unsigned Op1Regs = (M + VF - 1) / VF * C;
- return Op0Regs + Op1Regs >
- TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));
- }
- MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
- MatrixTy Res;
- auto *ColumType = FixedVectorType::get(EltType, R);
- for (unsigned I = 0; I < C; ++I)
- Res.addVector(ConstantAggregateZero::get(ColumType));
- return Res;
- }
- void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
- Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
- // Create the main tiling loop nest.
- TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
- DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
- Instruction *InsertI = cast<Instruction>(MatMul);
- BasicBlock *Start = InsertI->getParent();
- BasicBlock *End =
- SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
- IRBuilder<> Builder(MatMul);
- BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
- Type *TileVecTy =
- FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
- MatrixTy TileResult;
- // Insert in the inner loop header.
- Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
- // Create PHI nodes for the result columns to accumulate across iterations.
- SmallVector<PHINode *, 4> ColumnPhis;
- for (unsigned I = 0; I < TileSize; I++) {
- auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
- Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
- TI.RowLoop.Header->getSingleSuccessor());
- TileResult.addVector(Phi);
- ColumnPhis.push_back(Phi);
- }
- // Insert in the inner loop body, which computes
- // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
- Builder.SetInsertPoint(InnerBody->getTerminator());
- // Load tiles of the operands.
- MatrixTy A =
- loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
- {TileSize, TileSize}, EltType, Builder);
- MatrixTy B =
- loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
- {TileSize, TileSize}, EltType, Builder);
- emitMatrixMultiply(TileResult, A, B, Builder, true, false,
- getFastMathFlags(MatMul));
- // Store result after the inner loop is done.
- Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
- storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
- Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
- TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
- for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
- ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
- // Force unrolling of a few iterations of the inner loop, to make sure there
- // is enough work per iteration.
- // FIXME: The unroller should make this decision directly instead, but
- // currently the cost-model is not up to the task.
- unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
- addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
- "llvm.loop.unroll.count", InnerLoopUnrollCount);
- }
- void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
- StoreInst *Store,
- SmallPtrSetImpl<Instruction *> &FusedInsts) {
- assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
- "Tiling only supported for column-major matrixes at the moment!");
- if (!isFusionProfitable(MatMul))
- return;
- ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
- ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
- const unsigned R = LShape.NumRows;
- const unsigned C = RShape.NumColumns;
- const unsigned M = LShape.NumColumns;
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
- Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
- Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
- Value *CPtr = Store->getPointerOperand();
- if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
- createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
- else {
- IRBuilder<> Builder(Store);
- for (unsigned J = 0; J < C; J += TileSize)
- for (unsigned I = 0; I < R; I += TileSize) {
- const unsigned TileR = std::min(R - I, unsigned(TileSize));
- const unsigned TileC = std::min(C - J, unsigned(TileSize));
- MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
- for (unsigned K = 0; K < M; K += TileSize) {
- const unsigned TileM = std::min(M - K, unsigned(TileSize));
- MatrixTy A =
- loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
- LShape, Builder.getInt64(I), Builder.getInt64(K),
- {TileR, TileM}, EltType, Builder);
- MatrixTy B =
- loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
- RShape, Builder.getInt64(K), Builder.getInt64(J),
- {TileM, TileC}, EltType, Builder);
- emitMatrixMultiply(Res, A, B, Builder, true, false,
- getFastMathFlags(MatMul));
- }
- storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
- Builder.getInt64(I), Builder.getInt64(J), EltType,
- Builder);
- }
- }
- // Mark eliminated instructions as fused and remove them.
- FusedInsts.insert(Store);
- FusedInsts.insert(MatMul);
- Store->eraseFromParent();
- MatMul->eraseFromParent();
- if (LoadOp0->hasNUses(0)) {
- FusedInsts.insert(LoadOp0);
- LoadOp0->eraseFromParent();
- }
- if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
- FusedInsts.insert(LoadOp1);
- LoadOp1->eraseFromParent();
- }
- }
- /// Try to lower matrix multiply chains by fusing operations.
- ///
- /// Call finalizeLowering on lowered instructions. Instructions that are
- /// completely eliminated by fusion are added to \p FusedInsts.
- void LowerMatrixMultiplyFused(CallInst *MatMul,
- SmallPtrSetImpl<Instruction *> &FusedInsts) {
- if (!FuseMatrix || !DT)
- return;
- assert(AA && LI && "Analyses should be available");
- Value *A = MatMul->getArgOperand(0);
- Value *B = MatMul->getArgOperand(1);
- // We can fold the transpose into the operand that is used to fetch scalars.
- Value *T;
- if (MatrixLayout == MatrixLayoutTy::ColumnMajor
- ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
- : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
- IRBuilder<> Builder(MatMul);
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
- ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
- ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
- const unsigned R = LShape.NumRows;
- const unsigned M = LShape.NumColumns;
- const unsigned C = RShape.NumColumns;
- MatrixTy MA;
- MatrixTy MB;
- Value *Transpose;
- if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
- MA = getMatrix(A, ShapeInfo(R, M), Builder);
- MB = getMatrix(T, ShapeInfo(C, M), Builder);
- Transpose = B;
- } else {
- MA = getMatrix(T, ShapeInfo(R, M), Builder);
- MB = getMatrix(B, ShapeInfo(C, M), Builder);
- Transpose = A;
- }
- // Initialize the output
- MatrixTy Result(R, C, EltType);
- emitMatrixMultiply(Result, MA, MB, Builder, false, true,
- getFastMathFlags(MatMul));
- FusedInsts.insert(MatMul);
- if (Transpose->hasOneUse()) {
- FusedInsts.insert(cast<Instruction>(Transpose));
- ToRemove.push_back(cast<Instruction>(Transpose));
- // TODO: add a fake entry for the folded instruction so that this is
- // included in the expression in the remark.
- Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
- }
- finalizeLowering(MatMul, Result, Builder);
- return;
- }
- if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
- return;
- // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
- // since the single store user will be lowered as part of this.
- auto *LoadOp0 = dyn_cast<LoadInst>(A);
- auto *LoadOp1 = dyn_cast<LoadInst>(B);
- auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
- if (LoadOp0 && LoadOp1 && Store) {
- // The store address must dominate the MatMul instruction, otherwise
- // we create invalid IR.
- SetVector<Value *> WorkList;
- WorkList.insert(Store->getOperand(1));
- SmallVector<Instruction *> ToHoist;
- for (unsigned I = 0; I != WorkList.size(); ++I) {
- Value *Current = WorkList[I];
- auto *CurrI = dyn_cast<Instruction>(Current);
- if (!CurrI)
- continue;
- if (isa<PHINode>(CurrI))
- return;
- if (DT->dominates(CurrI, MatMul))
- continue;
- if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
- return;
- ToHoist.push_back(CurrI);
- WorkList.insert(CurrI->op_begin(), CurrI->op_end());
- }
- sort(ToHoist, [this](Instruction *A, Instruction *B) {
- return DT->dominates(A, B);
- });
- for (Instruction *I : ToHoist)
- I->moveBefore(MatMul);
- emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
- return;
- }
- }
- /// Lowers llvm.matrix.multiply.
- void LowerMultiply(CallInst *MatMul) {
- IRBuilder<> Builder(MatMul);
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
- ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
- ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
- const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
- const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
- assert(Lhs.getElementType() == Rhs.getElementType() &&
- "Matrix multiply argument element types do not match.");
- const unsigned R = LShape.NumRows;
- const unsigned C = RShape.NumColumns;
- assert(LShape.NumColumns == RShape.NumRows);
- // Initialize the output
- MatrixTy Result(R, C, EltType);
- assert(Lhs.getElementType() == Result.getElementType() &&
- "Matrix multiply result element type does not match arguments.");
- emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
- getFastMathFlags(MatMul));
- finalizeLowering(MatMul, Result, Builder);
- }
- /// Lowers llvm.matrix.transpose.
- void LowerTranspose(CallInst *Inst) {
- MatrixTy Result;
- IRBuilder<> Builder(Inst);
- Value *InputVal = Inst->getArgOperand(0);
- VectorType *VectorTy = cast<VectorType>(InputVal->getType());
- ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
- MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
- const unsigned NewNumVecs =
- InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
- const unsigned NewNumElts =
- InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
- for (unsigned I = 0; I < NewNumVecs; ++I) {
- // Build a single result vector. First initialize it.
- Value *ResultVector = PoisonValue::get(
- FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
- // Go through the old elements and insert it into the resulting vector.
- for (auto J : enumerate(InputMatrix.vectors())) {
- Value *Elt = Builder.CreateExtractElement(J.value(), I);
- // Row and column indices are transposed.
- ResultVector =
- Builder.CreateInsertElement(ResultVector, Elt, J.index());
- }
- Result.addVector(ResultVector);
- }
- // TODO: Improve estimate of operations needed for transposes. Currently we
- // just count the insertelement/extractelement instructions, but do not
- // account for later simplifications/combines.
- finalizeLowering(
- Inst,
- Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
- .addNumExposedTransposes(1),
- Builder);
- }
- /// Lower load instructions, if shape information is available.
- bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
- auto I = ShapeMap.find(Inst);
- if (I == ShapeMap.end())
- return false;
- LowerLoad(Inst, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
- return true;
- }
- bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
- IRBuilder<> &Builder) {
- auto I = ShapeMap.find(StoredVal);
- if (I == ShapeMap.end())
- return false;
- LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
- return true;
- }
- /// Lower binary operators, if shape information is available.
- bool VisitBinaryOperator(BinaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- if (I == ShapeMap.end())
- return false;
- Value *Lhs = Inst->getOperand(0);
- Value *Rhs = Inst->getOperand(1);
- IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
- MatrixTy Result;
- MatrixTy A = getMatrix(Lhs, Shape, Builder);
- MatrixTy B = getMatrix(Rhs, Shape, Builder);
- assert(A.isColumnMajor() == B.isColumnMajor() &&
- Result.isColumnMajor() == A.isColumnMajor() &&
- "operands must agree on matrix layout");
- Builder.setFastMathFlags(getFastMathFlags(Inst));
- // Helper to perform binary op on vectors.
- auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
- switch (Inst->getOpcode()) {
- case Instruction::Add:
- return Builder.CreateAdd(LHS, RHS);
- case Instruction::Mul:
- return Builder.CreateMul(LHS, RHS);
- case Instruction::Sub:
- return Builder.CreateSub(LHS, RHS);
- case Instruction::FAdd:
- return Builder.CreateFAdd(LHS, RHS);
- case Instruction::FMul:
- return Builder.CreateFMul(LHS, RHS);
- case Instruction::FSub:
- return Builder.CreateFSub(LHS, RHS);
- default:
- llvm_unreachable("Unsupported binary operator for matrix");
- }
- };
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
- Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
- finalizeLowering(Inst,
- Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors()),
- Builder);
- return true;
- }
- /// Lower unary operators, if shape information is available.
- bool VisitUnaryOperator(UnaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- if (I == ShapeMap.end())
- return false;
- Value *Op = Inst->getOperand(0);
- IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
- MatrixTy Result;
- MatrixTy M = getMatrix(Op, Shape, Builder);
- Builder.setFastMathFlags(getFastMathFlags(Inst));
- // Helper to perform unary op on vectors.
- auto BuildVectorOp = [&Builder, Inst](Value *Op) {
- switch (Inst->getOpcode()) {
- case Instruction::FNeg:
- return Builder.CreateFNeg(Op);
- default:
- llvm_unreachable("Unsupported unary operator for matrix");
- }
- };
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
- Result.addVector(BuildVectorOp(M.getVector(I)));
- finalizeLowering(Inst,
- Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors()),
- Builder);
- return true;
- }
- /// Helper to linearize a matrix expression tree into a string. Currently
- /// matrix expressions are linarized by starting at an expression leaf and
- /// linearizing bottom up.
- struct ExprLinearizer {
- unsigned LengthToBreak = 100;
- std::string Str;
- raw_string_ostream Stream;
- unsigned LineLength = 0;
- const DataLayout &DL;
- /// Mapping from instructions to matrixes. It is used to identify
- /// matrix instructions.
- const MapVector<Value *, MatrixTy> &Inst2Matrix;
- /// Mapping from values to the leaves of all expressions that the value is
- /// part of.
- const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
- /// Set of matrix expressions in the scope of a given DISubprogram.
- const SmallSetVector<Value *, 32> &ExprsInSubprogram;
- /// Leaf node of the expression to linearize.
- Value *Leaf;
- /// Used to keep track of sub-expressions that get reused while linearizing
- /// the expression. Re-used sub-expressions are marked as (reused).
- SmallPtrSet<Value *, 8> ReusedExprs;
- ExprLinearizer(const DataLayout &DL,
- const MapVector<Value *, MatrixTy> &Inst2Matrix,
- const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
- const SmallSetVector<Value *, 32> &ExprsInSubprogram,
- Value *Leaf)
- : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
- ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
- void indent(unsigned N) {
- LineLength += N;
- for (unsigned i = 0; i < N; i++)
- Stream << " ";
- }
- void lineBreak() {
- Stream << "\n";
- LineLength = 0;
- }
- void maybeIndent(unsigned Indent) {
- if (LineLength >= LengthToBreak)
- lineBreak();
- if (LineLength == 0)
- indent(Indent);
- }
- void write(StringRef S) {
- LineLength += S.size();
- Stream << S;
- }
- Value *getUnderlyingObjectThroughLoads(Value *V) {
- if (Value *Ptr = getPointerOperand(V))
- return getUnderlyingObjectThroughLoads(Ptr);
- else if (V->getType()->isPointerTy())
- return getUnderlyingObject(V);
- return V;
- }
- /// Returns true if \p V is a matrix value in the given subprogram.
- bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
- /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
- /// \p SS.
- void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
- auto M = Inst2Matrix.find(V);
- if (M == Inst2Matrix.end())
- SS << "unknown";
- else {
- SS << M->second.getNumRows();
- SS << "x";
- SS << M->second.getNumColumns();
- }
- }
- /// Write the called function name. Handles calls to llvm.matrix.*
- /// specially: we write the name, followed by the dimensions of the input
- /// matrixes, followed by the scalar type name.
- void writeFnName(CallInst *CI) {
- if (!CI->getCalledFunction())
- write("<no called fn>");
- else {
- StringRef Name = CI->getCalledFunction()->getName();
- if (!Name.startswith("llvm.matrix")) {
- write(Name);
- return;
- }
- auto *II = cast<IntrinsicInst>(CI);
- write(Intrinsic::getBaseName(II->getIntrinsicID())
- .drop_front(StringRef("llvm.matrix.").size()));
- write(".");
- std::string Tmp;
- raw_string_ostream SS(Tmp);
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- prettyPrintMatrixType(II->getOperand(0), SS);
- SS << ".";
- prettyPrintMatrixType(II->getOperand(1), SS);
- SS << "." << *II->getType()->getScalarType();
- break;
- case Intrinsic::matrix_transpose:
- prettyPrintMatrixType(II->getOperand(0), SS);
- SS << "." << *II->getType()->getScalarType();
- break;
- case Intrinsic::matrix_column_major_load:
- prettyPrintMatrixType(II, SS);
- SS << "." << *II->getType()->getScalarType();
- break;
- case Intrinsic::matrix_column_major_store:
- prettyPrintMatrixType(II->getOperand(0), SS);
- SS << "." << *II->getOperand(0)->getType()->getScalarType();
- break;
- default:
- llvm_unreachable("Unhandled case");
- }
- SS.flush();
- write(Tmp);
- }
- }
- unsigned getNumShapeArgs(CallInst *CI) const {
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- return 3;
- case Intrinsic::matrix_transpose:
- return 2;
- case Intrinsic::matrix_column_major_load:
- case Intrinsic::matrix_column_major_store:
- return 3;
- default:
- return 0;
- }
- }
- return 0;
- }
- /// Special printing for values: for pointers, we print if they refer to an
- /// (function) external address or a stack address, for other values we
- /// either print the constant or "scalar"/"matrix" for other values.
- void write(Value *V) {
- V = getUnderlyingObjectThroughLoads(V);
- if (V->getType()->isPointerTy()) {
- if (isa<AllocaInst>(V)) {
- Stream << "stack addr";
- LineLength += StringRef("stack addr").size();
- } else {
- Stream << "addr";
- LineLength += StringRef("addr").size();
- }
- if (!V->getName().empty()) {
- Stream << " %" << V->getName() << "";
- LineLength += V->getName().size() + 2;
- }
- return;
- }
- std::string Tmp;
- raw_string_ostream TmpStream(Tmp);
- if (auto *CI = dyn_cast<ConstantInt>(V))
- TmpStream << CI->getValue();
- else if (isa<Constant>(V))
- TmpStream << "constant";
- else {
- if (isMatrix(V))
- TmpStream << "matrix";
- else
- TmpStream << "scalar";
- }
- TmpStream.flush();
- Tmp = std::string(StringRef(Tmp).trim());
- LineLength += Tmp.size();
- Stream << Tmp;
- }
- /// Linearize expression \p Expr starting at an indentation of \p Indent.
- /// Expressions that are re-used multiple times are prefixed with (reused)
- /// at the re-used root instruction.
- void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
- bool ParentShared) {
- auto *I = cast<Instruction>(Expr);
- maybeIndent(Indent);
- SmallVector<Value *, 8> Ops;
- // Is Expr shared with other expression leaves?
- bool ExprShared = false;
- // Deal with shared subtrees. Mark them as shared, if required.
- if (!ParentShared) {
- auto SI = Shared.find(Expr);
- assert(SI != Shared.end() && SI->second.count(Leaf));
- for (Value *S : SI->second) {
- if (S == Leaf)
- continue;
- DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
- write("shared with remark at line " + std::to_string(DL.getLine()) +
- " column " + std::to_string(DL.getCol()) + " (");
- }
- ExprShared = SI->second.size() > 1;
- }
- bool Reused = !ReusedExprs.insert(Expr).second;
- if (Reused && !ParentReused)
- write("(reused) ");
- if (auto *CI = dyn_cast<CallInst>(I)) {
- writeFnName(CI);
- Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
- } else if (isa<BitCastInst>(Expr)) {
- // Special case bitcasts, which are used to materialize matrixes from
- // non-matrix ops.
- write("matrix");
- return;
- } else {
- Ops.append(I->value_op_begin(), I->value_op_end());
- write(std::string(I->getOpcodeName()));
- }
- write(std::string("("));
- unsigned NumOpsToBreak = 1;
- if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
- NumOpsToBreak = 2;
- for (Value *Op : Ops) {
- if (Ops.size() > NumOpsToBreak)
- lineBreak();
- maybeIndent(Indent + 1);
- if (isMatrix(Op))
- linearizeExpr(Op, Indent + 1, Reused, ExprShared);
- else
- write(Op);
- if (Op != Ops.back())
- write(", ");
- }
- write(")");
- }
- const std::string &getResult() {
- Stream.flush();
- return Str;
- }
- };
- /// Generate remarks for matrix operations in a function. To generate remarks
- /// for matrix expressions, the following approach is used:
- /// 1. Use the inlined-at debug information to group matrix operations to the
- /// DISubprograms they are contained in.
- /// 2. Collect leaves of matrix expressions (done in
- /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
- // mapping. Leaves are lowered matrix instructions without other matrix
- // users (like stores) in the current subprogram.
- /// 3. For each leaf, create a remark containing a linearizied version of the
- /// matrix expression. The expression is linearized by a recursive
- /// bottom-up traversal of the matrix operands, starting at a leaf. Note
- /// that multiple leaves can share sub-expressions. Shared subexpressions
- /// are explicitly marked as shared().
- struct RemarkGenerator {
- const MapVector<Value *, MatrixTy> &Inst2Matrix;
- OptimizationRemarkEmitter &ORE;
- Function &Func;
- const DataLayout &DL;
- RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
- OptimizationRemarkEmitter &ORE, Function &Func)
- : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
- DL(Func.getParent()->getDataLayout()) {}
- /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
- /// instructions in Inst2Matrix returning void or without any users in
- /// \p ExprsInSubprogram. Currently that should only include stores.
- SmallVector<Value *, 4>
- getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
- SmallVector<Value *, 4> Leaves;
- for (auto *Expr : ExprsInSubprogram)
- if (Expr->getType()->isVoidTy() ||
- !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
- return ExprsInSubprogram.count(U);
- }))
- Leaves.push_back(Expr);
- return Leaves;
- }
- /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
- /// to all visited expressions in \p Shared. Limit the matrix operations to
- /// the ones in \p ExprsInSubprogram.
- void collectSharedInfo(Value *Leaf, Value *V,
- const SmallSetVector<Value *, 32> &ExprsInSubprogram,
- DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
- if (!ExprsInSubprogram.count(V))
- return;
- auto I = Shared.insert({V, {}});
- I.first->second.insert(Leaf);
- for (Value *Op : cast<Instruction>(V)->operand_values())
- collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
- }
- /// Calculate the number of exclusive and shared op counts for expression
- /// starting at \p V. Expressions used multiple times are counted once.
- /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
- std::pair<OpInfoTy, OpInfoTy>
- sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
- const SmallSetVector<Value *, 32> &ExprsInSubprogram,
- DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
- if (!ExprsInSubprogram.count(Root))
- return {};
- // Already counted this expression. Stop.
- if (!ReusedExprs.insert(Root).second)
- return {};
- OpInfoTy SharedCount;
- OpInfoTy Count;
- auto I = Shared.find(Root);
- auto CM = Inst2Matrix.find(Root);
- if (I->second.size() == 1)
- Count = CM->second.getOpInfo();
- else
- SharedCount = CM->second.getOpInfo();
- for (Value *Op : cast<Instruction>(Root)->operand_values()) {
- auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
- Count += C.first;
- SharedCount += C.second;
- }
- return {Count, SharedCount};
- }
- void emitRemarks() {
- if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
- return;
- // Map matrix operations to their containting subprograms, by traversing
- // the inlinedAt chain. If the function does not have a DISubprogram, we
- // only map them to the containing function.
- MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
- for (const auto &KV : Inst2Matrix) {
- if (Func.getSubprogram()) {
- auto *I = cast<Instruction>(KV.first);
- DILocation *Context = I->getDebugLoc();
- while (Context) {
- auto I =
- Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
- I.first->second.push_back(KV.first);
- Context = DebugLoc(Context).getInlinedAt();
- }
- } else {
- auto I = Subprog2Exprs.insert({nullptr, {}});
- I.first->second.push_back(KV.first);
- }
- }
- for (auto &KV : Subprog2Exprs) {
- SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
- KV.second.end());
- auto Leaves = getExpressionLeaves(ExprsInSubprogram);
- DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
- for (Value *Leaf : Leaves)
- collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
- // Generate remarks for each leaf.
- for (auto *L : Leaves) {
- DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
- DILocation *Context = cast<Instruction>(L)->getDebugLoc();
- while (Context) {
- if (getSubprogram(Context->getScope()) == KV.first) {
- Loc = Context;
- break;
- }
- Context = DebugLoc(Context).getInlinedAt();
- }
- SmallPtrSet<Value *, 8> ReusedExprs;
- OpInfoTy Counts, SharedCounts;
- std::tie(Counts, SharedCounts) =
- sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
- OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
- cast<Instruction>(L)->getParent());
- Rem << "Lowered with ";
- Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
- << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
- << ore::NV("NumComputeOps", Counts.NumComputeOps)
- << " compute ops, "
- << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
- << " exposed transposes";
- if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
- SharedCounts.NumComputeOps > 0) {
- Rem << ",\nadditionally "
- << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
- << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
- << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
- << " compute ops"
- << " are shared with other expressions";
- }
- Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
- ORE.emit(Rem);
- }
- }
- }
- std::string
- linearize(Value *L,
- const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
- const SmallSetVector<Value *, 32> &ExprsInSubprogram,
- const DataLayout &DL) {
- ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
- Lin.linearizeExpr(L, 0, false, false);
- return Lin.getResult();
- }
- };
- };
- } // namespace
- PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
- FunctionAnalysisManager &AM) {
- auto &TTI = AM.getResult<TargetIRAnalysis>(F);
- OptimizationRemarkEmitter *ORE = nullptr;
- AAResults *AA = nullptr;
- DominatorTree *DT = nullptr;
- LoopInfo *LI = nullptr;
- if (!Minimal) {
- ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
- AA = &AM.getResult<AAManager>(F);
- DT = &AM.getResult<DominatorTreeAnalysis>(F);
- LI = &AM.getResult<LoopAnalysis>(F);
- }
- LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
- if (LMT.Visit()) {
- PreservedAnalyses PA;
- if (!Minimal) {
- PA.preserve<LoopAnalysis>();
- PA.preserve<DominatorTreeAnalysis>();
- }
- return PA;
- }
- return PreservedAnalyses::all();
- }
- void LowerMatrixIntrinsicsPass::printPipeline(
- raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
- static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
- OS, MapClassName2PassName);
- OS << "<";
- if (Minimal)
- OS << "minimal";
- OS << ">";
- }
- namespace {
- class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
- public:
- static char ID;
- LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
- initializeLowerMatrixIntrinsicsLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override {
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
- bool C = LMT.Visit();
- return C;
- }
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- }
- };
- } // namespace
- static const char pass_name[] = "Lower the matrix intrinsics";
- char LowerMatrixIntrinsicsLegacyPass::ID = 0;
- INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
- false, false)
- INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
- INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
- INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
- INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
- INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
- false, false)
- Pass *llvm::createLowerMatrixIntrinsicsPass() {
- return new LowerMatrixIntrinsicsLegacyPass();
- }
- namespace {
- /// A lightweight version of the matrix lowering pass that only requires TTI.
- /// Advanced features that require DT, AA or ORE like tiling are disabled. This
- /// is used to lower matrix intrinsics if the main lowering pass is not run, for
- /// example with -O0.
- class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
- public:
- static char ID;
- LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
- initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override {
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
- bool C = LMT.Visit();
- return C;
- }
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.setPreservesCFG();
- }
- };
- } // namespace
- static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
- char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
- INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
- "lower-matrix-intrinsics-minimal", pass_name_minimal,
- false, false)
- INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
- "lower-matrix-intrinsics-minimal", pass_name_minimal, false,
- false)
- Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() {
- return new LowerMatrixIntrinsicsMinimalLegacyPass();
- }
|