NVPTXAsmPrinter.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. //===-- NVPTXAsmPrinter.h - NVPTX LLVM assembly writer ----------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file contains a printer that converts from our internal representation
  10. // of machine-dependent LLVM code to NVPTX assembly language.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
  14. #define LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H
  15. #include "NVPTX.h"
  16. #include "NVPTXSubtarget.h"
  17. #include "NVPTXTargetMachine.h"
  18. #include "llvm/ADT/DenseMap.h"
  19. #include "llvm/ADT/SmallVector.h"
  20. #include "llvm/ADT/StringRef.h"
  21. #include "llvm/CodeGen/AsmPrinter.h"
  22. #include "llvm/CodeGen/MachineFunction.h"
  23. #include "llvm/CodeGen/MachineLoopInfo.h"
  24. #include "llvm/IR/Constants.h"
  25. #include "llvm/IR/DebugLoc.h"
  26. #include "llvm/IR/DerivedTypes.h"
  27. #include "llvm/IR/Function.h"
  28. #include "llvm/IR/GlobalValue.h"
  29. #include "llvm/IR/Value.h"
  30. #include "llvm/MC/MCExpr.h"
  31. #include "llvm/MC/MCStreamer.h"
  32. #include "llvm/MC/MCSymbol.h"
  33. #include "llvm/Pass.h"
  34. #include "llvm/Support/Casting.h"
  35. #include "llvm/Support/Compiler.h"
  36. #include "llvm/Support/ErrorHandling.h"
  37. #include "llvm/Support/raw_ostream.h"
  38. #include "llvm/Target/TargetMachine.h"
  39. #include <algorithm>
  40. #include <cassert>
  41. #include <map>
  42. #include <memory>
  43. #include <string>
  44. #include <vector>
  45. // The ptx syntax and format is very different from that usually seem in a .s
  46. // file,
  47. // therefore we are not able to use the MCAsmStreamer interface here.
  48. //
  49. // We are handcrafting the output method here.
  50. //
  51. // A better approach is to clone the MCAsmStreamer to a MCPTXAsmStreamer
  52. // (subclass of MCStreamer).
  53. namespace llvm {
  54. class MCOperand;
  55. class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
  56. class AggBuffer {
  57. // Used to buffer the emitted string for initializing global aggregates.
  58. //
  59. // Normally an aggregate (array, vector, or structure) is emitted as a u8[].
  60. // However, if either element/field of the aggregate is a non-NULL address,
  61. // and all such addresses are properly aligned, then the aggregate is
  62. // emitted as u32[] or u64[]. In the case of unaligned addresses, the
  63. // aggregate is emitted as u8[], and the mask() operator is used for all
  64. // pointers.
  65. //
  66. // We first layout the aggregate in 'buffer' in bytes, except for those
  67. // symbol addresses. For the i-th symbol address in the aggregate, its
  68. // corresponding 4-byte or 8-byte elements in 'buffer' are filled with 0s.
  69. // symbolPosInBuffer[i-1] records its position in 'buffer', and Symbols[i-1]
  70. // records the Value*.
  71. //
  72. // Once we have this AggBuffer setup, we can choose how to print it out.
  73. public:
  74. // number of symbol addresses
  75. unsigned numSymbols() const { return Symbols.size(); }
  76. bool allSymbolsAligned(unsigned ptrSize) const {
  77. return llvm::all_of(symbolPosInBuffer,
  78. [=](unsigned pos) { return pos % ptrSize == 0; });
  79. }
  80. private:
  81. const unsigned size; // size of the buffer in bytes
  82. std::vector<unsigned char> buffer; // the buffer
  83. SmallVector<unsigned, 4> symbolPosInBuffer;
  84. SmallVector<const Value *, 4> Symbols;
  85. // SymbolsBeforeStripping[i] is the original form of Symbols[i] before
  86. // stripping pointer casts, i.e.,
  87. // Symbols[i] == SymbolsBeforeStripping[i]->stripPointerCasts().
  88. //
  89. // We need to keep these values because AggBuffer::print decides whether to
  90. // emit a "generic()" cast for Symbols[i] depending on the address space of
  91. // SymbolsBeforeStripping[i].
  92. SmallVector<const Value *, 4> SymbolsBeforeStripping;
  93. unsigned curpos;
  94. NVPTXAsmPrinter &AP;
  95. bool EmitGeneric;
  96. public:
  97. AggBuffer(unsigned size, NVPTXAsmPrinter &AP)
  98. : size(size), buffer(size), AP(AP) {
  99. curpos = 0;
  100. EmitGeneric = AP.EmitGeneric;
  101. }
  102. // Copy Num bytes from Ptr.
  103. // if Bytes > Num, zero fill up to Bytes.
  104. unsigned addBytes(unsigned char *Ptr, int Num, int Bytes) {
  105. assert((curpos + Num) <= size);
  106. assert((curpos + Bytes) <= size);
  107. for (int i = 0; i < Num; ++i) {
  108. buffer[curpos] = Ptr[i];
  109. curpos++;
  110. }
  111. for (int i = Num; i < Bytes; ++i) {
  112. buffer[curpos] = 0;
  113. curpos++;
  114. }
  115. return curpos;
  116. }
  117. unsigned addZeros(int Num) {
  118. assert((curpos + Num) <= size);
  119. for (int i = 0; i < Num; ++i) {
  120. buffer[curpos] = 0;
  121. curpos++;
  122. }
  123. return curpos;
  124. }
  125. void addSymbol(const Value *GVar, const Value *GVarBeforeStripping) {
  126. symbolPosInBuffer.push_back(curpos);
  127. Symbols.push_back(GVar);
  128. SymbolsBeforeStripping.push_back(GVarBeforeStripping);
  129. }
  130. void printBytes(raw_ostream &os);
  131. void printWords(raw_ostream &os);
  132. private:
  133. void printSymbol(unsigned nSym, raw_ostream &os);
  134. };
  135. friend class AggBuffer;
  136. private:
  137. StringRef getPassName() const override { return "NVPTX Assembly Printer"; }
  138. const Function *F;
  139. std::string CurrentFnName;
  140. void emitStartOfAsmFile(Module &M) override;
  141. void emitBasicBlockStart(const MachineBasicBlock &MBB) override;
  142. void emitFunctionEntryLabel() override;
  143. void emitFunctionBodyStart() override;
  144. void emitFunctionBodyEnd() override;
  145. void emitImplicitDef(const MachineInstr *MI) const override;
  146. void emitInstruction(const MachineInstr *) override;
  147. void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
  148. bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
  149. MCOperand GetSymbolRef(const MCSymbol *Symbol);
  150. unsigned encodeVirtualRegister(unsigned Reg);
  151. void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &O,
  152. const char *Modifier = nullptr);
  153. void printModuleLevelGV(const GlobalVariable *GVar, raw_ostream &O,
  154. bool processDemoted, const NVPTXSubtarget &STI);
  155. void printParamName(Function::const_arg_iterator I, int paramIndex,
  156. raw_ostream &O);
  157. void emitGlobals(const Module &M);
  158. void emitHeader(Module &M, raw_ostream &O, const NVPTXSubtarget &STI);
  159. void emitKernelFunctionDirectives(const Function &F, raw_ostream &O) const;
  160. void emitVirtualRegister(unsigned int vr, raw_ostream &);
  161. void emitFunctionParamList(const Function *, raw_ostream &O);
  162. void emitFunctionParamList(const MachineFunction &MF, raw_ostream &O);
  163. void setAndEmitFunctionVirtualRegisters(const MachineFunction &MF);
  164. void printReturnValStr(const Function *, raw_ostream &O);
  165. void printReturnValStr(const MachineFunction &MF, raw_ostream &O);
  166. bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
  167. const char *ExtraCode, raw_ostream &) override;
  168. void printOperand(const MachineInstr *MI, int opNum, raw_ostream &O);
  169. bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo,
  170. const char *ExtraCode, raw_ostream &) override;
  171. const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric);
  172. void printMCExpr(const MCExpr &Expr, raw_ostream &OS);
  173. protected:
  174. bool doInitialization(Module &M) override;
  175. bool doFinalization(Module &M) override;
  176. private:
  177. bool GlobalsEmitted;
  178. // This is specific per MachineFunction.
  179. const MachineRegisterInfo *MRI;
  180. // The contents are specific for each
  181. // MachineFunction. But the size of the
  182. // array is not.
  183. typedef DenseMap<unsigned, unsigned> VRegMap;
  184. typedef DenseMap<const TargetRegisterClass *, VRegMap> VRegRCMap;
  185. VRegRCMap VRegMapping;
  186. // List of variables demoted to a function scope.
  187. std::map<const Function *, std::vector<const GlobalVariable *>> localDecls;
  188. void emitPTXGlobalVariable(const GlobalVariable *GVar, raw_ostream &O,
  189. const NVPTXSubtarget &STI);
  190. void emitPTXAddressSpace(unsigned int AddressSpace, raw_ostream &O) const;
  191. std::string getPTXFundamentalTypeStr(Type *Ty, bool = true) const;
  192. void printScalarConstant(const Constant *CPV, raw_ostream &O);
  193. void printFPConstant(const ConstantFP *Fp, raw_ostream &O);
  194. void bufferLEByte(const Constant *CPV, int Bytes, AggBuffer *aggBuffer);
  195. void bufferAggregateConstant(const Constant *CV, AggBuffer *aggBuffer);
  196. void emitLinkageDirective(const GlobalValue *V, raw_ostream &O);
  197. void emitDeclarations(const Module &, raw_ostream &O);
  198. void emitDeclaration(const Function *, raw_ostream &O);
  199. void emitDemotedVars(const Function *, raw_ostream &);
  200. bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
  201. MCOperand &MCOp);
  202. void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);
  203. bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;
  204. // Used to control the need to emit .generic() in the initializer of
  205. // module scope variables.
  206. // Although ptx supports the hybrid mode like the following,
  207. // .global .u32 a;
  208. // .global .u32 b;
  209. // .global .u32 addr[] = {a, generic(b)}
  210. // we have difficulty representing the difference in the NVVM IR.
  211. //
  212. // Since the address value should always be generic in CUDA C and always
  213. // be specific in OpenCL, we use this simple control here.
  214. //
  215. bool EmitGeneric;
  216. public:
  217. NVPTXAsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer)
  218. : AsmPrinter(TM, std::move(Streamer)),
  219. EmitGeneric(static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
  220. NVPTX::CUDA) {}
  221. bool runOnMachineFunction(MachineFunction &F) override;
  222. void getAnalysisUsage(AnalysisUsage &AU) const override {
  223. AU.addRequired<MachineLoopInfo>();
  224. AsmPrinter::getAnalysisUsage(AU);
  225. }
  226. std::string getVirtualRegisterName(unsigned) const;
  227. const MCSymbol *getFunctionFrameSymbol() const override;
  228. // Make emitGlobalVariable() no-op for NVPTX.
  229. // Global variables have been already emitted by the time the base AsmPrinter
  230. // attempts to do so in doFinalization() (see NVPTXAsmPrinter::emitGlobals()).
  231. void emitGlobalVariable(const GlobalVariable *GV) override {}
  232. };
  233. } // end namespace llvm
  234. #endif // LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H