ScheduleTreeTransform.cpp 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243
  1. //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Make changes to isl's schedule tree data structure.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "polly/ScheduleTreeTransform.h"
  13. #include "polly/Support/GICHelper.h"
  14. #include "polly/Support/ISLTools.h"
  15. #include "polly/Support/ScopHelper.h"
  16. #include "llvm/ADT/ArrayRef.h"
  17. #include "llvm/ADT/Sequence.h"
  18. #include "llvm/ADT/SmallVector.h"
  19. #include "llvm/IR/Constants.h"
  20. #include "llvm/IR/Metadata.h"
  21. #include "llvm/Transforms/Utils/UnrollLoop.h"
  22. #define DEBUG_TYPE "polly-opt-isl"
  23. using namespace polly;
  24. using namespace llvm;
  25. namespace {
  26. /// Copy the band member attributes (coincidence, loop type, isolate ast loop
  27. /// type) from one band to another.
  28. static isl::schedule_node_band
  29. applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx,
  30. const isl::schedule_node_band &Source,
  31. int SourceIdx) {
  32. bool Coincident = Source.member_get_coincident(SourceIdx).release();
  33. Target = Target.member_set_coincident(TargetIdx, Coincident);
  34. isl_ast_loop_type LoopType =
  35. isl_schedule_node_band_member_get_ast_loop_type(Source.get(), SourceIdx);
  36. Target = isl::manage(isl_schedule_node_band_member_set_ast_loop_type(
  37. Target.release(), TargetIdx, LoopType))
  38. .as<isl::schedule_node_band>();
  39. isl_ast_loop_type IsolateType =
  40. isl_schedule_node_band_member_get_isolate_ast_loop_type(Source.get(),
  41. SourceIdx);
  42. Target = isl::manage(isl_schedule_node_band_member_set_isolate_ast_loop_type(
  43. Target.release(), TargetIdx, IsolateType))
  44. .as<isl::schedule_node_band>();
  45. return Target;
  46. }
  47. /// Create a new band by copying members from another @p Band. @p IncludeCb
  48. /// decides which band indices are copied to the result.
  49. template <typename CbTy>
  50. static isl::schedule rebuildBand(isl::schedule_node_band OldBand,
  51. isl::schedule Body, CbTy IncludeCb) {
  52. int NumBandDims = unsignedFromIslSize(OldBand.n_member());
  53. bool ExcludeAny = false;
  54. bool IncludeAny = false;
  55. for (auto OldIdx : seq<int>(0, NumBandDims)) {
  56. if (IncludeCb(OldIdx))
  57. IncludeAny = true;
  58. else
  59. ExcludeAny = true;
  60. }
  61. // Instead of creating a zero-member band, don't create a band at all.
  62. if (!IncludeAny)
  63. return Body;
  64. isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule();
  65. isl::multi_union_pw_aff NewPartialSched;
  66. if (ExcludeAny) {
  67. // Select the included partial scatter functions.
  68. isl::union_pw_aff_list List = PartialSched.list();
  69. int NewIdx = 0;
  70. for (auto OldIdx : seq<int>(0, NumBandDims)) {
  71. if (IncludeCb(OldIdx))
  72. NewIdx += 1;
  73. else
  74. List = List.drop(NewIdx, 1);
  75. }
  76. isl::space ParamSpace = PartialSched.get_space().params();
  77. isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(NewIdx);
  78. NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List);
  79. } else {
  80. // Just reuse original scatter function of copying all of them.
  81. NewPartialSched = PartialSched;
  82. }
  83. // Create the new band node.
  84. isl::schedule_node_band NewBand =
  85. Body.insert_partial_schedule(NewPartialSched)
  86. .get_root()
  87. .child(0)
  88. .as<isl::schedule_node_band>();
  89. // If OldBand was permutable, so is the new one, even if some dimensions are
  90. // missing.
  91. bool IsPermutable = OldBand.permutable().release();
  92. NewBand = NewBand.set_permutable(IsPermutable);
  93. // Reapply member attributes.
  94. int NewIdx = 0;
  95. for (auto OldIdx : seq<int>(0, NumBandDims)) {
  96. if (!IncludeCb(OldIdx))
  97. continue;
  98. NewBand =
  99. applyBandMemberAttributes(std::move(NewBand), NewIdx, OldBand, OldIdx);
  100. NewIdx += 1;
  101. }
  102. return NewBand.get_schedule();
  103. }
  104. /// Rewrite a schedule tree by reconstructing it bottom-up.
  105. ///
  106. /// By default, the original schedule tree is reconstructed. To build a
  107. /// different tree, redefine visitor methods in a derived class (CRTP).
  108. ///
  109. /// Note that AST build options are not applied; Setting the isolate[] option
  110. /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence,
  111. /// AST build options must be set after the tree has been constructed.
  112. template <typename Derived, typename... Args>
  113. struct ScheduleTreeRewriter
  114. : public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> {
  115. Derived &getDerived() { return *static_cast<Derived *>(this); }
  116. const Derived &getDerived() const {
  117. return *static_cast<const Derived *>(this);
  118. }
  119. isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) {
  120. // Every schedule_tree already has a domain node, no need to add one.
  121. return getDerived().visit(Node.first_child(), std::forward<Args>(args)...);
  122. }
  123. isl::schedule visitBand(isl::schedule_node_band Band, Args... args) {
  124. isl::schedule NewChild =
  125. getDerived().visit(Band.child(0), std::forward<Args>(args)...);
  126. return rebuildBand(Band, NewChild, [](int) { return true; });
  127. }
  128. isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
  129. Args... args) {
  130. int NumChildren = isl_schedule_node_n_children(Sequence.get());
  131. isl::schedule Result =
  132. getDerived().visit(Sequence.child(0), std::forward<Args>(args)...);
  133. for (int i = 1; i < NumChildren; i += 1)
  134. Result = Result.sequence(
  135. getDerived().visit(Sequence.child(i), std::forward<Args>(args)...));
  136. return Result;
  137. }
  138. isl::schedule visitSet(isl::schedule_node_set Set, Args... args) {
  139. int NumChildren = isl_schedule_node_n_children(Set.get());
  140. isl::schedule Result =
  141. getDerived().visit(Set.child(0), std::forward<Args>(args)...);
  142. for (int i = 1; i < NumChildren; i += 1)
  143. Result = isl::manage(
  144. isl_schedule_set(Result.release(),
  145. getDerived()
  146. .visit(Set.child(i), std::forward<Args>(args)...)
  147. .release()));
  148. return Result;
  149. }
  150. isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
  151. return isl::schedule::from_domain(Leaf.get_domain());
  152. }
  153. isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) {
  154. isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id();
  155. isl::schedule_node NewChild =
  156. getDerived()
  157. .visit(Mark.first_child(), std::forward<Args>(args)...)
  158. .get_root()
  159. .first_child();
  160. return NewChild.insert_mark(TheMark).get_schedule();
  161. }
  162. isl::schedule visitExtension(isl::schedule_node_extension Extension,
  163. Args... args) {
  164. isl::union_map TheExtension =
  165. Extension.as<isl::schedule_node_extension>().get_extension();
  166. isl::schedule_node NewChild = getDerived()
  167. .visit(Extension.child(0), args...)
  168. .get_root()
  169. .first_child();
  170. isl::schedule_node NewExtension =
  171. isl::schedule_node::from_extension(TheExtension);
  172. return NewChild.graft_before(NewExtension).get_schedule();
  173. }
  174. isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) {
  175. isl::union_set FilterDomain =
  176. Filter.as<isl::schedule_node_filter>().get_filter();
  177. isl::schedule NewSchedule =
  178. getDerived().visit(Filter.child(0), std::forward<Args>(args)...);
  179. return NewSchedule.intersect_domain(FilterDomain);
  180. }
  181. isl::schedule visitNode(isl::schedule_node Node, Args... args) {
  182. llvm_unreachable("Not implemented");
  183. }
  184. };
  185. /// Rewrite the schedule tree without any changes. Useful to copy a subtree into
  186. /// a new schedule, discarding everything but.
  187. struct IdentityRewriter : public ScheduleTreeRewriter<IdentityRewriter> {};
  188. /// Rewrite a schedule tree to an equivalent one without extension nodes.
  189. ///
  190. /// Each visit method takes two additional arguments:
  191. ///
  192. /// * The new domain the node, which is the inherited domain plus any domains
  193. /// added by extension nodes.
  194. ///
  195. /// * A map of extension domains of all children is returned; it is required by
  196. /// band nodes to schedule the additional domains at the same position as the
  197. /// extension node would.
  198. ///
  199. struct ExtensionNodeRewriter
  200. : public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
  201. isl::union_map &> {
  202. using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
  203. const isl::union_set &, isl::union_map &>;
  204. BaseTy &getBase() { return *this; }
  205. const BaseTy &getBase() const { return *this; }
  206. isl::schedule visitSchedule(isl::schedule Schedule) {
  207. isl::union_map Extensions;
  208. isl::schedule Result =
  209. visit(Schedule.get_root(), Schedule.get_domain(), Extensions);
  210. assert(!Extensions.is_null() && Extensions.is_empty());
  211. return Result;
  212. }
  213. isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
  214. const isl::union_set &Domain,
  215. isl::union_map &Extensions) {
  216. int NumChildren = isl_schedule_node_n_children(Sequence.get());
  217. isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions);
  218. for (int i = 1; i < NumChildren; i += 1) {
  219. isl::schedule_node OldChild = Sequence.child(i);
  220. isl::union_map NewChildExtensions;
  221. isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
  222. NewNode = NewNode.sequence(NewChildNode);
  223. Extensions = Extensions.unite(NewChildExtensions);
  224. }
  225. return NewNode;
  226. }
  227. isl::schedule visitSet(isl::schedule_node_set Set,
  228. const isl::union_set &Domain,
  229. isl::union_map &Extensions) {
  230. int NumChildren = isl_schedule_node_n_children(Set.get());
  231. isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions);
  232. for (int i = 1; i < NumChildren; i += 1) {
  233. isl::schedule_node OldChild = Set.child(i);
  234. isl::union_map NewChildExtensions;
  235. isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
  236. NewNode = isl::manage(
  237. isl_schedule_set(NewNode.release(), NewChildNode.release()));
  238. Extensions = Extensions.unite(NewChildExtensions);
  239. }
  240. return NewNode;
  241. }
  242. isl::schedule visitLeaf(isl::schedule_node_leaf Leaf,
  243. const isl::union_set &Domain,
  244. isl::union_map &Extensions) {
  245. Extensions = isl::union_map::empty(Leaf.ctx());
  246. return isl::schedule::from_domain(Domain);
  247. }
  248. isl::schedule visitBand(isl::schedule_node_band OldNode,
  249. const isl::union_set &Domain,
  250. isl::union_map &OuterExtensions) {
  251. isl::schedule_node OldChild = OldNode.first_child();
  252. isl::multi_union_pw_aff PartialSched =
  253. isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get()));
  254. isl::union_map NewChildExtensions;
  255. isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions);
  256. // Add the extensions to the partial schedule.
  257. OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx());
  258. isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched);
  259. unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get());
  260. for (isl::map Ext : NewChildExtensions.get_map_list()) {
  261. unsigned ExtDims = unsignedFromIslSize(Ext.domain_tuple_dim());
  262. assert(ExtDims >= BandDims);
  263. unsigned OuterDims = ExtDims - BandDims;
  264. isl::map BandSched =
  265. Ext.project_out(isl::dim::in, 0, OuterDims).reverse();
  266. NewPartialSchedMap = NewPartialSchedMap.unite(BandSched);
  267. // There might be more outer bands that have to schedule the extensions.
  268. if (OuterDims > 0) {
  269. isl::map OuterSched =
  270. Ext.project_out(isl::dim::in, OuterDims, BandDims);
  271. OuterExtensions = OuterExtensions.unite(OuterSched);
  272. }
  273. }
  274. isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff =
  275. isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap);
  276. isl::schedule_node NewNode =
  277. NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff)
  278. .get_root()
  279. .child(0);
  280. // Reapply permutability and coincidence attributes.
  281. NewNode = isl::manage(isl_schedule_node_band_set_permutable(
  282. NewNode.release(),
  283. isl_schedule_node_band_get_permutable(OldNode.get())));
  284. for (unsigned i = 0; i < BandDims; i += 1)
  285. NewNode = applyBandMemberAttributes(NewNode.as<isl::schedule_node_band>(),
  286. i, OldNode, i);
  287. return NewNode.get_schedule();
  288. }
  289. isl::schedule visitFilter(isl::schedule_node_filter Filter,
  290. const isl::union_set &Domain,
  291. isl::union_map &Extensions) {
  292. isl::union_set FilterDomain =
  293. Filter.as<isl::schedule_node_filter>().get_filter();
  294. isl::union_set NewDomain = Domain.intersect(FilterDomain);
  295. // A filter is added implicitly if necessary when joining schedule trees.
  296. return visit(Filter.first_child(), NewDomain, Extensions);
  297. }
  298. isl::schedule visitExtension(isl::schedule_node_extension Extension,
  299. const isl::union_set &Domain,
  300. isl::union_map &Extensions) {
  301. isl::union_map ExtDomain =
  302. Extension.as<isl::schedule_node_extension>().get_extension();
  303. isl::union_set NewDomain = Domain.unite(ExtDomain.range());
  304. isl::union_map ChildExtensions;
  305. isl::schedule NewChild =
  306. visit(Extension.first_child(), NewDomain, ChildExtensions);
  307. Extensions = ChildExtensions.unite(ExtDomain);
  308. return NewChild;
  309. }
  310. };
  311. /// Collect all AST build options in any schedule tree band.
  312. ///
  313. /// ScheduleTreeRewriter cannot apply the schedule tree options. This class
  314. /// collects these options to apply them later.
  315. struct CollectASTBuildOptions
  316. : public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> {
  317. using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>;
  318. BaseTy &getBase() { return *this; }
  319. const BaseTy &getBase() const { return *this; }
  320. llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
  321. void visitBand(isl::schedule_node_band Band) {
  322. ASTBuildOptions.push_back(
  323. isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get())));
  324. return getBase().visitBand(Band);
  325. }
  326. };
  327. /// Apply AST build options to the bands in a schedule tree.
  328. ///
  329. /// This rewrites a schedule tree with the AST build options applied. We assume
  330. /// that the band nodes are visited in the same order as they were when the
  331. /// build options were collected, typically by CollectASTBuildOptions.
  332. struct ApplyASTBuildOptions
  333. : public ScheduleNodeRewriter<ApplyASTBuildOptions> {
  334. using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>;
  335. BaseTy &getBase() { return *this; }
  336. const BaseTy &getBase() const { return *this; }
  337. size_t Pos;
  338. llvm::ArrayRef<isl::union_set> ASTBuildOptions;
  339. ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
  340. : ASTBuildOptions(ASTBuildOptions) {}
  341. isl::schedule visitSchedule(isl::schedule Schedule) {
  342. Pos = 0;
  343. isl::schedule Result = visit(Schedule).get_schedule();
  344. assert(Pos == ASTBuildOptions.size() &&
  345. "AST build options must match to band nodes");
  346. return Result;
  347. }
  348. isl::schedule_node visitBand(isl::schedule_node_band Band) {
  349. isl::schedule_node_band Result =
  350. Band.set_ast_build_options(ASTBuildOptions[Pos]);
  351. Pos += 1;
  352. return getBase().visitBand(Result);
  353. }
  354. };
  355. /// Return whether the schedule contains an extension node.
  356. static bool containsExtensionNode(isl::schedule Schedule) {
  357. assert(!Schedule.is_null());
  358. auto Callback = [](__isl_keep isl_schedule_node *Node,
  359. void *User) -> isl_bool {
  360. if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) {
  361. // Stop walking the schedule tree.
  362. return isl_bool_error;
  363. }
  364. // Continue searching the subtree.
  365. return isl_bool_true;
  366. };
  367. isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down(
  368. Schedule.get(), Callback, nullptr);
  369. // We assume that the traversal itself does not fail, i.e. the only reason to
  370. // return isl_stat_error is that an extension node was found.
  371. return RetVal == isl_stat_error;
  372. }
  373. /// Find a named MDNode property in a LoopID.
  374. static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) {
  375. return dyn_cast_or_null<MDNode>(
  376. findMetadataOperand(LoopMD, Name).getValueOr(nullptr));
  377. }
  378. /// Is this node of type mark?
  379. static bool isMark(const isl::schedule_node &Node) {
  380. return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark;
  381. }
  382. /// Is this node of type band?
  383. static bool isBand(const isl::schedule_node &Node) {
  384. return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band;
  385. }
  386. #ifndef NDEBUG
  387. /// Is this node a band of a single dimension (i.e. could represent a loop)?
  388. static bool isBandWithSingleLoop(const isl::schedule_node &Node) {
  389. return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1;
  390. }
  391. #endif
  392. static bool isLeaf(const isl::schedule_node &Node) {
  393. return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf;
  394. }
  395. /// Create an isl::id representing the output loop after a transformation.
  396. static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) {
  397. // Don't need to id the followup.
  398. // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by
  399. // user followup-MD
  400. if (!FollowupLoopMD)
  401. return {};
  402. BandAttr *Attr = new BandAttr();
  403. Attr->Metadata = FollowupLoopMD;
  404. return getIslLoopAttr(Ctx, Attr);
  405. }
  406. /// A loop consists of a band and an optional marker that wraps it. Return the
  407. /// outermost of the two.
  408. /// That is, either the mark or, if there is not mark, the loop itself. Can
  409. /// start with either the mark or the band.
  410. static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) {
  411. if (isBandMark(BandOrMark)) {
  412. assert(isBandWithSingleLoop(BandOrMark.child(0)));
  413. return BandOrMark;
  414. }
  415. assert(isBandWithSingleLoop(BandOrMark));
  416. isl::schedule_node Mark = BandOrMark.parent();
  417. if (isBandMark(Mark))
  418. return Mark;
  419. // Band has no loop marker.
  420. return BandOrMark;
  421. }
  422. static isl::schedule_node removeMark(isl::schedule_node MarkOrBand,
  423. BandAttr *&Attr) {
  424. MarkOrBand = moveToBandMark(MarkOrBand);
  425. isl::schedule_node Band;
  426. if (isMark(MarkOrBand)) {
  427. Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
  428. Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release()));
  429. } else {
  430. Attr = nullptr;
  431. Band = MarkOrBand;
  432. }
  433. assert(isBandWithSingleLoop(Band));
  434. return Band;
  435. }
  436. /// Remove the mark that wraps a loop. Return the band representing the loop.
  437. static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) {
  438. BandAttr *Attr;
  439. return removeMark(MarkOrBand, Attr);
  440. }
  441. static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) {
  442. assert(isBand(Band));
  443. assert(moveToBandMark(Band).is_equal(Band) &&
  444. "Don't add a two marks for a band");
  445. return Band.insert_mark(Mark).child(0);
  446. }
  447. /// Return the (one-dimensional) set of numbers that are divisible by @p Factor
  448. /// with remainder @p Offset.
  449. ///
  450. /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 }
  451. /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 }
  452. ///
  453. static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor,
  454. long Offset) {
  455. isl::val ValFactor{Ctx, Factor};
  456. isl::val ValOffset{Ctx, Offset};
  457. isl::space Unispace{Ctx, 0, 1};
  458. isl::local_space LUnispace{Unispace};
  459. isl::aff AffFactor{LUnispace, ValFactor};
  460. isl::aff AffOffset{LUnispace, ValOffset};
  461. isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0);
  462. isl::aff DivMul = Id.mod(ValFactor);
  463. isl::basic_map Divisible = isl::basic_map::from_aff(DivMul);
  464. isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset);
  465. return Modulo.domain();
  466. }
  467. /// Make the last dimension of Set to take values from 0 to VectorWidth - 1.
  468. ///
  469. /// @param Set A set, which should be modified.
  470. /// @param VectorWidth A parameter, which determines the constraint.
  471. static isl::set addExtentConstraints(isl::set Set, int VectorWidth) {
  472. unsigned Dims = unsignedFromIslSize(Set.tuple_dim());
  473. assert(Dims >= 1);
  474. isl::space Space = Set.get_space();
  475. isl::local_space LocalSpace = isl::local_space(Space);
  476. isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
  477. ExtConstr = ExtConstr.set_constant_si(0);
  478. ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1);
  479. Set = Set.add_constraint(ExtConstr);
  480. ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
  481. ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1);
  482. ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1);
  483. return Set.add_constraint(ExtConstr);
  484. }
  485. /// Collapse perfectly nested bands into a single band.
  486. class BandCollapseRewriter : public ScheduleTreeRewriter<BandCollapseRewriter> {
  487. private:
  488. using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>;
  489. BaseTy &getBase() { return *this; }
  490. const BaseTy &getBase() const { return *this; }
  491. public:
  492. isl::schedule visitBand(isl::schedule_node_band RootBand) {
  493. isl::schedule_node_band Band = RootBand;
  494. isl::ctx Ctx = Band.ctx();
  495. // Do not merge permutable band to avoid loosing the permutability property.
  496. // Cannot collapse even two permutable loops, they might be permutable
  497. // individually, but not necassarily accross.
  498. if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable())
  499. return getBase().visitBand(Band);
  500. // Find collapsable bands.
  501. SmallVector<isl::schedule_node_band> Nest;
  502. int NumTotalLoops = 0;
  503. isl::schedule_node Body;
  504. while (true) {
  505. Nest.push_back(Band);
  506. NumTotalLoops += unsignedFromIslSize(Band.n_member());
  507. Body = Band.first_child();
  508. if (!Body.isa<isl::schedule_node_band>())
  509. break;
  510. Band = Body.as<isl::schedule_node_band>();
  511. // Do not include next band if it is permutable to not lose its
  512. // permutability property.
  513. if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable())
  514. break;
  515. }
  516. // Nothing to collapse, preserve permutability.
  517. if (Nest.size() <= 1)
  518. return getBase().visitBand(Band);
  519. LLVM_DEBUG({
  520. dbgs() << "Found loops to collapse between\n";
  521. dumpIslObj(RootBand, dbgs());
  522. dbgs() << "and\n";
  523. dumpIslObj(Body, dbgs());
  524. dbgs() << "\n";
  525. });
  526. isl::schedule NewBody = visit(Body);
  527. // Collect partial schedules from all members.
  528. isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops};
  529. for (isl::schedule_node_band Band : Nest) {
  530. int NumLoops = unsignedFromIslSize(Band.n_member());
  531. isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule();
  532. for (auto j : seq<int>(0, NumLoops))
  533. PartScheds = PartScheds.add(BandScheds.at(j));
  534. }
  535. isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops);
  536. isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds};
  537. isl::schedule_node_band CollapsedBand =
  538. NewBody.insert_partial_schedule(PartSchedsMulti)
  539. .get_root()
  540. .first_child()
  541. .as<isl::schedule_node_band>();
  542. // Copy over loop attributes form original bands.
  543. int LoopIdx = 0;
  544. for (isl::schedule_node_band Band : Nest) {
  545. int NumLoops = unsignedFromIslSize(Band.n_member());
  546. for (int i : seq<int>(0, NumLoops)) {
  547. CollapsedBand = applyBandMemberAttributes(std::move(CollapsedBand),
  548. LoopIdx, Band, i);
  549. LoopIdx += 1;
  550. }
  551. }
  552. assert(LoopIdx == NumTotalLoops &&
  553. "Expect the same number of loops to add up again");
  554. return CollapsedBand.get_schedule();
  555. }
  556. };
  557. static isl::schedule collapseBands(isl::schedule Sched) {
  558. LLVM_DEBUG(dbgs() << "Collapse bands in schedule\n");
  559. BandCollapseRewriter Rewriter;
  560. return Rewriter.visit(Sched);
  561. }
  562. /// Collect sequentially executed bands (or anything else), even if nested in a
  563. /// mark or other nodes whose child is executed just once. If we can
  564. /// successfully fuse the bands, we allow them to be removed.
  565. static void collectPotentiallyFusableBands(
  566. isl::schedule_node Node,
  567. SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>>
  568. &ScheduleBands,
  569. const isl::schedule_node &DirectChild) {
  570. switch (isl_schedule_node_get_type(Node.get())) {
  571. case isl_schedule_node_sequence:
  572. case isl_schedule_node_set:
  573. case isl_schedule_node_mark:
  574. case isl_schedule_node_domain:
  575. case isl_schedule_node_filter:
  576. if (Node.has_children()) {
  577. isl::schedule_node C = Node.first_child();
  578. while (true) {
  579. collectPotentiallyFusableBands(C, ScheduleBands, DirectChild);
  580. if (!C.has_next_sibling())
  581. break;
  582. C = C.next_sibling();
  583. }
  584. }
  585. break;
  586. default:
  587. // Something that does not execute suquentially (e.g. a band)
  588. ScheduleBands.push_back({Node, DirectChild});
  589. break;
  590. }
  591. }
  592. /// Remove dependencies that are resolved by @p PartSched. That is, remove
  593. /// everything that we already know is executed in-order.
  594. static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched,
  595. isl::union_map Deps) {
  596. unsigned NumDims = getNumScatterDims(PartSched);
  597. auto ParamSpace = PartSched.get_space().params();
  598. // { Scatter[] }
  599. isl::space ScatterSpace =
  600. ParamSpace.set_from_params().add_dims(isl::dim::set, NumDims);
  601. // { Scatter[] -> Domain[] }
  602. isl::union_map PartSchedRev = PartSched.reverse();
  603. // { Scatter[] -> Scatter[] }
  604. isl::map MaybeBefore = isl::map::lex_le(ScatterSpace);
  605. // { Domain[] -> Domain[] }
  606. isl::union_map DomMaybeBefore =
  607. MaybeBefore.apply_domain(PartSchedRev).apply_range(PartSchedRev);
  608. // { Domain[] -> Domain[] }
  609. isl::union_map ChildRemainingDeps = Deps.intersect(DomMaybeBefore);
  610. return ChildRemainingDeps;
  611. }
  612. /// Remove dependencies that are resolved by executing them in the order
  613. /// specified by @p Domains;
  614. static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains,
  615. isl::union_map Deps) {
  616. isl::ctx Ctx = Deps.ctx();
  617. isl::space ParamSpace = Deps.get_space().params();
  618. // Create a partial schedule mapping to constants that reflect the execution
  619. // order.
  620. isl::union_map PartialSchedules = isl::union_map::empty(Ctx);
  621. for (auto P : enumerate(Domains)) {
  622. isl::val ExecTime = isl::val(Ctx, P.index());
  623. isl::union_pw_aff DomSched{P.value(), ExecTime};
  624. PartialSchedules = PartialSchedules.unite(DomSched.as_union_map());
  625. }
  626. return remainingDepsFromPartialSchedule(PartialSchedules, Deps);
  627. }
  628. /// Determine whether the outermost loop of to bands can be fused while
  629. /// respecting validity dependencies.
  630. static bool canFuseOutermost(const isl::schedule_node_band &LHS,
  631. const isl::schedule_node_band &RHS,
  632. const isl::union_map &Deps) {
  633. // { LHSDomain[] -> Scatter[] }
  634. isl::union_map LHSPartSched =
  635. LHS.get_partial_schedule().get_at(0).as_union_map();
  636. // { Domain[] -> Scatter[] }
  637. isl::union_map RHSPartSched =
  638. RHS.get_partial_schedule().get_at(0).as_union_map();
  639. // Dependencies that are already resolved because LHS executes before RHS, but
  640. // will not be anymore after fusion. { DefDomain[] -> UseDomain[] }
  641. isl::union_map OrderedBySequence =
  642. Deps.intersect_domain(LHSPartSched.domain())
  643. .intersect_range(RHSPartSched.domain());
  644. isl::space ParamSpace = OrderedBySequence.get_space().params();
  645. isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(1);
  646. // { Scatter[] -> Scatter[] }
  647. isl::map After = isl::map::lex_gt(NewScatterSpace);
  648. // After fusion, instances with smaller (or equal, which means they will be
  649. // executed in the same iteration, but the LHS instance is still sequenced
  650. // before RHS) scatter value will still be executed before. This are the
  651. // orderings where this is not necessarily the case.
  652. // { LHSDomain[] -> RHSDomain[] }
  653. isl::union_map MightBeAfterDoms = After.apply_domain(LHSPartSched.reverse())
  654. .apply_range(RHSPartSched.reverse());
  655. // Dependencies that are not resolved by the new execution order.
  656. isl::union_map WithBefore = OrderedBySequence.intersect(MightBeAfterDoms);
  657. return WithBefore.is_empty();
  658. }
  659. /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies.
  660. static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS,
  661. isl::schedule_node_band RHS,
  662. const isl::union_map &Deps) {
  663. if (!canFuseOutermost(LHS, RHS, Deps))
  664. return {};
  665. LLVM_DEBUG({
  666. dbgs() << "Found loops for greedy fusion:\n";
  667. dumpIslObj(LHS, dbgs());
  668. dbgs() << "and\n";
  669. dumpIslObj(RHS, dbgs());
  670. dbgs() << "\n";
  671. });
  672. // The partial schedule of the bands outermost loop that we need to combine
  673. // for the fusion.
  674. isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(0);
  675. isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(0);
  676. // Isolate band bodies as roots of their own schedule trees.
  677. IdentityRewriter Rewriter;
  678. isl::schedule LHSBody = Rewriter.visit(LHS.first_child());
  679. isl::schedule RHSBody = Rewriter.visit(RHS.first_child());
  680. // Reconstruct the non-outermost (not going to be fused) loops from both
  681. // bands.
  682. // TODO: Maybe it is possibly to transfer the 'permutability' property from
  683. // LHS+RHS. At minimum we need merge multiple band members at once, otherwise
  684. // permutability has no meaning.
  685. isl::schedule LHSNewBody =
  686. rebuildBand(LHS, LHSBody, [](int i) { return i > 0; });
  687. isl::schedule RHSNewBody =
  688. rebuildBand(RHS, RHSBody, [](int i) { return i > 0; });
  689. // The loop body of the fused loop.
  690. isl::schedule NewCommonBody = LHSNewBody.sequence(RHSNewBody);
  691. // Combine the partial schedules of both loops to a new one. Instances with
  692. // the same scatter value are put together.
  693. isl::union_map NewCommonPartialSched =
  694. LHSPartOuterSched.as_union_map().unite(RHSPartOuterSched.as_union_map());
  695. isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule(
  696. NewCommonPartialSched.as_multi_union_pw_aff());
  697. return NewCommonSchedule;
  698. }
  699. static isl::schedule tryGreedyFuse(isl::schedule_node LHS,
  700. isl::schedule_node RHS,
  701. const isl::union_map &Deps) {
  702. // TODO: Non-bands could be interpreted as a band with just as single
  703. // iteration. However, this is only useful if both ends of a fused loop were
  704. // originally loops themselves.
  705. if (!LHS.isa<isl::schedule_node_band>())
  706. return {};
  707. if (!RHS.isa<isl::schedule_node_band>())
  708. return {};
  709. return tryGreedyFuse(LHS.as<isl::schedule_node_band>(),
  710. RHS.as<isl::schedule_node_band>(), Deps);
  711. }
  712. /// Fuse all fusable loop top-down in a schedule tree.
  713. ///
  714. /// The isl::union_map parameters is the set of validity dependencies that have
  715. /// not been resolved/carried by a parent schedule node.
  716. class GreedyFusionRewriter
  717. : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> {
  718. private:
  719. using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>;
  720. BaseTy &getBase() { return *this; }
  721. const BaseTy &getBase() const { return *this; }
  722. public:
  723. /// Is set to true if anything has been fused.
  724. bool AnyChange = false;
  725. isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) {
  726. // { Domain[] -> Scatter[] }
  727. isl::union_map PartSched =
  728. isl::union_map::from(Band.get_partial_schedule());
  729. assert(getNumScatterDims(PartSched) ==
  730. unsignedFromIslSize(Band.n_member()));
  731. isl::space ParamSpace = PartSched.get_space().params();
  732. // { Scatter[] -> Domain[] }
  733. isl::union_map PartSchedRev = PartSched.reverse();
  734. // Possible within the same iteration. Dependencies with smaller scatter
  735. // value are carried by this loop and therefore have been resolved by the
  736. // in-order execution if the loop iteration. A dependency with small scatter
  737. // value would be a dependency violation that we assume did not happen. {
  738. // Domain[] -> Domain[] }
  739. isl::union_map Unsequenced = PartSchedRev.apply_domain(PartSchedRev);
  740. // Actual dependencies within the same iteration.
  741. // { DefDomain[] -> UseDomain[] }
  742. isl::union_map RemDeps = Deps.intersect(Unsequenced);
  743. return getBase().visitBand(Band, RemDeps);
  744. }
  745. isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
  746. isl::union_map Deps) {
  747. int NumChildren = isl_schedule_node_n_children(Sequence.get());
  748. // List of fusion candidates. The first element is the fusion candidate, the
  749. // second is candidate's ancestor that is the sequence's direct child. It is
  750. // preferable to use the direct child if not if its non-direct children is
  751. // fused to preserve its structure such as mark nodes.
  752. SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands;
  753. for (auto i : seq<int>(0, NumChildren)) {
  754. isl::schedule_node Child = Sequence.child(i);
  755. collectPotentiallyFusableBands(Child, Bands, Child);
  756. }
  757. // Direct children that had at least one of its decendants fused.
  758. SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren;
  759. // Fuse neigboring bands until reaching the end of candidates.
  760. int i = 0;
  761. while (i + 1 < (int)Bands.size()) {
  762. isl::schedule Fused =
  763. tryGreedyFuse(Bands[i].first, Bands[i + 1].first, Deps);
  764. if (Fused.is_null()) {
  765. // Cannot merge this node with the next; look at next pair.
  766. i += 1;
  767. continue;
  768. }
  769. // Mark the direct children as (partially) fused.
  770. if (!Bands[i].second.is_null())
  771. ChangedDirectChildren.insert(Bands[i].second.get());
  772. if (!Bands[i + 1].second.is_null())
  773. ChangedDirectChildren.insert(Bands[i + 1].second.get());
  774. // Collapse the neigbros to a single new candidate that could be fused
  775. // with the next candidate.
  776. Bands[i] = {Fused.get_root(), {}};
  777. Bands.erase(Bands.begin() + i + 1);
  778. AnyChange = true;
  779. }
  780. // By construction equal if done with collectPotentiallyFusableBands's
  781. // output.
  782. SmallVector<isl::union_set> SubDomains;
  783. SubDomains.reserve(NumChildren);
  784. for (int i = 0; i < NumChildren; i += 1)
  785. SubDomains.push_back(Sequence.child(i).domain());
  786. auto SubRemainingDeps = remainigDepsFromSequence(SubDomains, Deps);
  787. // We may iterate over direct children multiple times, be sure to add each
  788. // at most once.
  789. SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded;
  790. isl::schedule Result;
  791. for (auto &P : Bands) {
  792. isl::schedule_node MaybeFused = P.first;
  793. isl::schedule_node DirectChild = P.second;
  794. // If not modified, use the direct child.
  795. if (!DirectChild.is_null() &&
  796. !ChangedDirectChildren.count(DirectChild.get())) {
  797. if (AlreadyAdded.count(DirectChild.get()))
  798. continue;
  799. AlreadyAdded.insert(DirectChild.get());
  800. MaybeFused = DirectChild;
  801. } else {
  802. assert(AnyChange &&
  803. "Need changed flag for be consistent with actual change");
  804. }
  805. // Top-down recursion: If the outermost loop has been fused, their nested
  806. // bands might be fusable now as well.
  807. isl::schedule InnerFused = visit(MaybeFused, SubRemainingDeps);
  808. // Reconstruct the sequence, with some of the children fused.
  809. if (Result.is_null())
  810. Result = InnerFused;
  811. else
  812. Result = Result.sequence(InnerFused);
  813. }
  814. return Result;
  815. }
  816. };
  817. } // namespace
  818. bool polly::isBandMark(const isl::schedule_node &Node) {
  819. return isMark(Node) &&
  820. isLoopAttr(Node.as<isl::schedule_node_mark>().get_id());
  821. }
  822. BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) {
  823. MarkOrBand = moveToBandMark(MarkOrBand);
  824. if (!isMark(MarkOrBand))
  825. return nullptr;
  826. return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
  827. }
  828. isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) {
  829. // If there is no extension node in the first place, return the original
  830. // schedule tree.
  831. if (!containsExtensionNode(Sched))
  832. return Sched;
  833. // Build options can anchor schedule nodes, such that the schedule tree cannot
  834. // be modified anymore. Therefore, apply build options after the tree has been
  835. // created.
  836. CollectASTBuildOptions Collector;
  837. Collector.visit(Sched);
  838. // Rewrite the schedule tree without extension nodes.
  839. ExtensionNodeRewriter Rewriter;
  840. isl::schedule NewSched = Rewriter.visitSchedule(Sched);
  841. // Reapply the AST build options. The rewriter must not change the iteration
  842. // order of bands. Any other node type is ignored.
  843. ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
  844. NewSched = Applicator.visitSchedule(NewSched);
  845. return NewSched;
  846. }
  847. isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
  848. isl::ctx Ctx = BandToUnroll.ctx();
  849. // Remove the loop's mark, the loop will disappear anyway.
  850. BandToUnroll = removeMark(BandToUnroll);
  851. assert(isBandWithSingleLoop(BandToUnroll));
  852. isl::multi_union_pw_aff PartialSched = isl::manage(
  853. isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
  854. assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u &&
  855. "Can only unroll a single dimension");
  856. isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
  857. isl::union_set Domain = BandToUnroll.get_domain();
  858. PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain);
  859. isl::union_map PartialSchedUMap =
  860. isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
  861. // Enumerator only the scatter elements.
  862. isl::union_set ScatterList = PartialSchedUMap.range();
  863. // Enumerate all loop iterations.
  864. // TODO: Diagnose if not enumerable or depends on a parameter.
  865. SmallVector<isl::point, 16> Elts;
  866. ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat {
  867. Elts.push_back(P);
  868. return isl::stat::ok();
  869. });
  870. // Don't assume that foreach_point returns in execution order.
  871. llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool {
  872. isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0);
  873. isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0);
  874. return C1.lt(C2);
  875. });
  876. // Convert the points to a sequence of filters.
  877. isl::union_set_list List = isl::union_set_list(Ctx, Elts.size());
  878. for (isl::point P : Elts) {
  879. // Determine the domains that map this scatter element.
  880. isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain();
  881. List = List.add(DomainFilter);
  882. }
  883. // Replace original band with unrolled sequence.
  884. isl::schedule_node Body =
  885. isl::manage(isl_schedule_node_delete(BandToUnroll.release()));
  886. Body = Body.insert_sequence(List);
  887. return Body.get_schedule();
  888. }
  889. isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll,
  890. int Factor) {
  891. assert(Factor > 0 && "Positive unroll factor required");
  892. isl::ctx Ctx = BandToUnroll.ctx();
  893. // Remove the mark, save the attribute for later use.
  894. BandAttr *Attr;
  895. BandToUnroll = removeMark(BandToUnroll, Attr);
  896. assert(isBandWithSingleLoop(BandToUnroll));
  897. isl::multi_union_pw_aff PartialSched = isl::manage(
  898. isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
  899. // { Stmt[] -> [x] }
  900. isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
  901. // Here we assume the schedule stride is one and starts with 0, which is not
  902. // necessarily the case.
  903. isl::union_pw_aff StridedPartialSchedUAff =
  904. isl::union_pw_aff::empty(PartialSchedUAff.get_space());
  905. isl::val ValFactor{Ctx, Factor};
  906. PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff,
  907. &ValFactor](isl::pw_aff PwAff) -> isl::stat {
  908. isl::space Space = PwAff.get_space();
  909. isl::set Universe = isl::set::universe(Space.domain());
  910. isl::pw_aff AffFactor{Universe, ValFactor};
  911. isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor);
  912. StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff);
  913. return isl::stat::ok();
  914. });
  915. isl::union_set_list List = isl::union_set_list(Ctx, Factor);
  916. for (auto i : seq<int>(0, Factor)) {
  917. // { Stmt[] -> [x] }
  918. isl::union_map UMap =
  919. isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
  920. // { [x] }
  921. isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i);
  922. // { Stmt[] }
  923. isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain();
  924. List = List.add(UnrolledDomain);
  925. }
  926. isl::schedule_node Body =
  927. isl::manage(isl_schedule_node_delete(BandToUnroll.copy()));
  928. Body = Body.insert_sequence(List);
  929. isl::schedule_node NewLoop =
  930. Body.insert_partial_schedule(StridedPartialSchedUAff);
  931. MDNode *FollowupMD = nullptr;
  932. if (Attr && Attr->Metadata)
  933. FollowupMD =
  934. findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled);
  935. isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD);
  936. if (!NewBandId.is_null())
  937. NewLoop = insertMark(NewLoop, NewBandId);
  938. return NewLoop.get_schedule();
  939. }
  940. isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange,
  941. int VectorWidth) {
  942. unsigned Dims = unsignedFromIslSize(ScheduleRange.tuple_dim());
  943. assert(Dims >= 1);
  944. isl::set LoopPrefixes =
  945. ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1);
  946. auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth);
  947. isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange);
  948. BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1);
  949. LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1);
  950. return LoopPrefixes.subtract(BadPrefixes);
  951. }
  952. isl::union_set polly::getIsolateOptions(isl::set IsolateDomain,
  953. unsigned OutDimsNum) {
  954. unsigned Dims = unsignedFromIslSize(IsolateDomain.tuple_dim());
  955. assert(OutDimsNum <= Dims &&
  956. "The isl::set IsolateDomain is used to describe the range of schedule "
  957. "dimensions values, which should be isolated. Consequently, the "
  958. "number of its dimensions should be greater than or equal to the "
  959. "number of the schedule dimensions.");
  960. isl::map IsolateRelation = isl::map::from_domain(IsolateDomain);
  961. IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in,
  962. Dims - OutDimsNum, OutDimsNum);
  963. isl::set IsolateOption = IsolateRelation.wrap();
  964. isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr);
  965. IsolateOption = IsolateOption.set_tuple_id(Id);
  966. return isl::union_set(IsolateOption);
  967. }
  968. isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) {
  969. isl::space Space(Ctx, 0, 1);
  970. auto DimOption = isl::set::universe(Space);
  971. auto Id = isl::id::alloc(Ctx, Option, nullptr);
  972. DimOption = DimOption.set_tuple_id(Id);
  973. return isl::union_set(DimOption);
  974. }
  975. isl::schedule_node polly::tileNode(isl::schedule_node Node,
  976. const char *Identifier,
  977. ArrayRef<int> TileSizes,
  978. int DefaultTileSize) {
  979. auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
  980. auto Dims = Space.dim(isl::dim::set);
  981. auto Sizes = isl::multi_val::zero(Space);
  982. std::string IdentifierString(Identifier);
  983. for (unsigned i : rangeIslSize(0, Dims)) {
  984. unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize;
  985. Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize));
  986. }
  987. auto TileLoopMarkerStr = IdentifierString + " - Tiles";
  988. auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr);
  989. Node = Node.insert_mark(TileLoopMarker);
  990. Node = Node.child(0);
  991. Node =
  992. isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release()));
  993. Node = Node.child(0);
  994. auto PointLoopMarkerStr = IdentifierString + " - Points";
  995. auto PointLoopMarker =
  996. isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr);
  997. Node = Node.insert_mark(PointLoopMarker);
  998. return Node.child(0);
  999. }
  1000. isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node,
  1001. ArrayRef<int> TileSizes,
  1002. int DefaultTileSize) {
  1003. Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize);
  1004. auto Ctx = Node.ctx();
  1005. return Node.as<isl::schedule_node_band>().set_ast_build_options(
  1006. isl::union_set(Ctx, "{unroll[x]}"));
  1007. }
  1008. /// Find statements and sub-loops in (possibly nested) sequences.
  1009. static void
  1010. collectFussionableStmts(isl::schedule_node Node,
  1011. SmallVectorImpl<isl::schedule_node> &ScheduleStmts) {
  1012. if (isBand(Node) || isLeaf(Node)) {
  1013. ScheduleStmts.push_back(Node);
  1014. return;
  1015. }
  1016. if (Node.has_children()) {
  1017. isl::schedule_node C = Node.first_child();
  1018. while (true) {
  1019. collectFussionableStmts(C, ScheduleStmts);
  1020. if (!C.has_next_sibling())
  1021. break;
  1022. C = C.next_sibling();
  1023. }
  1024. }
  1025. }
  1026. isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) {
  1027. isl::ctx Ctx = BandToFission.ctx();
  1028. BandToFission = removeMark(BandToFission);
  1029. isl::schedule_node BandBody = BandToFission.child(0);
  1030. SmallVector<isl::schedule_node> FissionableStmts;
  1031. collectFussionableStmts(BandBody, FissionableStmts);
  1032. size_t N = FissionableStmts.size();
  1033. // Collect the domain for each of the statements that will get their own loop.
  1034. isl::union_set_list DomList = isl::union_set_list(Ctx, N);
  1035. for (size_t i = 0; i < N; ++i) {
  1036. isl::schedule_node BodyPart = FissionableStmts[i];
  1037. DomList = DomList.add(BodyPart.get_domain());
  1038. }
  1039. // Apply the fission by copying the entire loop, but inserting a filter for
  1040. // the statement domains for each fissioned loop.
  1041. isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList);
  1042. return Fissioned.get_schedule();
  1043. }
  1044. isl::schedule polly::applyGreedyFusion(isl::schedule Sched,
  1045. const isl::union_map &Deps) {
  1046. LLVM_DEBUG(dbgs() << "Greedy loop fusion\n");
  1047. GreedyFusionRewriter Rewriter;
  1048. isl::schedule Result = Rewriter.visit(Sched, Deps);
  1049. if (!Rewriter.AnyChange) {
  1050. LLVM_DEBUG(dbgs() << "Found nothing to fuse\n");
  1051. return Sched;
  1052. }
  1053. // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops
  1054. // may have been split into multiple bands.
  1055. return collapseBands(Result);
  1056. }