AMDGPUMetadataVerifier.cpp 12 KB


  1. //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file
  10. /// Implements a verifier for AMDGPU HSA metadata.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
  14. #include "llvm/ADT/STLExtras.h"
  15. #include "llvm/ADT/StringSwitch.h"
  16. #include "llvm/BinaryFormat/MsgPackDocument.h"
  17. #include <map>
  18. #include <utility>
  19. namespace llvm {
  20. namespace AMDGPU {
  21. namespace HSAMD {
  22. namespace V3 {
  23. bool MetadataVerifier::verifyScalar(
  24. msgpack::DocNode &Node, msgpack::Type SKind,
  25. function_ref<bool(msgpack::DocNode &)> verifyValue) {
  26. if (!Node.isScalar())
  27. return false;
  28. if (Node.getKind() != SKind) {
  29. if (Strict)
  30. return false;
  31. // If we are not strict, we interpret string values as "implicitly typed"
  32. // and attempt to coerce them to the expected type here.
  33. if (Node.getKind() != msgpack::Type::String)
  34. return false;
  35. StringRef StringValue = Node.getString();
  36. Node.fromString(StringValue);
  37. if (Node.getKind() != SKind)
  38. return false;
  39. }
  40. if (verifyValue)
  41. return verifyValue(Node);
  42. return true;
  43. }
  44. bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
  45. if (!verifyScalar(Node, msgpack::Type::UInt))
  46. if (!verifyScalar(Node, msgpack::Type::Int))
  47. return false;
  48. return true;
  49. }
  50. bool MetadataVerifier::verifyArray(
  51. msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
  52. std::optional<size_t> Size) {
  53. if (!Node.isArray())
  54. return false;
  55. auto &Array = Node.getArray();
  56. if (Size && Array.size() != *Size)
  57. return false;
  58. return llvm::all_of(Array, verifyNode);
  59. }
  60. bool MetadataVerifier::verifyEntry(
  61. msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
  62. function_ref<bool(msgpack::DocNode &)> verifyNode) {
  63. auto Entry = MapNode.find(Key);
  64. if (Entry == MapNode.end())
  65. return !Required;
  66. return verifyNode(Entry->second);
  67. }
  68. bool MetadataVerifier::verifyScalarEntry(
  69. msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
  70. msgpack::Type SKind,
  71. function_ref<bool(msgpack::DocNode &)> verifyValue) {
  72. return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
  73. return verifyScalar(Node, SKind, verifyValue);
  74. });
  75. }
  76. bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
  77. StringRef Key, bool Required) {
  78. return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
  79. return verifyInteger(Node);
  80. });
  81. }
  82. bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
  83. if (!Node.isMap())
  84. return false;
  85. auto &ArgsMap = Node.getMap();
  86. if (!verifyScalarEntry(ArgsMap, ".name", false,
  87. msgpack::Type::String))
  88. return false;
  89. if (!verifyScalarEntry(ArgsMap, ".type_name", false,
  90. msgpack::Type::String))
  91. return false;
  92. if (!verifyIntegerEntry(ArgsMap, ".size", true))
  93. return false;
  94. if (!verifyIntegerEntry(ArgsMap, ".offset", true))
  95. return false;
  96. if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String,
  97. [](msgpack::DocNode &SNode) {
  98. return StringSwitch<bool>(SNode.getString())
  99. .Case("by_value", true)
  100. .Case("global_buffer", true)
  101. .Case("dynamic_shared_pointer", true)
  102. .Case("sampler", true)
  103. .Case("image", true)
  104. .Case("pipe", true)
  105. .Case("queue", true)
  106. .Case("hidden_block_count_x", true)
  107. .Case("hidden_block_count_y", true)
  108. .Case("hidden_block_count_z", true)
  109. .Case("hidden_group_size_x", true)
  110. .Case("hidden_group_size_y", true)
  111. .Case("hidden_group_size_z", true)
  112. .Case("hidden_remainder_x", true)
  113. .Case("hidden_remainder_y", true)
  114. .Case("hidden_remainder_z", true)
  115. .Case("hidden_global_offset_x", true)
  116. .Case("hidden_global_offset_y", true)
  117. .Case("hidden_global_offset_z", true)
  118. .Case("hidden_grid_dims", true)
  119. .Case("hidden_none", true)
  120. .Case("hidden_printf_buffer", true)
  121. .Case("hidden_hostcall_buffer", true)
  122. .Case("hidden_heap_v1", true)
  123. .Case("hidden_default_queue", true)
  124. .Case("hidden_completion_action", true)
  125. .Case("hidden_multigrid_sync_arg", true)
  126. .Case("hidden_private_base", true)
  127. .Case("hidden_shared_base", true)
  128. .Case("hidden_queue_ptr", true)
  129. .Default(false);
  130. }))
  131. return false;
  132. if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
  133. return false;
  134. if (!verifyScalarEntry(ArgsMap, ".address_space", false,
  135. msgpack::Type::String,
  136. [](msgpack::DocNode &SNode) {
  137. return StringSwitch<bool>(SNode.getString())
  138. .Case("private", true)
  139. .Case("global", true)
  140. .Case("constant", true)
  141. .Case("local", true)
  142. .Case("generic", true)
  143. .Case("region", true)
  144. .Default(false);
  145. }))
  146. return false;
  147. if (!verifyScalarEntry(ArgsMap, ".access", false,
  148. msgpack::Type::String,
  149. [](msgpack::DocNode &SNode) {
  150. return StringSwitch<bool>(SNode.getString())
  151. .Case("read_only", true)
  152. .Case("write_only", true)
  153. .Case("read_write", true)
  154. .Default(false);
  155. }))
  156. return false;
  157. if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
  158. msgpack::Type::String,
  159. [](msgpack::DocNode &SNode) {
  160. return StringSwitch<bool>(SNode.getString())
  161. .Case("read_only", true)
  162. .Case("write_only", true)
  163. .Case("read_write", true)
  164. .Default(false);
  165. }))
  166. return false;
  167. if (!verifyScalarEntry(ArgsMap, ".is_const", false,
  168. msgpack::Type::Boolean))
  169. return false;
  170. if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
  171. msgpack::Type::Boolean))
  172. return false;
  173. if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
  174. msgpack::Type::Boolean))
  175. return false;
  176. if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
  177. msgpack::Type::Boolean))
  178. return false;
  179. return true;
  180. }
  181. bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
  182. if (!Node.isMap())
  183. return false;
  184. auto &KernelMap = Node.getMap();
  185. if (!verifyScalarEntry(KernelMap, ".name", true,
  186. msgpack::Type::String))
  187. return false;
  188. if (!verifyScalarEntry(KernelMap, ".symbol", true,
  189. msgpack::Type::String))
  190. return false;
  191. if (!verifyScalarEntry(KernelMap, ".language", false,
  192. msgpack::Type::String,
  193. [](msgpack::DocNode &SNode) {
  194. return StringSwitch<bool>(SNode.getString())
  195. .Case("OpenCL C", true)
  196. .Case("OpenCL C++", true)
  197. .Case("HCC", true)
  198. .Case("HIP", true)
  199. .Case("OpenMP", true)
  200. .Case("Assembler", true)
  201. .Default(false);
  202. }))
  203. return false;
  204. if (!verifyEntry(
  205. KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
  206. return verifyArray(
  207. Node,
  208. [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
  209. }))
  210. return false;
  211. if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
  212. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  213. return verifyKernelArgs(Node);
  214. });
  215. }))
  216. return false;
  217. if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
  218. [this](msgpack::DocNode &Node) {
  219. return verifyArray(Node,
  220. [this](msgpack::DocNode &Node) {
  221. return verifyInteger(Node);
  222. },
  223. 3);
  224. }))
  225. return false;
  226. if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
  227. [this](msgpack::DocNode &Node) {
  228. return verifyArray(Node,
  229. [this](msgpack::DocNode &Node) {
  230. return verifyInteger(Node);
  231. },
  232. 3);
  233. }))
  234. return false;
  235. if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
  236. msgpack::Type::String))
  237. return false;
  238. if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
  239. msgpack::Type::String))
  240. return false;
  241. if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
  242. return false;
  243. if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
  244. return false;
  245. if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
  246. return false;
  247. if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false,
  248. msgpack::Type::Boolean))
  249. return false;
  250. if (!verifyIntegerEntry(KernelMap, ".workgroup_processor_mode", false))
  251. return false;
  252. if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
  253. return false;
  254. if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
  255. return false;
  256. if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
  257. return false;
  258. if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
  259. return false;
  260. if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
  261. return false;
  262. if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
  263. return false;
  264. if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
  265. return false;
  266. if (!verifyIntegerEntry(KernelMap, ".uniform_work_group_size", false))
  267. return false;
  268. return true;
  269. }
  270. bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
  271. if (!HSAMetadataRoot.isMap())
  272. return false;
  273. auto &RootMap = HSAMetadataRoot.getMap();
  274. if (!verifyEntry(
  275. RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
  276. return verifyArray(
  277. Node,
  278. [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
  279. }))
  280. return false;
  281. if (!verifyEntry(
  282. RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
  283. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  284. return verifyScalar(Node, msgpack::Type::String);
  285. });
  286. }))
  287. return false;
  288. if (!verifyEntry(RootMap, "amdhsa.kernels", true,
  289. [this](msgpack::DocNode &Node) {
  290. return verifyArray(Node, [this](msgpack::DocNode &Node) {
  291. return verifyKernel(Node);
  292. });
  293. }))
  294. return false;
  295. return true;
  296. }
  297. } // end namespace V3
  298. } // end namespace HSAMD
  299. } // end namespace AMDGPU
  300. } // end namespace llvm