NVPTXLowerArgs.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. //
  10. // Arguments to kernel and device functions are passed via param space,
  11. // which imposes certain restrictions:
  12. // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
  13. //
  14. // Kernel parameters are read-only and accessible only via ld.param
  15. // instruction, directly or via a pointer. Pointers to kernel
  16. // arguments can't be converted to generic address space.
  17. //
  18. // Device function parameters are directly accessible via
  19. // ld.param/st.param, but taking the address of one returns a pointer
  20. // to a copy created in local space which *can't* be used with
  21. // ld.param/st.param.
  22. //
  23. // Copying a byval struct into local memory in IR allows us to enforce
  24. // the param space restrictions, gives the rest of IR a pointer w/o
  25. // param space restrictions, and gives us an opportunity to eliminate
  26. // the copy.
  27. //
  28. // Pointer arguments to kernel functions need more work to be lowered:
  29. //
  30. // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
  31. // global address space. This allows later optimizations to emit
  32. // ld.global.*/st.global.* for accessing these pointer arguments. For
  33. // example,
  34. //
  35. // define void @foo(float* %input) {
  36. // %v = load float, float* %input, align 4
  37. // ...
  38. // }
  39. //
  40. // becomes
  41. //
  42. // define void @foo(float* %input) {
  43. // %input2 = addrspacecast float* %input to float addrspace(1)*
  44. // %input3 = addrspacecast float addrspace(1)* %input2 to float*
  45. // %v = load float, float* %input3, align 4
  46. // ...
  47. // }
  48. //
  49. // Later, NVPTXInferAddressSpaces will optimize it to
  50. //
  51. // define void @foo(float* %input) {
  52. // %input2 = addrspacecast float* %input to float addrspace(1)*
  53. // %v = load float, float addrspace(1)* %input2, align 4
  54. // ...
  55. // }
  56. //
  57. // 2. Convert pointers in a byval kernel parameter to pointers in the global
  58. // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
  59. //
  60. // struct S {
  61. // int *x;
  62. // int *y;
  63. // };
  64. // __global__ void foo(S s) {
  65. // int *b = s.y;
  66. // // use b
  67. // }
  68. //
  69. // "b" points to the global address space. In the IR level,
  70. //
  71. // define void @foo({i32*, i32*}* byval %input) {
  72. // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
  73. // %b = load i32*, i32** %b_ptr
  74. // ; use %b
  75. // }
  76. //
  77. // becomes
  78. //
  79. // define void @foo({i32*, i32*}* byval %input) {
  80. // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
  81. // %b = load i32*, i32** %b_ptr
  82. // %b_global = addrspacecast i32* %b to i32 addrspace(1)*
  83. // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
  84. // ; use %b_generic
  85. // }
  86. //
  87. // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
  88. // cancel the addrspacecast pair this pass emits.
  89. //===----------------------------------------------------------------------===//
  90. #include "NVPTX.h"
  91. #include "NVPTXTargetMachine.h"
  92. #include "NVPTXUtilities.h"
  93. #include "MCTargetDesc/NVPTXBaseInfo.h"
  94. #include "llvm/Analysis/ValueTracking.h"
  95. #include "llvm/IR/Function.h"
  96. #include "llvm/IR/Instructions.h"
  97. #include "llvm/IR/Module.h"
  98. #include "llvm/IR/Type.h"
  99. #include "llvm/Pass.h"
  100. #define DEBUG_TYPE "nvptx-lower-args"
  101. using namespace llvm;
  102. namespace llvm {
  103. void initializeNVPTXLowerArgsPass(PassRegistry &);
  104. }
  105. namespace {
  106. class NVPTXLowerArgs : public FunctionPass {
  107. bool runOnFunction(Function &F) override;
  108. bool runOnKernelFunction(Function &F);
  109. bool runOnDeviceFunction(Function &F);
  110. // handle byval parameters
  111. void handleByValParam(Argument *Arg);
  112. // Knowing Ptr must point to the global address space, this function
  113. // addrspacecasts Ptr to global and then back to generic. This allows
  114. // NVPTXInferAddressSpaces to fold the global-to-generic cast into
  115. // loads/stores that appear later.
  116. void markPointerAsGlobal(Value *Ptr);
  117. public:
  118. static char ID; // Pass identification, replacement for typeid
  119. NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr)
  120. : FunctionPass(ID), TM(TM) {}
  121. StringRef getPassName() const override {
  122. return "Lower pointer arguments of CUDA kernels";
  123. }
  124. private:
  125. const NVPTXTargetMachine *TM;
  126. };
  127. } // namespace
  128. char NVPTXLowerArgs::ID = 1;
  129. INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args",
  130. "Lower arguments (NVPTX)", false, false)
  131. // =============================================================================
  132. // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
  133. // and we can't guarantee that the only accesses are loads,
  134. // then add the following instructions to the first basic block:
  135. //
  136. // %temp = alloca %struct.x, align 8
  137. // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
  138. // %tv = load %struct.x addrspace(101)* %tempd
  139. // store %struct.x %tv, %struct.x* %temp, align 8
  140. //
  141. // The above code allocates some space in the stack and copies the incoming
  142. // struct from param space to local space.
  143. // Then replace all occurrences of %d by %temp.
  144. //
  145. // In case we know that all users are GEPs or Loads, replace them with the same
  146. // ones in parameter AS, so we can access them using ld.param.
  147. // =============================================================================
  148. // Replaces the \p OldUser instruction with the same in parameter AS.
  149. // Only Load and GEP are supported.
  150. static void convertToParamAS(Value *OldUser, Value *Param) {
  151. Instruction *I = dyn_cast<Instruction>(OldUser);
  152. assert(I && "OldUser must be an instruction");
  153. struct IP {
  154. Instruction *OldInstruction;
  155. Value *NewParam;
  156. };
  157. SmallVector<IP> ItemsToConvert = {{I, Param}};
  158. SmallVector<Instruction *> InstructionsToDelete;
  159. auto CloneInstInParamAS = [](const IP &I) -> Value * {
  160. if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
  161. LI->setOperand(0, I.NewParam);
  162. return LI;
  163. }
  164. if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
  165. SmallVector<Value *, 4> Indices(GEP->indices());
  166. auto *NewGEP = GetElementPtrInst::Create(GEP->getSourceElementType(),
  167. I.NewParam, Indices,
  168. GEP->getName(), GEP);
  169. NewGEP->setIsInBounds(GEP->isInBounds());
  170. return NewGEP;
  171. }
  172. if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
  173. auto *NewBCType = PointerType::getWithSamePointeeType(
  174. cast<PointerType>(BC->getType()), ADDRESS_SPACE_PARAM);
  175. return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
  176. BC->getName(), BC);
  177. }
  178. if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
  179. assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
  180. (void)ASC;
  181. // Just pass through the argument, the old ASC is no longer needed.
  182. return I.NewParam;
  183. }
  184. llvm_unreachable("Unsupported instruction");
  185. };
  186. while (!ItemsToConvert.empty()) {
  187. IP I = ItemsToConvert.pop_back_val();
  188. Value *NewInst = CloneInstInParamAS(I);
  189. if (NewInst && NewInst != I.OldInstruction) {
  190. // We've created a new instruction. Queue users of the old instruction to
  191. // be converted and the instruction itself to be deleted. We can't delete
  192. // the old instruction yet, because it's still in use by a load somewhere.
  193. llvm::for_each(
  194. I.OldInstruction->users(), [NewInst, &ItemsToConvert](Value *V) {
  195. ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
  196. });
  197. InstructionsToDelete.push_back(I.OldInstruction);
  198. }
  199. }
  200. // Now we know that all argument loads are using addresses in parameter space
  201. // and we can finally remove the old instructions in generic AS. Instructions
  202. // scheduled for removal should be processed in reverse order so the ones
  203. // closest to the load are deleted first. Otherwise they may still be in use.
  204. // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
  205. // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
  206. // the BitCast.
  207. llvm::for_each(reverse(InstructionsToDelete),
  208. [](Instruction *I) { I->eraseFromParent(); });
  209. }
  210. void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
  211. Function *Func = Arg->getParent();
  212. Instruction *FirstInst = &(Func->getEntryBlock().front());
  213. PointerType *PType = dyn_cast<PointerType>(Arg->getType());
  214. assert(PType && "Expecting pointer type in handleByValParam");
  215. Type *StructType = PType->getPointerElementType();
  216. auto IsALoadChain = [&](Value *Start) {
  217. SmallVector<Value *, 16> ValuesToCheck = {Start};
  218. auto IsALoadChainInstr = [](Value *V) -> bool {
  219. if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
  220. return true;
  221. // ASC to param space are OK, too -- we'll just strip them.
  222. if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
  223. if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
  224. return true;
  225. }
  226. return false;
  227. };
  228. while (!ValuesToCheck.empty()) {
  229. Value *V = ValuesToCheck.pop_back_val();
  230. if (!IsALoadChainInstr(V)) {
  231. LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
  232. << "\n");
  233. (void)Arg;
  234. return false;
  235. }
  236. if (!isa<LoadInst>(V))
  237. llvm::append_range(ValuesToCheck, V->users());
  238. }
  239. return true;
  240. };
  241. if (llvm::all_of(Arg->users(), IsALoadChain)) {
  242. // Convert all loads and intermediate operations to use parameter AS and
  243. // skip creation of a local copy of the argument.
  244. SmallVector<User *, 16> UsersToUpdate(Arg->users());
  245. Value *ArgInParamAS = new AddrSpaceCastInst(
  246. Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
  247. FirstInst);
  248. llvm::for_each(UsersToUpdate, [ArgInParamAS](Value *V) {
  249. convertToParamAS(V, ArgInParamAS);
  250. });
  251. LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
  252. return;
  253. }
  254. // Otherwise we have to create a temporary copy.
  255. const DataLayout &DL = Func->getParent()->getDataLayout();
  256. unsigned AS = DL.getAllocaAddrSpace();
  257. AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
  258. // Set the alignment to alignment of the byval parameter. This is because,
  259. // later load/stores assume that alignment, and we are going to replace
  260. // the use of the byval parameter with this alloca instruction.
  261. AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
  262. .getValueOr(DL.getPrefTypeAlign(StructType)));
  263. Arg->replaceAllUsesWith(AllocA);
  264. Value *ArgInParam = new AddrSpaceCastInst(
  265. Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
  266. FirstInst);
  267. // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
  268. // addrspacecast preserves alignment. Since params are constant, this load is
  269. // definitely not volatile.
  270. LoadInst *LI =
  271. new LoadInst(StructType, ArgInParam, Arg->getName(),
  272. /*isVolatile=*/false, AllocA->getAlign(), FirstInst);
  273. new StoreInst(LI, AllocA, FirstInst);
  274. }
  275. void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
  276. if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
  277. return;
  278. // Deciding where to emit the addrspacecast pair.
  279. BasicBlock::iterator InsertPt;
  280. if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
  281. // Insert at the functon entry if Ptr is an argument.
  282. InsertPt = Arg->getParent()->getEntryBlock().begin();
  283. } else {
  284. // Insert right after Ptr if Ptr is an instruction.
  285. InsertPt = ++cast<Instruction>(Ptr)->getIterator();
  286. assert(InsertPt != InsertPt->getParent()->end() &&
  287. "We don't call this function with Ptr being a terminator.");
  288. }
  289. Instruction *PtrInGlobal = new AddrSpaceCastInst(
  290. Ptr,
  291. PointerType::getWithSamePointeeType(cast<PointerType>(Ptr->getType()),
  292. ADDRESS_SPACE_GLOBAL),
  293. Ptr->getName(), &*InsertPt);
  294. Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
  295. Ptr->getName(), &*InsertPt);
  296. // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
  297. Ptr->replaceAllUsesWith(PtrInGeneric);
  298. PtrInGlobal->setOperand(0, Ptr);
  299. }
  300. // =============================================================================
  301. // Main function for this pass.
  302. // =============================================================================
  303. bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
  304. if (TM && TM->getDrvInterface() == NVPTX::CUDA) {
  305. // Mark pointers in byval structs as global.
  306. for (auto &B : F) {
  307. for (auto &I : B) {
  308. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
  309. if (LI->getType()->isPointerTy()) {
  310. Value *UO = getUnderlyingObject(LI->getPointerOperand());
  311. if (Argument *Arg = dyn_cast<Argument>(UO)) {
  312. if (Arg->hasByValAttr()) {
  313. // LI is a load from a pointer within a byval kernel parameter.
  314. markPointerAsGlobal(LI);
  315. }
  316. }
  317. }
  318. }
  319. }
  320. }
  321. }
  322. LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
  323. for (Argument &Arg : F.args()) {
  324. if (Arg.getType()->isPointerTy()) {
  325. if (Arg.hasByValAttr())
  326. handleByValParam(&Arg);
  327. else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
  328. markPointerAsGlobal(&Arg);
  329. }
  330. }
  331. return true;
  332. }
  333. // Device functions only need to copy byval args into local memory.
  334. bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
  335. LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
  336. for (Argument &Arg : F.args())
  337. if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
  338. handleByValParam(&Arg);
  339. return true;
  340. }
  341. bool NVPTXLowerArgs::runOnFunction(Function &F) {
  342. return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F);
  343. }
  344. FunctionPass *
  345. llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) {
  346. return new NVPTXLowerArgs(TM);
  347. }