AMDGPUMetadataVerifier.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file
  10. /// Implements a verifier for AMDGPU HSA metadata.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
  14. #include "llvm/ADT/STLExtras.h"
  15. #include "llvm/ADT/STLForwardCompat.h"
  16. #include "llvm/ADT/StringSwitch.h"
  17. #include "llvm/BinaryFormat/MsgPackDocument.h"
  18. #include <map>
  19. #include <utility>
  20. namespace llvm {
  21. namespace AMDGPU {
  22. namespace HSAMD {
  23. namespace V3 {
  24. bool MetadataVerifier::verifyScalar(
  25. msgpack::DocNode &Node, msgpack::Type SKind,
  26. function_ref<bool(msgpack::DocNode &)> verifyValue) {
  27. if (!Node.isScalar())
  28. return false;
  29. if (Node.getKind() != SKind) {
  30. if (Strict)
  31. return false;
  32. // If we are not strict, we interpret string values as "implicitly typed"
  33. // and attempt to coerce them to the expected type here.
  34. if (Node.getKind() != msgpack::Type::String)
  35. return false;
  36. StringRef StringValue = Node.getString();
  37. Node.fromString(StringValue);
  38. if (Node.getKind() != SKind)
  39. return false;
  40. }
  41. if (verifyValue)
  42. return verifyValue(Node);
  43. return true;
  44. }
  45. bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
  46. if (!verifyScalar(Node, msgpack::Type::UInt))
  47. if (!verifyScalar(Node, msgpack::Type::Int))
  48. return false;
  49. return true;
  50. }
  51. bool MetadataVerifier::verifyArray(
  52. msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
  53. Optional<size_t> Size) {
  54. if (!Node.isArray())
  55. return false;
  56. auto &Array = Node.getArray();
  57. if (Size && Array.size() != *Size)
  58. return false;
  59. return llvm::all_of(Array, verifyNode);
  60. }
  61. bool MetadataVerifier::verifyEntry(
  62. msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
  63. function_ref<bool(msgpack::DocNode &)> verifyNode) {
  64. auto Entry = MapNode.find(Key);
  65. if (Entry == MapNode.end())
  66. return !Required;
  67. return verifyNode(Entry->second);
  68. }
  69. bool MetadataVerifier::verifyScalarEntry(
  70. msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
  71. msgpack::Type SKind,
  72. function_ref<bool(msgpack::DocNode &)> verifyValue) {
  73. return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
  74. return verifyScalar(Node, SKind, verifyValue);
  75. });
  76. }
  77. bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
  78. StringRef Key, bool Required) {
  79. return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
  80. return verifyInteger(Node);
  81. });
  82. }
  83. bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
  84. if (!Node.isMap())
  85. return false;
  86. auto &ArgsMap = Node.getMap();
  87. if (!verifyScalarEntry(ArgsMap, ".name", false,
  88. msgpack::Type::String))
  89. return false;
  90. if (!verifyScalarEntry(ArgsMap, ".type_name", false,
  91. msgpack::Type::String))
  92. return false;
  93. if (!verifyIntegerEntry(ArgsMap, ".size", true))
  94. return false;
  95. if (!verifyIntegerEntry(ArgsMap, ".offset", true))
  96. return false;
  97. if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
  98. msgpack::Type::String,
  99. [](msgpack::DocNode &SNode) {
  100. return StringSwitch<bool>(SNode.getString())
  101. .Case("by_value", true)
  102. .Case("global_buffer", true)
  103. .Case("dynamic_shared_pointer", true)
  104. .Case("sampler", true)
  105. .Case("image", true)
  106. .Case("pipe", true)
  107. .Case("queue", true)
  108. .Case("hidden_block_count_x", true)
  109. .Case("hidden_block_count_y", true)
  110. .Case("hidden_block_count_z", true)
  111. .Case("hidden_group_size_x", true)
  112. .Case("hidden_group_size_y", true)
  113. .Case("hidden_group_size_z", true)
  114. .Case("hidden_remainder_x", true)
  115. .Case("hidden_remainder_y", true)
  116. .Case("hidden_remainder_z", true)
  117. .Case("hidden_global_offset_x", true)
  118. .Case("hidden_global_offset_y", true)
  119. .Case("hidden_global_offset_z", true)
  120. .Case("hidden_grid_dims", true)
  121. .Case("hidden_none", true)
  122. .Case("hidden_printf_buffer", true)
  123. .Case("hidden_hostcall_buffer", true)
  124. .Case("hidden_default_queue", true)
  125. .Case("hidden_completion_action", true)
  126. .Case("hidden_multigrid_sync_arg", true)
  127. .Case("hidden_private_base", true)
  128. .Case("hidden_shared_base", true)
  129. .Case("hidden_queue_ptr", true)
  130. .Default(false);
  131. }))
  132. return false;
  133. if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
  134. return false;
  135. if (!verifyScalarEntry(ArgsMap, ".address_space", false,
  136. msgpack::Type::String,
  137. [](msgpack::DocNode &SNode) {
  138. return StringSwitch<bool>(SNode.getString())
  139. .Case("private", true)
  140. .Case("global", true)
  141. .Case("constant", true)
  142. .Case("local", true)
  143. .Case("generic", true)
  144. .Case("region", true)
  145. .Default(false);
  146. }))
  147. return false;
  148. if (!verifyScalarEntry(ArgsMap, ".access", false,
  149. msgpack::Type::String,
  150. [](msgpack::DocNode &SNode) {
  151. return StringSwitch<bool>(SNode.getString())
  152. .Case("read_only", true)
  153. .Case("write_only", true)
  154. .Case("read_write", true)
  155. .Default(false);
  156. }))
  157. return false;
  158. if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
  159. msgpack::Type::String,
  160. [](msgpack::DocNode &SNode) {
  161. return StringSwitch<bool>(SNode.getString())
  162. .Case("read_only", true)
  163. .Case("write_only", true)
  164. .Case("read_write", true)
  165. .Default(false);
  166. }))
  167. return false;
  168. if (!verifyScalarEntry(ArgsMap, ".is_const", false,
  169. msgpack::Type::Boolean))
  170. return false;
  171. if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
  172. msgpack::Type::Boolean))
  173. return false;
  174. if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
  175. msgpack::Type::Boolean))
  176. return false;
  177. if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
  178. msgpack::Type::Boolean))
  179. return false;
  180. return true;
  181. }
  182. bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
  183. if (!Node.isMap())
  184. return false;
  185. auto &KernelMap = Node.getMap();
  186. if (!verifyScalarEntry(KernelMap, ".name", true,
  187. msgpack::Type::String))
  188. return false;
  189. if (!verifyScalarEntry(KernelMap, ".symbol", true,
  190. msgpack::Type::String))
  191. return false;
  192. if (!verifyScalarEntry(KernelMap, ".language", false,
  193. msgpack::Type::String,
  194. [](msgpack::DocNode &SNode) {
  195. return StringSwitch<bool>(SNode.getString())
  196. .Case("OpenCL C", true)
  197. .Case("OpenCL C++", true)
  198. .Case("HCC", true)
  199. .Case("HIP", true)
  200. .Case("OpenMP", true)
  201. .Case("Assembler", true)
  202. .Default(false);
  203. }))
  204. return false;
  205. if (!verifyEntry(
  206. KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
  207. return verifyArray(
  208. Node,
  209. [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
  210. }))
  211. return false;
  212. if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
  213. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  214. return verifyKernelArgs(Node);
  215. });
  216. }))
  217. return false;
  218. if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
  219. [this](msgpack::DocNode &Node) {
  220. return verifyArray(Node,
  221. [this](msgpack::DocNode &Node) {
  222. return verifyInteger(Node);
  223. },
  224. 3);
  225. }))
  226. return false;
  227. if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
  228. [this](msgpack::DocNode &Node) {
  229. return verifyArray(Node,
  230. [this](msgpack::DocNode &Node) {
  231. return verifyInteger(Node);
  232. },
  233. 3);
  234. }))
  235. return false;
  236. if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
  237. msgpack::Type::String))
  238. return false;
  239. if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
  240. msgpack::Type::String))
  241. return false;
  242. if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
  243. return false;
  244. if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
  245. return false;
  246. if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
  247. return false;
  248. if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
  249. return false;
  250. if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
  251. return false;
  252. if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
  253. return false;
  254. if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
  255. return false;
  256. if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
  257. return false;
  258. if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
  259. return false;
  260. if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
  261. return false;
  262. return true;
  263. }
  264. bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
  265. if (!HSAMetadataRoot.isMap())
  266. return false;
  267. auto &RootMap = HSAMetadataRoot.getMap();
  268. if (!verifyEntry(
  269. RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
  270. return verifyArray(
  271. Node,
  272. [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
  273. }))
  274. return false;
  275. if (!verifyEntry(
  276. RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
  277. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  278. return verifyScalar(Node, msgpack::Type::String);
  279. });
  280. }))
  281. return false;
  282. if (!verifyEntry(RootMap, "amdhsa.kernels", true,
  283. [this](msgpack::DocNode &Node) {
  284. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  285. return verifyKernel(Node);
  286. });
  287. }))
  288. return false;
  289. return true;
  290. }
  291. } // end namespace V3
  292. } // end namespace HSAMD
  293. } // end namespace AMDGPU
  294. } // end namespace llvm