CGGPUBuiltin.cpp 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. //===------ CGGPUBuiltin.cpp - Codegen for GPU builtins -------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Generates code for built-in GPU calls which are not runtime-specific.
  10. // (Runtime-specific codegen lives in programming model specific files.)
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "CodeGenFunction.h"
  14. #include "clang/Basic/Builtins.h"
  15. #include "llvm/IR/DataLayout.h"
  16. #include "llvm/IR/Instruction.h"
  17. #include "llvm/Support/MathExtras.h"
  18. #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
  19. using namespace clang;
  20. using namespace CodeGen;
  21. namespace {
  22. llvm::Function *GetVprintfDeclaration(llvm::Module &M) {
  23. llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()),
  24. llvm::Type::getInt8PtrTy(M.getContext())};
  25. llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(
  26. llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);
  27. if (auto *F = M.getFunction("vprintf")) {
  28. // Our CUDA system header declares vprintf with the right signature, so
  29. // nobody else should have been able to declare vprintf with a bogus
  30. // signature.
  31. assert(F->getFunctionType() == VprintfFuncType);
  32. return F;
  33. }
  34. // vprintf doesn't already exist; create a declaration and insert it into the
  35. // module.
  36. return llvm::Function::Create(
  37. VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, "vprintf", &M);
  38. }
  39. llvm::Function *GetOpenMPVprintfDeclaration(CodeGenModule &CGM) {
  40. const char *Name = "__llvm_omp_vprintf";
  41. llvm::Module &M = CGM.getModule();
  42. llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()),
  43. llvm::Type::getInt8PtrTy(M.getContext()),
  44. llvm::Type::getInt32Ty(M.getContext())};
  45. llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(
  46. llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);
  47. if (auto *F = M.getFunction(Name)) {
  48. if (F->getFunctionType() != VprintfFuncType) {
  49. CGM.Error(SourceLocation(),
  50. "Invalid type declaration for __llvm_omp_vprintf");
  51. return nullptr;
  52. }
  53. return F;
  54. }
  55. return llvm::Function::Create(
  56. VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, Name, &M);
  57. }
  58. // Transforms a call to printf into a call to the NVPTX vprintf syscall (which
  59. // isn't particularly special; it's invoked just like a regular function).
  60. // vprintf takes two args: A format string, and a pointer to a buffer containing
  61. // the varargs.
  62. //
  63. // For example, the call
  64. //
  65. // printf("format string", arg1, arg2, arg3);
  66. //
  67. // is converted into something resembling
  68. //
  69. // struct Tmp {
  70. // Arg1 a1;
  71. // Arg2 a2;
  72. // Arg3 a3;
  73. // };
  74. // char* buf = alloca(sizeof(Tmp));
  75. // *(Tmp*)buf = {a1, a2, a3};
  76. // vprintf("format string", buf);
  77. //
  78. // buf is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of the
  79. // args is itself aligned to its preferred alignment.
  80. //
  81. // Note that by the time this function runs, E's args have already undergone the
  82. // standard C vararg promotion (short -> int, float -> double, etc.).
  83. std::pair<llvm::Value *, llvm::TypeSize>
  84. packArgsIntoNVPTXFormatBuffer(CodeGenFunction *CGF, const CallArgList &Args) {
  85. const llvm::DataLayout &DL = CGF->CGM.getDataLayout();
  86. llvm::LLVMContext &Ctx = CGF->CGM.getLLVMContext();
  87. CGBuilderTy &Builder = CGF->Builder;
  88. // Construct and fill the args buffer that we'll pass to vprintf.
  89. if (Args.size() <= 1) {
  90. // If there are no args, pass a null pointer and size 0
  91. llvm::Value * BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx));
  92. return {BufferPtr, llvm::TypeSize::Fixed(0)};
  93. } else {
  94. llvm::SmallVector<llvm::Type *, 8> ArgTypes;
  95. for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I)
  96. ArgTypes.push_back(Args[I].getRValue(*CGF).getScalarVal()->getType());
  97. // Using llvm::StructType is correct only because printf doesn't accept
  98. // aggregates. If we had to handle aggregates here, we'd have to manually
  99. // compute the offsets within the alloca -- we wouldn't be able to assume
  100. // that the alignment of the llvm type was the same as the alignment of the
  101. // clang type.
  102. llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args");
  103. llvm::Value *Alloca = CGF->CreateTempAlloca(AllocaTy);
  104. for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) {
  105. llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1);
  106. llvm::Value *Arg = Args[I].getRValue(*CGF).getScalarVal();
  107. Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlign(Arg->getType()));
  108. }
  109. llvm::Value *BufferPtr =
  110. Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx));
  111. return {BufferPtr, DL.getTypeAllocSize(AllocaTy)};
  112. }
  113. }
  114. bool containsNonScalarVarargs(CodeGenFunction *CGF, CallArgList Args) {
  115. return llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) {
  116. return !A.getRValue(*CGF).isScalar();
  117. });
  118. }
  119. RValue EmitDevicePrintfCallExpr(const CallExpr *E, CodeGenFunction *CGF,
  120. llvm::Function *Decl, bool WithSizeArg) {
  121. CodeGenModule &CGM = CGF->CGM;
  122. CGBuilderTy &Builder = CGF->Builder;
  123. assert(E->getBuiltinCallee() == Builtin::BIprintf);
  124. assert(E->getNumArgs() >= 1); // printf always has at least one arg.
  125. // Uses the same format as nvptx for the argument packing, but also passes
  126. // an i32 for the total size of the passed pointer
  127. CallArgList Args;
  128. CGF->EmitCallArgs(Args,
  129. E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
  130. E->arguments(), E->getDirectCallee(),
  131. /* ParamsToSkip = */ 0);
  132. // We don't know how to emit non-scalar varargs.
  133. if (containsNonScalarVarargs(CGF, Args)) {
  134. CGM.ErrorUnsupported(E, "non-scalar arg to printf");
  135. return RValue::get(llvm::ConstantInt::get(CGF->IntTy, 0));
  136. }
  137. auto r = packArgsIntoNVPTXFormatBuffer(CGF, Args);
  138. llvm::Value *BufferPtr = r.first;
  139. llvm::SmallVector<llvm::Value *, 3> Vec = {
  140. Args[0].getRValue(*CGF).getScalarVal(), BufferPtr};
  141. if (WithSizeArg) {
  142. // Passing > 32bit of data as a local alloca doesn't work for nvptx or
  143. // amdgpu
  144. llvm::Constant *Size =
  145. llvm::ConstantInt::get(llvm::Type::getInt32Ty(CGM.getLLVMContext()),
  146. static_cast<uint32_t>(r.second.getFixedValue()));
  147. Vec.push_back(Size);
  148. }
  149. return RValue::get(Builder.CreateCall(Decl, Vec));
  150. }
  151. } // namespace
  152. RValue CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E) {
  153. assert(getTarget().getTriple().isNVPTX());
  154. return EmitDevicePrintfCallExpr(
  155. E, this, GetVprintfDeclaration(CGM.getModule()), false);
  156. }
  157. RValue CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E) {
  158. assert(getTarget().getTriple().getArch() == llvm::Triple::amdgcn);
  159. assert(E->getBuiltinCallee() == Builtin::BIprintf ||
  160. E->getBuiltinCallee() == Builtin::BI__builtin_printf);
  161. assert(E->getNumArgs() >= 1); // printf always has at least one arg.
  162. CallArgList CallArgs;
  163. EmitCallArgs(CallArgs,
  164. E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
  165. E->arguments(), E->getDirectCallee(),
  166. /* ParamsToSkip = */ 0);
  167. SmallVector<llvm::Value *, 8> Args;
  168. for (auto A : CallArgs) {
  169. // We don't know how to emit non-scalar varargs.
  170. if (!A.getRValue(*this).isScalar()) {
  171. CGM.ErrorUnsupported(E, "non-scalar arg to printf");
  172. return RValue::get(llvm::ConstantInt::get(IntTy, -1));
  173. }
  174. llvm::Value *Arg = A.getRValue(*this).getScalarVal();
  175. Args.push_back(Arg);
  176. }
  177. llvm::IRBuilder<> IRB(Builder.GetInsertBlock(), Builder.GetInsertPoint());
  178. IRB.SetCurrentDebugLocation(Builder.getCurrentDebugLocation());
  179. auto Printf = llvm::emitAMDGPUPrintfCall(IRB, Args);
  180. Builder.SetInsertPoint(IRB.GetInsertBlock(), IRB.GetInsertPoint());
  181. return RValue::get(Printf);
  182. }
  183. RValue CodeGenFunction::EmitOpenMPDevicePrintfCallExpr(const CallExpr *E) {
  184. assert(getTarget().getTriple().isNVPTX() ||
  185. getTarget().getTriple().isAMDGCN());
  186. return EmitDevicePrintfCallExpr(E, this, GetOpenMPVprintfDeclaration(CGM),
  187. true);
  188. }