NVVMIntrRange.cpp 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. //===- NVVMIntrRange.cpp - Set !range metadata for NVVM intrinsics --------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass adds appropriate !range metadata for calls to NVVM
  10. // intrinsics that return a limited range of values.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "NVPTX.h"
  14. #include "llvm/IR/Constants.h"
  15. #include "llvm/IR/InstIterator.h"
  16. #include "llvm/IR/Instructions.h"
  17. #include "llvm/IR/Intrinsics.h"
  18. #include "llvm/IR/IntrinsicsNVPTX.h"
  19. #include "llvm/IR/PassManager.h"
  20. #include "llvm/Support/CommandLine.h"
  21. using namespace llvm;
  22. #define DEBUG_TYPE "nvvm-intr-range"
  23. namespace llvm { void initializeNVVMIntrRangePass(PassRegistry &); }
  24. // Add !range metadata based on limits of given SM variant.
  25. static cl::opt<unsigned> NVVMIntrRangeSM("nvvm-intr-range-sm", cl::init(20),
  26. cl::Hidden, cl::desc("SM variant"));
  27. namespace {
  28. class NVVMIntrRange : public FunctionPass {
  29. private:
  30. unsigned SmVersion;
  31. public:
  32. static char ID;
  33. NVVMIntrRange() : NVVMIntrRange(NVVMIntrRangeSM) {}
  34. NVVMIntrRange(unsigned int SmVersion)
  35. : FunctionPass(ID), SmVersion(SmVersion) {
  36. initializeNVVMIntrRangePass(*PassRegistry::getPassRegistry());
  37. }
  38. bool runOnFunction(Function &) override;
  39. };
  40. }
  41. FunctionPass *llvm::createNVVMIntrRangePass(unsigned int SmVersion) {
  42. return new NVVMIntrRange(SmVersion);
  43. }
  44. char NVVMIntrRange::ID = 0;
  45. INITIALIZE_PASS(NVVMIntrRange, "nvvm-intr-range",
  46. "Add !range metadata to NVVM intrinsics.", false, false)
  47. // Adds the passed-in [Low,High) range information as metadata to the
  48. // passed-in call instruction.
  49. static bool addRangeMetadata(uint64_t Low, uint64_t High, CallInst *C) {
  50. // This call already has range metadata, nothing to do.
  51. if (C->getMetadata(LLVMContext::MD_range))
  52. return false;
  53. LLVMContext &Context = C->getParent()->getContext();
  54. IntegerType *Int32Ty = Type::getInt32Ty(Context);
  55. Metadata *LowAndHigh[] = {
  56. ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Low)),
  57. ConstantAsMetadata::get(ConstantInt::get(Int32Ty, High))};
  58. C->setMetadata(LLVMContext::MD_range, MDNode::get(Context, LowAndHigh));
  59. return true;
  60. }
  61. static bool runNVVMIntrRange(Function &F, unsigned SmVersion) {
  62. struct {
  63. unsigned x, y, z;
  64. } MaxBlockSize, MaxGridSize;
  65. MaxBlockSize.x = 1024;
  66. MaxBlockSize.y = 1024;
  67. MaxBlockSize.z = 64;
  68. MaxGridSize.x = SmVersion >= 30 ? 0x7fffffff : 0xffff;
  69. MaxGridSize.y = 0xffff;
  70. MaxGridSize.z = 0xffff;
  71. // Go through the calls in this function.
  72. bool Changed = false;
  73. for (Instruction &I : instructions(F)) {
  74. CallInst *Call = dyn_cast<CallInst>(&I);
  75. if (!Call)
  76. continue;
  77. if (Function *Callee = Call->getCalledFunction()) {
  78. switch (Callee->getIntrinsicID()) {
  79. // Index within block
  80. case Intrinsic::nvvm_read_ptx_sreg_tid_x:
  81. Changed |= addRangeMetadata(0, MaxBlockSize.x, Call);
  82. break;
  83. case Intrinsic::nvvm_read_ptx_sreg_tid_y:
  84. Changed |= addRangeMetadata(0, MaxBlockSize.y, Call);
  85. break;
  86. case Intrinsic::nvvm_read_ptx_sreg_tid_z:
  87. Changed |= addRangeMetadata(0, MaxBlockSize.z, Call);
  88. break;
  89. // Block size
  90. case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
  91. Changed |= addRangeMetadata(1, MaxBlockSize.x+1, Call);
  92. break;
  93. case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
  94. Changed |= addRangeMetadata(1, MaxBlockSize.y+1, Call);
  95. break;
  96. case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
  97. Changed |= addRangeMetadata(1, MaxBlockSize.z+1, Call);
  98. break;
  99. // Index within grid
  100. case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
  101. Changed |= addRangeMetadata(0, MaxGridSize.x, Call);
  102. break;
  103. case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
  104. Changed |= addRangeMetadata(0, MaxGridSize.y, Call);
  105. break;
  106. case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
  107. Changed |= addRangeMetadata(0, MaxGridSize.z, Call);
  108. break;
  109. // Grid size
  110. case Intrinsic::nvvm_read_ptx_sreg_nctaid_x:
  111. Changed |= addRangeMetadata(1, MaxGridSize.x+1, Call);
  112. break;
  113. case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
  114. Changed |= addRangeMetadata(1, MaxGridSize.y+1, Call);
  115. break;
  116. case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
  117. Changed |= addRangeMetadata(1, MaxGridSize.z+1, Call);
  118. break;
  119. // warp size is constant 32.
  120. case Intrinsic::nvvm_read_ptx_sreg_warpsize:
  121. Changed |= addRangeMetadata(32, 32+1, Call);
  122. break;
  123. // Lane ID is [0..warpsize)
  124. case Intrinsic::nvvm_read_ptx_sreg_laneid:
  125. Changed |= addRangeMetadata(0, 32, Call);
  126. break;
  127. default:
  128. break;
  129. }
  130. }
  131. }
  132. return Changed;
  133. }
  134. bool NVVMIntrRange::runOnFunction(Function &F) {
  135. return runNVVMIntrRange(F, SmVersion);
  136. }
  137. NVVMIntrRangePass::NVVMIntrRangePass() : NVVMIntrRangePass(NVVMIntrRangeSM) {}
  138. PreservedAnalyses NVVMIntrRangePass::run(Function &F,
  139. FunctionAnalysisManager &AM) {
  140. return runNVVMIntrRange(F, SmVersion) ? PreservedAnalyses::none()
  141. : PreservedAnalyses::all();
  142. }