compile_cuda.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import sys
  2. import subprocess
  3. import os
  4. import collections
  5. import re
  6. import tempfile
  7. def is_clang(command):
  8. for word in command:
  9. if '--compiler-bindir' in word and 'clang' in word:
  10. return True
  11. return False
  12. def main():
  13. try:
  14. sys.argv.remove('--y_skip_nocxxinc')
  15. skip_nocxxinc = True
  16. except ValueError:
  17. skip_nocxxinc = False
  18. spl = sys.argv.index('--cflags')
  19. cmd = 1
  20. mtime0 = None
  21. if sys.argv[1] == '--mtime':
  22. mtime0 = sys.argv[2]
  23. cmd = 3
  24. command = sys.argv[cmd:spl]
  25. cflags = sys.argv[spl + 1 :]
  26. dump_args = False
  27. if '--y_dump_args' in command:
  28. command.remove('--y_dump_args')
  29. dump_args = True
  30. executable = command[0]
  31. if not os.path.exists(executable):
  32. print >> sys.stderr, '{} not found'.format(executable)
  33. sys.exit(1)
  34. if is_clang(command):
  35. # nvcc concatenates the sources for clang, and clang reports unused
  36. # things from .h files as if they they were defined in a .cpp file.
  37. cflags += ['-Wno-unused-function', '-Wno-unused-parameter']
  38. if not is_clang(command) and '-fopenmp=libomp' in cflags:
  39. cflags.append('-fopenmp')
  40. cflags.remove('-fopenmp=libomp')
  41. skip_list = [
  42. '-gline-tables-only',
  43. # clang coverage
  44. '-fprofile-instr-generate',
  45. '-fcoverage-mapping',
  46. '/Zc:inline', # disable unreferenced functions (kernel registrators) remove
  47. '-Wno-c++17-extensions',
  48. '-flto',
  49. '-faligned-allocation',
  50. '-fsized-deallocation',
  51. # While it might be reasonable to compile host part of .cu sources with these optimizations enabled,
  52. # nvcc passes these options down towards cicc which lacks x86_64 extensions support.
  53. '-msse2',
  54. '-msse3',
  55. '-mssse3',
  56. '-msse4.1',
  57. '-msse4.2',
  58. ]
  59. if skip_nocxxinc:
  60. skip_list.append('-nostdinc++')
  61. for flag in skip_list:
  62. if flag in cflags:
  63. cflags.remove(flag)
  64. skip_prefix_list = [
  65. '-fsanitize=',
  66. '-fsanitize-coverage=',
  67. '-fsanitize-blacklist=',
  68. '--system-header-prefix',
  69. ]
  70. new_cflags = []
  71. for flag in cflags:
  72. if all(not flag.startswith(skip_prefix) for skip_prefix in skip_prefix_list):
  73. if flag.startswith('-fopenmp-version='):
  74. new_cflags.append(
  75. '-fopenmp-version=45'
  76. ) # Clang 11 only supports OpenMP 4.5, but the default is 5.0, so we need to forcefully redefine it.
  77. else:
  78. new_cflags.append(flag)
  79. cflags = new_cflags
  80. if not is_clang(command):
  81. def good(arg):
  82. if arg.startswith('--target='):
  83. return False
  84. return True
  85. cflags = filter(good, cflags)
  86. cpp_args = []
  87. compiler_args = []
  88. # NVCC requires particular MSVC versions which may differ from the version
  89. # used to compile regular C++ code. We have a separate MSVC in Arcadia for
  90. # the CUDA builds and pass it's root in $Y_VC_Root.
  91. # The separate MSVC for CUDA may absent in Yandex Open Source builds.
  92. vc_root = os.environ.get('Y_VC_Root')
  93. cflags_queue = collections.deque(cflags)
  94. while cflags_queue:
  95. arg = cflags_queue.popleft()
  96. if arg == '-mllvm':
  97. compiler_args.append(arg)
  98. compiler_args.append(cflags_queue.popleft())
  99. continue
  100. if arg[:2].upper() in ('-I', '/I', '-B'):
  101. value = arg[2:]
  102. if not value:
  103. value = cflags_queue.popleft()
  104. if arg[1] == 'I':
  105. cpp_args.append('-I{}'.format(value))
  106. elif arg[1] == 'B': # todo: delete "B" flag check when cuda stop to use gcc
  107. pass
  108. continue
  109. match = re.match(r'[-/]D(.*)', arg)
  110. if match:
  111. define = match.group(1)
  112. # We have C++ flags configured for the regular C++ build.
  113. # There is Y_MSVC_INCLUDE define with a path to the VC header files.
  114. # We need to change the path accordingly when using a separate MSVC for CUDA.
  115. if vc_root and define.startswith('Y_MSVC_INCLUDE'):
  116. define = os.path.expandvars('Y_MSVC_INCLUDE={}/include'.format(vc_root))
  117. cpp_args.append('-D' + define.replace('\\', '/'))
  118. continue
  119. compiler_args.append(arg)
  120. command += cpp_args
  121. if compiler_args:
  122. command += ['--compiler-options', ','.join(compiler_args)]
  123. # --keep is necessary to prevent nvcc from embedding nvcc pid in generated
  124. # symbols. It makes nvcc use the original file name as the prefix in the
  125. # generated files (otherwise it also prepends tmpxft_{pid}_00000000-5), and
  126. # cicc derives the module name from its {input}.cpp1.ii file name.
  127. command += ['--keep', '--keep-dir', tempfile.mkdtemp(prefix='compile_cuda.py.')]
  128. # nvcc generates symbols like __fatbinwrap_{len}_{basename}_{hash} where
  129. # {basename} is {input}.cpp1.ii with non-C chars translated to _, {len} is
  130. # {basename} length, and {hash} is the hash of first exported symbol in
  131. # {input}.cpp1.ii if there is one, otherwise it is based on its modification
  132. # time (converted to string in the local timezone) and the current working
  133. # directory. To stabilize the names of these symbols we need to fix mtime,
  134. # timezone, and cwd.
  135. if mtime0:
  136. os.environ['LD_PRELOAD'] = mtime0
  137. os.environ['TZ'] = 'UTC0' # POSIX fixed offset format.
  138. os.environ['TZDIR'] = '/var/empty' # Against counterfeit /usr/share/zoneinfo/$TZ.
  139. if dump_args:
  140. sys.stdout.write('\n'.join(command))
  141. else:
  142. sys.exit(subprocess.Popen(command, stdout=sys.stderr, stderr=sys.stderr, cwd='/').wait())
  143. if __name__ == '__main__':
  144. main()