MatmulOptimizer.cpp 77 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835
  1. //===- MatmulOptimizer.cpp -----------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "polly/MatmulOptimizer.h"
  9. #include "polly/DependenceInfo.h"
  10. #include "polly/Options.h"
  11. #include "polly/ScheduleTreeTransform.h"
  12. #include "polly/ScopInfo.h"
  13. #include "polly/ScopPass.h"
  14. #include "polly/Simplify.h"
  15. #include "polly/Support/GICHelper.h"
  16. #include "polly/Support/ISLTools.h"
  17. #include "llvm/ADT/ArrayRef.h"
  18. #include "llvm/ADT/DenseSet.h"
  19. #include "llvm/ADT/Sequence.h"
  20. #include "llvm/ADT/SetOperations.h"
  21. #include "llvm/ADT/SmallVector.h"
  22. #include "llvm/ADT/StringRef.h"
  23. #include "llvm/ADT/iterator_range.h"
  24. #include "llvm/Analysis/TargetTransformInfo.h"
  25. #include "llvm/IR/DataLayout.h"
  26. #include "llvm/IR/Function.h"
  27. #include "llvm/IR/Module.h"
  28. #include "llvm/Support/CommandLine.h"
  29. #include "llvm/Support/Debug.h"
  30. #include "llvm/Support/TypeSize.h"
  31. #include "llvm/Support/raw_ostream.h"
  32. #include "isl/ctx.h"
  33. #include "isl/schedule_node.h"
  34. #include "isl/schedule_type.h"
  35. #include "isl/union_map.h"
  36. #include "isl/union_set.h"
  37. #include <algorithm>
  38. #include <cassert>
  39. #include <cmath>
  40. #include <cstdint>
  41. #include <string>
  42. #include <vector>
  43. #define DEBUG_TYPE "polly-opt-isl"
  44. using namespace llvm;
  45. using namespace polly;
  46. namespace llvm {
  47. class Value;
  48. }
  49. static cl::opt<int> LatencyVectorFma(
  50. "polly-target-latency-vector-fma",
  51. cl::desc("The minimal number of cycles between issuing two "
  52. "dependent consecutive vector fused multiply-add "
  53. "instructions."),
  54. cl::Hidden, cl::init(8), cl::cat(PollyCategory));
  55. static cl::opt<int> ThroughputVectorFma(
  56. "polly-target-throughput-vector-fma",
  57. cl::desc("A throughput of the processor floating-point arithmetic units "
  58. "expressed in the number of vector fused multiply-add "
  59. "instructions per clock cycle."),
  60. cl::Hidden, cl::init(1), cl::cat(PollyCategory));
  61. static cl::opt<int> FirstCacheLevelSize(
  62. "polly-target-1st-cache-level-size",
  63. cl::desc("The size of the first cache level specified in bytes."),
  64. cl::Hidden, cl::init(-1), cl::cat(PollyCategory));
  65. static cl::opt<int> FirstCacheLevelDefaultSize(
  66. "polly-target-1st-cache-level-default-size",
  67. cl::desc("The default size of the first cache level specified in bytes"
  68. " (if not enough were provided by the TargetTransformInfo)."),
  69. cl::Hidden, cl::init(32768), cl::cat(PollyCategory));
  70. static cl::opt<int> SecondCacheLevelSize(
  71. "polly-target-2nd-cache-level-size",
  72. cl::desc("The size of the second level specified in bytes."), cl::Hidden,
  73. cl::init(-1), cl::cat(PollyCategory));
  74. static cl::opt<int> SecondCacheLevelDefaultSize(
  75. "polly-target-2nd-cache-level-default-size",
  76. cl::desc("The default size of the second cache level specified in bytes"
  77. " (if not enough were provided by the TargetTransformInfo)."),
  78. cl::Hidden, cl::init(262144), cl::cat(PollyCategory));
  79. // This option, along with --polly-target-2nd-cache-level-associativity,
  80. // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size
  81. // represent the parameters of the target cache, which do not have typical
  82. // values that can be used by default. However, to apply the pattern matching
  83. // optimizations, we use the values of the parameters of Intel Core i7-3820
  84. // SandyBridge in case the parameters are not specified or not provided by the
  85. // TargetTransformInfo.
  86. static cl::opt<int> FirstCacheLevelAssociativity(
  87. "polly-target-1st-cache-level-associativity",
  88. cl::desc("The associativity of the first cache level."), cl::Hidden,
  89. cl::init(-1), cl::cat(PollyCategory));
  90. static cl::opt<int> FirstCacheLevelDefaultAssociativity(
  91. "polly-target-1st-cache-level-default-associativity",
  92. cl::desc("The default associativity of the first cache level"
  93. " (if not enough were provided by the TargetTransformInfo)."),
  94. cl::Hidden, cl::init(8), cl::cat(PollyCategory));
  95. static cl::opt<int> SecondCacheLevelAssociativity(
  96. "polly-target-2nd-cache-level-associativity",
  97. cl::desc("The associativity of the second cache level."), cl::Hidden,
  98. cl::init(-1), cl::cat(PollyCategory));
  99. static cl::opt<int> SecondCacheLevelDefaultAssociativity(
  100. "polly-target-2nd-cache-level-default-associativity",
  101. cl::desc("The default associativity of the second cache level"
  102. " (if not enough were provided by the TargetTransformInfo)."),
  103. cl::Hidden, cl::init(8), cl::cat(PollyCategory));
  104. static cl::opt<int> VectorRegisterBitwidth(
  105. "polly-target-vector-register-bitwidth",
  106. cl::desc("The size in bits of a vector register (if not set, this "
  107. "information is taken from LLVM's target information."),
  108. cl::Hidden, cl::init(-1), cl::cat(PollyCategory));
  109. static cl::opt<int> PollyPatternMatchingNcQuotient(
  110. "polly-pattern-matching-nc-quotient",
  111. cl::desc("Quotient that is obtained by dividing Nc, the parameter of the"
  112. "macro-kernel, by Nr, the parameter of the micro-kernel"),
  113. cl::Hidden, cl::init(256), cl::cat(PollyCategory));
  114. static cl::opt<bool>
  115. PMBasedTCOpts("polly-tc-opt",
  116. cl::desc("Perform optimizations of tensor contractions based "
  117. "on pattern matching"),
  118. cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
  119. static cl::opt<bool>
  120. PMBasedMMMOpts("polly-matmul-opt",
  121. cl::desc("Perform optimizations of matrix multiplications "
  122. "based on pattern matching"),
  123. cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory));
  124. static cl::opt<int> OptComputeOut(
  125. "polly-tc-dependences-computeout",
  126. cl::desc("Bound the dependence analysis by a maximal amount of "
  127. "computational steps (0 means no bound)"),
  128. cl::Hidden, cl::init(500000), cl::ZeroOrMore, cl::cat(PollyCategory));
  129. namespace {
  130. /// Parameters of the micro kernel.
  131. ///
  132. /// Parameters, which determine sizes of rank-1 (i.e., outer product) update
  133. /// used in the optimized matrix multiplication.
  134. struct MicroKernelParamsTy {
  135. int Mr;
  136. int Nr;
  137. };
  138. /// Parameters of the macro kernel.
  139. ///
  140. /// Parameters, which determine sizes of blocks of partitioned matrices
  141. /// used in the optimized matrix multiplication.
  142. struct MacroKernelParamsTy {
  143. int Mc;
  144. int Nc;
  145. int Kc;
  146. };
  147. /// Parameters of the matrix multiplication operands.
  148. ///
  149. /// Parameters, which describe access relations that represent operands of the
  150. /// matrix multiplication.
  151. struct MatMulInfoTy {
  152. MemoryAccess *A = nullptr;
  153. MemoryAccess *B = nullptr;
  154. MemoryAccess *ReadFromC = nullptr;
  155. MemoryAccess *WriteToC = nullptr;
  156. int i = -1;
  157. int j = -1;
  158. int k = -1;
  159. };
  160. /// Parameters of the tensor contraction operands.
  161. ///
  162. /// A general d-dimensional tensor T ∈ R ^ Nu0 x ... x Nud−1 can be defined
  163. /// as the set of scalar elements indexed by the set of indices u0 ... ud,
  164. ///
  165. /// T ≡ {Anu0...nud−1 ∈ R | (u0,...,ud−1) ∈ Nu0 x ... x Nud−1}.
  166. ///
  167. /// Let A, B, and C be dA, dB, and dC-dimensional tensors, respectively.
  168. /// Let the free and the contracted indices of the tensor A be grouped into
  169. /// two bundles I = i0...ir−1 and P = p0...pt−1, respectively. Similarly,
  170. /// the free and the contracted indices of B are grouped into bundles
  171. /// J = j0..js−1 and P and the free indices of C are grouped into
  172. /// bundles I and J.
  173. ///
  174. /// Tensor contraction (TC) of tensors A, B into tensor C can be represented as
  175. /// C(shuffle(I,J))=∑α·A(shuffle(I,P))·B(shuffle(P,J))+β·C(shuffle(I,J)),
  176. /// where ∑ is a summation over all contracted indices of P,
  177. /// α, β ∈ R, Npi is the length of the tensor dimension that corresponds
  178. /// to the index pi, A(shuffle(I, P)), B(shuffle(P, J)), C(shuffle(I, J)) are
  179. /// accesses to tensors A, B, C, respectively,
  180. /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of
  181. /// the enclosed indices.
  182. ///
  183. /// Multiplication of C(shuffle(I,J)) by β can be moved into a different SCoP
  184. /// statement by loop distribution, which is done by the isl scheduler.
  185. // If β is not equal to one, the optimization of TC of Polly requires
  186. /// such a transformation.
  187. ///
  188. /// TCInfoTy contains parameters, which describe access relations that represent
  189. /// operands of the tensor contraction.
  190. struct TCInfoTy {
  191. /// @{
  192. /// Memory accesses that represent reading from tensors, which are operands of
  193. /// the tensor contraction.
  194. MemoryAccess *A = nullptr;
  195. MemoryAccess *B = nullptr;
  196. /// @}
  197. /// @{
  198. /// Memory accesses that represent reading from and writing into the tensor,
  199. /// which contains the result of the tensor contraction.
  200. MemoryAccess *ReadFromC = nullptr;
  201. MemoryAccess *WriteToC = nullptr;
  202. /// @}
  203. /// @{
  204. /// Input dimensions of the schedule space, which represent free
  205. /// indices of tensors.
  206. SmallDenseSet<int> I;
  207. SmallDenseSet<int> J;
  208. /// @}
  209. /// Input dimension of the schedule space, which represents contracted
  210. /// indices of tensors.
  211. SmallDenseSet<int> P;
  212. /// @{
  213. /// Sizes of tensor dimensions for corresponding input dimensions of
  214. /// the schedule space. The size of the tensor dimension can be larger than
  215. /// the size of the corresponding input dimension of the schedule space.
  216. /// This does not correspond to a tensor contraction. However, such a pattern
  217. /// will be optimized by the transformation.
  218. SmallVector<int> DimensionSizes;
  219. SmallVector<int> ADimensions;
  220. SmallVector<int> BDimensions;
  221. SmallVector<int> CDimensions;
  222. /// @}
  223. /// @{
  224. /// Permutations of indices of I, J, and P, which describe operands of
  225. /// the tensor contraction and its result.
  226. SmallVector<int> OrderedI;
  227. SmallVector<int> OrderedJ;
  228. SmallVector<int> OrderedP;
  229. /// @}
  230. };
  231. /// Create an isl::union_set, which describes the option of the form
  232. /// [isolate[] -> unroll[x]].
  233. ///
  234. /// @param Ctx An isl::ctx, which is used to create the isl::union_set.
  235. static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) {
  236. isl::space Space = isl::space(Ctx, 0, 0, 1);
  237. isl::map UnrollIsolatedSetOption = isl::map::universe(Space);
  238. isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr);
  239. isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr);
  240. UnrollIsolatedSetOption =
  241. UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId);
  242. UnrollIsolatedSetOption =
  243. UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId);
  244. return UnrollIsolatedSetOption.wrap();
  245. }
  246. /// Permute the two dimensions of the isl map.
  247. ///
  248. /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
  249. /// have type @p DimType.
  250. ///
  251. /// @param Map The isl map to be modified.
  252. /// @param DimType The type of the dimensions.
  253. /// @param DstPos The first dimension.
  254. /// @param SrcPos The second dimension.
  255. /// @return The modified map.
  256. static isl::map permuteDimensions(isl::map Map, isl::dim DimType,
  257. unsigned DstPos, unsigned SrcPos) {
  258. assert(DstPos < unsignedFromIslSize(Map.dim(DimType)) &&
  259. SrcPos < unsignedFromIslSize(Map.dim(DimType)));
  260. if (DstPos == SrcPos)
  261. return Map;
  262. isl::id DimId;
  263. if (Map.has_tuple_id(DimType))
  264. DimId = Map.get_tuple_id(DimType);
  265. auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in;
  266. isl::id FreeDimId;
  267. if (Map.has_tuple_id(FreeDim))
  268. FreeDimId = Map.get_tuple_id(FreeDim);
  269. auto MaxDim = std::max(DstPos, SrcPos);
  270. auto MinDim = std::min(DstPos, SrcPos);
  271. Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1);
  272. Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1);
  273. Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1);
  274. Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1);
  275. if (!DimId.is_null())
  276. Map = Map.set_tuple_id(DimType, DimId);
  277. if (!FreeDimId.is_null())
  278. Map = Map.set_tuple_id(FreeDim, FreeDimId);
  279. return Map;
  280. }
  281. /// Check the form of the access relation.
  282. ///
  283. /// Check that the access relation @p AccMap has the form M[i][j], where i
  284. /// is a @p FirstPos and j is a @p SecondPos.
  285. ///
  286. /// @param AccMap The access relation to be checked.
  287. /// @param FirstPos The index of the input dimension that is mapped to
  288. /// the first output dimension.
  289. /// @param SecondPos The index of the input dimension that is mapped to the
  290. /// second output dimension.
  291. /// @return True in case @p AccMap has the expected form and false,
  292. /// otherwise.
  293. static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
  294. int &SecondPos) {
  295. isl::space Space = AccMap.get_space();
  296. isl::map Universe = isl::map::universe(Space);
  297. if (unsignedFromIslSize(Space.dim(isl::dim::out)) != 2)
  298. return false;
  299. // MatMul has the form:
  300. // for (i = 0; i < N; i++)
  301. // for (j = 0; j < M; j++)
  302. // for (k = 0; k < P; k++)
  303. // C[i, j] += A[i, k] * B[k, j]
  304. //
  305. // Permutation of three outer loops: 3! = 6 possibilities.
  306. int FirstDims[] = {0, 0, 1, 1, 2, 2};
  307. int SecondDims[] = {1, 2, 2, 0, 0, 1};
  308. for (int i = 0; i < 6; i += 1) {
  309. auto PossibleMatMul =
  310. Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
  311. .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
  312. AccMap = AccMap.intersect_domain(Domain);
  313. PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
  314. // If AccMap spans entire domain (Non-partial write),
  315. // compute FirstPos and SecondPos.
  316. // If AccMap != PossibleMatMul here (the two maps have been gisted at
  317. // this point), it means that the writes are not complete, or in other
  318. // words, it is a Partial write and Partial writes must be rejected.
  319. if (AccMap.is_equal(PossibleMatMul)) {
  320. if (FirstPos != -1 && FirstPos != FirstDims[i])
  321. continue;
  322. FirstPos = FirstDims[i];
  323. if (SecondPos != -1 && SecondPos != SecondDims[i])
  324. continue;
  325. SecondPos = SecondDims[i];
  326. return true;
  327. }
  328. }
  329. return false;
  330. }
  331. /// Does the memory access represent a non-scalar operand of the matrix
  332. /// multiplication.
  333. ///
  334. /// Check that the memory access @p MemAccess is the read access to a non-scalar
  335. /// operand of the matrix multiplication or its result.
  336. ///
  337. /// @param MemAccess The memory access to be checked.
  338. /// @param MMI Parameters of the matrix multiplication operands.
  339. /// @return True in case the memory access represents the read access
  340. /// to a non-scalar operand of the matrix multiplication and
  341. /// false, otherwise.
  342. static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
  343. MatMulInfoTy &MMI) {
  344. if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
  345. return false;
  346. auto AccMap = MemAccess->getLatestAccessRelation();
  347. isl::set StmtDomain = MemAccess->getStatement()->getDomain();
  348. if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
  349. MMI.ReadFromC = MemAccess;
  350. return true;
  351. }
  352. if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
  353. MMI.A = MemAccess;
  354. return true;
  355. }
  356. if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
  357. MMI.B = MemAccess;
  358. return true;
  359. }
  360. return false;
  361. }
  362. /// Check accesses to operands of the matrix multiplication.
  363. ///
  364. /// Check that accesses of the SCoP statement, which corresponds to
  365. /// the partial schedule @p PartialSchedule, are scalar in terms of loops
  366. /// containing the matrix multiplication, in case they do not represent
  367. /// accesses to the non-scalar operands of the matrix multiplication or
  368. /// its result.
  369. ///
  370. /// @param PartialSchedule The partial schedule of the SCoP statement.
  371. /// @param MMI Parameters of the matrix multiplication operands.
  372. /// @return True in case the corresponding SCoP statement
  373. /// represents matrix multiplication and false,
  374. /// otherwise.
  375. static bool containsOnlyMatrMultAcc(isl::map PartialSchedule,
  376. MatMulInfoTy &MMI) {
  377. auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in);
  378. auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user());
  379. unsigned OutDimNum = unsignedFromIslSize(PartialSchedule.range_tuple_dim());
  380. assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest "
  381. "and, consequently, the corresponding scheduling "
  382. "functions have at least three dimensions.");
  383. auto MapI =
  384. permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1);
  385. auto MapJ =
  386. permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1);
  387. auto MapK =
  388. permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1);
  389. auto Accesses = getAccessesInOrder(*Stmt);
  390. for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) {
  391. auto *MemAccessPtr = *MemA;
  392. if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC &&
  393. !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) &&
  394. !(MemAccessPtr->isStrideZero(MapI) &&
  395. MemAccessPtr->isStrideZero(MapJ) && MemAccessPtr->isStrideZero(MapK)))
  396. return false;
  397. }
  398. return true;
  399. }
  400. /// Check for dependencies corresponding to the matrix multiplication.
  401. ///
  402. /// Check that there is only true dependence of the form
  403. /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement
  404. /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds
  405. /// to the dependency produced by the matrix multiplication.
  406. ///
  407. /// @param Schedule The schedule of the SCoP statement.
  408. /// @param D The SCoP dependencies.
  409. /// @param Pos The parameter to describe an acceptable true dependence.
  410. /// In case it has a negative value, try to determine its
  411. /// acceptable value.
  412. /// @return True in case dependencies correspond to the matrix multiplication
  413. /// and false, otherwise.
  414. static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D,
  415. int &Pos) {
  416. isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW);
  417. isl::union_map Red = D->getDependences(Dependences::TYPE_RED);
  418. if (!Red.is_null())
  419. Dep = Dep.unite(Red);
  420. auto DomainSpace = Schedule.get_space().domain();
  421. auto Space = DomainSpace.map_from_domain_and_range(DomainSpace);
  422. auto Deltas = Dep.extract_map(Space).deltas();
  423. int DeltasDimNum = unsignedFromIslSize(Deltas.dim(isl::dim::set));
  424. for (int i = 0; i < DeltasDimNum; i++) {
  425. auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i);
  426. Pos = Pos < 0 && Val.is_one() ? i : Pos;
  427. if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one())))
  428. return false;
  429. }
  430. if (DeltasDimNum == 0 || Pos < 0)
  431. return false;
  432. return true;
  433. }
  434. /// Check if the SCoP statement could probably be optimized with analytical
  435. /// modeling.
  436. ///
  437. /// containsMatrMult tries to determine whether the following conditions
  438. /// are true:
  439. /// 1. The last memory access modeling an array, MA1, represents writing to
  440. /// memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or
  441. /// S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement
  442. /// under consideration.
  443. /// 2. There is only one loop-carried true dependency, and it has the
  444. /// form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no
  445. /// loop-carried or anti dependencies.
  446. /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent
  447. /// reading from memory and have the form S(..., i3, ...) -> M(i1, i3),
  448. /// S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively,
  449. /// and all memory accesses of the SCoP that are different from MA1, MA2,
  450. /// MA3, and MA4 have stride 0, if the innermost loop is exchanged with any
  451. /// of loops i1, i2 and i3.
  452. ///
  453. /// @param PartialSchedule The PartialSchedule that contains a SCoP statement
  454. /// to check.
  455. /// @D The SCoP dependencies.
  456. /// @MMI Parameters of the matrix multiplication operands.
  457. static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
  458. MatMulInfoTy &MMI) {
  459. auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in);
  460. auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
  461. if (Stmt->size() <= 1)
  462. return false;
  463. auto Accesses = getAccessesInOrder(*Stmt);
  464. for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) {
  465. auto *MemAccessPtr = *MemA;
  466. if (!MemAccessPtr->isLatestArrayKind())
  467. continue;
  468. if (!MemAccessPtr->isWrite())
  469. return false;
  470. auto AccMap = MemAccessPtr->getLatestAccessRelation();
  471. if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
  472. return false;
  473. MMI.WriteToC = MemAccessPtr;
  474. break;
  475. }
  476. if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k))
  477. return false;
  478. if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI))
  479. return false;
  480. if (!MMI.A || !MMI.B || !MMI.ReadFromC)
  481. return false;
  482. return true;
  483. }
  484. /// Permute two dimensions of the band node.
  485. ///
  486. /// Permute FirstDim and SecondDim dimensions of the Node.
  487. ///
  488. /// @param Node The band node to be modified.
  489. /// @param FirstDim The first dimension to be permuted.
  490. /// @param SecondDim The second dimension to be permuted.
  491. static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node,
  492. unsigned FirstDim,
  493. unsigned SecondDim) {
  494. assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band &&
  495. (unsigned)isl_schedule_node_band_n_member(Node.get()) >
  496. std::max(FirstDim, SecondDim));
  497. auto PartialSchedule =
  498. isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get()));
  499. auto PartialScheduleFirstDim = PartialSchedule.at(FirstDim);
  500. auto PartialScheduleSecondDim = PartialSchedule.at(SecondDim);
  501. PartialSchedule =
  502. PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim);
  503. PartialSchedule =
  504. PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim);
  505. Node = isl::manage(isl_schedule_node_delete(Node.release()));
  506. return Node.insert_partial_schedule(PartialSchedule);
  507. }
  508. static isl::schedule_node
  509. createMicroKernel(isl::schedule_node Node,
  510. MicroKernelParamsTy MicroKernelParams) {
  511. Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr},
  512. 1);
  513. Node = Node.parent().parent();
  514. return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0);
  515. }
  516. /// Create the BLIS macro-kernel.
  517. ///
  518. /// We create the BLIS macro-kernel by applying a combination of tiling
  519. /// of dimensions of the band node and interchanging of two innermost
  520. /// modified dimensions. The values of of MacroKernelParams's fields are used
  521. /// as tile sizes.
  522. ///
  523. /// @param Node The schedule node to be modified.
  524. /// @param MacroKernelParams Parameters of the macro kernel
  525. /// to be used as tile sizes.
  526. static isl::schedule_node
  527. createMacroKernel(isl::schedule_node Node,
  528. MacroKernelParamsTy MacroKernelParams) {
  529. assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
  530. if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 &&
  531. MacroKernelParams.Kc == 1)
  532. return Node;
  533. int DimOutNum = isl_schedule_node_band_n_member(Node.get());
  534. std::vector<int> TileSizes(DimOutNum, 1);
  535. TileSizes[DimOutNum - 3] = MacroKernelParams.Mc;
  536. TileSizes[DimOutNum - 2] = MacroKernelParams.Nc;
  537. TileSizes[DimOutNum - 1] = MacroKernelParams.Kc;
  538. Node = tileNode(Node, "1st level tiling", TileSizes, 1);
  539. Node = Node.parent().parent();
  540. Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1);
  541. Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1);
  542. return Node.child(0).child(0);
  543. }
  544. /// Get the size of the widest type of the matrix multiplication operands
  545. /// in bytes, including alignment padding.
  546. ///
  547. /// @param MMI Parameters of the matrix multiplication operands.
  548. /// @return The size of the widest type of the matrix multiplication operands
  549. /// in bytes, including alignment padding.
  550. static uint64_t getMatMulAlignTypeSize(MatMulInfoTy MMI) {
  551. auto *S = MMI.A->getStatement()->getParent();
  552. auto &DL = S->getFunction().getParent()->getDataLayout();
  553. auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType());
  554. auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType());
  555. auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType());
  556. return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
  557. }
  558. /// Get the size of the widest type of the matrix multiplication operands
  559. /// in bits.
  560. ///
  561. /// @param MMI Parameters of the matrix multiplication operands.
  562. /// @return The size of the widest type of the matrix multiplication operands
  563. /// in bits.
  564. static uint64_t getMatMulTypeSize(MatMulInfoTy MMI) {
  565. auto *S = MMI.A->getStatement()->getParent();
  566. auto &DL = S->getFunction().getParent()->getDataLayout();
  567. auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType());
  568. auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType());
  569. auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType());
  570. return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
  571. }
  572. /// Get parameters of the BLIS micro kernel.
  573. ///
  574. /// We choose the Mr and Nr parameters of the micro kernel to be large enough
  575. /// such that no stalls caused by the combination of latencies and dependencies
  576. /// are introduced during the updates of the resulting matrix of the matrix
  577. /// multiplication. However, they should also be as small as possible to
  578. /// release more registers for entries of multiplied matrices.
  579. ///
  580. /// @param TTI Target Transform Info.
  581. /// @param MMI Parameters of the matrix multiplication operands.
  582. /// @return The structure of type MicroKernelParamsTy.
  583. /// @see MicroKernelParamsTy
  584. static MicroKernelParamsTy getMicroKernelParams(const TargetTransformInfo *TTI,
  585. MatMulInfoTy MMI) {
  586. assert(TTI && "The target transform info should be provided.");
  587. // Nvec - Number of double-precision floating-point numbers that can be hold
  588. // by a vector register. Use 2 by default.
  589. long RegisterBitwidth = VectorRegisterBitwidth;
  590. if (RegisterBitwidth == -1)
  591. RegisterBitwidth =
  592. TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
  593. auto ElementSize = getMatMulTypeSize(MMI);
  594. assert(ElementSize > 0 && "The element size of the matrix multiplication "
  595. "operands should be greater than zero.");
  596. auto Nvec = RegisterBitwidth / ElementSize;
  597. if (Nvec == 0)
  598. Nvec = 2;
  599. int Nr = ceil(sqrt((double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) /
  600. Nvec) *
  601. Nvec;
  602. int Mr = ceil((double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr));
  603. return {Mr, Nr};
  604. }
  605. /// Determine parameters of the target cache.
  606. ///
  607. /// @param TTI Target Transform Info.
  608. static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) {
  609. auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D;
  610. auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D;
  611. if (FirstCacheLevelSize == -1) {
  612. if (TTI->getCacheSize(L1DCache))
  613. FirstCacheLevelSize = TTI->getCacheSize(L1DCache).value();
  614. else
  615. FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize);
  616. }
  617. if (SecondCacheLevelSize == -1) {
  618. if (TTI->getCacheSize(L2DCache))
  619. SecondCacheLevelSize = TTI->getCacheSize(L2DCache).value();
  620. else
  621. SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize);
  622. }
  623. if (FirstCacheLevelAssociativity == -1) {
  624. if (TTI->getCacheAssociativity(L1DCache))
  625. FirstCacheLevelAssociativity =
  626. TTI->getCacheAssociativity(L1DCache).value();
  627. else
  628. FirstCacheLevelAssociativity =
  629. static_cast<int>(FirstCacheLevelDefaultAssociativity);
  630. }
  631. if (SecondCacheLevelAssociativity == -1) {
  632. if (TTI->getCacheAssociativity(L2DCache))
  633. SecondCacheLevelAssociativity =
  634. TTI->getCacheAssociativity(L2DCache).value();
  635. else
  636. SecondCacheLevelAssociativity =
  637. static_cast<int>(SecondCacheLevelDefaultAssociativity);
  638. }
  639. }
  640. /// Get parameters of the BLIS macro kernel.
  641. ///
  642. /// During the computation of matrix multiplication, blocks of partitioned
  643. /// matrices are mapped to different layers of the memory hierarchy.
  644. /// To optimize data reuse, blocks should be ideally kept in cache between
  645. /// iterations. Since parameters of the macro kernel determine sizes of these
  646. /// blocks, there are upper and lower bounds on these parameters.
  647. ///
  648. /// @param TTI Target Transform Info.
  649. /// @param MicroKernelParams Parameters of the micro-kernel
  650. /// to be taken into account.
  651. /// @param MMI Parameters of the matrix multiplication operands.
  652. /// @return The structure of type MacroKernelParamsTy.
  653. /// @see MacroKernelParamsTy
  654. /// @see MicroKernelParamsTy
  655. static MacroKernelParamsTy
  656. getMacroKernelParams(const llvm::TargetTransformInfo *TTI,
  657. const MicroKernelParamsTy &MicroKernelParams,
  658. MatMulInfoTy MMI) {
  659. getTargetCacheParameters(TTI);
  660. // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf,
  661. // it requires information about the first two levels of a cache to determine
  662. // all the parameters of a macro-kernel. It also checks that an associativity
  663. // degree of a cache level is greater than two. Otherwise, another algorithm
  664. // for determination of the parameters should be used.
  665. if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 &&
  666. FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 &&
  667. FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2))
  668. return {1, 1, 1};
  669. // The quotient should be greater than zero.
  670. if (PollyPatternMatchingNcQuotient <= 0)
  671. return {1, 1, 1};
  672. int Car = floor(
  673. (FirstCacheLevelAssociativity - 1) /
  674. (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr));
  675. // Car can be computed to be zero since it is floor to int.
  676. // On Mac OS, division by 0 does not raise a signal. This causes negative
  677. // tile sizes to be computed. Prevent division by Cac==0 by early returning
  678. // if this happens.
  679. if (Car == 0)
  680. return {1, 1, 1};
  681. auto ElementSize = getMatMulAlignTypeSize(MMI);
  682. assert(ElementSize > 0 && "The element size of the matrix multiplication "
  683. "operands should be greater than zero.");
  684. int Kc = (Car * FirstCacheLevelSize) /
  685. (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize);
  686. double Cac =
  687. static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) /
  688. SecondCacheLevelSize;
  689. int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac);
  690. int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr;
  691. assert(Mc > 0 && Nc > 0 && Kc > 0 &&
  692. "Matrix block sizes should be greater than zero");
  693. return {Mc, Nc, Kc};
  694. }
  695. /// Create an access relation that is specific to
  696. /// the matrix multiplication pattern.
  697. ///
  698. /// Create an access relation of the following form:
  699. /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ]
  700. /// where I is @p FirstDim, J is @p SecondDim.
  701. ///
  702. /// It can be used, for example, to create relations that helps to consequently
  703. /// access elements of operands of a matrix multiplication after creation of
  704. /// the BLIS micro and macro kernels.
  705. ///
  706. /// @see ScheduleTreeOptimizer::createMicroKernel
  707. /// @see ScheduleTreeOptimizer::createMacroKernel
  708. ///
  709. /// Subsequently, the described access relation is applied to the range of
  710. /// @p MapOldIndVar, that is used to map original induction variables to
  711. /// the ones, which are produced by schedule transformations. It helps to
  712. /// define relations using a new space and, at the same time, keep them
  713. /// in the original one.
  714. ///
  715. /// @param MapOldIndVar The relation, which maps original induction variables
  716. /// to the ones, which are produced by schedule
  717. /// transformations.
  718. /// @param FirstDim, SecondDim The input dimensions that are used to define
  719. /// the specified access relation.
  720. /// @return The specified access relation.
  721. static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim,
  722. unsigned SecondDim) {
  723. auto AccessRelSpace = isl::space(MapOldIndVar.ctx(), 0, 9, 3);
  724. auto AccessRel = isl::map::universe(AccessRelSpace);
  725. AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0);
  726. AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1);
  727. AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2);
  728. return MapOldIndVar.apply_range(AccessRel);
  729. }
  730. static isl::schedule_node createExtensionNode(isl::schedule_node Node,
  731. isl::map ExtensionMap) {
  732. auto Extension = isl::union_map(ExtensionMap);
  733. auto NewNode = isl::schedule_node::from_extension(Extension);
  734. return Node.graft_before(NewNode);
  735. }
  736. static isl::schedule_node optimizePackedB(isl::schedule_node Node,
  737. ScopStmt *Stmt, isl::map MapOldIndVar,
  738. MicroKernelParamsTy MicroParams,
  739. MacroKernelParamsTy MacroParams,
  740. MatMulInfoTy &MMI) {
  741. Scop *S = Stmt->getParent();
  742. isl::set Domain = Stmt->getDomain();
  743. // Create packed array.
  744. unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr;
  745. unsigned SecondDimSize = MacroParams.Kc;
  746. unsigned ThirdDimSize = MicroParams.Nr;
  747. ScopArrayInfo *PackedB =
  748. S->createScopArrayInfo(MMI.B->getElementType(), "Packed_B",
  749. {FirstDimSize, SecondDimSize, ThirdDimSize});
  750. // Compute the access relation for copying from B to PackedB.
  751. isl::map AccRelB = MMI.B->getLatestAccessRelation();
  752. isl::map AccRelPackedB = getMatMulAccRel(MapOldIndVar, 3, 7);
  753. AccRelPackedB =
  754. AccRelPackedB.set_tuple_id(isl::dim::out, PackedB->getBasePtrId());
  755. // Create the copy statement and redirect access.
  756. ScopStmt *CopyStmt = S->addScopStmt(AccRelB, AccRelPackedB, Domain);
  757. MMI.B->setNewAccessRelation(AccRelPackedB);
  758. unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim());
  759. assert(Dim >= 2);
  760. // Insert into the schedule tree.
  761. isl::map ExtMap = MapOldIndVar.project_out(isl::dim::out, 2, Dim - 2);
  762. ExtMap = ExtMap.reverse();
  763. ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0);
  764. ExtMap = ExtMap.intersect_range(Domain);
  765. ExtMap = ExtMap.set_tuple_id(isl::dim::out, CopyStmt->getDomainId());
  766. return createExtensionNode(Node, ExtMap);
  767. }
  768. static isl::schedule_node optimizePackedA(isl::schedule_node Node, ScopStmt *,
  769. isl::map MapOldIndVar,
  770. MicroKernelParamsTy MicroParams,
  771. MacroKernelParamsTy MacroParams,
  772. MatMulInfoTy &MMI) {
  773. isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
  774. ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
  775. isl::set Domain = Stmt->getDomain();
  776. isl::id DomainId = Domain.get_tuple_id();
  777. // Create the packed array.
  778. unsigned FirstDimSize = MacroParams.Mc / MicroParams.Mr;
  779. unsigned SecondDimSize = MacroParams.Kc;
  780. unsigned ThirdDimSize = MicroParams.Mr;
  781. ScopArrayInfo *PackedA = Stmt->getParent()->createScopArrayInfo(
  782. MMI.A->getElementType(), "Packed_A",
  783. {FirstDimSize, SecondDimSize, ThirdDimSize});
  784. // Compute the access relation for copying from A to PackedA.
  785. isl::map AccRelA = MMI.A->getLatestAccessRelation();
  786. isl::map AccRelPackedA = getMatMulAccRel(MapOldIndVar, 4, 6);
  787. AccRelPackedA =
  788. AccRelPackedA.set_tuple_id(isl::dim::out, PackedA->getBasePtrId());
  789. // { MemrefA[] -> PackedA[] }
  790. isl::map PackedATranslator = AccRelPackedA.apply_domain(AccRelA);
  791. // Compute the domain for the copy statement.
  792. // Construct the copy statement domain out of the 3 outermost scatter
  793. // dimensions (to match the 3 band nodes surrounding the extension node) and
  794. // the array elements to copy (one statement instance per array element).
  795. // { Scatter[] }
  796. isl::set ScatterDomain = MapOldIndVar.intersect_domain(Domain).range();
  797. // { Scatter[] -> OutermostScatter[] }
  798. isl::map OuterDomainMap =
  799. makeIdentityMap(ScatterDomain, true).project_out(isl::dim::out, 3, 6);
  800. // { Scatter[] -> MemrefA[] }
  801. isl::map CopyFrom = MapOldIndVar.reverse().apply_range(AccRelA);
  802. // { Scatter[] -> CopyStmt[] }
  803. isl::map DomainTranslator = OuterDomainMap.range_product(CopyFrom);
  804. // { CopyStmt[] }
  805. isl::set CopyDomain = DomainTranslator.range();
  806. // Translate the access relations to the new domain.
  807. // { CopyStmt[] -> MemrefA[] }
  808. CopyFrom = CopyFrom.apply_domain(DomainTranslator);
  809. // { CopyStmt[] -> PackedA[] }
  810. isl::map CopyTo = CopyFrom.apply_range(PackedATranslator);
  811. // Create the copy statement and redirect access.
  812. ScopStmt *CopyStmt =
  813. Stmt->getParent()->addScopStmt(CopyFrom, CopyTo, CopyDomain);
  814. MMI.A->setNewAccessRelation(AccRelPackedA);
  815. // Insert into the schedule tree.
  816. // { Scatter[] -> CopyStmt[] }
  817. isl::map ExtScatterCopy = makeIdentityMap(CopyStmt->getDomain(), true);
  818. ExtScatterCopy = ExtScatterCopy.project_out(isl::dim::in, 3, 2);
  819. return createExtensionNode(Node, ExtScatterCopy);
  820. }
  821. /// Apply the packing transformation.
  822. ///
  823. /// The packing transformation can be described as a data-layout
  824. /// transformation that requires to introduce a new array, copy data
  825. /// to the array, and change memory access locations to reference the array.
  826. /// It can be used to ensure that elements of the new array are read in-stride
  827. /// access, aligned to cache lines boundaries, and preloaded into certain cache
  828. /// levels.
  829. ///
  830. /// As an example let us consider the packing of the array A that would help
  831. /// to read its elements with in-stride access. An access to the array A
  832. /// is represented by an access relation that has the form
  833. /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has
  834. /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr),
  835. /// k mod Kc, j mod Nr, i mod Mr].
  836. ///
  837. /// To ensure that elements of the array A are read in-stride access, we add
  838. /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using
  839. /// Scop::createScopArrayInfo, change the access relation
  840. /// S[i, j, k] -> A[i, k] to
  841. /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using
  842. /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using
  843. /// the copy statement created by Scop::addScopStmt.
  844. ///
  845. /// @param Node The schedule node to be optimized.
  846. /// @param MapOldIndVar The relation, which maps original induction variables
  847. /// to the ones, which are produced by schedule
  848. /// transformations.
  849. /// @param MicroParams, MacroParams Parameters of the BLIS kernel
  850. /// to be taken into account.
  851. /// @param MMI Parameters of the matrix multiplication operands.
  852. /// @return The optimized schedule node.
  853. static isl::schedule_node
  854. optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar,
  855. MicroKernelParamsTy MicroParams,
  856. MacroKernelParamsTy MacroParams,
  857. MatMulInfoTy &MMI) {
  858. isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
  859. ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
  860. Node = Node.parent().parent().parent().parent().parent().parent();
  861. Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2));
  862. Node = Node.child(0);
  863. Node =
  864. optimizePackedB(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI);
  865. Node = Node.child(0);
  866. Node =
  867. optimizePackedA(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI);
  868. return Node.child(0).child(0).child(0).child(0).child(0);
  869. }
  870. /// Get a relation mapping induction variables produced by schedule
  871. /// transformations to the original ones.
  872. ///
  873. /// @param Node The schedule node produced as the result of creation
  874. /// of the BLIS kernels.
  875. /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel
  876. /// to be taken into account.
  877. /// @return The relation mapping original induction variables to the ones
  878. /// produced by schedule transformation.
  879. /// @see ScheduleTreeOptimizer::createMicroKernel
  880. /// @see ScheduleTreeOptimizer::createMacroKernel
  881. /// @see getMacroKernelParams
  882. static isl::map
  883. getInductionVariablesSubstitution(isl::schedule_node Node,
  884. MicroKernelParamsTy MicroKernelParams,
  885. MacroKernelParamsTy MacroKernelParams) {
  886. auto Child = Node.child(0);
  887. auto UnMapOldIndVar = Child.get_prefix_schedule_union_map();
  888. auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar);
  889. unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim());
  890. if (Dim > 9u)
  891. return MapOldIndVar.project_out(isl::dim::out, 0, Dim - 9);
  892. return MapOldIndVar;
  893. }
  894. /// Isolate a set of partial tile prefixes and unroll the isolated part.
  895. ///
  896. /// The set should ensure that it contains only partial tile prefixes that have
  897. /// exactly Mr x Nr iterations of the two innermost loops produced by
  898. /// the optimization of the matrix multiplication. Mr and Nr are parameters of
  899. /// the micro-kernel.
  900. ///
  901. /// In case of parametric bounds, this helps to auto-vectorize the unrolled
  902. /// innermost loops, using the SLP vectorizer.
  903. ///
  904. /// @param Node The schedule node to be modified.
  905. /// @param MicroKernelParams Parameters of the micro-kernel
  906. /// to be taken into account.
  907. /// @return The modified isl_schedule_node.
  908. static isl::schedule_node
  909. isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node,
  910. MicroKernelParamsTy MicroKernelParams) {
  911. isl::schedule_node Child = Node.child(0);
  912. isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation();
  913. isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range();
  914. unsigned Dims = unsignedFromIslSize(Prefix.tuple_dim());
  915. assert(Dims >= 1);
  916. Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1);
  917. Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr);
  918. Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr);
  919. isl::union_set IsolateOption =
  920. getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3);
  921. isl::ctx Ctx = Node.ctx();
  922. auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll"));
  923. Options = Options.unite(getUnrollIsolatedSetOptions(Ctx));
  924. Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options);
  925. Node = Node.parent().parent().parent();
  926. IsolateOption = getIsolateOptions(Prefix, 3);
  927. Options = IsolateOption.unite(getDimOptions(Ctx, "separate"));
  928. Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options);
  929. Node = Node.child(0).child(0).child(0);
  930. return Node;
  931. }
  932. /// Insert "Loop Vectorizer Disabled" mark node.
  933. ///
  934. /// @param Node The child of the mark node to be inserted.
  935. /// @return The modified isl_schedule_node.
  936. static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) {
  937. auto Id = isl::id::alloc(Node.ctx(), "Loop Vectorizer Disabled", nullptr);
  938. return Node.insert_mark(Id).child(0);
  939. }
  940. /// Restore the initial ordering of dimensions of the band node
  941. ///
  942. /// In case the band node represents all the dimensions of the iteration
  943. /// domain, recreate the band node to restore the initial ordering of the
  944. /// dimensions.
  945. ///
  946. /// @param Node The band node to be modified.
  947. /// @return The modified schedule node.
  948. static isl::schedule_node
  949. getBandNodeWithOriginDimOrder(isl::schedule_node Node) {
  950. assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
  951. if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf)
  952. return Node;
  953. auto Domain = Node.get_universe_domain();
  954. assert(isl_union_set_n_set(Domain.get()) == 1);
  955. if (Node.get_schedule_depth().release() != 0 ||
  956. (unsignedFromIslSize(isl::set(Domain).tuple_dim()) !=
  957. unsignedFromIslSize(Node.as<isl::schedule_node_band>().n_member())))
  958. return Node;
  959. Node = isl::manage(isl_schedule_node_delete(Node.copy()));
  960. auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff();
  961. auto PartialScheduleMultiPwAff =
  962. isl::multi_union_pw_aff(PartialSchedulePwAff);
  963. PartialScheduleMultiPwAff =
  964. PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set);
  965. return Node.insert_partial_schedule(PartialScheduleMultiPwAff);
  966. }
  967. static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node,
  968. const TargetTransformInfo *TTI,
  969. MatMulInfoTy &MMI) {
  970. assert(TTI && "The target transform info should be provided.");
  971. int DimOutNum = isl_schedule_node_band_n_member(Node.get());
  972. assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest "
  973. "and, consequently, the corresponding scheduling "
  974. "functions have at least three dimensions.");
  975. Node = getBandNodeWithOriginDimOrder(Node);
  976. Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3);
  977. int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j;
  978. int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k;
  979. Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2);
  980. NewK = NewK == DimOutNum - 2 ? NewJ : NewK;
  981. Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1);
  982. auto MicroKernelParams = getMicroKernelParams(TTI, MMI);
  983. auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI);
  984. Node = createMacroKernel(Node, MacroKernelParams);
  985. Node = createMicroKernel(Node, MicroKernelParams);
  986. if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 ||
  987. MacroKernelParams.Kc == 1)
  988. return Node;
  989. auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams,
  990. MacroKernelParams);
  991. if (MapOldIndVar.is_null())
  992. return Node;
  993. Node = markLoopVectorizerDisabled(Node.parent()).child(0);
  994. Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams);
  995. return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
  996. MacroKernelParams, MMI);
  997. }
  998. /// Check if this node contains a partial schedule that could
  999. /// probably be optimized with analytical modeling.
  1000. ///
  1001. /// isMatrMultPattern tries to determine whether the following conditions
  1002. /// are true:
  1003. /// 1. the partial schedule contains only one statement.
  1004. /// 2. there are exactly three input dimensions.
  1005. /// 3. all memory accesses of the statement will have stride 0 or 1, if we
  1006. /// interchange loops (switch the variable used in the inner loop to
  1007. /// the outer loop).
  1008. /// 4. all memory accesses of the statement except from the last one, are
  1009. /// read memory access and the last one is write memory access.
  1010. /// 5. all subscripts of the last memory access of the statement don't
  1011. /// contain the variable used in the inner loop.
  1012. /// If this is the case, we could try to use an approach that is similar to
  1013. /// the one used to get close-to-peak performance of matrix multiplications.
  1014. ///
  1015. /// @param Node The node to check.
  1016. /// @param D The SCoP dependencies.
  1017. /// @param MMI Parameters of the matrix multiplication operands.
  1018. static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D,
  1019. MatMulInfoTy &MMI) {
  1020. auto PartialSchedule = isl::manage(
  1021. isl_schedule_node_band_get_partial_schedule_union_map(Node.get()));
  1022. if (isl_schedule_node_band_n_member(Node.get()) < 3 ||
  1023. Node.get_schedule_depth().release() != 0 ||
  1024. isl_union_map_n_map(PartialSchedule.get()) != 1)
  1025. return false;
  1026. auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule);
  1027. if (containsMatrMult(NewPartialSchedule, D, MMI))
  1028. return true;
  1029. return false;
  1030. }
  1031. /// Get the dimension size.
  1032. ///
  1033. /// Return the size of the dimension @p Pos, which is obtained from @p SAI.
  1034. /// Return -1 in the case of the first dimension of a multi-dimensional array,
  1035. /// since the ScopArrayInfo class does not carry size information.
  1036. ///
  1037. /// @param SAI The information about the array.
  1038. /// @param Pos The position of the dimension.
  1039. /// @return The size of the dimension.
  1040. static int getDimSize(const ScopArrayInfo *SAI, unsigned Pos) {
  1041. if (Pos == 0)
  1042. return -1;
  1043. const llvm::SCEV *SCEVDimSize = SAI->getDimensionSize(Pos);
  1044. assert(SCEVDimSize);
  1045. auto *ConstantDimSize = dyn_cast<const SCEVConstant>(SCEVDimSize);
  1046. assert(ConstantDimSize);
  1047. auto *IntDimSize = dyn_cast<ConstantInt>(ConstantDimSize->getValue());
  1048. assert(IntDimSize);
  1049. return IntDimSize->getSExtValue();
  1050. }
  1051. /// Check whether the access relation has the specified form.
  1052. ///
  1053. /// Check that the access relation @p AccMap has the form T[I0, …, In], where
  1054. /// indexes I0, …, In are specified by @p Dimensions.
  1055. ///
  1056. /// @param Domain The domain of the access relation.
  1057. /// @param AccMap The access relation to be checked.
  1058. /// @param Dimensions The permutation of the subset of the input dimensions.
  1059. /// @return True if @p AccMap has the expected form and false,
  1060. /// otherwise.
  1061. static bool isCorrectAccessMap(isl::set Domain, isl::map AccMap,
  1062. ArrayRef<int> Dimensions) {
  1063. isl::space Space = AccMap.get_space();
  1064. if (unsignedFromIslSize(Space.dim(isl::dim::out)) != Dimensions.size())
  1065. return false;
  1066. // Create an access relation of the following form:
  1067. // [I0, …, Im] -> [Il, …, In], where indexes
  1068. // Il, …, In are specified by @p Dimensions.
  1069. isl::map PossibleTensor = isl::map::universe(Space);
  1070. unsigned DimInSize = unsignedFromIslSize(Space.dim(isl::dim::in));
  1071. for (unsigned i = 0; i < Dimensions.size(); i++) {
  1072. const int InPos = Dimensions[i];
  1073. if ((InPos >= static_cast<int>(DimInSize)) || (InPos < 0))
  1074. return false;
  1075. PossibleTensor =
  1076. PossibleTensor.equate(isl::dim::in, InPos, isl::dim::out, i);
  1077. }
  1078. AccMap = AccMap.intersect_domain(Domain);
  1079. PossibleTensor = PossibleTensor.intersect_domain(Domain);
  1080. // If AccMap != PossibleTensor here (the two maps have been gisted at
  1081. // this point), it means that the writes are not complete, or in other
  1082. // words, it is a Partial write and Partial writes must be rejected.
  1083. return AccMap.is_equal(PossibleTensor);
  1084. }
  1085. /// Check whether the access represents the tensor contraction operand.
  1086. ///
  1087. /// Check that the access relation @p AccMap has the form T[i1, …, in].
  1088. /// Obtained indexes i1, …, in, their sizes and their permutation are stored
  1089. /// into @p IndexSet, @p DimensionSizes, and @p Dimensions, respectively.
  1090. ///
  1091. /// @param Domain The domain of the access relation.
  1092. /// @param AccMap The access relation to be checked.
  1093. /// @param IndexSet The subset of the input dimensions.
  1094. /// @param DimensionSizes Sizes of the input dimensions of @p Dimensions.
  1095. /// @param Dimensions The permutation of the subset of the input dimensions.
  1096. /// @return True if @p AccMap has the expected form and false,
  1097. /// otherwise.
  1098. static bool isTCOperandAcc(isl::set Domain, isl::map AccMap,
  1099. SmallDenseSet<int> &IndexSet,
  1100. SmallVectorImpl<int> &DimensionSizes,
  1101. SmallVectorImpl<int> &Dimensions) {
  1102. isl::id Id = AccMap.get_tuple_id(isl::dim::out);
  1103. const ScopArrayInfo *SAI = ScopArrayInfo::getFromId(Id);
  1104. assert(SAI && "AccMap should represent memory access");
  1105. // Fix values of output dimensions with respect to their positions.
  1106. // In the case of the tensor contraction, values of output dimensions are
  1107. // fixed and form a permutation of a subset of values of input dimensions.
  1108. //
  1109. // For example, in the case of Stmt[i][j][k] -> A[k][i], which represents
  1110. // the operand of the tensor contraction, we get the following map by fixing
  1111. // the output dimensions Stmt[1][j][0] -> A[0][1].
  1112. //
  1113. // We store the permutation of the subset of the input dimensions {2, 0} into
  1114. // @p Dimensions.
  1115. //
  1116. // The obtained permutation and the isCorrectAccessMap function are used to
  1117. // check whether the access relation @p AccMap represents the tensor
  1118. // contraction operand. For example, in the case of
  1119. // Stmt[i][j][k] -> A[i-1][j+1], we get Stmt[1][0][k] -> A[0][1] and,
  1120. // consequently, {1, 0}, which is rejected by isCorrectAccessMap,
  1121. // since it corresponds to Stmt[i][j][k] -> A[j][i].
  1122. isl::map CheckMap = isl::manage(AccMap.copy());
  1123. unsigned OutDimNum = unsignedFromIslSize(CheckMap.dim(isl::dim::out));
  1124. for (unsigned i = 0; i < OutDimNum; i++)
  1125. CheckMap = CheckMap.fix_si(isl::dim::out, i, i);
  1126. // Try to obtain the permutation and sizes of corresponding input dimensions.
  1127. Dimensions.assign(OutDimNum, -1);
  1128. for (unsigned i : rangeIslSize(0, CheckMap.dim(isl::dim::in))) {
  1129. isl::val Val = getConstant(CheckMap, isl::dim::in, i);
  1130. if (!Val.is_int())
  1131. continue;
  1132. int OutPos = -1;
  1133. llvm::APInt ValAPInt = APIntFromVal(Val);
  1134. if (ValAPInt.isSignedIntN(32))
  1135. OutPos = ValAPInt.getSExtValue();
  1136. if ((OutPos < 0) || (OutPos >= static_cast<int>(OutDimNum)) ||
  1137. IndexSet.count(i))
  1138. return false;
  1139. IndexSet.insert(i);
  1140. Dimensions[OutPos] = i;
  1141. if (DimensionSizes[i] <= 0)
  1142. DimensionSizes[i] = getDimSize(SAI, OutPos);
  1143. }
  1144. return isCorrectAccessMap(Domain, AccMap, Dimensions);
  1145. }
  1146. /// Find the intersection of two sets.
  1147. ///
  1148. /// Find the intersection of the set @p A and the set @p B.
  1149. ///
  1150. /// @param A, B Sets to intersect.
  1151. /// @return The set intersection.
  1152. static SmallDenseSet<int> intersect(const SmallDenseSet<int> &A,
  1153. const SmallDenseSet<int> &B) {
  1154. SmallDenseSet<int> Intersection = A;
  1155. set_intersect(Intersection, B);
  1156. return Intersection;
  1157. }
  1158. /// Check whether the set is a superset.
  1159. ///
  1160. /// Check that the set @p A is a superset of @p B.
  1161. ///
  1162. /// @param A, B Sets to be checked.
  1163. /// @return True if the set A is a superset of B.
  1164. static bool isSuperset(const SmallDenseSet<int> &A,
  1165. const SmallDenseSet<int> &B) {
  1166. return intersect(A, B).size() == B.size();
  1167. }
  1168. /// Find the union of two sets.
  1169. ///
  1170. /// Find the union of the set @p A and the set @p B.
  1171. ///
  1172. /// @param A, B Sets to unite.
  1173. /// @return The set union.
  1174. static SmallDenseSet<int> unite(const SmallDenseSet<int> &A,
  1175. const SmallDenseSet<int> &B) {
  1176. SmallDenseSet<int> Union = A;
  1177. set_union(Union, B);
  1178. return Union;
  1179. }
  1180. /// Determine the access that writes to the tensor, which contains
  1181. /// the result of the tensor contraction.
  1182. ///
  1183. /// @param Domain The domain of the statement.
  1184. /// @param Stmt The statement, which writes to memory.
  1185. /// @param TCI The information about the tensor contraction.
  1186. /// @param IandJIndexSet The set, which contains free indexes of tensors.
  1187. /// @return The determined MemoryAccess, or nullptr if there is no necessary
  1188. /// access within the SCoP.
  1189. static MemoryAccess *getWriteAccess(isl::set Domain, ScopStmt *Stmt,
  1190. TCInfoTy &TCI,
  1191. SmallDenseSet<int> &IandJIndexSet) {
  1192. TCI.WriteToC = nullptr;
  1193. SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt);
  1194. for (MemoryAccess *MemA : reverse(Accesses)) {
  1195. // A TC-like does not contain write scalar memory accesses
  1196. if (!MemA->isLatestArrayKind())
  1197. return nullptr;
  1198. // The last memory access should be a write memory access.
  1199. if (!MemA->isWrite())
  1200. return nullptr;
  1201. isl::map AccMap = MemA->getLatestAccessRelation();
  1202. if (!isTCOperandAcc(Domain, AccMap, IandJIndexSet, TCI.DimensionSizes,
  1203. TCI.CDimensions))
  1204. return nullptr;
  1205. return MemA;
  1206. }
  1207. return nullptr;
  1208. }
  1209. /// Determine an access, which reads elements of an operand of the tensor
  1210. /// contraction
  1211. ///
  1212. /// @param MemAccessPtr The access, which reads elements of the tensor.
  1213. /// @param IndexSet The set, which contains indexes of the tensors.
  1214. /// @param IandJIndexSet The set, which contains free indexes of tensors.
  1215. /// @param Dimensions The permutation of the subset of the input dimensions.
  1216. /// @param TCI The information about the tensor contraction.
  1217. /// @return True if the memory access @p MemAccessPtr corresponds
  1218. /// to the tensor contraction.
  1219. static bool setReadAccess(MemoryAccess *MemAccessPtr,
  1220. const SmallDenseSet<int> &IndexSet,
  1221. const SmallDenseSet<int> &IandJIndexSet,
  1222. ArrayRef<int> Dimensions, TCInfoTy &TCI) {
  1223. if (!TCI.A) {
  1224. // Probably IndexSet is a union of I and P sets.
  1225. if (!isSuperset(IndexSet, TCI.P))
  1226. return false;
  1227. // Obtain the set I.
  1228. TCI.I = set_difference(IndexSet, TCI.P);
  1229. if (!isSuperset(IandJIndexSet, TCI.I))
  1230. return false;
  1231. // Obtain the set J.
  1232. TCI.J = set_difference(IandJIndexSet, TCI.I);
  1233. // Set the first operand of the tensor contraction.
  1234. TCI.A = MemAccessPtr;
  1235. llvm::replace(TCI.ADimensions, TCI.ADimensions.begin(),
  1236. TCI.ADimensions.end(), Dimensions.begin(), Dimensions.end());
  1237. return true;
  1238. }
  1239. if (!TCI.B) {
  1240. // IndexSet should be a union of J and P sets.
  1241. if (unite(TCI.P, TCI.J) != IndexSet)
  1242. return false;
  1243. // Set the second operand of the tensor contraction.
  1244. TCI.B = MemAccessPtr;
  1245. llvm::replace(TCI.BDimensions, TCI.BDimensions.begin(),
  1246. TCI.BDimensions.end(), Dimensions.begin(), Dimensions.end());
  1247. return true;
  1248. }
  1249. return false;
  1250. }
  1251. /// Check that all memory accesses of the statement, except from the last
  1252. /// one, are read memory accesses, which read elements of operands of the tensor
  1253. /// contraction and its result.
  1254. ///
  1255. /// @param Domain The domain of the statement.
  1256. /// @param Stmt The statement, which writes to memory.
  1257. /// @param TCI The information about the tensor contraction.
  1258. /// @param IandJIndexSet The set, which contains free indexes of tensors.
  1259. /// @return True if all read memory accesses of the statement @p Stmt correspond
  1260. /// to the tensor contraction.
  1261. static bool setReadAccesses(isl::set Domain, ScopStmt *Stmt, TCInfoTy &TCI,
  1262. SmallDenseSet<int> &IandJIndexSet) {
  1263. TCI.A = nullptr;
  1264. TCI.B = nullptr;
  1265. TCI.ReadFromC = nullptr;
  1266. SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt);
  1267. for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) {
  1268. MemoryAccess *MemAccessPtr = *MemA;
  1269. // All memory accesses, except from the last one, should be read memory
  1270. // accesses.
  1271. if (MemAccessPtr->isWrite())
  1272. return false;
  1273. isl::map AccMap = MemAccessPtr->getLatestAccessRelation();
  1274. if (!MemAccessPtr->isLatestArrayKind()) {
  1275. // Check whether the scalar read memory access is not partial.
  1276. if (!Domain.is_subset(AccMap.domain()))
  1277. return false;
  1278. continue;
  1279. return false;
  1280. }
  1281. // There is only one memory access, which reads elements of the result of
  1282. // the tensor contraction.
  1283. if (AccMap.is_equal(TCI.WriteToC->getLatestAccessRelation())) {
  1284. if (TCI.ReadFromC)
  1285. return false;
  1286. TCI.ReadFromC = MemAccessPtr;
  1287. continue;
  1288. }
  1289. SmallVector<int> Dimensions;
  1290. SmallDenseSet<int> IndexSet;
  1291. if (!isTCOperandAcc(Domain, AccMap, IndexSet, TCI.DimensionSizes,
  1292. Dimensions))
  1293. return false;
  1294. if (!setReadAccess(MemAccessPtr, IndexSet, IandJIndexSet, Dimensions, TCI))
  1295. return false;
  1296. }
  1297. // Check that there are read memory accesses, which read elements of operands
  1298. // of the tensor contraction and its result.
  1299. return TCI.ReadFromC && TCI.A && TCI.B;
  1300. }
  1301. /// Check accesses to operands of the tensor contraction.
  1302. ///
  1303. /// Check that accesses of the SCoP statement, which corresponds to
  1304. /// the partial schedule @p PartialSchedule, represent accesses
  1305. /// to the non-scalar operands of the tensor contraction.
  1306. ///
  1307. /// @param Domain The domain of the SCoP statement.
  1308. /// @param PartialSchedule The partial schedule of the SCoP statement.
  1309. /// @param TCI Parameters of the tensor contraction operands.
  1310. /// @return True if the corresponding SCoP statement
  1311. /// represents tensor contraction and false,
  1312. /// otherwise.
  1313. static bool containsOnlyTCAcc(isl::set Domain, isl::map PartialSchedule,
  1314. TCInfoTy &TCI) {
  1315. isl::id InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in);
  1316. ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
  1317. // In region statements, the order of memory accesses execution is not
  1318. // predictable at compile-time.
  1319. if ((Stmt->size() <= 1) || Stmt->isRegionStmt())
  1320. return false;
  1321. unsigned DimNum = unsignedFromIslSize(PartialSchedule.dim(isl::dim::in));
  1322. TCI.DimensionSizes.resize(DimNum);
  1323. SmallDenseSet<int> IandJIndexSet;
  1324. TCI.WriteToC = getWriteAccess(Domain, Stmt, TCI, IandJIndexSet);
  1325. if (!TCI.WriteToC)
  1326. return false;
  1327. if (intersect(IandJIndexSet, TCI.P).size() != 0)
  1328. return false;
  1329. if (!setReadAccesses(Domain, Stmt, TCI, IandJIndexSet))
  1330. return false;
  1331. return true;
  1332. }
  1333. /// Check that dependency corresponds to the tensor contraction carried over
  1334. /// loop dimension @p Dim.
  1335. ///
  1336. /// Check that the dependency has the form
  1337. /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) ->
  1338. /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP
  1339. /// statement. For this purpose, we analyze the set @p DepDelta, which
  1340. /// represents the differences between image elements and domain elements of
  1341. /// the corresponding map.
  1342. ///
  1343. /// @param DepDelta The set contains the differences between image elements
  1344. /// and corresponding domain elements of the map, which
  1345. /// represents the dependency.
  1346. /// @param Dim The position of the index ki.
  1347. /// @param BoundDeltas In the case of indexes of ki, the difference between
  1348. /// image elements and corresponding domain elements
  1349. /// corresponds to the difference between lexicographic
  1350. /// minimum and lexicographic maximum of the corresponding
  1351. /// dimension of the domain of the statement.
  1352. /// @param IndexSet Obtained indexes ki, which describe the dependency.
  1353. /// @return True if dependencies correspond to the tensor contraction
  1354. /// and false, otherwise.
  1355. static bool isReductionCarriedOverDim(isl::set DepDelta, unsigned Dim,
  1356. isl::pw_multi_aff BoundDeltas,
  1357. const SmallDenseSet<int> &IndexSet) {
  1358. isl::space Space = DepDelta.get_space();
  1359. isl::set Superset = isl::set::universe(Space);
  1360. for (unsigned i = 0; i < Dim; i += 1)
  1361. Superset = Superset.fix_si(isl::dim::set, i, 0);
  1362. Superset = Superset.fix_si(isl::dim::set, Dim, 1);
  1363. // Check that the difference between the image element and the domain element
  1364. // is equal to one in the case of the index ki. Image elements and
  1365. // corresponding domain elements should be equal in the case of positions,
  1366. // which are lower than the specified position.
  1367. if (!DepDelta.is_subset(Superset))
  1368. return false;
  1369. // Compute a set, which is used to analyze how values of
  1370. // the domain are related to the map that describes the dependency.
  1371. isl_pw_multi_aff *DepDeltaPW = isl_pw_multi_aff_from_set(DepDelta.copy());
  1372. BoundDeltas = BoundDeltas.add(isl::manage(DepDeltaPW));
  1373. isl_set *ComplementRawSet = isl_set_from_pw_multi_aff(BoundDeltas.release());
  1374. isl::set Complement = isl::manage(ComplementRawSet);
  1375. for (unsigned i : rangeIslSize(Dim + 1, DepDelta.dim(isl::dim::set))) {
  1376. if (!IndexSet.count(i)) {
  1377. // Check the difference between the image element and the domain element
  1378. // in the case of indexes, which do not describe the dependency.
  1379. if (DepDelta.plain_get_val_if_fixed(isl::dim::set, i).is_zero())
  1380. continue;
  1381. return false;
  1382. }
  1383. // In the case of other indexes, which describe the dependency,
  1384. // the difference between the image element and the domain element
  1385. // should be equal to the difference between lexicographic minimum and
  1386. // lexicographic maximum of the domain of the statement.
  1387. if (!Complement.plain_get_val_if_fixed(isl::dim::set, i).is_zero())
  1388. return false;
  1389. }
  1390. return true;
  1391. }
  1392. /// Check whether dependencies are over the complete domain.
  1393. ///
  1394. /// In the case of the tensor contraction RAW, WAW, WAR dependencies
  1395. /// have the form
  1396. /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) ->
  1397. /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP
  1398. /// statement. Consequently, the domain of the dependencies
  1399. /// can be described as
  1400. /// Domain / Domain ∩ S(…, max(kn),…) ∩ S(…, max(k(i + 1)),…),
  1401. /// where Domain is the domain of the statement S.
  1402. ///
  1403. /// For example, in the case of the following tensor contraction,
  1404. /// corresponding domains will have the following form.
  1405. ///
  1406. /// An example of the tensor contraction:
  1407. /// for (i = 0; i < 1024; i++)
  1408. /// for (j = 0; j < 1024; j++)
  1409. /// for (l = 0; l < 64; ++l)
  1410. /// for (w = 0; w < 64; ++w)
  1411. /// C[i][j] += A[i][l][w] * B[w][j][l];
  1412. ///
  1413. /// The domain of the statement:
  1414. /// { S[i0, i1, i2, i3] : i0 >= 0 and i0 <= 1023 and
  1415. /// i1 >= 0 and i1 <= 1023 and
  1416. /// i2 >= 0 and i2 <= 63 and
  1417. /// i3 >= 0 and i3 <= 63 }
  1418. ///
  1419. /// The domain of the dependencies:
  1420. /// { S[i0, i1, i2, i3] : (i0 >= 0 and i0 <= 1023 and
  1421. /// i1 >= 0 and i1 <= 1023 and
  1422. /// i2 >= 0 and i2 <= 63 and
  1423. /// i3 >= 0 and i3 <= 62) or
  1424. /// (i3 = 63 and i0 >= 0 and i0 <= 1023 and
  1425. /// i1 >= 0 and i1 <= 1023 and
  1426. /// i2 >= 0 and i2 <= 62) }
  1427. ///
  1428. /// @param Domain The domain of the statement.
  1429. /// @param DepsForStmt RAW and RED dependencies for the statement.
  1430. /// @param UpperBound The lexicographic maximum of the elements in
  1431. /// the @p Domain.
  1432. /// @param IndexSet Obtained indexes ki, which describe the dependencies.
  1433. /// @return True if dependencies are over the complete domain
  1434. /// and false, otherwise.
  1435. static bool areDepsOverCompleteDomain(isl::set Domain, isl::map DepsForStmt,
  1436. isl::pw_multi_aff UpperBound,
  1437. SmallDenseSet<int> &IndexSet) {
  1438. isl_set *UpperBoundRawSet = isl_set_from_pw_multi_aff(UpperBound.copy());
  1439. isl::set UpperBoundSet = isl::manage(UpperBoundRawSet);
  1440. isl::set DomainRed = isl::manage(Domain.copy());
  1441. for (const auto It : IndexSet) {
  1442. isl::val FixedVal = UpperBoundSet.plain_get_val_if_fixed(isl::dim::set, It);
  1443. if (FixedVal.is_nan())
  1444. return false;
  1445. DomainRed = isl::manage(
  1446. isl_set_fix_val(DomainRed.copy(), isl_dim_set, It, FixedVal.release()));
  1447. }
  1448. return DepsForStmt.domain().intersect(Domain).is_equal(
  1449. Domain.subtract(DomainRed));
  1450. }
  1451. /// Check that dependencies correspond to the tensor contraction.
  1452. ///
  1453. /// Check that there are only true dependencies of the form
  1454. /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) ->
  1455. /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP
  1456. /// statement represented by @p Schedule. Such dependencies are produced by
  1457. /// the tensor contraction. Obtained indexes ki are stored into @p IndexSet.
  1458. ///
  1459. /// The form of anti and output dependencies is specified implicitly by
  1460. /// the form the SCoP statement, which is checked by subsequent analysis.
  1461. ///
  1462. /// @param Schedule The schedule of the SCoP statement.
  1463. /// @param D The SCoP dependencies.
  1464. /// @param Domain The domain of the statement.
  1465. /// @param IndexSet Obtained indexes ki, which describe the dependencies.
  1466. /// @return True if dependencies correspond to the tensor contraction
  1467. /// and false, otherwise.
  1468. static bool containsOnlyTcDeps(isl::map Schedule, const Dependences *D,
  1469. SmallDenseSet<int> &IndexSet, isl::set Domain) {
  1470. IslMaxOperationsGuard MaxOpGuard(Schedule.ctx().get(), OptComputeOut);
  1471. isl::union_map Dep =
  1472. D->getDependences(Dependences::TYPE_RAW | Dependences::TYPE_RED);
  1473. isl::space DomainSpace = Schedule.get_space().domain();
  1474. isl::space Space = DomainSpace.map_from_domain_and_range(DomainSpace);
  1475. isl::map DepsForStmt = Dep.extract_map(Space);
  1476. isl::set DepDeltas = DepsForStmt.deltas();
  1477. isl::size DeltasDimNum = DepDeltas.dim(isl::dim::set);
  1478. isl::pw_multi_aff LowerBound = Domain.lexmin_pw_multi_aff();
  1479. isl::pw_multi_aff UpperBound = Domain.lexmax_pw_multi_aff();
  1480. isl::pw_multi_aff BoundDeltas = UpperBound.sub(LowerBound);
  1481. for (int i : reverse(rangeIslSize(0, DeltasDimNum))) {
  1482. // In the case of the tensor contraction, the difference between image
  1483. // elements and domain elements lies on a hyperplane where a dimension
  1484. // has the fixed value one.
  1485. isl::set Intersection = DepDeltas.fix_si(isl::dim::set, i, 1);
  1486. if (Intersection.is_empty())
  1487. continue;
  1488. if (!isReductionCarriedOverDim(Intersection, i, BoundDeltas, IndexSet))
  1489. return false;
  1490. IndexSet.insert(i);
  1491. DepDeltas = DepDeltas.subtract(Intersection);
  1492. }
  1493. // In the case of the tensor contraction, all dependencies should have
  1494. // the previously described form.
  1495. if ((unsignedFromIslSize(DeltasDimNum) == 0) || !DepDeltas.is_empty())
  1496. return false;
  1497. return areDepsOverCompleteDomain(Domain, DepsForStmt, UpperBound, IndexSet);
  1498. }
  1499. /// Check if the SCoP statement could probably be optimized with analytical
  1500. /// modeling.
  1501. ///
  1502. /// containsTCInfoTy tries to determine whether the following conditions
  1503. /// are true:
  1504. ///
  1505. /// 1. The last memory access modeling an array, MA1, represents writing to
  1506. /// memory and has the form S(..., I, ..., J, ...) -> M(shuffle(I, J)),
  1507. /// where S is the SCoP statement under consideration and shuffle(I, J)
  1508. /// is a permutation of indexes of sets I and J.
  1509. /// 2. There are only true dependencies of the form
  1510. /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) ->
  1511. /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP
  1512. /// statement represented by @p Schedule and ki are indexes of the set P.
  1513. /// 3. SCoP contains an arbitrary number of reads from constants and only three
  1514. /// access relations, MA2, MA3, and MA4 that represent reading from memory
  1515. /// and have the form
  1516. /// S(..., I, ..., P, ...) -> M(shuffle(I, P)),
  1517. /// S(..., P, ..., J, ...) -> M(shuffle(J, P)),
  1518. /// S(...) -> M(shuffle(I, J)), respectively.
  1519. ///
  1520. /// @param PartialSchedule The PartialSchedule that contains a SCoP statement
  1521. /// to check.
  1522. /// @param D The SCoP dependencies.
  1523. /// @param TCI Parameters of the tensor contraction operands.
  1524. /// @param Domain The domain of the statement.
  1525. /// @return True if dependencies and memory accesses correspond to the tensor
  1526. /// contraction and false, otherwise.
  1527. static bool containsTCInfoTy(isl::map PartialSchedule, const Dependences *D,
  1528. TCInfoTy &TCI, isl::set Domain) {
  1529. if (!containsOnlyTcDeps(PartialSchedule, D, TCI.P, Domain))
  1530. return false;
  1531. // TODO: handle cases of scalar multiplication if needed.
  1532. if (TCI.P.size() == 0)
  1533. return false;
  1534. if (!containsOnlyTCAcc(Domain, PartialSchedule, TCI))
  1535. return false;
  1536. // TODO: handle cases of GEMV if needed.
  1537. if ((TCI.I.size() == 0) || (TCI.J.size() == 0))
  1538. return false;
  1539. return true;
  1540. }
  1541. /// Check if this node contains a partial schedule that could
  1542. /// probably be optimized with analytical modeling.
  1543. ///
  1544. /// isTCPattern is used to determine whether the SCoP represents a TC-like
  1545. /// kernel [1], which is a perfectly nested set of loops, with a data usage
  1546. /// pattern that is similar to that produced by the tensor contraction.
  1547. ///
  1548. /// A TC-like kernel can be defined as follows:
  1549. ///
  1550. /// 1. It satisfies the requirements of the polyhedral model.
  1551. /// 2. Without loss of generality, it contains three nonempty bundles of
  1552. /// one-dimensional for-loops with induction variables that are grouped into
  1553. /// bundles I = i0...i(r-1), J = j0..j(s-1), and P = p0...p(t-1), and they
  1554. /// are incremented by one.
  1555. /// 3. The innermost loop body can be represented as a statement of the form
  1556. /// C(shuffle(I, J)) = E(A(shuffle(I, P)), B(shuffle(P, J)),
  1557. /// C(shuffle(I, J))), where A(shuffle(I, P)), B(shuffle(P, J)),
  1558. /// C(shuffle(I, J)) are accesses to tensors A, B, C, respectively,
  1559. /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of the
  1560. /// enclosed indices, and E is an expression that contains reads from
  1561. /// the tensors A, B, C, and an arbitrary number of reads from constants
  1562. /// with respect to bundles I, J, and P.
  1563. ///
  1564. /// TC can be considered as a particular case of a TC-like kernel.
  1565. ///
  1566. /// The order of loops with indexes from P should be preserved. Otherwise,
  1567. /// isTCPattern should check if a commutative operation is used.
  1568. ///
  1569. /// isTCPattern performs the following steps to check whether the SCoP
  1570. /// corresponds to a definition of a TC-like kernel:
  1571. ///
  1572. /// 1. Checks that the node is the innermost band node.
  1573. /// 2. Checks that the partial schedule contains only one statement.
  1574. /// 3. Check that all ancestors of the node contain all band nodes for
  1575. /// the statement and only mark nodes interleave such band nodes. This
  1576. /// corresponds to a straightforward implementation of TC.
  1577. /// 4. Analyses the dependencies to determine contraction dimensions.
  1578. /// 5. Check that the last memory access modeling an array, represents writing
  1579. /// to the result of the TC-like kernel.
  1580. /// 6. Check that SCoP contains only three access relations that represent
  1581. /// reading of the operands of the TC-like kernel and an arbitrary number of
  1582. /// reads from constants.
  1583. ///
  1584. /// [1] - Gareev R., Grosser T., Kruse M. High-Performance Generalized Tensor
  1585. /// Operations: A Compiler-Oriented Approach // ACM Transactions
  1586. /// Architecture and Code Optimization (TACO). 2018.
  1587. /// Vol. 15, no. 3. P. 34:1–34:27. DOI: 10.1145/3235029.
  1588. ///
  1589. /// If this is the case, we could logically represent tensors as matrices and
  1590. /// apply algorithms, which are used to get close-to-peak performance of
  1591. /// matrix multiplications in manually tuned BLAS libraries (e.g., BLIS).
  1592. ///
  1593. /// @param Node The node to check.
  1594. /// @param D The SCoP dependencies.
  1595. /// @param TCI Parameters of the tensor contraction operands.
  1596. static bool isTCPattern(isl::schedule_node Node, const Dependences *D,
  1597. TCInfoTy &TCI) {
  1598. Node = Node.child(0);
  1599. isl::union_map PartialSchedule = Node.get_prefix_schedule_union_map();
  1600. isl::union_set Domain = Node.domain();
  1601. Node = Node.parent();
  1602. // The partial schedule should contain only one statement.
  1603. // TODO: This constraint should not be intrinsic to the algorithm.
  1604. if (isl_union_set_n_set(Domain.get()) != 1)
  1605. return false;
  1606. isl_schedule_node_type NodeType = isl_schedule_node_get_type(Node.get());
  1607. // Check that all ancestors of the node contain all band nodes for
  1608. // the statement, which represents the TC-like kernel, and only mark nodes
  1609. // interleave such band nodes. This corresponds to a straightforward
  1610. // implementation of TC with/without DeLICM applied.
  1611. //
  1612. // For example, this covers the matrix multiplication pattern after a full
  1613. // run of -polly-optree and -polly-delicm, where the write access is not
  1614. // through the original memory access, but trough a PHI node that was
  1615. // delicmed. Subsequently, such band nodes will be replaced by a single band
  1616. // node.
  1617. //
  1618. // The corresponding schedule can be the following, where Stmt_for_body8
  1619. // contains the matrix multiplication:
  1620. //
  1621. // domain: "{ Stmt_for_body8[i0, i1, i2] : 0 <= i0 <= 1599 and
  1622. // 0 <= i1 <= 1799 and
  1623. // 0 <= i2 <= 2199;
  1624. // Stmt_for_body3[i0, i1] : 0 <= i0 <= 1599 and
  1625. // 0 <= i1 <= 1799;
  1626. // Stmt_for_body3_last[i0, i1] : 0 <= i0 <= 1599 and
  1627. // 0 <= i1 <= 1799 }"
  1628. // child:
  1629. // sequence:
  1630. // - filter: "{ Stmt_for_body3[i0, i1] }"
  1631. // child:
  1632. // schedule: "[{ Stmt_for_body3[i0, i1] -> [(i0)] },
  1633. // { Stmt_for_body3[i0, i1] -> [(i1)] }]"
  1634. // permutable: 1
  1635. // coincident: [ 1, 1 ]
  1636. // - filter: "{ Stmt_for_body3_last[i0, i1] }"
  1637. // child:
  1638. // schedule: "[{ Stmt_for_body3_last[i0, i1] -> [(i0)] },
  1639. // { Stmt_for_body3_last[i0, i1] -> [(i1)] }]"
  1640. // permutable: 1
  1641. // coincident: [ 1, 1 ]
  1642. // - filter: "{ Stmt_for_body8[i0, i1, i2] }"
  1643. // child:
  1644. // schedule: "[{ Stmt_for_body8[i0, i1, i2] -> [(i0)] },
  1645. // { Stmt_for_body8[i0, i1, i2] -> [(i1)] },
  1646. // { Stmt_for_body8[i0, i1, i2] -> [(i2)] }]"
  1647. // permutable: 1
  1648. // coincident: [ 1, 1, 0 ]
  1649. //
  1650. while (NodeType != isl_schedule_node_domain) {
  1651. if (NodeType == isl_schedule_node_filter) {
  1652. if (!Node.parent().isa<isl::schedule_node_sequence>() ||
  1653. !Node.parent().parent().isa<isl::schedule_node_domain>())
  1654. return false;
  1655. break;
  1656. }
  1657. if ((NodeType != isl_schedule_node_band) &&
  1658. (NodeType != isl_schedule_node_mark))
  1659. return false;
  1660. Node = Node.parent();
  1661. NodeType = isl_schedule_node_get_type(Node.get());
  1662. }
  1663. isl::map PartialScheduleMap = isl::map::from_union_map(PartialSchedule);
  1664. if (containsTCInfoTy(PartialScheduleMap, D, TCI, isl::set(Domain)))
  1665. return true;
  1666. return false;
  1667. }
  1668. } // namespace
  1669. isl::schedule_node
  1670. polly::tryOptimizeMatMulPattern(isl::schedule_node Node,
  1671. const llvm::TargetTransformInfo *TTI,
  1672. const Dependences *D) {
  1673. TCInfoTy TCI;
  1674. if (PMBasedTCOpts && isTCPattern(Node, D, TCI))
  1675. LLVM_DEBUG(dbgs() << "The tensor contraction pattern was detected\n");
  1676. MatMulInfoTy MMI;
  1677. if (PMBasedMMMOpts && isMatrMultPattern(Node, D, MMI)) {
  1678. LLVM_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
  1679. return optimizeMatMulPattern(Node, TTI, MMI);
  1680. }
  1681. return {};
  1682. }