gpu_print.c 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /*
  2. * Copyright 2012 Ecole Normale Superieure
  3. *
  4. * Use of this software is governed by the MIT license
  5. *
  6. * Written by Sven Verdoolaege,
  7. * Ecole Normale Superieure, 45 rue d’Ulm, 75230 Paris, France
  8. */
  9. #include <string.h>
  10. #include <isl/aff.h>
  11. #include "gpu_print.h"
  12. #include "print.h"
  13. #include "schedule.h"
  14. /* Print declarations to "p" for arrays that are local to "prog"
  15. * but that are used on the host and therefore require a declaration.
  16. */
  17. __isl_give isl_printer *gpu_print_local_declarations(__isl_take isl_printer *p,
  18. struct gpu_prog *prog)
  19. {
  20. int i;
  21. if (!prog)
  22. return isl_printer_free(p);
  23. for (i = 0; i < prog->n_array; ++i) {
  24. struct gpu_array_info *array = &prog->array[i];
  25. isl_ast_expr *size;
  26. if (!array->declare_local)
  27. continue;
  28. size = array->declared_size;
  29. p = ppcg_print_declaration_with_size(p, array->type, size);
  30. }
  31. return p;
  32. }
  33. /* Print an expression for the size of "array" in bytes.
  34. */
  35. __isl_give isl_printer *gpu_array_info_print_size(__isl_take isl_printer *prn,
  36. struct gpu_array_info *array)
  37. {
  38. int i;
  39. for (i = 0; i < array->n_index; ++i) {
  40. isl_ast_expr *bound;
  41. prn = isl_printer_print_str(prn, "(");
  42. bound = isl_ast_expr_get_op_arg(array->bound_expr, 1 + i);
  43. prn = isl_printer_print_ast_expr(prn, bound);
  44. isl_ast_expr_free(bound);
  45. prn = isl_printer_print_str(prn, ") * ");
  46. }
  47. prn = isl_printer_print_str(prn, "sizeof(");
  48. prn = isl_printer_print_str(prn, array->type);
  49. prn = isl_printer_print_str(prn, ")");
  50. return prn;
  51. }
  52. /* Print the declaration of a non-linearized array argument.
  53. */
  54. static __isl_give isl_printer *print_non_linearized_declaration_argument(
  55. __isl_take isl_printer *p, struct gpu_array_info *array)
  56. {
  57. p = isl_printer_print_str(p, array->type);
  58. p = isl_printer_print_str(p, " ");
  59. p = isl_printer_print_ast_expr(p, array->bound_expr);
  60. return p;
  61. }
  62. /* Print the declaration of an array argument.
  63. * "memory_space" allows to specify a memory space prefix.
  64. */
  65. __isl_give isl_printer *gpu_array_info_print_declaration_argument(
  66. __isl_take isl_printer *p, struct gpu_array_info *array,
  67. const char *memory_space)
  68. {
  69. if (gpu_array_is_read_only_scalar(array)) {
  70. p = isl_printer_print_str(p, array->type);
  71. p = isl_printer_print_str(p, " ");
  72. p = isl_printer_print_str(p, array->name);
  73. return p;
  74. }
  75. if (memory_space) {
  76. p = isl_printer_print_str(p, memory_space);
  77. p = isl_printer_print_str(p, " ");
  78. }
  79. if (array->n_index != 0 && !array->linearize)
  80. return print_non_linearized_declaration_argument(p, array);
  81. p = isl_printer_print_str(p, array->type);
  82. p = isl_printer_print_str(p, " ");
  83. p = isl_printer_print_str(p, "*");
  84. p = isl_printer_print_str(p, array->name);
  85. return p;
  86. }
  87. /* Print the call of an array argument.
  88. */
  89. __isl_give isl_printer *gpu_array_info_print_call_argument(
  90. __isl_take isl_printer *p, struct gpu_array_info *array)
  91. {
  92. if (gpu_array_is_read_only_scalar(array))
  93. return isl_printer_print_str(p, array->name);
  94. p = isl_printer_print_str(p, "dev_");
  95. p = isl_printer_print_str(p, array->name);
  96. return p;
  97. }
  98. /* Print an access to the element in the private/shared memory copy
  99. * described by "stmt". The index of the copy is recorded in
  100. * stmt->local_index as an access to the array.
  101. */
  102. static __isl_give isl_printer *stmt_print_local_index(__isl_take isl_printer *p,
  103. struct ppcg_kernel_stmt *stmt)
  104. {
  105. return isl_printer_print_ast_expr(p, stmt->u.c.local_index);
  106. }
  107. /* Print an access to the element in the global memory copy
  108. * described by "stmt". The index of the copy is recorded in
  109. * stmt->index as an access to the array.
  110. */
  111. static __isl_give isl_printer *stmt_print_global_index(
  112. __isl_take isl_printer *p, struct ppcg_kernel_stmt *stmt)
  113. {
  114. struct gpu_array_info *array = stmt->u.c.array;
  115. isl_ast_expr *index;
  116. if (gpu_array_is_scalar(array)) {
  117. if (!gpu_array_is_read_only_scalar(array))
  118. p = isl_printer_print_str(p, "*");
  119. p = isl_printer_print_str(p, array->name);
  120. return p;
  121. }
  122. index = isl_ast_expr_copy(stmt->u.c.index);
  123. p = isl_printer_print_ast_expr(p, index);
  124. isl_ast_expr_free(index);
  125. return p;
  126. }
  127. /* Print a copy statement.
  128. *
  129. * A read copy statement is printed as
  130. *
  131. * local = global;
  132. *
  133. * while a write copy statement is printed as
  134. *
  135. * global = local;
  136. */
  137. __isl_give isl_printer *ppcg_kernel_print_copy(__isl_take isl_printer *p,
  138. struct ppcg_kernel_stmt *stmt)
  139. {
  140. p = isl_printer_start_line(p);
  141. if (stmt->u.c.read) {
  142. p = stmt_print_local_index(p, stmt);
  143. p = isl_printer_print_str(p, " = ");
  144. p = stmt_print_global_index(p, stmt);
  145. } else {
  146. p = stmt_print_global_index(p, stmt);
  147. p = isl_printer_print_str(p, " = ");
  148. p = stmt_print_local_index(p, stmt);
  149. }
  150. p = isl_printer_print_str(p, ";");
  151. p = isl_printer_end_line(p);
  152. return p;
  153. }
  154. __isl_give isl_printer *ppcg_kernel_print_domain(__isl_take isl_printer *p,
  155. struct ppcg_kernel_stmt *stmt)
  156. {
  157. return pet_stmt_print_body(stmt->u.d.stmt->stmt, p, stmt->u.d.ref2expr);
  158. }
  159. /* This function is called for each node in a GPU AST.
  160. * In case of a user node, print the macro definitions required
  161. * for printing the AST expressions in the annotation, if any.
  162. * For other nodes, return true such that descendants are also
  163. * visited.
  164. *
  165. * In particular, for a kernel launch, print the macro definitions
  166. * needed for the grid size.
  167. * For a copy statement, print the macro definitions needed
  168. * for the two index expressions.
  169. * For an original user statement, print the macro definitions
  170. * needed for the substitutions.
  171. */
  172. static isl_bool at_node(__isl_keep isl_ast_node *node, void *user)
  173. {
  174. const char *name;
  175. isl_id *id;
  176. int is_kernel;
  177. struct ppcg_kernel *kernel;
  178. struct ppcg_kernel_stmt *stmt;
  179. isl_printer **p = user;
  180. if (isl_ast_node_get_type(node) != isl_ast_node_user)
  181. return isl_bool_true;
  182. id = isl_ast_node_get_annotation(node);
  183. if (!id)
  184. return isl_bool_false;
  185. name = isl_id_get_name(id);
  186. if (!name)
  187. return isl_bool_error;
  188. is_kernel = !strcmp(name, "kernel");
  189. kernel = is_kernel ? isl_id_get_user(id) : NULL;
  190. stmt = is_kernel ? NULL : isl_id_get_user(id);
  191. isl_id_free(id);
  192. if ((is_kernel && !kernel) || (!is_kernel && !stmt))
  193. return isl_bool_error;
  194. if (is_kernel) {
  195. *p = ppcg_ast_expr_print_macros(kernel->grid_size_expr, *p);
  196. } else if (stmt->type == ppcg_kernel_copy) {
  197. *p = ppcg_ast_expr_print_macros(stmt->u.c.index, *p);
  198. *p = ppcg_ast_expr_print_macros(stmt->u.c.local_index, *p);
  199. } else if (stmt->type == ppcg_kernel_domain) {
  200. *p = ppcg_print_body_macros(*p, stmt->u.d.ref2expr);
  201. }
  202. if (!*p)
  203. return isl_bool_error;
  204. return isl_bool_false;
  205. }
  206. /* Print the required macros for the GPU AST "node" to "p",
  207. * including those needed for the user statements inside the AST.
  208. */
  209. __isl_give isl_printer *gpu_print_macros(__isl_take isl_printer *p,
  210. __isl_keep isl_ast_node *node)
  211. {
  212. if (isl_ast_node_foreach_descendant_top_down(node, &at_node, &p) < 0)
  213. return isl_printer_free(p);
  214. p = ppcg_print_macros(p, node);
  215. return p;
  216. }
  217. /* Was the definition of "type" printed before?
  218. * That is, does its name appear in the list of printed types "types"?
  219. */
  220. static int already_printed(struct gpu_types *types,
  221. struct pet_type *type)
  222. {
  223. int i;
  224. for (i = 0; i < types->n; ++i)
  225. if (!strcmp(types->name[i], type->name))
  226. return 1;
  227. return 0;
  228. }
  229. /* Print the definitions of all types prog->scop that have not been
  230. * printed before (according to "types") on "p".
  231. * Extend the list of printed types "types" with the newly printed types.
  232. */
  233. __isl_give isl_printer *gpu_print_types(__isl_take isl_printer *p,
  234. struct gpu_types *types, struct gpu_prog *prog)
  235. {
  236. int i, n;
  237. isl_ctx *ctx;
  238. char **name;
  239. n = prog->scop->pet->n_type;
  240. if (n == 0)
  241. return p;
  242. ctx = isl_printer_get_ctx(p);
  243. name = isl_realloc_array(ctx, types->name, char *, types->n + n);
  244. if (!name)
  245. return isl_printer_free(p);
  246. types->name = name;
  247. for (i = 0; i < n; ++i) {
  248. struct pet_type *type = prog->scop->pet->types[i];
  249. if (already_printed(types, type))
  250. continue;
  251. p = isl_printer_start_line(p);
  252. p = isl_printer_print_str(p, type->definition);
  253. p = isl_printer_print_str(p, ";");
  254. p = isl_printer_end_line(p);
  255. types->name[types->n++] = strdup(type->name);
  256. }
  257. return p;
  258. }