FlattenAlgo.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. //===------ FlattenAlgo.cpp ------------------------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Main algorithm of the FlattenSchedulePass. This is a separate file to avoid
  10. // the unittest for this requiring linking against LLVM.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "polly/FlattenAlgo.h"
  14. #include "polly/Support/ISLOStream.h"
  15. #include "polly/Support/ISLTools.h"
  16. #include "llvm/Support/Debug.h"
  17. #define DEBUG_TYPE "polly-flatten-algo"
  18. using namespace polly;
  19. using namespace llvm;
  20. namespace {
  21. /// Whether a dimension of a set is bounded (lower and upper) by a constant,
  22. /// i.e. there are two constants Min and Max, such that every value x of the
  23. /// chosen dimensions is Min <= x <= Max.
  24. bool isDimBoundedByConstant(isl::set Set, unsigned dim) {
  25. auto ParamDims = unsignedFromIslSize(Set.dim(isl::dim::param));
  26. Set = Set.project_out(isl::dim::param, 0, ParamDims);
  27. Set = Set.project_out(isl::dim::set, 0, dim);
  28. auto SetDims = unsignedFromIslSize(Set.tuple_dim());
  29. assert(SetDims >= 1);
  30. Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
  31. return bool(Set.is_bounded());
  32. }
  33. /// Whether a dimension of a set is (lower and upper) bounded by a constant or
  34. /// parameters, i.e. there are two expressions Min_p and Max_p of the parameters
  35. /// p, such that every value x of the chosen dimensions is
  36. /// Min_p <= x <= Max_p.
  37. bool isDimBoundedByParameter(isl::set Set, unsigned dim) {
  38. Set = Set.project_out(isl::dim::set, 0, dim);
  39. auto SetDims = unsignedFromIslSize(Set.tuple_dim());
  40. assert(SetDims >= 1);
  41. Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
  42. return bool(Set.is_bounded());
  43. }
  44. /// Whether BMap's first out-dimension is not a constant.
  45. bool isVariableDim(const isl::basic_map &BMap) {
  46. auto FixedVal = BMap.plain_get_val_if_fixed(isl::dim::out, 0);
  47. return FixedVal.is_null() || FixedVal.is_nan();
  48. }
  49. /// Whether Map's first out dimension is no constant nor piecewise constant.
  50. bool isVariableDim(const isl::map &Map) {
  51. for (isl::basic_map BMap : Map.get_basic_map_list())
  52. if (isVariableDim(BMap))
  53. return false;
  54. return true;
  55. }
  56. /// Whether UMap's first out dimension is no (piecewise) constant.
  57. bool isVariableDim(const isl::union_map &UMap) {
  58. for (isl::map Map : UMap.get_map_list())
  59. if (isVariableDim(Map))
  60. return false;
  61. return true;
  62. }
  63. /// Compute @p UPwAff - @p Val.
  64. isl::union_pw_aff subtract(isl::union_pw_aff UPwAff, isl::val Val) {
  65. if (Val.is_zero())
  66. return UPwAff;
  67. auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
  68. isl::stat Stat =
  69. UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
  70. auto ValAff =
  71. isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
  72. auto Subtracted = PwAff.sub(ValAff);
  73. Result = Result.union_add(isl::union_pw_aff(Subtracted));
  74. return isl::stat::ok();
  75. });
  76. if (Stat.is_error())
  77. return {};
  78. return Result;
  79. }
  80. /// Compute @UPwAff * @p Val.
  81. isl::union_pw_aff multiply(isl::union_pw_aff UPwAff, isl::val Val) {
  82. if (Val.is_one())
  83. return UPwAff;
  84. auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
  85. isl::stat Stat =
  86. UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
  87. auto ValAff =
  88. isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
  89. auto Multiplied = PwAff.mul(ValAff);
  90. Result = Result.union_add(Multiplied);
  91. return isl::stat::ok();
  92. });
  93. if (Stat.is_error())
  94. return {};
  95. return Result;
  96. }
  97. /// Remove @p n dimensions from @p UMap's range, starting at @p first.
  98. ///
  99. /// It is assumed that all maps in the maps have at least the necessary number
  100. /// of out dimensions.
  101. isl::union_map scheduleProjectOut(const isl::union_map &UMap, unsigned first,
  102. unsigned n) {
  103. if (n == 0)
  104. return UMap; /* isl_map_project_out would also reset the tuple, which should
  105. have no effect on schedule ranges */
  106. auto Result = isl::union_map::empty(UMap.ctx());
  107. for (isl::map Map : UMap.get_map_list()) {
  108. auto Outprojected = Map.project_out(isl::dim::out, first, n);
  109. Result = Result.unite(Outprojected);
  110. }
  111. return Result;
  112. }
  113. /// Return the @p pos' range dimension, converted to an isl_union_pw_aff.
  114. isl::union_pw_aff scheduleExtractDimAff(isl::union_map UMap, unsigned pos) {
  115. auto SingleUMap = isl::union_map::empty(UMap.ctx());
  116. for (isl::map Map : UMap.get_map_list()) {
  117. unsigned MapDims = unsignedFromIslSize(Map.range_tuple_dim());
  118. assert(MapDims > pos);
  119. isl::map SingleMap = Map.project_out(isl::dim::out, 0, pos);
  120. SingleMap = SingleMap.project_out(isl::dim::out, 1, MapDims - pos - 1);
  121. SingleUMap = SingleUMap.unite(SingleMap);
  122. };
  123. auto UAff = isl::union_pw_multi_aff(SingleUMap);
  124. auto FirstMAff = isl::multi_union_pw_aff(UAff);
  125. return FirstMAff.at(0);
  126. }
  127. /// Flatten a sequence-like first dimension.
  128. ///
  129. /// A sequence-like scatter dimension is constant, or at least only small
  130. /// variation, typically the result of ordering a sequence of different
  131. /// statements. An example would be:
  132. /// { Stmt_A[] -> [0, X, ...]; Stmt_B[] -> [1, Y, ...] }
  133. /// to schedule all instances of Stmt_A before any instance of Stmt_B.
  134. ///
  135. /// To flatten, first begin with an offset of zero. Then determine the lowest
  136. /// possible value of the dimension, call it "i" [In the example we start at 0].
  137. /// Considering only schedules with that value, consider only instances with
  138. /// that value and determine the extent of the next dimension. Let l_X(i) and
  139. /// u_X(i) its minimum (lower bound) and maximum (upper bound) value. Add them
  140. /// as "Offset + X - l_X(i)" to the new schedule, then add "u_X(i) - l_X(i) + 1"
  141. /// to Offset and remove all i-instances from the old schedule. Repeat with the
  142. /// remaining lowest value i' until there are no instances in the old schedule
  143. /// left.
  144. /// The example schedule would be transformed to:
  145. /// { Stmt_X[] -> [X - l_X, ...]; Stmt_B -> [l_X - u_X + 1 + Y - l_Y, ...] }
  146. isl::union_map tryFlattenSequence(isl::union_map Schedule) {
  147. auto IslCtx = Schedule.ctx();
  148. auto ScatterSet = isl::set(Schedule.range());
  149. auto ParamSpace = Schedule.get_space().params();
  150. auto Dims = unsignedFromIslSize(ScatterSet.tuple_dim());
  151. assert(Dims >= 2u);
  152. // Would cause an infinite loop.
  153. if (!isDimBoundedByConstant(ScatterSet, 0)) {
  154. LLVM_DEBUG(dbgs() << "Abort; dimension is not of fixed size\n");
  155. return {};
  156. }
  157. auto AllDomains = Schedule.domain();
  158. auto AllDomainsToNull = isl::union_pw_multi_aff(AllDomains);
  159. auto NewSchedule = isl::union_map::empty(ParamSpace.ctx());
  160. auto Counter = isl::pw_aff(isl::local_space(ParamSpace.set_from_params()));
  161. while (!ScatterSet.is_empty()) {
  162. LLVM_DEBUG(dbgs() << "Next counter:\n " << Counter << "\n");
  163. LLVM_DEBUG(dbgs() << "Remaining scatter set:\n " << ScatterSet << "\n");
  164. auto ThisSet = ScatterSet.project_out(isl::dim::set, 1, Dims - 1);
  165. auto ThisFirst = ThisSet.lexmin();
  166. auto ScatterFirst = ThisFirst.add_dims(isl::dim::set, Dims - 1);
  167. auto SubSchedule = Schedule.intersect_range(ScatterFirst);
  168. SubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
  169. SubSchedule = flattenSchedule(SubSchedule);
  170. unsigned SubDims = getNumScatterDims(SubSchedule);
  171. assert(SubDims >= 1);
  172. auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1);
  173. auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0);
  174. auto RemainingSubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
  175. auto FirstSubScatter = isl::set(FirstSubSchedule.range());
  176. LLVM_DEBUG(dbgs() << "Next step in sequence is:\n " << FirstSubScatter
  177. << "\n");
  178. if (!isDimBoundedByParameter(FirstSubScatter, 0)) {
  179. LLVM_DEBUG(dbgs() << "Abort; sequence step is not bounded\n");
  180. return {};
  181. }
  182. auto FirstSubScatterMap = isl::map::from_range(FirstSubScatter);
  183. // isl_set_dim_max returns a strange isl_pw_aff with domain tuple_id of
  184. // 'none'. It doesn't match with any space including a 0-dimensional
  185. // anonymous tuple.
  186. // Interesting, one can create such a set using
  187. // isl_set_universe(ParamSpace). Bug?
  188. auto PartMin = FirstSubScatterMap.dim_min(0);
  189. auto PartMax = FirstSubScatterMap.dim_max(0);
  190. auto One = isl::pw_aff(isl::set::universe(ParamSpace.set_from_params()),
  191. isl::val::one(IslCtx));
  192. auto PartLen = PartMax.add(PartMin.neg()).add(One);
  193. auto AllPartMin = isl::union_pw_aff(PartMin).pullback(AllDomainsToNull);
  194. auto FirstScheduleAffNormalized = FirstScheduleAff.sub(AllPartMin);
  195. auto AllCounter = isl::union_pw_aff(Counter).pullback(AllDomainsToNull);
  196. auto FirstScheduleAffWithOffset =
  197. FirstScheduleAffNormalized.add(AllCounter);
  198. auto ScheduleWithOffset =
  199. isl::union_map::from(
  200. isl::union_pw_multi_aff(FirstScheduleAffWithOffset))
  201. .flat_range_product(RemainingSubSchedule);
  202. NewSchedule = NewSchedule.unite(ScheduleWithOffset);
  203. ScatterSet = ScatterSet.subtract(ScatterFirst);
  204. Counter = Counter.add(PartLen);
  205. }
  206. LLVM_DEBUG(dbgs() << "Sequence-flatten result is:\n " << NewSchedule
  207. << "\n");
  208. return NewSchedule;
  209. }
  210. /// Flatten a loop-like first dimension.
  211. ///
  212. /// A loop-like dimension is one that depends on a variable (usually a loop's
  213. /// induction variable). Let the input schedule look like this:
  214. /// { Stmt[i] -> [i, X, ...] }
  215. ///
  216. /// To flatten, we determine the largest extent of X which may not depend on the
  217. /// actual value of i. Let l_X() the smallest possible value of X and u_X() its
  218. /// largest value. Then, construct a new schedule
  219. /// { Stmt[i] -> [i * (u_X() - l_X() + 1), ...] }
  220. isl::union_map tryFlattenLoop(isl::union_map Schedule) {
  221. assert(getNumScatterDims(Schedule) >= 2);
  222. auto Remaining = scheduleProjectOut(Schedule, 0, 1);
  223. auto SubSchedule = flattenSchedule(Remaining);
  224. unsigned SubDims = getNumScatterDims(SubSchedule);
  225. assert(SubDims >= 1);
  226. auto SubExtent = isl::set(SubSchedule.range());
  227. auto SubExtentDims = unsignedFromIslSize(SubExtent.dim(isl::dim::param));
  228. SubExtent = SubExtent.project_out(isl::dim::param, 0, SubExtentDims);
  229. SubExtent = SubExtent.project_out(isl::dim::set, 1, SubDims - 1);
  230. if (!isDimBoundedByConstant(SubExtent, 0)) {
  231. LLVM_DEBUG(dbgs() << "Abort; dimension not bounded by constant\n");
  232. return {};
  233. }
  234. auto Min = SubExtent.dim_min(0);
  235. LLVM_DEBUG(dbgs() << "Min bound:\n " << Min << "\n");
  236. auto MinVal = getConstant(Min, false, true);
  237. auto Max = SubExtent.dim_max(0);
  238. LLVM_DEBUG(dbgs() << "Max bound:\n " << Max << "\n");
  239. auto MaxVal = getConstant(Max, true, false);
  240. if (MinVal.is_null() || MaxVal.is_null() || MinVal.is_nan() ||
  241. MaxVal.is_nan()) {
  242. LLVM_DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n");
  243. return {};
  244. }
  245. auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0);
  246. auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
  247. auto LenVal = MaxVal.sub(MinVal).add(1);
  248. auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal);
  249. // TODO: Normalize FirstAff to zero (convert to isl_map, determine minimum,
  250. // subtract it)
  251. auto FirstAff = scheduleExtractDimAff(Schedule, 0);
  252. auto Offset = multiply(FirstAff, LenVal);
  253. isl::union_pw_multi_aff Index = FirstSubScheduleNormalized.add(Offset);
  254. auto IndexMap = isl::union_map::from(Index);
  255. auto Result = IndexMap.flat_range_product(RemainingSubSchedule);
  256. LLVM_DEBUG(dbgs() << "Loop-flatten result is:\n " << Result << "\n");
  257. return Result;
  258. }
  259. } // anonymous namespace
  260. isl::union_map polly::flattenSchedule(isl::union_map Schedule) {
  261. unsigned Dims = getNumScatterDims(Schedule);
  262. LLVM_DEBUG(dbgs() << "Recursive schedule to process:\n " << Schedule
  263. << "\n");
  264. // Base case; no dimensions left
  265. if (Dims == 0) {
  266. // TODO: Add one dimension?
  267. return Schedule;
  268. }
  269. // Base case; already one-dimensional
  270. if (Dims == 1)
  271. return Schedule;
  272. // Fixed dimension; no need to preserve variabledness.
  273. if (!isVariableDim(Schedule)) {
  274. LLVM_DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n");
  275. auto NewScheduleSequence = tryFlattenSequence(Schedule);
  276. if (!NewScheduleSequence.is_null())
  277. return NewScheduleSequence;
  278. }
  279. // Constant stride
  280. LLVM_DEBUG(dbgs() << "Try loop flattening\n");
  281. auto NewScheduleLoop = tryFlattenLoop(Schedule);
  282. if (!NewScheduleLoop.is_null())
  283. return NewScheduleLoop;
  284. // Try again without loop condition (may blow up the number of pieces!!)
  285. LLVM_DEBUG(dbgs() << "Try sequence flattening again\n");
  286. auto NewScheduleSequence = tryFlattenSequence(Schedule);
  287. if (!NewScheduleSequence.is_null())
  288. return NewScheduleSequence;
  289. // Cannot flatten
  290. return Schedule;
  291. }