MatmulOptimizer.cpp 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039
  1. //===- MatmulOptimizer.cpp -----------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "polly/MatmulOptimizer.h"
  9. #include "polly/DependenceInfo.h"
  10. #include "polly/Options.h"
  11. #include "polly/ScheduleTreeTransform.h"
  12. #include "polly/ScopInfo.h"
  13. #include "polly/ScopPass.h"
  14. #include "polly/Simplify.h"
  15. #include "polly/Support/ISLTools.h"
  16. #include "llvm/ADT/ArrayRef.h"
  17. #include "llvm/ADT/Optional.h"
  18. #include "llvm/ADT/Sequence.h"
  19. #include "llvm/ADT/SmallVector.h"
  20. #include "llvm/ADT/StringRef.h"
  21. #include "llvm/ADT/iterator_range.h"
  22. #include "llvm/Analysis/TargetTransformInfo.h"
  23. #include "llvm/IR/DataLayout.h"
  24. #include "llvm/IR/Function.h"
  25. #include "llvm/IR/Module.h"
  26. #include "llvm/Support/CommandLine.h"
  27. #include "llvm/Support/Debug.h"
  28. #include "llvm/Support/TypeSize.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. #include "isl/ctx.h"
  31. #include "isl/schedule_node.h"
  32. #include "isl/schedule_type.h"
  33. #include "isl/union_map.h"
  34. #include "isl/union_set.h"
  35. #include <algorithm>
  36. #include <cassert>
  37. #include <cmath>
  38. #include <cstdint>
  39. #include <string>
  40. #include <vector>
  41. #define DEBUG_TYPE "polly-opt-isl"
  42. using namespace llvm;
  43. using namespace polly;
  44. namespace llvm {
  45. class Value;
  46. }
  47. static cl::opt<int> LatencyVectorFma(
  48. "polly-target-latency-vector-fma",
  49. cl::desc("The minimal number of cycles between issuing two "
  50. "dependent consecutive vector fused multiply-add "
  51. "instructions."),
  52. cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
  53. static cl::opt<int> ThroughputVectorFma(
  54. "polly-target-throughput-vector-fma",
  55. cl::desc("A throughput of the processor floating-point arithmetic units "
  56. "expressed in the number of vector fused multiply-add "
  57. "instructions per clock cycle."),
  58. cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory));
  59. static cl::opt<int> FirstCacheLevelSize(
  60. "polly-target-1st-cache-level-size",
  61. cl::desc("The size of the first cache level specified in bytes."),
  62. cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
  63. static cl::opt<int> FirstCacheLevelDefaultSize(
  64. "polly-target-1st-cache-level-default-size",
  65. cl::desc("The default size of the first cache level specified in bytes"
  66. " (if not enough were provided by the TargetTransformInfo)."),
  67. cl::Hidden, cl::init(32768), cl::ZeroOrMore, cl::cat(PollyCategory));
  68. static cl::opt<int> SecondCacheLevelSize(
  69. "polly-target-2nd-cache-level-size",
  70. cl::desc("The size of the second level specified in bytes."), cl::Hidden,
  71. cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
  72. static cl::opt<int> SecondCacheLevelDefaultSize(
  73. "polly-target-2nd-cache-level-default-size",
  74. cl::desc("The default size of the second cache level specified in bytes"
  75. " (if not enough were provided by the TargetTransformInfo)."),
  76. cl::Hidden, cl::init(262144), cl::ZeroOrMore, cl::cat(PollyCategory));
  77. // This option, along with --polly-target-2nd-cache-level-associativity,
  78. // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size
  79. // represent the parameters of the target cache, which do not have typical
  80. // values that can be used by default. However, to apply the pattern matching
  81. // optimizations, we use the values of the parameters of Intel Core i7-3820
  82. // SandyBridge in case the parameters are not specified or not provided by the
  83. // TargetTransformInfo.
  84. static cl::opt<int> FirstCacheLevelAssociativity(
  85. "polly-target-1st-cache-level-associativity",
  86. cl::desc("The associativity of the first cache level."), cl::Hidden,
  87. cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
  88. static cl::opt<int> FirstCacheLevelDefaultAssociativity(
  89. "polly-target-1st-cache-level-default-associativity",
  90. cl::desc("The default associativity of the first cache level"
  91. " (if not enough were provided by the TargetTransformInfo)."),
  92. cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
  93. static cl::opt<int> SecondCacheLevelAssociativity(
  94. "polly-target-2nd-cache-level-associativity",
  95. cl::desc("The associativity of the second cache level."), cl::Hidden,
  96. cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
  97. static cl::opt<int> SecondCacheLevelDefaultAssociativity(
  98. "polly-target-2nd-cache-level-default-associativity",
  99. cl::desc("The default associativity of the second cache level"
  100. " (if not enough were provided by the TargetTransformInfo)."),
  101. cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
  102. static cl::opt<int> VectorRegisterBitwidth(
  103. "polly-target-vector-register-bitwidth",
  104. cl::desc("The size in bits of a vector register (if not set, this "
  105. "information is taken from LLVM's target information."),
  106. cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
  107. static cl::opt<int> PollyPatternMatchingNcQuotient(
  108. "polly-pattern-matching-nc-quotient",
  109. cl::desc("Quotient that is obtained by dividing Nc, the parameter of the"
  110. "macro-kernel, by Nr, the parameter of the micro-kernel"),
  111. cl::Hidden, cl::init(256), cl::ZeroOrMore, cl::cat(PollyCategory));
  112. namespace {
  113. /// Parameters of the micro kernel.
  114. ///
  115. /// Parameters, which determine sizes of rank-1 (i.e., outer product) update
  116. /// used in the optimized matrix multiplication.
  117. struct MicroKernelParamsTy {
  118. int Mr;
  119. int Nr;
  120. };
  121. /// Parameters of the macro kernel.
  122. ///
  123. /// Parameters, which determine sizes of blocks of partitioned matrices
  124. /// used in the optimized matrix multiplication.
  125. struct MacroKernelParamsTy {
  126. int Mc;
  127. int Nc;
  128. int Kc;
  129. };
  130. /// Parameters of the matrix multiplication operands.
  131. ///
  132. /// Parameters, which describe access relations that represent operands of the
  133. /// matrix multiplication.
  134. struct MatMulInfoTy {
  135. MemoryAccess *A = nullptr;
  136. MemoryAccess *B = nullptr;
  137. MemoryAccess *ReadFromC = nullptr;
  138. MemoryAccess *WriteToC = nullptr;
  139. int i = -1;
  140. int j = -1;
  141. int k = -1;
  142. };
  143. /// Create an isl::union_set, which describes the option of the form
  144. /// [isolate[] -> unroll[x]].
  145. ///
  146. /// @param Ctx An isl::ctx, which is used to create the isl::union_set.
  147. static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) {
  148. isl::space Space = isl::space(Ctx, 0, 0, 1);
  149. isl::map UnrollIsolatedSetOption = isl::map::universe(Space);
  150. isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr);
  151. isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr);
  152. UnrollIsolatedSetOption =
  153. UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId);
  154. UnrollIsolatedSetOption =
  155. UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId);
  156. return UnrollIsolatedSetOption.wrap();
  157. }
  158. /// Permute the two dimensions of the isl map.
  159. ///
  160. /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
  161. /// have type @p DimType.
  162. ///
  163. /// @param Map The isl map to be modified.
  164. /// @param DimType The type of the dimensions.
  165. /// @param DstPos The first dimension.
  166. /// @param SrcPos The second dimension.
  167. /// @return The modified map.
  168. static isl::map permuteDimensions(isl::map Map, isl::dim DimType,
  169. unsigned DstPos, unsigned SrcPos) {
  170. assert(DstPos < unsignedFromIslSize(Map.dim(DimType)) &&
  171. SrcPos < unsignedFromIslSize(Map.dim(DimType)));
  172. if (DstPos == SrcPos)
  173. return Map;
  174. isl::id DimId;
  175. if (Map.has_tuple_id(DimType))
  176. DimId = Map.get_tuple_id(DimType);
  177. auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in;
  178. isl::id FreeDimId;
  179. if (Map.has_tuple_id(FreeDim))
  180. FreeDimId = Map.get_tuple_id(FreeDim);
  181. auto MaxDim = std::max(DstPos, SrcPos);
  182. auto MinDim = std::min(DstPos, SrcPos);
  183. Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1);
  184. Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1);
  185. Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1);
  186. Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1);
  187. if (!DimId.is_null())
  188. Map = Map.set_tuple_id(DimType, DimId);
  189. if (!FreeDimId.is_null())
  190. Map = Map.set_tuple_id(FreeDim, FreeDimId);
  191. return Map;
  192. }
  193. /// Check the form of the access relation.
  194. ///
  195. /// Check that the access relation @p AccMap has the form M[i][j], where i
  196. /// is a @p FirstPos and j is a @p SecondPos.
  197. ///
  198. /// @param AccMap The access relation to be checked.
  199. /// @param FirstPos The index of the input dimension that is mapped to
  200. /// the first output dimension.
  201. /// @param SecondPos The index of the input dimension that is mapped to the
  202. /// second output dimension.
  203. /// @return True in case @p AccMap has the expected form and false,
  204. /// otherwise.
  205. static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
  206. int &SecondPos) {
  207. isl::space Space = AccMap.get_space();
  208. isl::map Universe = isl::map::universe(Space);
  209. if (unsignedFromIslSize(Space.dim(isl::dim::out)) != 2)
  210. return false;
  211. // MatMul has the form:
  212. // for (i = 0; i < N; i++)
  213. // for (j = 0; j < M; j++)
  214. // for (k = 0; k < P; k++)
  215. // C[i, j] += A[i, k] * B[k, j]
  216. //
  217. // Permutation of three outer loops: 3! = 6 possibilities.
  218. int FirstDims[] = {0, 0, 1, 1, 2, 2};
  219. int SecondDims[] = {1, 2, 2, 0, 0, 1};
  220. for (int i = 0; i < 6; i += 1) {
  221. auto PossibleMatMul =
  222. Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
  223. .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
  224. AccMap = AccMap.intersect_domain(Domain);
  225. PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
  226. // If AccMap spans entire domain (Non-partial write),
  227. // compute FirstPos and SecondPos.
  228. // If AccMap != PossibleMatMul here (the two maps have been gisted at
  229. // this point), it means that the writes are not complete, or in other
  230. // words, it is a Partial write and Partial writes must be rejected.
  231. if (AccMap.is_equal(PossibleMatMul)) {
  232. if (FirstPos != -1 && FirstPos != FirstDims[i])
  233. continue;
  234. FirstPos = FirstDims[i];
  235. if (SecondPos != -1 && SecondPos != SecondDims[i])
  236. continue;
  237. SecondPos = SecondDims[i];
  238. return true;
  239. }
  240. }
  241. return false;
  242. }
  243. /// Does the memory access represent a non-scalar operand of the matrix
  244. /// multiplication.
  245. ///
  246. /// Check that the memory access @p MemAccess is the read access to a non-scalar
  247. /// operand of the matrix multiplication or its result.
  248. ///
  249. /// @param MemAccess The memory access to be checked.
  250. /// @param MMI Parameters of the matrix multiplication operands.
  251. /// @return True in case the memory access represents the read access
  252. /// to a non-scalar operand of the matrix multiplication and
  253. /// false, otherwise.
  254. static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
  255. MatMulInfoTy &MMI) {
  256. if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
  257. return false;
  258. auto AccMap = MemAccess->getLatestAccessRelation();
  259. isl::set StmtDomain = MemAccess->getStatement()->getDomain();
  260. if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
  261. MMI.ReadFromC = MemAccess;
  262. return true;
  263. }
  264. if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
  265. MMI.A = MemAccess;
  266. return true;
  267. }
  268. if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
  269. MMI.B = MemAccess;
  270. return true;
  271. }
  272. return false;
  273. }
  274. /// Check accesses to operands of the matrix multiplication.
  275. ///
  276. /// Check that accesses of the SCoP statement, which corresponds to
  277. /// the partial schedule @p PartialSchedule, are scalar in terms of loops
  278. /// containing the matrix multiplication, in case they do not represent
  279. /// accesses to the non-scalar operands of the matrix multiplication or
  280. /// its result.
  281. ///
  282. /// @param PartialSchedule The partial schedule of the SCoP statement.
  283. /// @param MMI Parameters of the matrix multiplication operands.
  284. /// @return True in case the corresponding SCoP statement
  285. /// represents matrix multiplication and false,
  286. /// otherwise.
  287. static bool containsOnlyMatrMultAcc(isl::map PartialSchedule,
  288. MatMulInfoTy &MMI) {
  289. auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in);
  290. auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user());
  291. unsigned OutDimNum = unsignedFromIslSize(PartialSchedule.range_tuple_dim());
  292. assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest "
  293. "and, consequently, the corresponding scheduling "
  294. "functions have at least three dimensions.");
  295. auto MapI =
  296. permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1);
  297. auto MapJ =
  298. permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1);
  299. auto MapK =
  300. permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1);
  301. auto Accesses = getAccessesInOrder(*Stmt);
  302. for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) {
  303. auto *MemAccessPtr = *MemA;
  304. if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC &&
  305. !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) &&
  306. !(MemAccessPtr->isStrideZero(MapI) &&
  307. MemAccessPtr->isStrideZero(MapJ) && MemAccessPtr->isStrideZero(MapK)))
  308. return false;
  309. }
  310. return true;
  311. }
  312. /// Check for dependencies corresponding to the matrix multiplication.
  313. ///
  314. /// Check that there is only true dependence of the form
  315. /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement
  316. /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds
  317. /// to the dependency produced by the matrix multiplication.
  318. ///
  319. /// @param Schedule The schedule of the SCoP statement.
  320. /// @param D The SCoP dependencies.
  321. /// @param Pos The parameter to describe an acceptable true dependence.
  322. /// In case it has a negative value, try to determine its
  323. /// acceptable value.
  324. /// @return True in case dependencies correspond to the matrix multiplication
  325. /// and false, otherwise.
  326. static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D,
  327. int &Pos) {
  328. isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW);
  329. isl::union_map Red = D->getDependences(Dependences::TYPE_RED);
  330. if (!Red.is_null())
  331. Dep = Dep.unite(Red);
  332. auto DomainSpace = Schedule.get_space().domain();
  333. auto Space = DomainSpace.map_from_domain_and_range(DomainSpace);
  334. auto Deltas = Dep.extract_map(Space).deltas();
  335. int DeltasDimNum = unsignedFromIslSize(Deltas.dim(isl::dim::set));
  336. for (int i = 0; i < DeltasDimNum; i++) {
  337. auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i);
  338. Pos = Pos < 0 && Val.is_one() ? i : Pos;
  339. if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one())))
  340. return false;
  341. }
  342. if (DeltasDimNum == 0 || Pos < 0)
  343. return false;
  344. return true;
  345. }
  346. /// Check if the SCoP statement could probably be optimized with analytical
  347. /// modeling.
  348. ///
  349. /// containsMatrMult tries to determine whether the following conditions
  350. /// are true:
  351. /// 1. The last memory access modeling an array, MA1, represents writing to
  352. /// memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or
  353. /// S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement
  354. /// under consideration.
  355. /// 2. There is only one loop-carried true dependency, and it has the
  356. /// form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no
  357. /// loop-carried or anti dependencies.
  358. /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent
  359. /// reading from memory and have the form S(..., i3, ...) -> M(i1, i3),
  360. /// S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively,
  361. /// and all memory accesses of the SCoP that are different from MA1, MA2,
  362. /// MA3, and MA4 have stride 0, if the innermost loop is exchanged with any
  363. /// of loops i1, i2 and i3.
  364. ///
  365. /// @param PartialSchedule The PartialSchedule that contains a SCoP statement
  366. /// to check.
  367. /// @D The SCoP dependencies.
  368. /// @MMI Parameters of the matrix multiplication operands.
  369. static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
  370. MatMulInfoTy &MMI) {
  371. auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in);
  372. auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
  373. if (Stmt->size() <= 1)
  374. return false;
  375. auto Accesses = getAccessesInOrder(*Stmt);
  376. for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) {
  377. auto *MemAccessPtr = *MemA;
  378. if (!MemAccessPtr->isLatestArrayKind())
  379. continue;
  380. if (!MemAccessPtr->isWrite())
  381. return false;
  382. auto AccMap = MemAccessPtr->getLatestAccessRelation();
  383. if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
  384. return false;
  385. MMI.WriteToC = MemAccessPtr;
  386. break;
  387. }
  388. if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k))
  389. return false;
  390. if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI))
  391. return false;
  392. if (!MMI.A || !MMI.B || !MMI.ReadFromC)
  393. return false;
  394. return true;
  395. }
  396. /// Permute two dimensions of the band node.
  397. ///
  398. /// Permute FirstDim and SecondDim dimensions of the Node.
  399. ///
  400. /// @param Node The band node to be modified.
  401. /// @param FirstDim The first dimension to be permuted.
  402. /// @param SecondDim The second dimension to be permuted.
  403. static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node,
  404. unsigned FirstDim,
  405. unsigned SecondDim) {
  406. assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band &&
  407. (unsigned)isl_schedule_node_band_n_member(Node.get()) >
  408. std::max(FirstDim, SecondDim));
  409. auto PartialSchedule =
  410. isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get()));
  411. auto PartialScheduleFirstDim = PartialSchedule.at(FirstDim);
  412. auto PartialScheduleSecondDim = PartialSchedule.at(SecondDim);
  413. PartialSchedule =
  414. PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim);
  415. PartialSchedule =
  416. PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim);
  417. Node = isl::manage(isl_schedule_node_delete(Node.release()));
  418. return Node.insert_partial_schedule(PartialSchedule);
  419. }
  420. static isl::schedule_node
  421. createMicroKernel(isl::schedule_node Node,
  422. MicroKernelParamsTy MicroKernelParams) {
  423. Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr},
  424. 1);
  425. Node = Node.parent().parent();
  426. return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0);
  427. }
  428. /// Create the BLIS macro-kernel.
  429. ///
  430. /// We create the BLIS macro-kernel by applying a combination of tiling
  431. /// of dimensions of the band node and interchanging of two innermost
  432. /// modified dimensions. The values of of MacroKernelParams's fields are used
  433. /// as tile sizes.
  434. ///
  435. /// @param Node The schedule node to be modified.
  436. /// @param MacroKernelParams Parameters of the macro kernel
  437. /// to be used as tile sizes.
  438. static isl::schedule_node
  439. createMacroKernel(isl::schedule_node Node,
  440. MacroKernelParamsTy MacroKernelParams) {
  441. assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
  442. if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 &&
  443. MacroKernelParams.Kc == 1)
  444. return Node;
  445. int DimOutNum = isl_schedule_node_band_n_member(Node.get());
  446. std::vector<int> TileSizes(DimOutNum, 1);
  447. TileSizes[DimOutNum - 3] = MacroKernelParams.Mc;
  448. TileSizes[DimOutNum - 2] = MacroKernelParams.Nc;
  449. TileSizes[DimOutNum - 1] = MacroKernelParams.Kc;
  450. Node = tileNode(Node, "1st level tiling", TileSizes, 1);
  451. Node = Node.parent().parent();
  452. Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1);
  453. Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1);
  454. // Mark the outermost loop as parallelizable.
  455. Node = Node.as<isl::schedule_node_band>().member_set_coincident(0, true);
  456. return Node.child(0).child(0);
  457. }
  458. /// Get the size of the widest type of the matrix multiplication operands
  459. /// in bytes, including alignment padding.
  460. ///
  461. /// @param MMI Parameters of the matrix multiplication operands.
  462. /// @return The size of the widest type of the matrix multiplication operands
  463. /// in bytes, including alignment padding.
  464. static uint64_t getMatMulAlignTypeSize(MatMulInfoTy MMI) {
  465. auto *S = MMI.A->getStatement()->getParent();
  466. auto &DL = S->getFunction().getParent()->getDataLayout();
  467. auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType());
  468. auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType());
  469. auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType());
  470. return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
  471. }
  472. /// Get the size of the widest type of the matrix multiplication operands
  473. /// in bits.
  474. ///
  475. /// @param MMI Parameters of the matrix multiplication operands.
  476. /// @return The size of the widest type of the matrix multiplication operands
  477. /// in bits.
  478. static uint64_t getMatMulTypeSize(MatMulInfoTy MMI) {
  479. auto *S = MMI.A->getStatement()->getParent();
  480. auto &DL = S->getFunction().getParent()->getDataLayout();
  481. auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType());
  482. auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType());
  483. auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType());
  484. return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
  485. }
  486. /// Get parameters of the BLIS micro kernel.
  487. ///
  488. /// We choose the Mr and Nr parameters of the micro kernel to be large enough
  489. /// such that no stalls caused by the combination of latencies and dependencies
  490. /// are introduced during the updates of the resulting matrix of the matrix
  491. /// multiplication. However, they should also be as small as possible to
  492. /// release more registers for entries of multiplied matrices.
  493. ///
  494. /// @param TTI Target Transform Info.
  495. /// @param MMI Parameters of the matrix multiplication operands.
  496. /// @return The structure of type MicroKernelParamsTy.
  497. /// @see MicroKernelParamsTy
  498. static struct MicroKernelParamsTy
  499. getMicroKernelParams(const TargetTransformInfo *TTI, MatMulInfoTy MMI) {
  500. assert(TTI && "The target transform info should be provided.");
  501. // Nvec - Number of double-precision floating-point numbers that can be hold
  502. // by a vector register. Use 2 by default.
  503. long RegisterBitwidth = VectorRegisterBitwidth;
  504. if (RegisterBitwidth == -1)
  505. RegisterBitwidth =
  506. TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
  507. auto ElementSize = getMatMulTypeSize(MMI);
  508. assert(ElementSize > 0 && "The element size of the matrix multiplication "
  509. "operands should be greater than zero.");
  510. auto Nvec = RegisterBitwidth / ElementSize;
  511. if (Nvec == 0)
  512. Nvec = 2;
  513. int Nr = ceil(sqrt((double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) /
  514. Nvec) *
  515. Nvec;
  516. int Mr = ceil((double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr));
  517. return {Mr, Nr};
  518. }
  519. /// Determine parameters of the target cache.
  520. ///
  521. /// @param TTI Target Transform Info.
  522. static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) {
  523. auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D;
  524. auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D;
  525. if (FirstCacheLevelSize == -1) {
  526. if (TTI->getCacheSize(L1DCache).hasValue())
  527. FirstCacheLevelSize = TTI->getCacheSize(L1DCache).getValue();
  528. else
  529. FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize);
  530. }
  531. if (SecondCacheLevelSize == -1) {
  532. if (TTI->getCacheSize(L2DCache).hasValue())
  533. SecondCacheLevelSize = TTI->getCacheSize(L2DCache).getValue();
  534. else
  535. SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize);
  536. }
  537. if (FirstCacheLevelAssociativity == -1) {
  538. if (TTI->getCacheAssociativity(L1DCache).hasValue())
  539. FirstCacheLevelAssociativity =
  540. TTI->getCacheAssociativity(L1DCache).getValue();
  541. else
  542. FirstCacheLevelAssociativity =
  543. static_cast<int>(FirstCacheLevelDefaultAssociativity);
  544. }
  545. if (SecondCacheLevelAssociativity == -1) {
  546. if (TTI->getCacheAssociativity(L2DCache).hasValue())
  547. SecondCacheLevelAssociativity =
  548. TTI->getCacheAssociativity(L2DCache).getValue();
  549. else
  550. SecondCacheLevelAssociativity =
  551. static_cast<int>(SecondCacheLevelDefaultAssociativity);
  552. }
  553. }
  554. /// Get parameters of the BLIS macro kernel.
  555. ///
  556. /// During the computation of matrix multiplication, blocks of partitioned
  557. /// matrices are mapped to different layers of the memory hierarchy.
  558. /// To optimize data reuse, blocks should be ideally kept in cache between
  559. /// iterations. Since parameters of the macro kernel determine sizes of these
  560. /// blocks, there are upper and lower bounds on these parameters.
  561. ///
  562. /// @param TTI Target Transform Info.
  563. /// @param MicroKernelParams Parameters of the micro-kernel
  564. /// to be taken into account.
  565. /// @param MMI Parameters of the matrix multiplication operands.
  566. /// @return The structure of type MacroKernelParamsTy.
  567. /// @see MacroKernelParamsTy
  568. /// @see MicroKernelParamsTy
  569. static struct MacroKernelParamsTy
  570. getMacroKernelParams(const llvm::TargetTransformInfo *TTI,
  571. const MicroKernelParamsTy &MicroKernelParams,
  572. MatMulInfoTy MMI) {
  573. getTargetCacheParameters(TTI);
  574. // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf,
  575. // it requires information about the first two levels of a cache to determine
  576. // all the parameters of a macro-kernel. It also checks that an associativity
  577. // degree of a cache level is greater than two. Otherwise, another algorithm
  578. // for determination of the parameters should be used.
  579. if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 &&
  580. FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 &&
  581. FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2))
  582. return {1, 1, 1};
  583. // The quotient should be greater than zero.
  584. if (PollyPatternMatchingNcQuotient <= 0)
  585. return {1, 1, 1};
  586. int Car = floor(
  587. (FirstCacheLevelAssociativity - 1) /
  588. (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr));
  589. // Car can be computed to be zero since it is floor to int.
  590. // On Mac OS, division by 0 does not raise a signal. This causes negative
  591. // tile sizes to be computed. Prevent division by Cac==0 by early returning
  592. // if this happens.
  593. if (Car == 0)
  594. return {1, 1, 1};
  595. auto ElementSize = getMatMulAlignTypeSize(MMI);
  596. assert(ElementSize > 0 && "The element size of the matrix multiplication "
  597. "operands should be greater than zero.");
  598. int Kc = (Car * FirstCacheLevelSize) /
  599. (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize);
  600. double Cac =
  601. static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) /
  602. SecondCacheLevelSize;
  603. int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac);
  604. int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr;
  605. assert(Mc > 0 && Nc > 0 && Kc > 0 &&
  606. "Matrix block sizes should be greater than zero");
  607. return {Mc, Nc, Kc};
  608. }
  609. /// Create an access relation that is specific to
  610. /// the matrix multiplication pattern.
  611. ///
  612. /// Create an access relation of the following form:
  613. /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ]
  614. /// where I is @p FirstDim, J is @p SecondDim.
  615. ///
  616. /// It can be used, for example, to create relations that helps to consequently
  617. /// access elements of operands of a matrix multiplication after creation of
  618. /// the BLIS micro and macro kernels.
  619. ///
  620. /// @see ScheduleTreeOptimizer::createMicroKernel
  621. /// @see ScheduleTreeOptimizer::createMacroKernel
  622. ///
  623. /// Subsequently, the described access relation is applied to the range of
  624. /// @p MapOldIndVar, that is used to map original induction variables to
  625. /// the ones, which are produced by schedule transformations. It helps to
  626. /// define relations using a new space and, at the same time, keep them
  627. /// in the original one.
  628. ///
  629. /// @param MapOldIndVar The relation, which maps original induction variables
  630. /// to the ones, which are produced by schedule
  631. /// transformations.
  632. /// @param FirstDim, SecondDim The input dimensions that are used to define
  633. /// the specified access relation.
  634. /// @return The specified access relation.
  635. static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim,
  636. unsigned SecondDim) {
  637. auto AccessRelSpace = isl::space(MapOldIndVar.ctx(), 0, 9, 3);
  638. auto AccessRel = isl::map::universe(AccessRelSpace);
  639. AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0);
  640. AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1);
  641. AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2);
  642. return MapOldIndVar.apply_range(AccessRel);
  643. }
  644. static isl::schedule_node createExtensionNode(isl::schedule_node Node,
  645. isl::map ExtensionMap) {
  646. auto Extension = isl::union_map(ExtensionMap);
  647. auto NewNode = isl::schedule_node::from_extension(Extension);
  648. return Node.graft_before(NewNode);
  649. }
  650. static isl::schedule_node optimizePackedB(isl::schedule_node Node,
  651. ScopStmt *Stmt, isl::map MapOldIndVar,
  652. MicroKernelParamsTy MicroParams,
  653. MacroKernelParamsTy MacroParams,
  654. MatMulInfoTy &MMI) {
  655. Scop *S = Stmt->getParent();
  656. isl::set Domain = Stmt->getDomain();
  657. // Create packed array.
  658. unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr;
  659. unsigned SecondDimSize = MacroParams.Kc;
  660. unsigned ThirdDimSize = MicroParams.Nr;
  661. ScopArrayInfo *PackedB =
  662. S->createScopArrayInfo(MMI.B->getElementType(), "Packed_B",
  663. {FirstDimSize, SecondDimSize, ThirdDimSize});
  664. // Compute the access relation for copying from B to PackedB.
  665. isl::map AccRelB = MMI.B->getLatestAccessRelation();
  666. isl::map AccRelPackedB = getMatMulAccRel(MapOldIndVar, 3, 7);
  667. AccRelPackedB =
  668. AccRelPackedB.set_tuple_id(isl::dim::out, PackedB->getBasePtrId());
  669. // Create the copy statement and redirect access.
  670. ScopStmt *CopyStmt = S->addScopStmt(AccRelB, AccRelPackedB, Domain);
  671. MMI.B->setNewAccessRelation(AccRelPackedB);
  672. unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim());
  673. assert(Dim >= 2);
  674. // Insert into the schedule tree.
  675. isl::map ExtMap = MapOldIndVar.project_out(isl::dim::out, 2, Dim - 2);
  676. ExtMap = ExtMap.reverse();
  677. ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0);
  678. ExtMap = ExtMap.intersect_range(Domain);
  679. ExtMap = ExtMap.set_tuple_id(isl::dim::out, CopyStmt->getDomainId());
  680. return createExtensionNode(Node, ExtMap);
  681. }
  682. static isl::schedule_node optimizePackedA(isl::schedule_node Node, ScopStmt *,
  683. isl::map MapOldIndVar,
  684. MicroKernelParamsTy MicroParams,
  685. MacroKernelParamsTy MacroParams,
  686. MatMulInfoTy &MMI) {
  687. isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
  688. ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
  689. isl::set Domain = Stmt->getDomain();
  690. isl::id DomainId = Domain.get_tuple_id();
  691. // Create the packed array.
  692. unsigned FirstDimSize = MacroParams.Mc / MicroParams.Mr;
  693. unsigned SecondDimSize = MacroParams.Kc;
  694. unsigned ThirdDimSize = MicroParams.Mr;
  695. ScopArrayInfo *PackedA = Stmt->getParent()->createScopArrayInfo(
  696. MMI.A->getElementType(), "Packed_A",
  697. {FirstDimSize, SecondDimSize, ThirdDimSize});
  698. // Compute the access relation for copying from A to PackedA.
  699. isl::map AccRelA = MMI.A->getLatestAccessRelation();
  700. isl::map AccRelPackedA = getMatMulAccRel(MapOldIndVar, 4, 6);
  701. AccRelPackedA =
  702. AccRelPackedA.set_tuple_id(isl::dim::out, PackedA->getBasePtrId());
  703. // { MemrefA[] -> PackedA[] }
  704. isl::map PackedATranslator = AccRelPackedA.apply_domain(AccRelA);
  705. // Compute the domain for the copy statement.
  706. // Construct the copy statement domain out of the 3 outermost scatter
  707. // dimensions (to match the 3 band nodes surrounding the extension node) and
  708. // the array elements to copy (one statement instance per array element).
  709. // { Scatter[] }
  710. isl::set ScatterDomain = MapOldIndVar.intersect_domain(Domain).range();
  711. // { Scatter[] -> OutermostScatter[] }
  712. isl::map OuterDomainMap =
  713. makeIdentityMap(ScatterDomain, true).project_out(isl::dim::out, 3, 6);
  714. // { Scatter[] -> MemrefA[] }
  715. isl::map CopyFrom = MapOldIndVar.reverse().apply_range(AccRelA);
  716. // { Scatter[] -> CopyStmt[] }
  717. isl::map DomainTranslator = OuterDomainMap.range_product(CopyFrom);
  718. // { CopyStmt[] }
  719. isl::set CopyDomain = DomainTranslator.range();
  720. // Translate the access relations to the new domain.
  721. // { CopyStmt[] -> MemrefA[] }
  722. CopyFrom = CopyFrom.apply_domain(DomainTranslator);
  723. // { CopyStmt[] -> PackedA[] }
  724. isl::map CopyTo = CopyFrom.apply_range(PackedATranslator);
  725. // Create the copy statement and redirect access.
  726. ScopStmt *CopyStmt =
  727. Stmt->getParent()->addScopStmt(CopyFrom, CopyTo, CopyDomain);
  728. MMI.A->setNewAccessRelation(AccRelPackedA);
  729. // Insert into the schedule tree.
  730. // { Scatter[] -> CopyStmt[] }
  731. isl::map ExtScatterCopy = makeIdentityMap(CopyStmt->getDomain(), true);
  732. ExtScatterCopy = ExtScatterCopy.project_out(isl::dim::in, 3, 2);
  733. return createExtensionNode(Node, ExtScatterCopy);
  734. }
  735. /// Apply the packing transformation.
  736. ///
  737. /// The packing transformation can be described as a data-layout
  738. /// transformation that requires to introduce a new array, copy data
  739. /// to the array, and change memory access locations to reference the array.
  740. /// It can be used to ensure that elements of the new array are read in-stride
  741. /// access, aligned to cache lines boundaries, and preloaded into certain cache
  742. /// levels.
  743. ///
  744. /// As an example let us consider the packing of the array A that would help
  745. /// to read its elements with in-stride access. An access to the array A
  746. /// is represented by an access relation that has the form
  747. /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has
  748. /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr),
  749. /// k mod Kc, j mod Nr, i mod Mr].
  750. ///
  751. /// To ensure that elements of the array A are read in-stride access, we add
  752. /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using
  753. /// Scop::createScopArrayInfo, change the access relation
  754. /// S[i, j, k] -> A[i, k] to
  755. /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using
  756. /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using
  757. /// the copy statement created by Scop::addScopStmt.
  758. ///
  759. /// @param Node The schedule node to be optimized.
  760. /// @param MapOldIndVar The relation, which maps original induction variables
  761. /// to the ones, which are produced by schedule
  762. /// transformations.
  763. /// @param MicroParams, MacroParams Parameters of the BLIS kernel
  764. /// to be taken into account.
  765. /// @param MMI Parameters of the matrix multiplication operands.
  766. /// @return The optimized schedule node.
  767. static isl::schedule_node
  768. optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar,
  769. MicroKernelParamsTy MicroParams,
  770. MacroKernelParamsTy MacroParams,
  771. MatMulInfoTy &MMI) {
  772. isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
  773. ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
  774. Node = Node.parent().parent().parent().parent().parent().parent();
  775. Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2));
  776. Node = Node.child(0);
  777. Node =
  778. optimizePackedB(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI);
  779. Node = Node.child(0);
  780. Node =
  781. optimizePackedA(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI);
  782. return Node.child(0).child(0).child(0).child(0).child(0);
  783. }
  784. /// Get a relation mapping induction variables produced by schedule
  785. /// transformations to the original ones.
  786. ///
  787. /// @param Node The schedule node produced as the result of creation
  788. /// of the BLIS kernels.
  789. /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel
  790. /// to be taken into account.
  791. /// @return The relation mapping original induction variables to the ones
  792. /// produced by schedule transformation.
  793. /// @see ScheduleTreeOptimizer::createMicroKernel
  794. /// @see ScheduleTreeOptimizer::createMacroKernel
  795. /// @see getMacroKernelParams
  796. static isl::map
  797. getInductionVariablesSubstitution(isl::schedule_node Node,
  798. MicroKernelParamsTy MicroKernelParams,
  799. MacroKernelParamsTy MacroKernelParams) {
  800. auto Child = Node.child(0);
  801. auto UnMapOldIndVar = Child.get_prefix_schedule_union_map();
  802. auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar);
  803. unsigned Dim = unsignedFromIslSize(MapOldIndVar.range_tuple_dim());
  804. if (Dim > 9u)
  805. return MapOldIndVar.project_out(isl::dim::out, 0, Dim - 9);
  806. return MapOldIndVar;
  807. }
  808. /// Isolate a set of partial tile prefixes and unroll the isolated part.
  809. ///
  810. /// The set should ensure that it contains only partial tile prefixes that have
  811. /// exactly Mr x Nr iterations of the two innermost loops produced by
  812. /// the optimization of the matrix multiplication. Mr and Nr are parameters of
  813. /// the micro-kernel.
  814. ///
  815. /// In case of parametric bounds, this helps to auto-vectorize the unrolled
  816. /// innermost loops, using the SLP vectorizer.
  817. ///
  818. /// @param Node The schedule node to be modified.
  819. /// @param MicroKernelParams Parameters of the micro-kernel
  820. /// to be taken into account.
  821. /// @return The modified isl_schedule_node.
  822. static isl::schedule_node
  823. isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node,
  824. struct MicroKernelParamsTy MicroKernelParams) {
  825. isl::schedule_node Child = Node.child(0);
  826. isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation();
  827. isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range();
  828. unsigned Dims = unsignedFromIslSize(Prefix.tuple_dim());
  829. assert(Dims >= 1);
  830. Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1);
  831. Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr);
  832. Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr);
  833. isl::union_set IsolateOption =
  834. getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3);
  835. isl::ctx Ctx = Node.ctx();
  836. auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll"));
  837. Options = Options.unite(getUnrollIsolatedSetOptions(Ctx));
  838. Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options);
  839. Node = Node.parent().parent().parent();
  840. IsolateOption = getIsolateOptions(Prefix, 3);
  841. Options = IsolateOption.unite(getDimOptions(Ctx, "separate"));
  842. Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options);
  843. Node = Node.child(0).child(0).child(0);
  844. return Node;
  845. }
  846. /// Insert "Loop Vectorizer Disabled" mark node.
  847. ///
  848. /// @param Node The child of the mark node to be inserted.
  849. /// @return The modified isl_schedule_node.
  850. static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) {
  851. auto Id = isl::id::alloc(Node.ctx(), "Loop Vectorizer Disabled", nullptr);
  852. return Node.insert_mark(Id).child(0);
  853. }
  854. /// Restore the initial ordering of dimensions of the band node
  855. ///
  856. /// In case the band node represents all the dimensions of the iteration
  857. /// domain, recreate the band node to restore the initial ordering of the
  858. /// dimensions.
  859. ///
  860. /// @param Node The band node to be modified.
  861. /// @return The modified schedule node.
  862. static isl::schedule_node
  863. getBandNodeWithOriginDimOrder(isl::schedule_node Node) {
  864. assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
  865. if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf)
  866. return Node;
  867. auto Domain = Node.get_universe_domain();
  868. assert(isl_union_set_n_set(Domain.get()) == 1);
  869. if (Node.get_schedule_depth().release() != 0 ||
  870. (unsignedFromIslSize(isl::set(Domain).tuple_dim()) !=
  871. unsignedFromIslSize(Node.as<isl::schedule_node_band>().n_member())))
  872. return Node;
  873. Node = isl::manage(isl_schedule_node_delete(Node.copy()));
  874. auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff();
  875. auto PartialScheduleMultiPwAff =
  876. isl::multi_union_pw_aff(PartialSchedulePwAff);
  877. PartialScheduleMultiPwAff =
  878. PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set);
  879. return Node.insert_partial_schedule(PartialScheduleMultiPwAff);
  880. }
  881. static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node,
  882. const TargetTransformInfo *TTI,
  883. MatMulInfoTy &MMI) {
  884. assert(TTI && "The target transform info should be provided.");
  885. int DimOutNum = isl_schedule_node_band_n_member(Node.get());
  886. assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest "
  887. "and, consequently, the corresponding scheduling "
  888. "functions have at least three dimensions.");
  889. Node = getBandNodeWithOriginDimOrder(Node);
  890. Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3);
  891. int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j;
  892. int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k;
  893. Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2);
  894. NewK = NewK == DimOutNum - 2 ? NewJ : NewK;
  895. Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1);
  896. auto MicroKernelParams = getMicroKernelParams(TTI, MMI);
  897. auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI);
  898. Node = createMacroKernel(Node, MacroKernelParams);
  899. Node = createMicroKernel(Node, MicroKernelParams);
  900. if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 ||
  901. MacroKernelParams.Kc == 1)
  902. return Node;
  903. auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams,
  904. MacroKernelParams);
  905. if (MapOldIndVar.is_null())
  906. return Node;
  907. Node = markLoopVectorizerDisabled(Node.parent()).child(0);
  908. Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams);
  909. return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
  910. MacroKernelParams, MMI);
  911. }
  912. /// Check if this node contains a partial schedule that could
  913. /// probably be optimized with analytical modeling.
  914. ///
  915. /// isMatrMultPattern tries to determine whether the following conditions
  916. /// are true:
  917. /// 1. the partial schedule contains only one statement.
  918. /// 2. there are exactly three input dimensions.
  919. /// 3. all memory accesses of the statement will have stride 0 or 1, if we
  920. /// interchange loops (switch the variable used in the inner loop to
  921. /// the outer loop).
  922. /// 4. all memory accesses of the statement except from the last one, are
  923. /// read memory access and the last one is write memory access.
  924. /// 5. all subscripts of the last memory access of the statement don't
  925. /// contain the variable used in the inner loop.
  926. /// If this is the case, we could try to use an approach that is similar to
  927. /// the one used to get close-to-peak performance of matrix multiplications.
  928. ///
  929. /// @param Node The node to check.
  930. /// @param D The SCoP dependencies.
  931. /// @param MMI Parameters of the matrix multiplication operands.
  932. static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D,
  933. MatMulInfoTy &MMI) {
  934. auto PartialSchedule = isl::manage(
  935. isl_schedule_node_band_get_partial_schedule_union_map(Node.get()));
  936. Node = Node.child(0);
  937. auto LeafType = isl_schedule_node_get_type(Node.get());
  938. Node = Node.parent();
  939. if (LeafType != isl_schedule_node_leaf ||
  940. isl_schedule_node_band_n_member(Node.get()) < 3 ||
  941. Node.get_schedule_depth().release() != 0 ||
  942. isl_union_map_n_map(PartialSchedule.get()) != 1)
  943. return false;
  944. auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule);
  945. if (containsMatrMult(NewPartialSchedule, D, MMI))
  946. return true;
  947. return false;
  948. }
  949. } // namespace
  950. isl::schedule_node
  951. polly::tryOptimizeMatMulPattern(isl::schedule_node Node,
  952. const llvm::TargetTransformInfo *TTI,
  953. const Dependences *D) {
  954. MatMulInfoTy MMI;
  955. if (isMatrMultPattern(Node, D, MMI)) {
  956. LLVM_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
  957. return optimizeMatMulPattern(Node, TTI, MMI);
  958. }
  959. return {};
  960. }