NVVMReflect.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. //===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
  10. // with an integer.
  11. //
  12. // We choose the value we use by looking at metadata in the module itself. Note
  13. // that we intentionally only have one way to choose these values, because other
  14. // parts of LLVM (particularly, InstCombineCall) rely on being able to predict
  15. // the values chosen by this pass.
  16. //
  17. // If we see an unknown string, we replace its call with 0.
  18. //
  19. //===----------------------------------------------------------------------===//
  20. #include "NVPTX.h"
  21. #include "llvm/ADT/SmallVector.h"
  22. #include "llvm/ADT/StringMap.h"
  23. #include "llvm/IR/Constants.h"
  24. #include "llvm/IR/DerivedTypes.h"
  25. #include "llvm/IR/Function.h"
  26. #include "llvm/IR/InstIterator.h"
  27. #include "llvm/IR/Instructions.h"
  28. #include "llvm/IR/Intrinsics.h"
  29. #include "llvm/IR/IntrinsicsNVPTX.h"
  30. #include "llvm/IR/Module.h"
  31. #include "llvm/IR/PassManager.h"
  32. #include "llvm/IR/Type.h"
  33. #include "llvm/Pass.h"
  34. #include "llvm/Support/CommandLine.h"
  35. #include "llvm/Support/Debug.h"
  36. #include "llvm/Support/raw_os_ostream.h"
  37. #include "llvm/Support/raw_ostream.h"
  38. #include "llvm/Transforms/Scalar.h"
  39. #include <sstream>
  40. #include <string>
  41. #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
  42. #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
  43. using namespace llvm;
  44. #define DEBUG_TYPE "nvptx-reflect"
  45. namespace llvm { void initializeNVVMReflectPass(PassRegistry &); }
  46. namespace {
  47. class NVVMReflect : public FunctionPass {
  48. public:
  49. static char ID;
  50. unsigned int SmVersion;
  51. NVVMReflect() : NVVMReflect(0) {}
  52. explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {
  53. initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
  54. }
  55. bool runOnFunction(Function &) override;
  56. };
  57. }
  58. FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
  59. return new NVVMReflect(SmVersion);
  60. }
  61. static cl::opt<bool>
  62. NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
  63. cl::desc("NVVM reflection, enabled by default"));
  64. char NVVMReflect::ID = 0;
  65. INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
  66. "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
  67. false)
  68. static bool runNVVMReflect(Function &F, unsigned SmVersion) {
  69. if (!NVVMReflectEnabled)
  70. return false;
  71. if (F.getName() == NVVM_REFLECT_FUNCTION ||
  72. F.getName() == NVVM_REFLECT_OCL_FUNCTION) {
  73. assert(F.isDeclaration() && "_reflect function should not have a body");
  74. assert(F.getReturnType()->isIntegerTy() &&
  75. "_reflect's return type should be integer");
  76. return false;
  77. }
  78. SmallVector<Instruction *, 4> ToRemove;
  79. // Go through the calls in this function. Each call to __nvvm_reflect or
  80. // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
  81. // First validate that. If the c-string corresponding to the ConstantArray can
  82. // be found successfully, see if it can be found in VarMap. If so, replace the
  83. // uses of CallInst with the value found in VarMap. If not, replace the use
  84. // with value 0.
  85. // The IR for __nvvm_reflect calls differs between CUDA versions.
  86. //
  87. // CUDA 6.5 and earlier uses this sequence:
  88. // %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
  89. // (i8 addrspace(4)* getelementptr inbounds
  90. // ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
  91. // %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
  92. //
  93. // The value returned by Sym->getOperand(0) is a Constant with a
  94. // ConstantDataSequential operand which can be converted to string and used
  95. // for lookup.
  96. //
  97. // CUDA 7.0 does it slightly differently:
  98. // %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
  99. // (i8 addrspace(1)* getelementptr inbounds
  100. // ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
  101. //
  102. // In this case, we get a Constant with a GlobalVariable operand and we need
  103. // to dig deeper to find its initializer with the string we'll use for lookup.
  104. for (Instruction &I : instructions(F)) {
  105. CallInst *Call = dyn_cast<CallInst>(&I);
  106. if (!Call)
  107. continue;
  108. Function *Callee = Call->getCalledFunction();
  109. if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
  110. Callee->getName() != NVVM_REFLECT_OCL_FUNCTION &&
  111. Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
  112. continue;
  113. // FIXME: Improve error handling here and elsewhere in this pass.
  114. assert(Call->getNumOperands() == 2 &&
  115. "Wrong number of operands to __nvvm_reflect function");
  116. // In cuda 6.5 and earlier, we will have an extra constant-to-generic
  117. // conversion of the string.
  118. const Value *Str = Call->getArgOperand(0);
  119. if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
  120. // FIXME: Add assertions about ConvCall.
  121. Str = ConvCall->getArgOperand(0);
  122. }
  123. // Pre opaque pointers we have a constant expression wrapping the constant
  124. // string.
  125. Str = Str->stripPointerCasts();
  126. assert(isa<Constant>(Str) &&
  127. "Format of __nvvm_reflect function not recognized");
  128. const Value *Operand = cast<Constant>(Str)->getOperand(0);
  129. if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
  130. // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
  131. // initializer.
  132. assert(GV->hasInitializer() &&
  133. "Format of _reflect function not recognized");
  134. const Constant *Initializer = GV->getInitializer();
  135. Operand = Initializer;
  136. }
  137. assert(isa<ConstantDataSequential>(Operand) &&
  138. "Format of _reflect function not recognized");
  139. assert(cast<ConstantDataSequential>(Operand)->isCString() &&
  140. "Format of _reflect function not recognized");
  141. StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
  142. ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
  143. LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
  144. int ReflectVal = 0; // The default value is 0
  145. if (ReflectArg == "__CUDA_FTZ") {
  146. // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag. Our
  147. // choice here must be kept in sync with AutoUpgrade, which uses the same
  148. // technique to detect whether ftz is enabled.
  149. if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
  150. F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
  151. ReflectVal = Flag->getSExtValue();
  152. } else if (ReflectArg == "__CUDA_ARCH") {
  153. ReflectVal = SmVersion * 10;
  154. }
  155. Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
  156. ToRemove.push_back(Call);
  157. }
  158. for (Instruction *I : ToRemove)
  159. I->eraseFromParent();
  160. return ToRemove.size() > 0;
  161. }
  162. bool NVVMReflect::runOnFunction(Function &F) {
  163. return runNVVMReflect(F, SmVersion);
  164. }
  165. NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {}
  166. PreservedAnalyses NVVMReflectPass::run(Function &F,
  167. FunctionAnalysisManager &AM) {
  168. return runNVVMReflect(F, SmVersion) ? PreservedAnalyses::none()
  169. : PreservedAnalyses::all();
  170. }