IslAst.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. //===- IslAst.cpp - isl code generator interface --------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // The isl code generator interface takes a Scop and generates an isl_ast. This
  10. // ist_ast can either be returned directly or it can be pretty printed to
  11. // stdout.
  12. //
  13. // A typical isl_ast output looks like this:
  14. //
  15. // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) {
  16. // bb2(c2);
  17. // }
  18. //
  19. // An in-depth discussion of our AST generation approach can be found in:
  20. //
  21. // Polyhedral AST generation is more than scanning polyhedra
  22. // Tobias Grosser, Sven Verdoolaege, Albert Cohen
  23. // ACM Transactions on Programming Languages and Systems (TOPLAS),
  24. // 37(4), July 2015
  25. // http://www.grosser.es/#pub-polyhedral-AST-generation
  26. //
  27. //===----------------------------------------------------------------------===//
  28. #include "polly/CodeGen/IslAst.h"
  29. #include "polly/CodeGen/CodeGeneration.h"
  30. #include "polly/DependenceInfo.h"
  31. #include "polly/LinkAllPasses.h"
  32. #include "polly/Options.h"
  33. #include "polly/ScopDetection.h"
  34. #include "polly/ScopInfo.h"
  35. #include "polly/ScopPass.h"
  36. #include "polly/Support/GICHelper.h"
  37. #include "llvm/ADT/Statistic.h"
  38. #include "llvm/IR/Function.h"
  39. #include "llvm/Support/Debug.h"
  40. #include "llvm/Support/raw_ostream.h"
  41. #include "isl/aff.h"
  42. #include "isl/ast.h"
  43. #include "isl/ast_build.h"
  44. #include "isl/id.h"
  45. #include "isl/isl-noexceptions.h"
  46. #include "isl/printer.h"
  47. #include "isl/schedule.h"
  48. #include "isl/set.h"
  49. #include "isl/union_map.h"
  50. #include "isl/val.h"
  51. #include <cassert>
  52. #include <cstdlib>
  53. #define DEBUG_TYPE "polly-ast"
  54. using namespace llvm;
  55. using namespace polly;
  56. using IslAstUserPayload = IslAstInfo::IslAstUserPayload;
  57. static cl::opt<bool>
  58. PollyParallel("polly-parallel",
  59. cl::desc("Generate thread parallel code (isl codegen only)"),
  60. cl::cat(PollyCategory));
  61. static cl::opt<bool> PrintAccesses("polly-ast-print-accesses",
  62. cl::desc("Print memory access functions"),
  63. cl::cat(PollyCategory));
  64. static cl::opt<bool> PollyParallelForce(
  65. "polly-parallel-force",
  66. cl::desc(
  67. "Force generation of thread parallel code ignoring any cost model"),
  68. cl::cat(PollyCategory));
  69. static cl::opt<bool> UseContext("polly-ast-use-context",
  70. cl::desc("Use context"), cl::Hidden,
  71. cl::init(true), cl::cat(PollyCategory));
  72. static cl::opt<bool> DetectParallel("polly-ast-detect-parallel",
  73. cl::desc("Detect parallelism"), cl::Hidden,
  74. cl::cat(PollyCategory));
  75. STATISTIC(ScopsProcessed, "Number of SCoPs processed");
  76. STATISTIC(ScopsBeneficial, "Number of beneficial SCoPs");
  77. STATISTIC(BeneficialAffineLoops, "Number of beneficial affine loops");
  78. STATISTIC(BeneficialBoxedLoops, "Number of beneficial boxed loops");
  79. STATISTIC(NumForLoops, "Number of for-loops");
  80. STATISTIC(NumParallel, "Number of parallel for-loops");
  81. STATISTIC(NumInnermostParallel, "Number of innermost parallel for-loops");
  82. STATISTIC(NumOutermostParallel, "Number of outermost parallel for-loops");
  83. STATISTIC(NumReductionParallel, "Number of reduction-parallel for-loops");
  84. STATISTIC(NumExecutedInParallel, "Number of for-loops executed in parallel");
  85. STATISTIC(NumIfConditions, "Number of if-conditions");
  86. namespace polly {
  87. /// Temporary information used when building the ast.
  88. struct AstBuildUserInfo {
  89. /// Construct and initialize the helper struct for AST creation.
  90. AstBuildUserInfo() = default;
  91. /// The dependence information used for the parallelism check.
  92. const Dependences *Deps = nullptr;
  93. /// Flag to indicate that we are inside a parallel for node.
  94. bool InParallelFor = false;
  95. /// Flag to indicate that we are inside an SIMD node.
  96. bool InSIMD = false;
  97. /// The last iterator id created for the current SCoP.
  98. isl_id *LastForNodeId = nullptr;
  99. };
  100. } // namespace polly
  101. /// Free an IslAstUserPayload object pointed to by @p Ptr.
  102. static void freeIslAstUserPayload(void *Ptr) {
  103. delete ((IslAstInfo::IslAstUserPayload *)Ptr);
  104. }
  105. /// Print a string @p str in a single line using @p Printer.
  106. static isl_printer *printLine(__isl_take isl_printer *Printer,
  107. const std::string &str,
  108. __isl_keep isl_pw_aff *PWA = nullptr) {
  109. Printer = isl_printer_start_line(Printer);
  110. Printer = isl_printer_print_str(Printer, str.c_str());
  111. if (PWA)
  112. Printer = isl_printer_print_pw_aff(Printer, PWA);
  113. return isl_printer_end_line(Printer);
  114. }
  115. /// Return all broken reductions as a string of clauses (OpenMP style).
  116. static const std::string getBrokenReductionsStr(const isl::ast_node &Node) {
  117. IslAstInfo::MemoryAccessSet *BrokenReductions;
  118. std::string str;
  119. BrokenReductions = IslAstInfo::getBrokenReductions(Node);
  120. if (!BrokenReductions || BrokenReductions->empty())
  121. return "";
  122. // Map each type of reduction to a comma separated list of the base addresses.
  123. std::map<MemoryAccess::ReductionType, std::string> Clauses;
  124. for (MemoryAccess *MA : *BrokenReductions)
  125. if (MA->isWrite())
  126. Clauses[MA->getReductionType()] +=
  127. ", " + MA->getScopArrayInfo()->getName();
  128. // Now print the reductions sorted by type. Each type will cause a clause
  129. // like: reduction (+ : sum0, sum1, sum2)
  130. for (const auto &ReductionClause : Clauses) {
  131. str += " reduction (";
  132. str += MemoryAccess::getReductionOperatorStr(ReductionClause.first);
  133. // Remove the first two symbols (", ") to make the output look pretty.
  134. str += " : " + ReductionClause.second.substr(2) + ")";
  135. }
  136. return str;
  137. }
  138. /// Callback executed for each for node in the ast in order to print it.
  139. static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
  140. __isl_take isl_ast_print_options *Options,
  141. __isl_keep isl_ast_node *Node, void *) {
  142. isl::pw_aff DD =
  143. IslAstInfo::getMinimalDependenceDistance(isl::manage_copy(Node));
  144. const std::string BrokenReductionsStr =
  145. getBrokenReductionsStr(isl::manage_copy(Node));
  146. const std::string KnownParallelStr = "#pragma known-parallel";
  147. const std::string DepDisPragmaStr = "#pragma minimal dependence distance: ";
  148. const std::string SimdPragmaStr = "#pragma simd";
  149. const std::string OmpPragmaStr = "#pragma omp parallel for";
  150. if (!DD.is_null())
  151. Printer = printLine(Printer, DepDisPragmaStr, DD.get());
  152. if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node)))
  153. Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr);
  154. if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node)))
  155. Printer = printLine(Printer, OmpPragmaStr);
  156. else if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node)))
  157. Printer = printLine(Printer, KnownParallelStr + BrokenReductionsStr);
  158. return isl_ast_node_for_print(Node, Printer, Options);
  159. }
  160. /// Check if the current scheduling dimension is parallel.
  161. ///
  162. /// In case the dimension is parallel we also check if any reduction
  163. /// dependences is broken when we exploit this parallelism. If so,
  164. /// @p IsReductionParallel will be set to true. The reduction dependences we use
  165. /// to check are actually the union of the transitive closure of the initial
  166. /// reduction dependences together with their reversal. Even though these
  167. /// dependences connect all iterations with each other (thus they are cyclic)
  168. /// we can perform the parallelism check as we are only interested in a zero
  169. /// (or non-zero) dependence distance on the dimension in question.
  170. static bool astScheduleDimIsParallel(const isl::ast_build &Build,
  171. const Dependences *D,
  172. IslAstUserPayload *NodeInfo) {
  173. if (!D->hasValidDependences())
  174. return false;
  175. isl::union_map Schedule = Build.get_schedule();
  176. isl::union_map Dep = D->getDependences(
  177. Dependences::TYPE_RAW | Dependences::TYPE_WAW | Dependences::TYPE_WAR);
  178. if (!D->isParallel(Schedule.get(), Dep.release())) {
  179. isl::union_map DepsAll =
  180. D->getDependences(Dependences::TYPE_RAW | Dependences::TYPE_WAW |
  181. Dependences::TYPE_WAR | Dependences::TYPE_TC_RED);
  182. // TODO: We will need to change isParallel to stop the unwrapping
  183. isl_pw_aff *MinimalDependenceDistanceIsl = nullptr;
  184. D->isParallel(Schedule.get(), DepsAll.release(),
  185. &MinimalDependenceDistanceIsl);
  186. NodeInfo->MinimalDependenceDistance =
  187. isl::manage(MinimalDependenceDistanceIsl);
  188. return false;
  189. }
  190. isl::union_map RedDeps = D->getDependences(Dependences::TYPE_TC_RED);
  191. if (!D->isParallel(Schedule.get(), RedDeps.release()))
  192. NodeInfo->IsReductionParallel = true;
  193. if (!NodeInfo->IsReductionParallel)
  194. return true;
  195. for (const auto &MaRedPair : D->getReductionDependences()) {
  196. if (!MaRedPair.second)
  197. continue;
  198. isl::union_map MaRedDeps = isl::manage_copy(MaRedPair.second);
  199. if (!D->isParallel(Schedule.get(), MaRedDeps.release()))
  200. NodeInfo->BrokenReductions.insert(MaRedPair.first);
  201. }
  202. return true;
  203. }
  204. // This method is executed before the construction of a for node. It creates
  205. // an isl_id that is used to annotate the subsequently generated ast for nodes.
  206. //
  207. // In this function we also run the following analyses:
  208. //
  209. // - Detection of openmp parallel loops
  210. //
  211. static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build,
  212. void *User) {
  213. AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
  214. IslAstUserPayload *Payload = new IslAstUserPayload();
  215. isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload);
  216. Id = isl_id_set_free_user(Id, freeIslAstUserPayload);
  217. BuildInfo->LastForNodeId = Id;
  218. Payload->IsParallel = astScheduleDimIsParallel(isl::manage_copy(Build),
  219. BuildInfo->Deps, Payload);
  220. // Test for parallelism only if we are not already inside a parallel loop
  221. if (!BuildInfo->InParallelFor && !BuildInfo->InSIMD)
  222. BuildInfo->InParallelFor = Payload->IsOutermostParallel =
  223. Payload->IsParallel;
  224. return Id;
  225. }
  226. // This method is executed after the construction of a for node.
  227. //
  228. // It performs the following actions:
  229. //
  230. // - Reset the 'InParallelFor' flag, as soon as we leave a for node,
  231. // that is marked as openmp parallel.
  232. //
  233. static __isl_give isl_ast_node *
  234. astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
  235. void *User) {
  236. isl_id *Id = isl_ast_node_get_annotation(Node);
  237. assert(Id && "Post order visit assumes annotated for nodes");
  238. IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id);
  239. assert(Payload && "Post order visit assumes annotated for nodes");
  240. AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
  241. assert(Payload->Build.is_null() && "Build environment already set");
  242. Payload->Build = isl::manage_copy(Build);
  243. Payload->IsInnermost = (Id == BuildInfo->LastForNodeId);
  244. Payload->IsInnermostParallel =
  245. Payload->IsInnermost && (BuildInfo->InSIMD || Payload->IsParallel);
  246. if (Payload->IsOutermostParallel)
  247. BuildInfo->InParallelFor = false;
  248. isl_id_free(Id);
  249. return Node;
  250. }
  251. static isl_stat astBuildBeforeMark(__isl_keep isl_id *MarkId,
  252. __isl_keep isl_ast_build *Build,
  253. void *User) {
  254. if (!MarkId)
  255. return isl_stat_error;
  256. AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
  257. if (strcmp(isl_id_get_name(MarkId), "SIMD") == 0)
  258. BuildInfo->InSIMD = true;
  259. return isl_stat_ok;
  260. }
  261. static __isl_give isl_ast_node *
  262. astBuildAfterMark(__isl_take isl_ast_node *Node,
  263. __isl_keep isl_ast_build *Build, void *User) {
  264. assert(isl_ast_node_get_type(Node) == isl_ast_node_mark);
  265. AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
  266. auto *Id = isl_ast_node_mark_get_id(Node);
  267. if (strcmp(isl_id_get_name(Id), "SIMD") == 0)
  268. BuildInfo->InSIMD = false;
  269. isl_id_free(Id);
  270. return Node;
  271. }
  272. static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node,
  273. __isl_keep isl_ast_build *Build,
  274. void *User) {
  275. assert(!isl_ast_node_get_annotation(Node) && "Node already annotated");
  276. IslAstUserPayload *Payload = new IslAstUserPayload();
  277. isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload);
  278. Id = isl_id_set_free_user(Id, freeIslAstUserPayload);
  279. Payload->Build = isl::manage_copy(Build);
  280. return isl_ast_node_set_annotation(Node, Id);
  281. }
  282. // Build alias check condition given a pair of minimal/maximal access.
  283. static isl::ast_expr buildCondition(Scop &S, isl::ast_build Build,
  284. const Scop::MinMaxAccessTy *It0,
  285. const Scop::MinMaxAccessTy *It1) {
  286. isl::pw_multi_aff AFirst = It0->first;
  287. isl::pw_multi_aff ASecond = It0->second;
  288. isl::pw_multi_aff BFirst = It1->first;
  289. isl::pw_multi_aff BSecond = It1->second;
  290. isl::id Left = AFirst.get_tuple_id(isl::dim::set);
  291. isl::id Right = BFirst.get_tuple_id(isl::dim::set);
  292. isl::ast_expr True =
  293. isl::ast_expr::from_val(isl::val::int_from_ui(Build.ctx(), 1));
  294. isl::ast_expr False =
  295. isl::ast_expr::from_val(isl::val::int_from_ui(Build.ctx(), 0));
  296. const ScopArrayInfo *BaseLeft =
  297. ScopArrayInfo::getFromId(Left)->getBasePtrOriginSAI();
  298. const ScopArrayInfo *BaseRight =
  299. ScopArrayInfo::getFromId(Right)->getBasePtrOriginSAI();
  300. if (BaseLeft && BaseLeft == BaseRight)
  301. return True;
  302. isl::set Params = S.getContext();
  303. isl::ast_expr NonAliasGroup, MinExpr, MaxExpr;
  304. // In the following, we first check if any accesses will be empty under
  305. // the execution context of the scop and do not code generate them if this
  306. // is the case as isl will fail to derive valid AST expressions for such
  307. // accesses.
  308. if (!AFirst.intersect_params(Params).domain().is_empty() &&
  309. !BSecond.intersect_params(Params).domain().is_empty()) {
  310. MinExpr = Build.access_from(AFirst).address_of();
  311. MaxExpr = Build.access_from(BSecond).address_of();
  312. NonAliasGroup = MaxExpr.le(MinExpr);
  313. }
  314. if (!BFirst.intersect_params(Params).domain().is_empty() &&
  315. !ASecond.intersect_params(Params).domain().is_empty()) {
  316. MinExpr = Build.access_from(BFirst).address_of();
  317. MaxExpr = Build.access_from(ASecond).address_of();
  318. isl::ast_expr Result = MaxExpr.le(MinExpr);
  319. if (!NonAliasGroup.is_null())
  320. NonAliasGroup = isl::manage(
  321. isl_ast_expr_or(NonAliasGroup.release(), Result.release()));
  322. else
  323. NonAliasGroup = Result;
  324. }
  325. if (NonAliasGroup.is_null())
  326. NonAliasGroup = True;
  327. return NonAliasGroup;
  328. }
  329. isl::ast_expr IslAst::buildRunCondition(Scop &S, const isl::ast_build &Build) {
  330. isl::ast_expr RunCondition;
  331. // The conditions that need to be checked at run-time for this scop are
  332. // available as an isl_set in the runtime check context from which we can
  333. // directly derive a run-time condition.
  334. auto PosCond = Build.expr_from(S.getAssumedContext());
  335. if (S.hasTrivialInvalidContext()) {
  336. RunCondition = std::move(PosCond);
  337. } else {
  338. auto ZeroV = isl::val::zero(Build.ctx());
  339. auto NegCond = Build.expr_from(S.getInvalidContext());
  340. auto NotNegCond =
  341. isl::ast_expr::from_val(std::move(ZeroV)).eq(std::move(NegCond));
  342. RunCondition =
  343. isl::manage(isl_ast_expr_and(PosCond.release(), NotNegCond.release()));
  344. }
  345. // Create the alias checks from the minimal/maximal accesses in each alias
  346. // group which consists of read only and non read only (read write) accesses.
  347. // This operation is by construction quadratic in the read-write pointers and
  348. // linear in the read only pointers in each alias group.
  349. for (const Scop::MinMaxVectorPairTy &MinMaxAccessPair : S.getAliasGroups()) {
  350. auto &MinMaxReadWrite = MinMaxAccessPair.first;
  351. auto &MinMaxReadOnly = MinMaxAccessPair.second;
  352. auto RWAccEnd = MinMaxReadWrite.end();
  353. for (auto RWAccIt0 = MinMaxReadWrite.begin(); RWAccIt0 != RWAccEnd;
  354. ++RWAccIt0) {
  355. for (auto RWAccIt1 = RWAccIt0 + 1; RWAccIt1 != RWAccEnd; ++RWAccIt1)
  356. RunCondition = isl::manage(isl_ast_expr_and(
  357. RunCondition.release(),
  358. buildCondition(S, Build, RWAccIt0, RWAccIt1).release()));
  359. for (const Scop::MinMaxAccessTy &ROAccIt : MinMaxReadOnly)
  360. RunCondition = isl::manage(isl_ast_expr_and(
  361. RunCondition.release(),
  362. buildCondition(S, Build, RWAccIt0, &ROAccIt).release()));
  363. }
  364. }
  365. return RunCondition;
  366. }
  367. /// Simple cost analysis for a given SCoP.
  368. ///
  369. /// TODO: Improve this analysis and extract it to make it usable in other
  370. /// places too.
  371. /// In order to improve the cost model we could either keep track of
  372. /// performed optimizations (e.g., tiling) or compute properties on the
  373. /// original as well as optimized SCoP (e.g., #stride-one-accesses).
  374. static bool benefitsFromPolly(Scop &Scop, bool PerformParallelTest) {
  375. if (PollyProcessUnprofitable)
  376. return true;
  377. // Check if nothing interesting happened.
  378. if (!PerformParallelTest && !Scop.isOptimized() &&
  379. Scop.getAliasGroups().empty())
  380. return false;
  381. // The default assumption is that Polly improves the code.
  382. return true;
  383. }
  384. /// Collect statistics for the syntax tree rooted at @p Ast.
  385. static void walkAstForStatistics(const isl::ast_node &Ast) {
  386. assert(!Ast.is_null());
  387. isl_ast_node_foreach_descendant_top_down(
  388. Ast.get(),
  389. [](__isl_keep isl_ast_node *Node, void *User) -> isl_bool {
  390. switch (isl_ast_node_get_type(Node)) {
  391. case isl_ast_node_for:
  392. NumForLoops++;
  393. if (IslAstInfo::isParallel(isl::manage_copy(Node)))
  394. NumParallel++;
  395. if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node)))
  396. NumInnermostParallel++;
  397. if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node)))
  398. NumOutermostParallel++;
  399. if (IslAstInfo::isReductionParallel(isl::manage_copy(Node)))
  400. NumReductionParallel++;
  401. if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node)))
  402. NumExecutedInParallel++;
  403. break;
  404. case isl_ast_node_if:
  405. NumIfConditions++;
  406. break;
  407. default:
  408. break;
  409. }
  410. // Continue traversing subtrees.
  411. return isl_bool_true;
  412. },
  413. nullptr);
  414. }
  415. IslAst::IslAst(Scop &Scop) : S(Scop), Ctx(Scop.getSharedIslCtx()) {}
  416. IslAst::IslAst(IslAst &&O)
  417. : S(O.S), Ctx(O.Ctx), RunCondition(std::move(O.RunCondition)),
  418. Root(std::move(O.Root)) {}
  419. void IslAst::init(const Dependences &D) {
  420. bool PerformParallelTest = PollyParallel || DetectParallel ||
  421. PollyVectorizerChoice != VECTORIZER_NONE;
  422. auto ScheduleTree = S.getScheduleTree();
  423. // Skip AST and code generation if there was no benefit achieved.
  424. if (!benefitsFromPolly(S, PerformParallelTest))
  425. return;
  426. auto ScopStats = S.getStatistics();
  427. ScopsBeneficial++;
  428. BeneficialAffineLoops += ScopStats.NumAffineLoops;
  429. BeneficialBoxedLoops += ScopStats.NumBoxedLoops;
  430. auto Ctx = S.getIslCtx();
  431. isl_options_set_ast_build_atomic_upper_bound(Ctx.get(), true);
  432. isl_options_set_ast_build_detect_min_max(Ctx.get(), true);
  433. isl_ast_build *Build;
  434. AstBuildUserInfo BuildInfo;
  435. if (UseContext)
  436. Build = isl_ast_build_from_context(S.getContext().release());
  437. else
  438. Build = isl_ast_build_from_context(
  439. isl_set_universe(S.getParamSpace().release()));
  440. Build = isl_ast_build_set_at_each_domain(Build, AtEachDomain, nullptr);
  441. if (PerformParallelTest) {
  442. BuildInfo.Deps = &D;
  443. BuildInfo.InParallelFor = false;
  444. BuildInfo.InSIMD = false;
  445. Build = isl_ast_build_set_before_each_for(Build, &astBuildBeforeFor,
  446. &BuildInfo);
  447. Build =
  448. isl_ast_build_set_after_each_for(Build, &astBuildAfterFor, &BuildInfo);
  449. Build = isl_ast_build_set_before_each_mark(Build, &astBuildBeforeMark,
  450. &BuildInfo);
  451. Build = isl_ast_build_set_after_each_mark(Build, &astBuildAfterMark,
  452. &BuildInfo);
  453. }
  454. RunCondition = buildRunCondition(S, isl::manage_copy(Build));
  455. Root = isl::manage(
  456. isl_ast_build_node_from_schedule(Build, S.getScheduleTree().release()));
  457. walkAstForStatistics(Root);
  458. isl_ast_build_free(Build);
  459. }
  460. IslAst IslAst::create(Scop &Scop, const Dependences &D) {
  461. IslAst Ast{Scop};
  462. Ast.init(D);
  463. return Ast;
  464. }
  465. isl::ast_node IslAst::getAst() { return Root; }
  466. isl::ast_expr IslAst::getRunCondition() { return RunCondition; }
  467. isl::ast_node IslAstInfo::getAst() { return Ast.getAst(); }
  468. isl::ast_expr IslAstInfo::getRunCondition() { return Ast.getRunCondition(); }
  469. IslAstUserPayload *IslAstInfo::getNodePayload(const isl::ast_node &Node) {
  470. isl::id Id = Node.get_annotation();
  471. if (Id.is_null())
  472. return nullptr;
  473. IslAstUserPayload *Payload = (IslAstUserPayload *)Id.get_user();
  474. return Payload;
  475. }
  476. bool IslAstInfo::isInnermost(const isl::ast_node &Node) {
  477. IslAstUserPayload *Payload = getNodePayload(Node);
  478. return Payload && Payload->IsInnermost;
  479. }
  480. bool IslAstInfo::isParallel(const isl::ast_node &Node) {
  481. return IslAstInfo::isInnermostParallel(Node) ||
  482. IslAstInfo::isOutermostParallel(Node);
  483. }
  484. bool IslAstInfo::isInnermostParallel(const isl::ast_node &Node) {
  485. IslAstUserPayload *Payload = getNodePayload(Node);
  486. return Payload && Payload->IsInnermostParallel;
  487. }
  488. bool IslAstInfo::isOutermostParallel(const isl::ast_node &Node) {
  489. IslAstUserPayload *Payload = getNodePayload(Node);
  490. return Payload && Payload->IsOutermostParallel;
  491. }
  492. bool IslAstInfo::isReductionParallel(const isl::ast_node &Node) {
  493. IslAstUserPayload *Payload = getNodePayload(Node);
  494. return Payload && Payload->IsReductionParallel;
  495. }
  496. bool IslAstInfo::isExecutedInParallel(const isl::ast_node &Node) {
  497. if (!PollyParallel)
  498. return false;
  499. // Do not parallelize innermost loops.
  500. //
  501. // Parallelizing innermost loops is often not profitable, especially if
  502. // they have a low number of iterations.
  503. //
  504. // TODO: Decide this based on the number of loop iterations that will be
  505. // executed. This can possibly require run-time checks, which again
  506. // raises the question of both run-time check overhead and code size
  507. // costs.
  508. if (!PollyParallelForce && isInnermost(Node))
  509. return false;
  510. return isOutermostParallel(Node) && !isReductionParallel(Node);
  511. }
  512. isl::union_map IslAstInfo::getSchedule(const isl::ast_node &Node) {
  513. IslAstUserPayload *Payload = getNodePayload(Node);
  514. return Payload ? Payload->Build.get_schedule() : isl::union_map();
  515. }
  516. isl::pw_aff
  517. IslAstInfo::getMinimalDependenceDistance(const isl::ast_node &Node) {
  518. IslAstUserPayload *Payload = getNodePayload(Node);
  519. return Payload ? Payload->MinimalDependenceDistance : isl::pw_aff();
  520. }
  521. IslAstInfo::MemoryAccessSet *
  522. IslAstInfo::getBrokenReductions(const isl::ast_node &Node) {
  523. IslAstUserPayload *Payload = getNodePayload(Node);
  524. return Payload ? &Payload->BrokenReductions : nullptr;
  525. }
  526. isl::ast_build IslAstInfo::getBuild(const isl::ast_node &Node) {
  527. IslAstUserPayload *Payload = getNodePayload(Node);
  528. return Payload ? Payload->Build : isl::ast_build();
  529. }
  530. static std::unique_ptr<IslAstInfo> runIslAst(
  531. Scop &Scop,
  532. function_ref<const Dependences &(Dependences::AnalysisLevel)> GetDeps) {
  533. // Skip SCoPs in case they're already handled by PPCGCodeGeneration.
  534. if (Scop.isToBeSkipped())
  535. return {};
  536. ScopsProcessed++;
  537. const Dependences &D = GetDeps(Dependences::AL_Statement);
  538. if (D.getSharedIslCtx() != Scop.getSharedIslCtx()) {
  539. LLVM_DEBUG(
  540. dbgs() << "Got dependence analysis for different SCoP/isl_ctx\n");
  541. return {};
  542. }
  543. std::unique_ptr<IslAstInfo> Ast = std::make_unique<IslAstInfo>(Scop, D);
  544. LLVM_DEBUG({
  545. if (Ast)
  546. Ast->print(dbgs());
  547. });
  548. return Ast;
  549. }
  550. IslAstInfo IslAstAnalysis::run(Scop &S, ScopAnalysisManager &SAM,
  551. ScopStandardAnalysisResults &SAR) {
  552. auto GetDeps = [&](Dependences::AnalysisLevel Lvl) -> const Dependences & {
  553. return SAM.getResult<DependenceAnalysis>(S, SAR).getDependences(Lvl);
  554. };
  555. return std::move(*runIslAst(S, GetDeps));
  556. }
  557. static __isl_give isl_printer *cbPrintUser(__isl_take isl_printer *P,
  558. __isl_take isl_ast_print_options *O,
  559. __isl_keep isl_ast_node *Node,
  560. void *User) {
  561. isl::ast_node_user AstNode = isl::manage_copy(Node).as<isl::ast_node_user>();
  562. isl::ast_expr NodeExpr = AstNode.expr();
  563. isl::ast_expr CallExpr = NodeExpr.get_op_arg(0);
  564. isl::id CallExprId = CallExpr.get_id();
  565. ScopStmt *AccessStmt = (ScopStmt *)CallExprId.get_user();
  566. P = isl_printer_start_line(P);
  567. P = isl_printer_print_str(P, AccessStmt->getBaseName());
  568. P = isl_printer_print_str(P, "(");
  569. P = isl_printer_end_line(P);
  570. P = isl_printer_indent(P, 2);
  571. for (MemoryAccess *MemAcc : *AccessStmt) {
  572. P = isl_printer_start_line(P);
  573. if (MemAcc->isRead())
  574. P = isl_printer_print_str(P, "/* read */ &");
  575. else
  576. P = isl_printer_print_str(P, "/* write */ ");
  577. isl::ast_build Build = IslAstInfo::getBuild(isl::manage_copy(Node));
  578. if (MemAcc->isAffine()) {
  579. isl_pw_multi_aff *PwmaPtr =
  580. MemAcc->applyScheduleToAccessRelation(Build.get_schedule()).release();
  581. isl::pw_multi_aff Pwma = isl::manage(PwmaPtr);
  582. isl::ast_expr AccessExpr = Build.access_from(Pwma);
  583. P = isl_printer_print_ast_expr(P, AccessExpr.get());
  584. } else {
  585. P = isl_printer_print_str(
  586. P, MemAcc->getLatestScopArrayInfo()->getName().c_str());
  587. P = isl_printer_print_str(P, "[*]");
  588. }
  589. P = isl_printer_end_line(P);
  590. }
  591. P = isl_printer_indent(P, -2);
  592. P = isl_printer_start_line(P);
  593. P = isl_printer_print_str(P, ");");
  594. P = isl_printer_end_line(P);
  595. isl_ast_print_options_free(O);
  596. return P;
  597. }
  598. void IslAstInfo::print(raw_ostream &OS) {
  599. isl_ast_print_options *Options;
  600. isl::ast_node RootNode = Ast.getAst();
  601. Function &F = S.getFunction();
  602. OS << ":: isl ast :: " << F.getName() << " :: " << S.getNameStr() << "\n";
  603. if (RootNode.is_null()) {
  604. OS << ":: isl ast generation and code generation was skipped!\n\n";
  605. OS << ":: This is either because no useful optimizations could be applied "
  606. "(use -polly-process-unprofitable to enforce code generation) or "
  607. "because earlier passes such as dependence analysis timed out (use "
  608. "-polly-dependences-computeout=0 to set dependence analysis timeout "
  609. "to infinity)\n\n";
  610. return;
  611. }
  612. isl::ast_expr RunCondition = Ast.getRunCondition();
  613. char *RtCStr, *AstStr;
  614. Options = isl_ast_print_options_alloc(S.getIslCtx().get());
  615. if (PrintAccesses)
  616. Options =
  617. isl_ast_print_options_set_print_user(Options, cbPrintUser, nullptr);
  618. Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr);
  619. isl_printer *P = isl_printer_to_str(S.getIslCtx().get());
  620. P = isl_printer_set_output_format(P, ISL_FORMAT_C);
  621. P = isl_printer_print_ast_expr(P, RunCondition.get());
  622. RtCStr = isl_printer_get_str(P);
  623. P = isl_printer_flush(P);
  624. P = isl_printer_indent(P, 4);
  625. P = isl_ast_node_print(RootNode.get(), P, Options);
  626. AstStr = isl_printer_get_str(P);
  627. LLVM_DEBUG({
  628. dbgs() << S.getContextStr() << "\n";
  629. dbgs() << stringFromIslObj(S.getScheduleTree(), "null");
  630. });
  631. OS << "\nif (" << RtCStr << ")\n\n";
  632. OS << AstStr << "\n";
  633. OS << "else\n";
  634. OS << " { /* original code */ }\n\n";
  635. free(RtCStr);
  636. free(AstStr);
  637. isl_printer_free(P);
  638. }
  639. AnalysisKey IslAstAnalysis::Key;
  640. PreservedAnalyses IslAstPrinterPass::run(Scop &S, ScopAnalysisManager &SAM,
  641. ScopStandardAnalysisResults &SAR,
  642. SPMUpdater &U) {
  643. auto &Ast = SAM.getResult<IslAstAnalysis>(S, SAR);
  644. Ast.print(OS);
  645. return PreservedAnalyses::all();
  646. }
  647. void IslAstInfoWrapperPass::releaseMemory() { Ast.reset(); }
  648. bool IslAstInfoWrapperPass::runOnScop(Scop &Scop) {
  649. auto GetDeps = [this](Dependences::AnalysisLevel Lvl) -> const Dependences & {
  650. return getAnalysis<DependenceInfo>().getDependences(Lvl);
  651. };
  652. Ast = runIslAst(Scop, GetDeps);
  653. return false;
  654. }
  655. void IslAstInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
  656. // Get the Common analysis usage of ScopPasses.
  657. ScopPass::getAnalysisUsage(AU);
  658. AU.addRequiredTransitive<ScopInfoRegionPass>();
  659. AU.addRequired<DependenceInfo>();
  660. AU.addPreserved<DependenceInfo>();
  661. }
  662. void IslAstInfoWrapperPass::printScop(raw_ostream &OS, Scop &S) const {
  663. OS << "Printing analysis 'Polly - Generate an AST of the SCoP (isl)'"
  664. << S.getName() << "' in function '" << S.getFunction().getName() << "':\n";
  665. if (Ast)
  666. Ast->print(OS);
  667. }
  668. char IslAstInfoWrapperPass::ID = 0;
  669. Pass *polly::createIslAstInfoWrapperPassPass() {
  670. return new IslAstInfoWrapperPass();
  671. }
  672. INITIALIZE_PASS_BEGIN(IslAstInfoWrapperPass, "polly-ast",
  673. "Polly - Generate an AST of the SCoP (isl)", false,
  674. false);
  675. INITIALIZE_PASS_DEPENDENCY(ScopInfoRegionPass);
  676. INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
  677. INITIALIZE_PASS_END(IslAstInfoWrapperPass, "polly-ast",
  678. "Polly - Generate an AST from the SCoP (isl)", false, false)
  679. //===----------------------------------------------------------------------===//
  680. namespace {
  681. /// Print result from IslAstInfoWrapperPass.
  682. class IslAstInfoPrinterLegacyPass final : public ScopPass {
  683. public:
  684. static char ID;
  685. IslAstInfoPrinterLegacyPass() : IslAstInfoPrinterLegacyPass(outs()) {}
  686. explicit IslAstInfoPrinterLegacyPass(llvm::raw_ostream &OS)
  687. : ScopPass(ID), OS(OS) {}
  688. bool runOnScop(Scop &S) override {
  689. IslAstInfoWrapperPass &P = getAnalysis<IslAstInfoWrapperPass>();
  690. OS << "Printing analysis '" << P.getPassName() << "' for region: '"
  691. << S.getRegion().getNameStr() << "' in function '"
  692. << S.getFunction().getName() << "':\n";
  693. P.printScop(OS, S);
  694. return false;
  695. }
  696. void getAnalysisUsage(AnalysisUsage &AU) const override {
  697. ScopPass::getAnalysisUsage(AU);
  698. AU.addRequired<IslAstInfoWrapperPass>();
  699. AU.setPreservesAll();
  700. }
  701. private:
  702. llvm::raw_ostream &OS;
  703. };
  704. char IslAstInfoPrinterLegacyPass::ID = 0;
  705. } // namespace
  706. Pass *polly::createIslAstInfoPrinterLegacyPass(raw_ostream &OS) {
  707. return new IslAstInfoPrinterLegacyPass(OS);
  708. }
  709. INITIALIZE_PASS_BEGIN(IslAstInfoPrinterLegacyPass, "polly-print-ast",
  710. "Polly - Print the AST from a SCoP (isl)", false, false);
  711. INITIALIZE_PASS_DEPENDENCY(IslAstInfoWrapperPass);
  712. INITIALIZE_PASS_END(IslAstInfoPrinterLegacyPass, "polly-print-ast",
  713. "Polly - Print the AST from a SCoP (isl)", false, false)