AMDGPUMetadataVerifier.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file
  10. /// Implements a verifier for AMDGPU HSA metadata.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
  14. #include "llvm/ADT/StringSwitch.h"
  15. #include "llvm/Support/AMDGPUMetadata.h"
  16. namespace llvm {
  17. namespace AMDGPU {
  18. namespace HSAMD {
  19. namespace V3 {
  20. bool MetadataVerifier::verifyScalar(
  21. msgpack::DocNode &Node, msgpack::Type SKind,
  22. function_ref<bool(msgpack::DocNode &)> verifyValue) {
  23. if (!Node.isScalar())
  24. return false;
  25. if (Node.getKind() != SKind) {
  26. if (Strict)
  27. return false;
  28. // If we are not strict, we interpret string values as "implicitly typed"
  29. // and attempt to coerce them to the expected type here.
  30. if (Node.getKind() != msgpack::Type::String)
  31. return false;
  32. StringRef StringValue = Node.getString();
  33. Node.fromString(StringValue);
  34. if (Node.getKind() != SKind)
  35. return false;
  36. }
  37. if (verifyValue)
  38. return verifyValue(Node);
  39. return true;
  40. }
  41. bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
  42. if (!verifyScalar(Node, msgpack::Type::UInt))
  43. if (!verifyScalar(Node, msgpack::Type::Int))
  44. return false;
  45. return true;
  46. }
  47. bool MetadataVerifier::verifyArray(
  48. msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
  49. Optional<size_t> Size) {
  50. if (!Node.isArray())
  51. return false;
  52. auto &Array = Node.getArray();
  53. if (Size && Array.size() != *Size)
  54. return false;
  55. for (auto &Item : Array)
  56. if (!verifyNode(Item))
  57. return false;
  58. return true;
  59. }
  60. bool MetadataVerifier::verifyEntry(
  61. msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
  62. function_ref<bool(msgpack::DocNode &)> verifyNode) {
  63. auto Entry = MapNode.find(Key);
  64. if (Entry == MapNode.end())
  65. return !Required;
  66. return verifyNode(Entry->second);
  67. }
  68. bool MetadataVerifier::verifyScalarEntry(
  69. msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
  70. msgpack::Type SKind,
  71. function_ref<bool(msgpack::DocNode &)> verifyValue) {
  72. return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
  73. return verifyScalar(Node, SKind, verifyValue);
  74. });
  75. }
  76. bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
  77. StringRef Key, bool Required) {
  78. return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
  79. return verifyInteger(Node);
  80. });
  81. }
  82. bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
  83. if (!Node.isMap())
  84. return false;
  85. auto &ArgsMap = Node.getMap();
  86. if (!verifyScalarEntry(ArgsMap, ".name", false,
  87. msgpack::Type::String))
  88. return false;
  89. if (!verifyScalarEntry(ArgsMap, ".type_name", false,
  90. msgpack::Type::String))
  91. return false;
  92. if (!verifyIntegerEntry(ArgsMap, ".size", true))
  93. return false;
  94. if (!verifyIntegerEntry(ArgsMap, ".offset", true))
  95. return false;
  96. if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
  97. msgpack::Type::String,
  98. [](msgpack::DocNode &SNode) {
  99. return StringSwitch<bool>(SNode.getString())
  100. .Case("by_value", true)
  101. .Case("global_buffer", true)
  102. .Case("dynamic_shared_pointer", true)
  103. .Case("sampler", true)
  104. .Case("image", true)
  105. .Case("pipe", true)
  106. .Case("queue", true)
  107. .Case("hidden_global_offset_x", true)
  108. .Case("hidden_global_offset_y", true)
  109. .Case("hidden_global_offset_z", true)
  110. .Case("hidden_none", true)
  111. .Case("hidden_printf_buffer", true)
  112. .Case("hidden_hostcall_buffer", true)
  113. .Case("hidden_default_queue", true)
  114. .Case("hidden_completion_action", true)
  115. .Case("hidden_multigrid_sync_arg", true)
  116. .Default(false);
  117. }))
  118. return false;
  119. if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
  120. return false;
  121. if (!verifyScalarEntry(ArgsMap, ".address_space", false,
  122. msgpack::Type::String,
  123. [](msgpack::DocNode &SNode) {
  124. return StringSwitch<bool>(SNode.getString())
  125. .Case("private", true)
  126. .Case("global", true)
  127. .Case("constant", true)
  128. .Case("local", true)
  129. .Case("generic", true)
  130. .Case("region", true)
  131. .Default(false);
  132. }))
  133. return false;
  134. if (!verifyScalarEntry(ArgsMap, ".access", false,
  135. msgpack::Type::String,
  136. [](msgpack::DocNode &SNode) {
  137. return StringSwitch<bool>(SNode.getString())
  138. .Case("read_only", true)
  139. .Case("write_only", true)
  140. .Case("read_write", true)
  141. .Default(false);
  142. }))
  143. return false;
  144. if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
  145. msgpack::Type::String,
  146. [](msgpack::DocNode &SNode) {
  147. return StringSwitch<bool>(SNode.getString())
  148. .Case("read_only", true)
  149. .Case("write_only", true)
  150. .Case("read_write", true)
  151. .Default(false);
  152. }))
  153. return false;
  154. if (!verifyScalarEntry(ArgsMap, ".is_const", false,
  155. msgpack::Type::Boolean))
  156. return false;
  157. if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
  158. msgpack::Type::Boolean))
  159. return false;
  160. if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
  161. msgpack::Type::Boolean))
  162. return false;
  163. if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
  164. msgpack::Type::Boolean))
  165. return false;
  166. return true;
  167. }
  168. bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
  169. if (!Node.isMap())
  170. return false;
  171. auto &KernelMap = Node.getMap();
  172. if (!verifyScalarEntry(KernelMap, ".name", true,
  173. msgpack::Type::String))
  174. return false;
  175. if (!verifyScalarEntry(KernelMap, ".symbol", true,
  176. msgpack::Type::String))
  177. return false;
  178. if (!verifyScalarEntry(KernelMap, ".language", false,
  179. msgpack::Type::String,
  180. [](msgpack::DocNode &SNode) {
  181. return StringSwitch<bool>(SNode.getString())
  182. .Case("OpenCL C", true)
  183. .Case("OpenCL C++", true)
  184. .Case("HCC", true)
  185. .Case("HIP", true)
  186. .Case("OpenMP", true)
  187. .Case("Assembler", true)
  188. .Default(false);
  189. }))
  190. return false;
  191. if (!verifyEntry(
  192. KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
  193. return verifyArray(
  194. Node,
  195. [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
  196. }))
  197. return false;
  198. if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
  199. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  200. return verifyKernelArgs(Node);
  201. });
  202. }))
  203. return false;
  204. if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
  205. [this](msgpack::DocNode &Node) {
  206. return verifyArray(Node,
  207. [this](msgpack::DocNode &Node) {
  208. return verifyInteger(Node);
  209. },
  210. 3);
  211. }))
  212. return false;
  213. if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
  214. [this](msgpack::DocNode &Node) {
  215. return verifyArray(Node,
  216. [this](msgpack::DocNode &Node) {
  217. return verifyInteger(Node);
  218. },
  219. 3);
  220. }))
  221. return false;
  222. if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
  223. msgpack::Type::String))
  224. return false;
  225. if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
  226. msgpack::Type::String))
  227. return false;
  228. if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
  229. return false;
  230. if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
  231. return false;
  232. if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
  233. return false;
  234. if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
  235. return false;
  236. if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
  237. return false;
  238. if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
  239. return false;
  240. if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
  241. return false;
  242. if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
  243. return false;
  244. if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
  245. return false;
  246. if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
  247. return false;
  248. return true;
  249. }
  250. bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
  251. if (!HSAMetadataRoot.isMap())
  252. return false;
  253. auto &RootMap = HSAMetadataRoot.getMap();
  254. if (!verifyEntry(
  255. RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
  256. return verifyArray(
  257. Node,
  258. [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
  259. }))
  260. return false;
  261. if (!verifyEntry(
  262. RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
  263. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  264. return verifyScalar(Node, msgpack::Type::String);
  265. });
  266. }))
  267. return false;
  268. if (!verifyEntry(RootMap, "amdhsa.kernels", true,
  269. [this](msgpack::DocNode &Node) {
  270. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  271. return verifyKernel(Node);
  272. });
  273. }))
  274. return false;
  275. return true;
  276. }
  277. } // end namespace V3
  278. } // end namespace HSAMD
  279. } // end namespace AMDGPU
  280. } // end namespace llvm