OffloadWrapper.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "OffloadWrapper.h"
  9. #include "llvm/ADT/ArrayRef.h"
  10. #include "llvm/ADT/Triple.h"
  11. #include "llvm/IR/Constants.h"
  12. #include "llvm/IR/GlobalVariable.h"
  13. #include "llvm/IR/IRBuilder.h"
  14. #include "llvm/IR/LLVMContext.h"
  15. #include "llvm/IR/Module.h"
  16. #include "llvm/Object/OffloadBinary.h"
  17. #include "llvm/Support/Error.h"
  18. #include "llvm/Transforms/Utils/ModuleUtils.h"
  19. using namespace llvm;
  20. namespace {
  21. /// Magic number that begins the section containing the CUDA fatbinary.
  22. constexpr unsigned CudaFatMagic = 0x466243b1;
  23. constexpr unsigned HIPFatMagic = 0x48495046;
  24. /// Copied from clang/CGCudaRuntime.h.
  25. enum OffloadEntryKindFlag : uint32_t {
  26. /// Mark the entry as a global entry. This indicates the presense of a
  27. /// kernel if the size size field is zero and a variable otherwise.
  28. OffloadGlobalEntry = 0x0,
  29. /// Mark the entry as a managed global variable.
  30. OffloadGlobalManagedEntry = 0x1,
  31. /// Mark the entry as a surface variable.
  32. OffloadGlobalSurfaceEntry = 0x2,
  33. /// Mark the entry as a texture variable.
  34. OffloadGlobalTextureEntry = 0x3,
  35. };
  36. IntegerType *getSizeTTy(Module &M) {
  37. LLVMContext &C = M.getContext();
  38. switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) {
  39. case 4u:
  40. return Type::getInt32Ty(C);
  41. case 8u:
  42. return Type::getInt64Ty(C);
  43. }
  44. llvm_unreachable("unsupported pointer type size");
  45. }
  46. // struct __tgt_offload_entry {
  47. // void *addr;
  48. // char *name;
  49. // size_t size;
  50. // int32_t flags;
  51. // int32_t reserved;
  52. // };
  53. StructType *getEntryTy(Module &M) {
  54. LLVMContext &C = M.getContext();
  55. StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry");
  56. if (!EntryTy)
  57. EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C),
  58. Type::getInt8PtrTy(C), getSizeTTy(M),
  59. Type::getInt32Ty(C), Type::getInt32Ty(C));
  60. return EntryTy;
  61. }
  62. PointerType *getEntryPtrTy(Module &M) {
  63. return PointerType::getUnqual(getEntryTy(M));
  64. }
  65. // struct __tgt_device_image {
  66. // void *ImageStart;
  67. // void *ImageEnd;
  68. // __tgt_offload_entry *EntriesBegin;
  69. // __tgt_offload_entry *EntriesEnd;
  70. // };
  71. StructType *getDeviceImageTy(Module &M) {
  72. LLVMContext &C = M.getContext();
  73. StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image");
  74. if (!ImageTy)
  75. ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C),
  76. Type::getInt8PtrTy(C), getEntryPtrTy(M),
  77. getEntryPtrTy(M));
  78. return ImageTy;
  79. }
  80. PointerType *getDeviceImagePtrTy(Module &M) {
  81. return PointerType::getUnqual(getDeviceImageTy(M));
  82. }
  83. // struct __tgt_bin_desc {
  84. // int32_t NumDeviceImages;
  85. // __tgt_device_image *DeviceImages;
  86. // __tgt_offload_entry *HostEntriesBegin;
  87. // __tgt_offload_entry *HostEntriesEnd;
  88. // };
  89. StructType *getBinDescTy(Module &M) {
  90. LLVMContext &C = M.getContext();
  91. StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc");
  92. if (!DescTy)
  93. DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C),
  94. getDeviceImagePtrTy(M), getEntryPtrTy(M),
  95. getEntryPtrTy(M));
  96. return DescTy;
  97. }
  98. PointerType *getBinDescPtrTy(Module &M) {
  99. return PointerType::getUnqual(getBinDescTy(M));
  100. }
  101. /// Creates binary descriptor for the given device images. Binary descriptor
  102. /// is an object that is passed to the offloading runtime at program startup
  103. /// and it describes all device images available in the executable or shared
  104. /// library. It is defined as follows
  105. ///
  106. /// __attribute__((visibility("hidden")))
  107. /// extern __tgt_offload_entry *__start_omp_offloading_entries;
  108. /// __attribute__((visibility("hidden")))
  109. /// extern __tgt_offload_entry *__stop_omp_offloading_entries;
  110. ///
  111. /// static const char Image0[] = { <Bufs.front() contents> };
  112. /// ...
  113. /// static const char ImageN[] = { <Bufs.back() contents> };
  114. ///
  115. /// static const __tgt_device_image Images[] = {
  116. /// {
  117. /// Image0, /*ImageStart*/
  118. /// Image0 + sizeof(Image0), /*ImageEnd*/
  119. /// __start_omp_offloading_entries, /*EntriesBegin*/
  120. /// __stop_omp_offloading_entries /*EntriesEnd*/
  121. /// },
  122. /// ...
  123. /// {
  124. /// ImageN, /*ImageStart*/
  125. /// ImageN + sizeof(ImageN), /*ImageEnd*/
  126. /// __start_omp_offloading_entries, /*EntriesBegin*/
  127. /// __stop_omp_offloading_entries /*EntriesEnd*/
  128. /// }
  129. /// };
  130. ///
  131. /// static const __tgt_bin_desc BinDesc = {
  132. /// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/
  133. /// Images, /*DeviceImages*/
  134. /// __start_omp_offloading_entries, /*HostEntriesBegin*/
  135. /// __stop_omp_offloading_entries /*HostEntriesEnd*/
  136. /// };
  137. ///
  138. /// Global variable that represents BinDesc is returned.
  139. GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
  140. LLVMContext &C = M.getContext();
  141. // Create external begin/end symbols for the offload entries table.
  142. auto *EntriesB = new GlobalVariable(
  143. M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
  144. /*Initializer*/ nullptr, "__start_omp_offloading_entries");
  145. EntriesB->setVisibility(GlobalValue::HiddenVisibility);
  146. auto *EntriesE = new GlobalVariable(
  147. M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
  148. /*Initializer*/ nullptr, "__stop_omp_offloading_entries");
  149. EntriesE->setVisibility(GlobalValue::HiddenVisibility);
  150. // We assume that external begin/end symbols that we have created above will
  151. // be defined by the linker. But linker will do that only if linker inputs
  152. // have section with "omp_offloading_entries" name which is not guaranteed.
  153. // So, we just create dummy zero sized object in the offload entries section
  154. // to force linker to define those symbols.
  155. auto *DummyInit =
  156. ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
  157. auto *DummyEntry = new GlobalVariable(
  158. M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
  159. "__dummy.omp_offloading.entry");
  160. DummyEntry->setSection("omp_offloading_entries");
  161. DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
  162. auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
  163. Constant *ZeroZero[] = {Zero, Zero};
  164. // Create initializer for the images array.
  165. SmallVector<Constant *, 4u> ImagesInits;
  166. ImagesInits.reserve(Bufs.size());
  167. for (ArrayRef<char> Buf : Bufs) {
  168. auto *Data = ConstantDataArray::get(C, Buf);
  169. auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
  170. GlobalVariable::InternalLinkage, Data,
  171. ".omp_offloading.device_image");
  172. Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
  173. Image->setSection(".llvm.offloading");
  174. Image->setAlignment(Align(object::OffloadBinary::getAlignment()));
  175. auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size());
  176. Constant *ZeroSize[] = {Zero, Size};
  177. auto *ImageB =
  178. ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero);
  179. auto *ImageE =
  180. ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize);
  181. ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB,
  182. ImageE, EntriesB, EntriesE));
  183. }
  184. // Then create images array.
  185. auto *ImagesData = ConstantArray::get(
  186. ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits);
  187. auto *Images =
  188. new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
  189. GlobalValue::InternalLinkage, ImagesData,
  190. ".omp_offloading.device_images");
  191. Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
  192. auto *ImagesB =
  193. ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero);
  194. // And finally create the binary descriptor object.
  195. auto *DescInit = ConstantStruct::get(
  196. getBinDescTy(M),
  197. ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
  198. EntriesB, EntriesE);
  199. return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
  200. GlobalValue::InternalLinkage, DescInit,
  201. ".omp_offloading.descriptor");
  202. }
  203. void createRegisterFunction(Module &M, GlobalVariable *BinDesc) {
  204. LLVMContext &C = M.getContext();
  205. auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
  206. auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
  207. ".omp_offloading.descriptor_reg", &M);
  208. Func->setSection(".text.startup");
  209. // Get __tgt_register_lib function declaration.
  210. auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
  211. /*isVarArg*/ false);
  212. FunctionCallee RegFuncC =
  213. M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
  214. // Construct function body
  215. IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
  216. Builder.CreateCall(RegFuncC, BinDesc);
  217. Builder.CreateRetVoid();
  218. // Add this function to constructors.
  219. // Set priority to 1 so that __tgt_register_lib is executed AFTER
  220. // __tgt_register_requires (we want to know what requirements have been
  221. // asked for before we load a libomptarget plugin so that by the time the
  222. // plugin is loaded it can report how many devices there are which can
  223. // satisfy these requirements).
  224. appendToGlobalCtors(M, Func, /*Priority*/ 1);
  225. }
  226. void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) {
  227. LLVMContext &C = M.getContext();
  228. auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
  229. auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
  230. ".omp_offloading.descriptor_unreg", &M);
  231. Func->setSection(".text.startup");
  232. // Get __tgt_unregister_lib function declaration.
  233. auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
  234. /*isVarArg*/ false);
  235. FunctionCallee UnRegFuncC =
  236. M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
  237. // Construct function body
  238. IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
  239. Builder.CreateCall(UnRegFuncC, BinDesc);
  240. Builder.CreateRetVoid();
  241. // Add this function to global destructors.
  242. // Match priority of __tgt_register_lib
  243. appendToGlobalDtors(M, Func, /*Priority*/ 1);
  244. }
  245. // struct fatbin_wrapper {
  246. // int32_t magic;
  247. // int32_t version;
  248. // void *image;
  249. // void *reserved;
  250. //};
  251. StructType *getFatbinWrapperTy(Module &M) {
  252. LLVMContext &C = M.getContext();
  253. StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
  254. if (!FatbinTy)
  255. FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C),
  256. Type::getInt32Ty(C), Type::getInt8PtrTy(C),
  257. Type::getInt8PtrTy(C));
  258. return FatbinTy;
  259. }
  260. /// Embed the image \p Image into the module \p M so it can be found by the
  261. /// runtime.
  262. GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
  263. LLVMContext &C = M.getContext();
  264. llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C);
  265. llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
  266. // Create the global string containing the fatbinary.
  267. StringRef FatbinConstantSection =
  268. IsHIP ? ".hip_fatbin"
  269. : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
  270. auto *Data = ConstantDataArray::get(C, Image);
  271. auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
  272. GlobalVariable::InternalLinkage, Data,
  273. ".fatbin_image");
  274. Fatbin->setSection(FatbinConstantSection);
  275. // Create the fatbinary wrapper
  276. StringRef FatbinWrapperSection = IsHIP ? ".hipFatBinSegment"
  277. : Triple.isMacOSX() ? "__NV_CUDA,__fatbin"
  278. : ".nvFatBinSegment";
  279. Constant *FatbinWrapper[] = {
  280. ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic),
  281. ConstantInt::get(Type::getInt32Ty(C), 1),
  282. ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
  283. ConstantPointerNull::get(Type::getInt8PtrTy(C))};
  284. Constant *FatbinInitializer =
  285. ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
  286. auto *FatbinDesc =
  287. new GlobalVariable(M, getFatbinWrapperTy(M),
  288. /*isConstant*/ true, GlobalValue::InternalLinkage,
  289. FatbinInitializer, ".fatbin_wrapper");
  290. FatbinDesc->setSection(FatbinWrapperSection);
  291. FatbinDesc->setAlignment(Align(8));
  292. // We create a dummy entry to ensure the linker will define the begin / end
  293. // symbols. The CUDA runtime should ignore the null address if we attempt to
  294. // register it.
  295. auto *DummyInit =
  296. ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
  297. auto *DummyEntry = new GlobalVariable(
  298. M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
  299. IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry");
  300. DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
  301. DummyEntry->setSection(IsHIP ? "hip_offloading_entries"
  302. : "cuda_offloading_entries");
  303. return FatbinDesc;
  304. }
  305. /// Create the register globals function. We will iterate all of the offloading
  306. /// entries stored at the begin / end symbols and register them according to
  307. /// their type. This creates the following function in IR:
  308. ///
  309. /// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
  310. /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
  311. ///
  312. /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
  313. /// void *, void *, void *, void *, int *);
  314. /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
  315. /// int64_t, int32_t, int32_t);
  316. ///
  317. /// void __cudaRegisterTest(void **fatbinHandle) {
  318. /// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
  319. /// entry != &__stop_cuda_offloading_entries; ++entry) {
  320. /// if (!entry->size)
  321. /// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
  322. /// entry->name, -1, 0, 0, 0, 0, 0);
  323. /// else
  324. /// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
  325. /// 0, entry->size, 0, 0);
  326. /// }
  327. /// }
  328. Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
  329. LLVMContext &C = M.getContext();
  330. // Get the __cudaRegisterFunction function declaration.
  331. auto *RegFuncTy = FunctionType::get(
  332. Type::getInt32Ty(C),
  333. {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
  334. Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
  335. Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C),
  336. Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)},
  337. /*isVarArg*/ false);
  338. FunctionCallee RegFunc = M.getOrInsertFunction(
  339. IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy);
  340. // Get the __cudaRegisterVar function declaration.
  341. auto *RegVarTy = FunctionType::get(
  342. Type::getVoidTy(C),
  343. {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
  344. Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
  345. getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
  346. /*isVarArg*/ false);
  347. FunctionCallee RegVar = M.getOrInsertFunction(
  348. IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
  349. // Create the references to the start / stop symbols defined by the linker.
  350. auto *EntriesB =
  351. new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
  352. /*isConstant*/ true, GlobalValue::ExternalLinkage,
  353. /*Initializer*/ nullptr,
  354. IsHIP ? "__start_hip_offloading_entries"
  355. : "__start_cuda_offloading_entries");
  356. EntriesB->setVisibility(GlobalValue::HiddenVisibility);
  357. auto *EntriesE =
  358. new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
  359. /*isConstant*/ true, GlobalValue::ExternalLinkage,
  360. /*Initializer*/ nullptr,
  361. IsHIP ? "__stop_hip_offloading_entries"
  362. : "__stop_cuda_offloading_entries");
  363. EntriesE->setVisibility(GlobalValue::HiddenVisibility);
  364. auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C),
  365. Type::getInt8PtrTy(C)->getPointerTo(),
  366. /*isVarArg*/ false);
  367. auto *RegGlobalsFn =
  368. Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage,
  369. IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M);
  370. RegGlobalsFn->setSection(".text.startup");
  371. // Create the loop to register all the entries.
  372. IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
  373. auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
  374. auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
  375. auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
  376. auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn);
  377. auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn);
  378. auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn);
  379. auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn);
  380. auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
  381. auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
  382. auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
  383. Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
  384. Builder.SetInsertPoint(EntryBB);
  385. auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry");
  386. auto *AddrPtr =
  387. Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
  388. {ConstantInt::get(getSizeTTy(M), 0),
  389. ConstantInt::get(Type::getInt32Ty(C), 0)});
  390. auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr");
  391. auto *NamePtr =
  392. Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
  393. {ConstantInt::get(getSizeTTy(M), 0),
  394. ConstantInt::get(Type::getInt32Ty(C), 1)});
  395. auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name");
  396. auto *SizePtr =
  397. Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
  398. {ConstantInt::get(getSizeTTy(M), 0),
  399. ConstantInt::get(Type::getInt32Ty(C), 2)});
  400. auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
  401. auto *FlagsPtr =
  402. Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
  403. {ConstantInt::get(getSizeTTy(M), 0),
  404. ConstantInt::get(Type::getInt32Ty(C), 3)});
  405. auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag");
  406. auto *FnCond =
  407. Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
  408. Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
  409. // Create kernel registration code.
  410. Builder.SetInsertPoint(IfThenBB);
  411. Builder.CreateCall(RegFunc,
  412. {RegGlobalsFn->arg_begin(), Addr, Name, Name,
  413. ConstantInt::get(Type::getInt32Ty(C), -1),
  414. ConstantPointerNull::get(Type::getInt8PtrTy(C)),
  415. ConstantPointerNull::get(Type::getInt8PtrTy(C)),
  416. ConstantPointerNull::get(Type::getInt8PtrTy(C)),
  417. ConstantPointerNull::get(Type::getInt8PtrTy(C)),
  418. ConstantPointerNull::get(Type::getInt32PtrTy(C))});
  419. Builder.CreateBr(IfEndBB);
  420. Builder.SetInsertPoint(IfElseBB);
  421. auto *Switch = Builder.CreateSwitch(Flags, IfEndBB);
  422. // Create global variable registration code.
  423. Builder.SetInsertPoint(SwGlobalBB);
  424. Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
  425. ConstantInt::get(Type::getInt32Ty(C), 0), Size,
  426. ConstantInt::get(Type::getInt32Ty(C), 0),
  427. ConstantInt::get(Type::getInt32Ty(C), 0)});
  428. Builder.CreateBr(IfEndBB);
  429. Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB);
  430. // Create managed variable registration code.
  431. Builder.SetInsertPoint(SwManagedBB);
  432. Builder.CreateBr(IfEndBB);
  433. Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB);
  434. // Create surface variable registration code.
  435. Builder.SetInsertPoint(SwSurfaceBB);
  436. Builder.CreateBr(IfEndBB);
  437. Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB);
  438. // Create texture variable registration code.
  439. Builder.SetInsertPoint(SwTextureBB);
  440. Builder.CreateBr(IfEndBB);
  441. Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB);
  442. Builder.SetInsertPoint(IfEndBB);
  443. auto *NewEntry = Builder.CreateInBoundsGEP(
  444. getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
  445. auto *Cmp = Builder.CreateICmpEQ(
  446. NewEntry,
  447. ConstantExpr::getInBoundsGetElementPtr(
  448. ArrayType::get(getEntryTy(M), 0), EntriesE,
  449. ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
  450. ConstantInt::get(getSizeTTy(M), 0)})));
  451. Entry->addIncoming(
  452. ConstantExpr::getInBoundsGetElementPtr(
  453. ArrayType::get(getEntryTy(M), 0), EntriesB,
  454. ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
  455. ConstantInt::get(getSizeTTy(M), 0)})),
  456. &RegGlobalsFn->getEntryBlock());
  457. Entry->addIncoming(NewEntry, IfEndBB);
  458. Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
  459. Builder.SetInsertPoint(ExitBB);
  460. Builder.CreateRetVoid();
  461. return RegGlobalsFn;
  462. }
  463. // Create the constructor and destructor to register the fatbinary with the CUDA
  464. // runtime.
  465. void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
  466. bool IsHIP) {
  467. LLVMContext &C = M.getContext();
  468. auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
  469. auto *CtorFunc =
  470. Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
  471. IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M);
  472. CtorFunc->setSection(".text.startup");
  473. auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
  474. auto *DtorFunc =
  475. Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
  476. IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M);
  477. DtorFunc->setSection(".text.startup");
  478. // Get the __cudaRegisterFatBinary function declaration.
  479. auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(),
  480. Type::getInt8PtrTy(C),
  481. /*isVarArg*/ false);
  482. FunctionCallee RegFatbin = M.getOrInsertFunction(
  483. IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy);
  484. // Get the __cudaRegisterFatBinaryEnd function declaration.
  485. auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C),
  486. Type::getInt8PtrTy(C)->getPointerTo(),
  487. /*isVarArg*/ false);
  488. FunctionCallee RegFatbinEnd =
  489. M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
  490. // Get the __cudaUnregisterFatBinary function declaration.
  491. auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C),
  492. Type::getInt8PtrTy(C)->getPointerTo(),
  493. /*isVarArg*/ false);
  494. FunctionCallee UnregFatbin = M.getOrInsertFunction(
  495. IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
  496. UnregFatTy);
  497. auto *AtExitTy =
  498. FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(),
  499. /*isVarArg*/ false);
  500. FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
  501. auto *BinaryHandleGlobal = new llvm::GlobalVariable(
  502. M, Type::getInt8PtrTy(C)->getPointerTo(), false,
  503. llvm::GlobalValue::InternalLinkage,
  504. llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()),
  505. IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle");
  506. // Create the constructor to register this image with the runtime.
  507. IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
  508. CallInst *Handle = CtorBuilder.CreateCall(
  509. RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
  510. FatbinDesc, Type::getInt8PtrTy(C)));
  511. CtorBuilder.CreateAlignedStore(
  512. Handle, BinaryHandleGlobal,
  513. Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
  514. CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle);
  515. if (!IsHIP)
  516. CtorBuilder.CreateCall(RegFatbinEnd, Handle);
  517. CtorBuilder.CreateCall(AtExit, DtorFunc);
  518. CtorBuilder.CreateRetVoid();
  519. // Create the destructor to unregister the image with the runtime. We cannot
  520. // use a standard global destructor after CUDA 9.2 so this must be called by
  521. // `atexit()` intead.
  522. IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
  523. LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
  524. Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal,
  525. Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
  526. DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
  527. DtorBuilder.CreateRetVoid();
  528. // Add this function to constructors.
  529. appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
  530. }
  531. } // namespace
  532. Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
  533. GlobalVariable *Desc = createBinDesc(M, Images);
  534. if (!Desc)
  535. return createStringError(inconvertibleErrorCode(),
  536. "No binary descriptors created.");
  537. createRegisterFunction(M, Desc);
  538. createUnregisterFunction(M, Desc);
  539. return Error::success();
  540. }
  541. Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
  542. GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false);
  543. if (!Desc)
  544. return createStringError(inconvertibleErrorCode(),
  545. "No fatinbary section created.");
  546. createRegisterFatbinFunction(M, Desc, /* IsHIP */ false);
  547. return Error::success();
  548. }
  549. Error wrapHIPBinary(Module &M, ArrayRef<char> Image) {
  550. GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true);
  551. if (!Desc)
  552. return createStringError(inconvertibleErrorCode(),
  553. "No fatinbary section created.");
  554. createRegisterFatbinFunction(M, Desc, /* IsHIP */ true);
  555. return Error::success();
  556. }