CGHLSLRuntime.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This provides an abstract class for HLSL code generation. Concrete
  10. // subclasses of this implement code generation for specific HLSL
  11. // runtime libraries.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "CGHLSLRuntime.h"
  15. #include "CGDebugInfo.h"
  16. #include "CodeGenModule.h"
  17. #include "clang/AST/Decl.h"
  18. #include "clang/Basic/TargetOptions.h"
  19. #include "llvm/IR/IntrinsicsDirectX.h"
  20. #include "llvm/IR/Metadata.h"
  21. #include "llvm/IR/Module.h"
  22. #include "llvm/Support/FormatVariadic.h"
  23. using namespace clang;
  24. using namespace CodeGen;
  25. using namespace clang::hlsl;
  26. using namespace llvm;
  27. namespace {
  28. void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
  29. // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
  30. // Assume ValVersionStr is legal here.
  31. VersionTuple Version;
  32. if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
  33. Version.getSubminor() || !Version.getMinor()) {
  34. return;
  35. }
  36. uint64_t Major = Version.getMajor();
  37. uint64_t Minor = *Version.getMinor();
  38. auto &Ctx = M.getContext();
  39. IRBuilder<> B(M.getContext());
  40. MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
  41. ConstantAsMetadata::get(B.getInt32(Minor))});
  42. StringRef DXILValKey = "dx.valver";
  43. auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
  44. DXILValMD->addOperand(Val);
  45. }
  46. void addDisableOptimizations(llvm::Module &M) {
  47. StringRef Key = "dx.disable_optimizations";
  48. M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
  49. }
  50. // cbuffer will be translated into global variable in special address space.
  51. // If translate into C,
  52. // cbuffer A {
  53. // float a;
  54. // float b;
  55. // }
  56. // float foo() { return a + b; }
  57. //
  58. // will be translated into
  59. //
  60. // struct A {
  61. // float a;
  62. // float b;
  63. // } cbuffer_A __attribute__((address_space(4)));
  64. // float foo() { return cbuffer_A.a + cbuffer_A.b; }
  65. //
  66. // layoutBuffer will create the struct A type.
  67. // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
  68. // and cbuffer_A.b.
  69. //
  70. void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
  71. if (Buf.Constants.empty())
  72. return;
  73. std::vector<llvm::Type *> EltTys;
  74. for (auto &Const : Buf.Constants) {
  75. GlobalVariable *GV = Const.first;
  76. Const.second = EltTys.size();
  77. llvm::Type *Ty = GV->getValueType();
  78. EltTys.emplace_back(Ty);
  79. }
  80. Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
  81. }
  82. GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
  83. // Create global variable for CB.
  84. GlobalVariable *CBGV = new GlobalVariable(
  85. Buf.LayoutStruct, /*isConstant*/ true,
  86. GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
  87. llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
  88. GlobalValue::NotThreadLocal);
  89. IRBuilder<> B(CBGV->getContext());
  90. Value *ZeroIdx = B.getInt32(0);
  91. // Replace Const use with CB use.
  92. for (auto &[GV, Offset] : Buf.Constants) {
  93. Value *GEP =
  94. B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
  95. assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
  96. "constant type mismatch");
  97. // Replace.
  98. GV->replaceAllUsesWith(GEP);
  99. // Erase GV.
  100. GV->removeDeadConstantUsers();
  101. GV->eraseFromParent();
  102. }
  103. return CBGV;
  104. }
  105. } // namespace
  106. void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
  107. if (D->getStorageClass() == SC_Static) {
  108. // For static inside cbuffer, take as global static.
  109. // Don't add to cbuffer.
  110. CGM.EmitGlobal(D);
  111. return;
  112. }
  113. auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
  114. // Add debug info for constVal.
  115. if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
  116. if (CGM.getCodeGenOpts().getDebugInfo() >=
  117. codegenoptions::DebugInfoKind::LimitedDebugInfo)
  118. DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
  119. // FIXME: support packoffset.
  120. // See https://github.com/llvm/llvm-project/issues/57914.
  121. uint32_t Offset = 0;
  122. bool HasUserOffset = false;
  123. unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
  124. CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
  125. }
  126. void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
  127. for (Decl *it : DC->decls()) {
  128. if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
  129. addConstant(ConstDecl, CB);
  130. } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
  131. // Nothing to do for this declaration.
  132. } else if (isa<FunctionDecl>(it)) {
  133. // A function within an cbuffer is effectively a top-level function,
  134. // as it only refers to globally scoped declarations.
  135. CGM.EmitTopLevelDecl(it);
  136. }
  137. }
  138. }
  139. void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
  140. Buffers.emplace_back(Buffer(D));
  141. addBufferDecls(D, Buffers.back());
  142. }
  143. void CGHLSLRuntime::finishCodeGen() {
  144. auto &TargetOpts = CGM.getTarget().getTargetOpts();
  145. llvm::Module &M = CGM.getModule();
  146. Triple T(M.getTargetTriple());
  147. if (T.getArch() == Triple::ArchType::dxil)
  148. addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
  149. generateGlobalCtorDtorCalls();
  150. if (CGM.getCodeGenOpts().OptimizationLevel == 0)
  151. addDisableOptimizations(M);
  152. const DataLayout &DL = M.getDataLayout();
  153. for (auto &Buf : Buffers) {
  154. layoutBuffer(Buf, DL);
  155. GlobalVariable *GV = replaceBuffer(Buf);
  156. M.getGlobalList().push_back(GV);
  157. llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
  158. ? llvm::hlsl::ResourceClass::CBuffer
  159. : llvm::hlsl::ResourceClass::SRV;
  160. llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
  161. ? llvm::hlsl::ResourceKind::CBuffer
  162. : llvm::hlsl::ResourceKind::TBuffer;
  163. std::string TyName =
  164. Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty";
  165. addBufferResourceAnnotation(GV, TyName, RC, RK, Buf.Binding);
  166. }
  167. }
  168. CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
  169. : Name(D->getName()), IsCBuffer(D->isCBuffer()),
  170. Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
  171. void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
  172. llvm::StringRef TyName,
  173. llvm::hlsl::ResourceClass RC,
  174. llvm::hlsl::ResourceKind RK,
  175. BufferResBinding &Binding) {
  176. llvm::Module &M = CGM.getModule();
  177. NamedMDNode *ResourceMD = nullptr;
  178. switch (RC) {
  179. case llvm::hlsl::ResourceClass::UAV:
  180. ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
  181. break;
  182. case llvm::hlsl::ResourceClass::SRV:
  183. ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
  184. break;
  185. case llvm::hlsl::ResourceClass::CBuffer:
  186. ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
  187. break;
  188. default:
  189. assert(false && "Unsupported buffer type!");
  190. return;
  191. }
  192. assert(ResourceMD != nullptr &&
  193. "ResourceMD must have been set by the switch above.");
  194. llvm::hlsl::FrontendResource Res(
  195. GV, TyName, RK, Binding.Reg.value_or(UINT_MAX), Binding.Space);
  196. ResourceMD->addOperand(Res.getMetadata());
  197. }
  198. static llvm::hlsl::ResourceKind
  199. castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK) {
  200. switch (RK) {
  201. case HLSLResourceAttr::ResourceKind::Texture1D:
  202. return llvm::hlsl::ResourceKind::Texture1D;
  203. case HLSLResourceAttr::ResourceKind::Texture2D:
  204. return llvm::hlsl::ResourceKind::Texture2D;
  205. case HLSLResourceAttr::ResourceKind::Texture2DMS:
  206. return llvm::hlsl::ResourceKind::Texture2DMS;
  207. case HLSLResourceAttr::ResourceKind::Texture3D:
  208. return llvm::hlsl::ResourceKind::Texture3D;
  209. case HLSLResourceAttr::ResourceKind::TextureCube:
  210. return llvm::hlsl::ResourceKind::TextureCube;
  211. case HLSLResourceAttr::ResourceKind::Texture1DArray:
  212. return llvm::hlsl::ResourceKind::Texture1DArray;
  213. case HLSLResourceAttr::ResourceKind::Texture2DArray:
  214. return llvm::hlsl::ResourceKind::Texture2DArray;
  215. case HLSLResourceAttr::ResourceKind::Texture2DMSArray:
  216. return llvm::hlsl::ResourceKind::Texture2DMSArray;
  217. case HLSLResourceAttr::ResourceKind::TextureCubeArray:
  218. return llvm::hlsl::ResourceKind::TextureCubeArray;
  219. case HLSLResourceAttr::ResourceKind::TypedBuffer:
  220. return llvm::hlsl::ResourceKind::TypedBuffer;
  221. case HLSLResourceAttr::ResourceKind::RawBuffer:
  222. return llvm::hlsl::ResourceKind::RawBuffer;
  223. case HLSLResourceAttr::ResourceKind::StructuredBuffer:
  224. return llvm::hlsl::ResourceKind::StructuredBuffer;
  225. case HLSLResourceAttr::ResourceKind::CBufferKind:
  226. return llvm::hlsl::ResourceKind::CBuffer;
  227. case HLSLResourceAttr::ResourceKind::SamplerKind:
  228. return llvm::hlsl::ResourceKind::Sampler;
  229. case HLSLResourceAttr::ResourceKind::TBuffer:
  230. return llvm::hlsl::ResourceKind::TBuffer;
  231. case HLSLResourceAttr::ResourceKind::RTAccelerationStructure:
  232. return llvm::hlsl::ResourceKind::RTAccelerationStructure;
  233. case HLSLResourceAttr::ResourceKind::FeedbackTexture2D:
  234. return llvm::hlsl::ResourceKind::FeedbackTexture2D;
  235. case HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray:
  236. return llvm::hlsl::ResourceKind::FeedbackTexture2DArray;
  237. }
  238. // Make sure to update HLSLResourceAttr::ResourceKind when add new Kind to
  239. // hlsl::ResourceKind. Assume FeedbackTexture2DArray is the last enum for
  240. // HLSLResourceAttr::ResourceKind.
  241. static_assert(
  242. static_cast<uint32_t>(
  243. HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray) ==
  244. (static_cast<uint32_t>(llvm::hlsl::ResourceKind::NumEntries) - 2));
  245. llvm_unreachable("all switch cases should be covered");
  246. }
  247. void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
  248. const Type *Ty = D->getType()->getPointeeOrArrayElementType();
  249. if (!Ty)
  250. return;
  251. const auto *RD = Ty->getAsCXXRecordDecl();
  252. if (!RD)
  253. return;
  254. const auto *Attr = RD->getAttr<HLSLResourceAttr>();
  255. if (!Attr)
  256. return;
  257. HLSLResourceAttr::ResourceClass RC = Attr->getResourceType();
  258. llvm::hlsl::ResourceKind RK =
  259. castResourceShapeToResourceKind(Attr->getResourceShape());
  260. QualType QT(Ty, 0);
  261. BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
  262. addBufferResourceAnnotation(GV, QT.getAsString(),
  263. static_cast<llvm::hlsl::ResourceClass>(RC), RK,
  264. Binding);
  265. }
  266. CGHLSLRuntime::BufferResBinding::BufferResBinding(
  267. HLSLResourceBindingAttr *Binding) {
  268. if (Binding) {
  269. llvm::APInt RegInt(64, 0);
  270. Binding->getSlot().substr(1).getAsInteger(10, RegInt);
  271. Reg = RegInt.getLimitedValue();
  272. llvm::APInt SpaceInt(64, 0);
  273. Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
  274. Space = SpaceInt.getLimitedValue();
  275. } else {
  276. Space = 0;
  277. }
  278. }
  279. void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
  280. const FunctionDecl *FD, llvm::Function *Fn) {
  281. const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
  282. assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
  283. const StringRef ShaderAttrKindStr = "hlsl.shader";
  284. Fn->addFnAttr(ShaderAttrKindStr,
  285. ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
  286. if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
  287. const StringRef NumThreadsKindStr = "hlsl.numthreads";
  288. std::string NumThreadsStr =
  289. formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
  290. NumThreadsAttr->getZ());
  291. Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
  292. }
  293. }
  294. static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
  295. if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
  296. Value *Result = PoisonValue::get(Ty);
  297. for (unsigned I = 0; I < VT->getNumElements(); ++I) {
  298. Value *Elt = B.CreateCall(F, {B.getInt32(I)});
  299. Result = B.CreateInsertElement(Result, Elt, I);
  300. }
  301. return Result;
  302. }
  303. return B.CreateCall(F, {B.getInt32(0)});
  304. }
  305. llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
  306. const ParmVarDecl &D,
  307. llvm::Type *Ty) {
  308. assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
  309. if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
  310. llvm::Function *DxGroupIndex =
  311. CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
  312. return B.CreateCall(FunctionCallee(DxGroupIndex));
  313. }
  314. if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
  315. llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id);
  316. return buildVectorInput(B, DxThreadID, Ty);
  317. }
  318. assert(false && "Unhandled parameter attribute");
  319. return nullptr;
  320. }
  321. void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
  322. llvm::Function *Fn) {
  323. llvm::Module &M = CGM.getModule();
  324. llvm::LLVMContext &Ctx = M.getContext();
  325. auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
  326. Function *EntryFn =
  327. Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
  328. // Copy function attributes over, we have no argument or return attributes
  329. // that can be valid on the real entry.
  330. AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
  331. Fn->getAttributes().getFnAttrs());
  332. EntryFn->setAttributes(NewAttrs);
  333. setHLSLEntryAttributes(FD, EntryFn);
  334. // Set the called function as internal linkage.
  335. Fn->setLinkage(GlobalValue::InternalLinkage);
  336. BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
  337. IRBuilder<> B(BB);
  338. llvm::SmallVector<Value *> Args;
  339. // FIXME: support struct parameters where semantics are on members.
  340. // See: https://github.com/llvm/llvm-project/issues/57874
  341. unsigned SRetOffset = 0;
  342. for (const auto &Param : Fn->args()) {
  343. if (Param.hasStructRetAttr()) {
  344. // FIXME: support output.
  345. // See: https://github.com/llvm/llvm-project/issues/57874
  346. SRetOffset = 1;
  347. Args.emplace_back(PoisonValue::get(Param.getType()));
  348. continue;
  349. }
  350. const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
  351. Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
  352. }
  353. CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
  354. (void)CI;
  355. // FIXME: Handle codegen for return type semantics.
  356. // See: https://github.com/llvm/llvm-project/issues/57875
  357. B.CreateRetVoid();
  358. }
  359. static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
  360. bool CtorOrDtor) {
  361. const auto *GV =
  362. M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
  363. if (!GV)
  364. return;
  365. const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
  366. if (!CA)
  367. return;
  368. // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
  369. // HLSL neither supports priorities or COMDat values, so we will check those
  370. // in an assert but not handle them.
  371. llvm::SmallVector<Function *> CtorFns;
  372. for (const auto &Ctor : CA->operands()) {
  373. if (isa<ConstantAggregateZero>(Ctor))
  374. continue;
  375. ConstantStruct *CS = cast<ConstantStruct>(Ctor);
  376. assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
  377. "HLSL doesn't support setting priority for global ctors.");
  378. assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
  379. "HLSL doesn't support COMDat for global ctors.");
  380. Fns.push_back(cast<Function>(CS->getOperand(1)));
  381. }
  382. }
  383. void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
  384. llvm::Module &M = CGM.getModule();
  385. SmallVector<Function *> CtorFns;
  386. SmallVector<Function *> DtorFns;
  387. gatherFunctions(CtorFns, M, true);
  388. gatherFunctions(DtorFns, M, false);
  389. // Insert a call to the global constructor at the beginning of the entry block
  390. // to externally exported functions. This is a bit of a hack, but HLSL allows
  391. // global constructors, but doesn't support driver initialization of globals.
  392. for (auto &F : M.functions()) {
  393. if (!F.hasFnAttribute("hlsl.shader"))
  394. continue;
  395. IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
  396. for (auto *Fn : CtorFns)
  397. B.CreateCall(FunctionCallee(Fn));
  398. // Insert global dtors before the terminator of the last instruction
  399. B.SetInsertPoint(F.back().getTerminator());
  400. for (auto *Fn : DtorFns)
  401. B.CreateCall(FunctionCallee(Fn));
  402. }
  403. // No need to keep global ctors/dtors for non-lib profile after call to
  404. // ctors/dtors added for entry.
  405. Triple T(M.getTargetTriple());
  406. if (T.getEnvironment() != Triple::EnvironmentType::Library) {
  407. if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
  408. GV->eraseFromParent();
  409. if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
  410. GV->eraseFromParent();
  411. }
  412. }