NVPTXLowerArgs.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. //
  10. // Arguments to kernel and device functions are passed via param space,
  11. // which imposes certain restrictions:
  12. // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
  13. //
  14. // Kernel parameters are read-only and accessible only via ld.param
  15. // instruction, directly or via a pointer. Pointers to kernel
  16. // arguments can't be converted to generic address space.
  17. //
  18. // Device function parameters are directly accessible via
  19. // ld.param/st.param, but taking the address of one returns a pointer
  20. // to a copy created in local space which *can't* be used with
  21. // ld.param/st.param.
  22. //
  23. // Copying a byval struct into local memory in IR allows us to enforce
  24. // the param space restrictions, gives the rest of IR a pointer w/o
  25. // param space restrictions, and gives us an opportunity to eliminate
  26. // the copy.
  27. //
  28. // Pointer arguments to kernel functions need more work to be lowered:
  29. //
  30. // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
  31. // global address space. This allows later optimizations to emit
  32. // ld.global.*/st.global.* for accessing these pointer arguments. For
  33. // example,
  34. //
  35. // define void @foo(float* %input) {
  36. // %v = load float, float* %input, align 4
  37. // ...
  38. // }
  39. //
  40. // becomes
  41. //
  42. // define void @foo(float* %input) {
  43. // %input2 = addrspacecast float* %input to float addrspace(1)*
  44. // %input3 = addrspacecast float addrspace(1)* %input2 to float*
  45. // %v = load float, float* %input3, align 4
  46. // ...
  47. // }
  48. //
  49. // Later, NVPTXInferAddressSpaces will optimize it to
  50. //
  51. // define void @foo(float* %input) {
  52. // %input2 = addrspacecast float* %input to float addrspace(1)*
  53. // %v = load float, float addrspace(1)* %input2, align 4
  54. // ...
  55. // }
  56. //
  57. // 2. Convert pointers in a byval kernel parameter to pointers in the global
  58. // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
  59. //
  60. // struct S {
  61. // int *x;
  62. // int *y;
  63. // };
  64. // __global__ void foo(S s) {
  65. // int *b = s.y;
  66. // // use b
  67. // }
  68. //
  69. // "b" points to the global address space. In the IR level,
  70. //
  71. // define void @foo({i32*, i32*}* byval %input) {
  72. // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
  73. // %b = load i32*, i32** %b_ptr
  74. // ; use %b
  75. // }
  76. //
  77. // becomes
  78. //
  79. // define void @foo({i32*, i32*}* byval %input) {
  80. // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
  81. // %b = load i32*, i32** %b_ptr
  82. // %b_global = addrspacecast i32* %b to i32 addrspace(1)*
  83. // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
  84. // ; use %b_generic
  85. // }
  86. //
  87. // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
  88. // cancel the addrspacecast pair this pass emits.
  89. //===----------------------------------------------------------------------===//
  90. #include "MCTargetDesc/NVPTXBaseInfo.h"
  91. #include "NVPTX.h"
  92. #include "NVPTXTargetMachine.h"
  93. #include "NVPTXUtilities.h"
  94. #include "llvm/Analysis/ValueTracking.h"
  95. #include "llvm/IR/Function.h"
  96. #include "llvm/IR/Instructions.h"
  97. #include "llvm/IR/Module.h"
  98. #include "llvm/IR/Type.h"
  99. #include "llvm/Pass.h"
  100. #include <numeric>
  101. #include <queue>
  102. #define DEBUG_TYPE "nvptx-lower-args"
  103. using namespace llvm;
  104. namespace llvm {
  105. void initializeNVPTXLowerArgsPass(PassRegistry &);
  106. }
  107. namespace {
  108. class NVPTXLowerArgs : public FunctionPass {
  109. bool runOnFunction(Function &F) override;
  110. bool runOnKernelFunction(Function &F);
  111. bool runOnDeviceFunction(Function &F);
  112. // handle byval parameters
  113. void handleByValParam(Argument *Arg);
  114. // Knowing Ptr must point to the global address space, this function
  115. // addrspacecasts Ptr to global and then back to generic. This allows
  116. // NVPTXInferAddressSpaces to fold the global-to-generic cast into
  117. // loads/stores that appear later.
  118. void markPointerAsGlobal(Value *Ptr);
  119. public:
  120. static char ID; // Pass identification, replacement for typeid
  121. NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr)
  122. : FunctionPass(ID), TM(TM) {}
  123. StringRef getPassName() const override {
  124. return "Lower pointer arguments of CUDA kernels";
  125. }
  126. private:
  127. const NVPTXTargetMachine *TM;
  128. };
  129. } // namespace
  130. char NVPTXLowerArgs::ID = 1;
  131. INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args",
  132. "Lower arguments (NVPTX)", false, false)
  133. // =============================================================================
  134. // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
  135. // and we can't guarantee that the only accesses are loads,
  136. // then add the following instructions to the first basic block:
  137. //
  138. // %temp = alloca %struct.x, align 8
  139. // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
  140. // %tv = load %struct.x addrspace(101)* %tempd
  141. // store %struct.x %tv, %struct.x* %temp, align 8
  142. //
  143. // The above code allocates some space in the stack and copies the incoming
  144. // struct from param space to local space.
  145. // Then replace all occurrences of %d by %temp.
  146. //
  147. // In case we know that all users are GEPs or Loads, replace them with the same
  148. // ones in parameter AS, so we can access them using ld.param.
  149. // =============================================================================
  150. // Replaces the \p OldUser instruction with the same in parameter AS.
  151. // Only Load and GEP are supported.
  152. static void convertToParamAS(Value *OldUser, Value *Param) {
  153. Instruction *I = dyn_cast<Instruction>(OldUser);
  154. assert(I && "OldUser must be an instruction");
  155. struct IP {
  156. Instruction *OldInstruction;
  157. Value *NewParam;
  158. };
  159. SmallVector<IP> ItemsToConvert = {{I, Param}};
  160. SmallVector<Instruction *> InstructionsToDelete;
  161. auto CloneInstInParamAS = [](const IP &I) -> Value * {
  162. if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
  163. LI->setOperand(0, I.NewParam);
  164. return LI;
  165. }
  166. if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
  167. SmallVector<Value *, 4> Indices(GEP->indices());
  168. auto *NewGEP = GetElementPtrInst::Create(GEP->getSourceElementType(),
  169. I.NewParam, Indices,
  170. GEP->getName(), GEP);
  171. NewGEP->setIsInBounds(GEP->isInBounds());
  172. return NewGEP;
  173. }
  174. if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
  175. auto *NewBCType = PointerType::getWithSamePointeeType(
  176. cast<PointerType>(BC->getType()), ADDRESS_SPACE_PARAM);
  177. return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
  178. BC->getName(), BC);
  179. }
  180. if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
  181. assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
  182. (void)ASC;
  183. // Just pass through the argument, the old ASC is no longer needed.
  184. return I.NewParam;
  185. }
  186. llvm_unreachable("Unsupported instruction");
  187. };
  188. while (!ItemsToConvert.empty()) {
  189. IP I = ItemsToConvert.pop_back_val();
  190. Value *NewInst = CloneInstInParamAS(I);
  191. if (NewInst && NewInst != I.OldInstruction) {
  192. // We've created a new instruction. Queue users of the old instruction to
  193. // be converted and the instruction itself to be deleted. We can't delete
  194. // the old instruction yet, because it's still in use by a load somewhere.
  195. for (Value *V : I.OldInstruction->users())
  196. ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
  197. InstructionsToDelete.push_back(I.OldInstruction);
  198. }
  199. }
  200. // Now we know that all argument loads are using addresses in parameter space
  201. // and we can finally remove the old instructions in generic AS. Instructions
  202. // scheduled for removal should be processed in reverse order so the ones
  203. // closest to the load are deleted first. Otherwise they may still be in use.
  204. // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
  205. // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
  206. // the BitCast.
  207. for (Instruction *I : llvm::reverse(InstructionsToDelete))
  208. I->eraseFromParent();
  209. }
  210. // Adjust alignment of arguments passed byval in .param address space. We can
  211. // increase alignment of such arguments in a way that ensures that we can
  212. // effectively vectorize their loads. We should also traverse all loads from
  213. // byval pointer and adjust their alignment, if those were using known offset.
  214. // Such alignment changes must be conformed with parameter store and load in
  215. // NVPTXTargetLowering::LowerCall.
  216. static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
  217. const NVPTXTargetLowering *TLI) {
  218. Function *Func = Arg->getParent();
  219. Type *StructType = Arg->getParamByValType();
  220. const DataLayout DL(Func->getParent());
  221. uint64_t NewArgAlign =
  222. TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value();
  223. uint64_t CurArgAlign =
  224. Arg->getAttribute(Attribute::Alignment).getValueAsInt();
  225. if (CurArgAlign >= NewArgAlign)
  226. return;
  227. LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of "
  228. << CurArgAlign << " for " << *Arg << '\n');
  229. auto NewAlignAttr =
  230. Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
  231. Arg->removeAttr(Attribute::Alignment);
  232. Arg->addAttr(NewAlignAttr);
  233. struct Load {
  234. LoadInst *Inst;
  235. uint64_t Offset;
  236. };
  237. struct LoadContext {
  238. Value *InitialVal;
  239. uint64_t Offset;
  240. };
  241. SmallVector<Load> Loads;
  242. std::queue<LoadContext> Worklist;
  243. Worklist.push({ArgInParamAS, 0});
  244. while (!Worklist.empty()) {
  245. LoadContext Ctx = Worklist.front();
  246. Worklist.pop();
  247. for (User *CurUser : Ctx.InitialVal->users()) {
  248. if (auto *I = dyn_cast<LoadInst>(CurUser)) {
  249. Loads.push_back({I, Ctx.Offset});
  250. continue;
  251. }
  252. if (auto *I = dyn_cast<BitCastInst>(CurUser)) {
  253. Worklist.push({I, Ctx.Offset});
  254. continue;
  255. }
  256. if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
  257. APInt OffsetAccumulated =
  258. APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
  259. if (!I->accumulateConstantOffset(DL, OffsetAccumulated))
  260. continue;
  261. uint64_t OffsetLimit = -1;
  262. uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit);
  263. assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
  264. Worklist.push({I, Ctx.Offset + Offset});
  265. continue;
  266. }
  267. llvm_unreachable("All users must be one of: load, "
  268. "bitcast, getelementptr.");
  269. }
  270. }
  271. for (Load &CurLoad : Loads) {
  272. Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
  273. Align CurLoadAlign(CurLoad.Inst->getAlign());
  274. CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
  275. }
  276. }
  277. void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
  278. Function *Func = Arg->getParent();
  279. Instruction *FirstInst = &(Func->getEntryBlock().front());
  280. Type *StructType = Arg->getParamByValType();
  281. assert(StructType && "Missing byval type");
  282. auto IsALoadChain = [&](Value *Start) {
  283. SmallVector<Value *, 16> ValuesToCheck = {Start};
  284. auto IsALoadChainInstr = [](Value *V) -> bool {
  285. if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
  286. return true;
  287. // ASC to param space are OK, too -- we'll just strip them.
  288. if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
  289. if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
  290. return true;
  291. }
  292. return false;
  293. };
  294. while (!ValuesToCheck.empty()) {
  295. Value *V = ValuesToCheck.pop_back_val();
  296. if (!IsALoadChainInstr(V)) {
  297. LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
  298. << "\n");
  299. (void)Arg;
  300. return false;
  301. }
  302. if (!isa<LoadInst>(V))
  303. llvm::append_range(ValuesToCheck, V->users());
  304. }
  305. return true;
  306. };
  307. if (llvm::all_of(Arg->users(), IsALoadChain)) {
  308. // Convert all loads and intermediate operations to use parameter AS and
  309. // skip creation of a local copy of the argument.
  310. SmallVector<User *, 16> UsersToUpdate(Arg->users());
  311. Value *ArgInParamAS = new AddrSpaceCastInst(
  312. Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
  313. FirstInst);
  314. for (Value *V : UsersToUpdate)
  315. convertToParamAS(V, ArgInParamAS);
  316. LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
  317. // Further optimizations require target lowering info.
  318. if (!TM)
  319. return;
  320. const auto *TLI =
  321. cast<NVPTXTargetLowering>(TM->getSubtargetImpl()->getTargetLowering());
  322. adjustByValArgAlignment(Arg, ArgInParamAS, TLI);
  323. return;
  324. }
  325. // Otherwise we have to create a temporary copy.
  326. const DataLayout &DL = Func->getParent()->getDataLayout();
  327. unsigned AS = DL.getAllocaAddrSpace();
  328. AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
  329. // Set the alignment to alignment of the byval parameter. This is because,
  330. // later load/stores assume that alignment, and we are going to replace
  331. // the use of the byval parameter with this alloca instruction.
  332. AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
  333. .value_or(DL.getPrefTypeAlign(StructType)));
  334. Arg->replaceAllUsesWith(AllocA);
  335. Value *ArgInParam = new AddrSpaceCastInst(
  336. Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
  337. FirstInst);
  338. // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
  339. // addrspacecast preserves alignment. Since params are constant, this load is
  340. // definitely not volatile.
  341. LoadInst *LI =
  342. new LoadInst(StructType, ArgInParam, Arg->getName(),
  343. /*isVolatile=*/false, AllocA->getAlign(), FirstInst);
  344. new StoreInst(LI, AllocA, FirstInst);
  345. }
  346. void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
  347. if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
  348. return;
  349. // Deciding where to emit the addrspacecast pair.
  350. BasicBlock::iterator InsertPt;
  351. if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
  352. // Insert at the functon entry if Ptr is an argument.
  353. InsertPt = Arg->getParent()->getEntryBlock().begin();
  354. } else {
  355. // Insert right after Ptr if Ptr is an instruction.
  356. InsertPt = ++cast<Instruction>(Ptr)->getIterator();
  357. assert(InsertPt != InsertPt->getParent()->end() &&
  358. "We don't call this function with Ptr being a terminator.");
  359. }
  360. Instruction *PtrInGlobal = new AddrSpaceCastInst(
  361. Ptr,
  362. PointerType::getWithSamePointeeType(cast<PointerType>(Ptr->getType()),
  363. ADDRESS_SPACE_GLOBAL),
  364. Ptr->getName(), &*InsertPt);
  365. Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
  366. Ptr->getName(), &*InsertPt);
  367. // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
  368. Ptr->replaceAllUsesWith(PtrInGeneric);
  369. PtrInGlobal->setOperand(0, Ptr);
  370. }
  371. // =============================================================================
  372. // Main function for this pass.
  373. // =============================================================================
  374. bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
  375. if (TM && TM->getDrvInterface() == NVPTX::CUDA) {
  376. // Mark pointers in byval structs as global.
  377. for (auto &B : F) {
  378. for (auto &I : B) {
  379. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
  380. if (LI->getType()->isPointerTy()) {
  381. Value *UO = getUnderlyingObject(LI->getPointerOperand());
  382. if (Argument *Arg = dyn_cast<Argument>(UO)) {
  383. if (Arg->hasByValAttr()) {
  384. // LI is a load from a pointer within a byval kernel parameter.
  385. markPointerAsGlobal(LI);
  386. }
  387. }
  388. }
  389. }
  390. }
  391. }
  392. }
  393. LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
  394. for (Argument &Arg : F.args()) {
  395. if (Arg.getType()->isPointerTy()) {
  396. if (Arg.hasByValAttr())
  397. handleByValParam(&Arg);
  398. else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
  399. markPointerAsGlobal(&Arg);
  400. }
  401. }
  402. return true;
  403. }
  404. // Device functions only need to copy byval args into local memory.
  405. bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
  406. LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
  407. for (Argument &Arg : F.args())
  408. if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
  409. handleByValParam(&Arg);
  410. return true;
  411. }
  412. bool NVPTXLowerArgs::runOnFunction(Function &F) {
  413. return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F);
  414. }
  415. FunctionPass *
  416. llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) {
  417. return new NVPTXLowerArgs(TM);
  418. }