forkserver.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import errno
  2. import os
  3. import selectors
  4. import signal
  5. import socket
  6. import struct
  7. import sys
  8. import threading
  9. import warnings
  10. from . import connection
  11. from . import process
  12. from .context import reduction
  13. from . import resource_tracker
  14. from . import spawn
  15. from . import util
  16. __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
  17. 'set_forkserver_preload']
  18. #
  19. #
  20. #
  21. MAXFDS_TO_SEND = 256
  22. SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
  23. #
  24. # Forkserver class
  25. #
  26. class ForkServer(object):
  27. def __init__(self):
  28. self._forkserver_address = None
  29. self._forkserver_alive_fd = None
  30. self._forkserver_pid = None
  31. self._inherited_fds = None
  32. self._lock = threading.Lock()
  33. self._preload_modules = ['__main__']
  34. def _stop(self):
  35. # Method used by unit tests to stop the server
  36. with self._lock:
  37. self._stop_unlocked()
  38. def _stop_unlocked(self):
  39. if self._forkserver_pid is None:
  40. return
  41. # close the "alive" file descriptor asks the server to stop
  42. os.close(self._forkserver_alive_fd)
  43. self._forkserver_alive_fd = None
  44. os.waitpid(self._forkserver_pid, 0)
  45. self._forkserver_pid = None
  46. if not util.is_abstract_socket_namespace(self._forkserver_address):
  47. os.unlink(self._forkserver_address)
  48. self._forkserver_address = None
  49. def set_forkserver_preload(self, modules_names):
  50. '''Set list of module names to try to load in forkserver process.'''
  51. if not all(type(mod) is str for mod in modules_names):
  52. raise TypeError('module_names must be a list of strings')
  53. self._preload_modules = modules_names
  54. def get_inherited_fds(self):
  55. '''Return list of fds inherited from parent process.
  56. This returns None if the current process was not started by fork
  57. server.
  58. '''
  59. return self._inherited_fds
  60. def connect_to_new_process(self, fds):
  61. '''Request forkserver to create a child process.
  62. Returns a pair of fds (status_r, data_w). The calling process can read
  63. the child process's pid and (eventually) its returncode from status_r.
  64. The calling process should write to data_w the pickled preparation and
  65. process data.
  66. '''
  67. self.ensure_running()
  68. if len(fds) + 4 >= MAXFDS_TO_SEND:
  69. raise ValueError('too many fds')
  70. with socket.socket(socket.AF_UNIX) as client:
  71. client.connect(self._forkserver_address)
  72. parent_r, child_w = os.pipe()
  73. child_r, parent_w = os.pipe()
  74. allfds = [child_r, child_w, self._forkserver_alive_fd,
  75. resource_tracker.getfd()]
  76. allfds += fds
  77. try:
  78. reduction.sendfds(client, allfds)
  79. return parent_r, parent_w
  80. except:
  81. os.close(parent_r)
  82. os.close(parent_w)
  83. raise
  84. finally:
  85. os.close(child_r)
  86. os.close(child_w)
  87. def ensure_running(self):
  88. '''Make sure that a fork server is running.
  89. This can be called from any process. Note that usually a child
  90. process will just reuse the forkserver started by its parent, so
  91. ensure_running() will do nothing.
  92. '''
  93. with self._lock:
  94. resource_tracker.ensure_running()
  95. if self._forkserver_pid is not None:
  96. # forkserver was launched before, is it still running?
  97. pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
  98. if not pid:
  99. # still alive
  100. return
  101. # dead, launch it again
  102. os.close(self._forkserver_alive_fd)
  103. self._forkserver_address = None
  104. self._forkserver_alive_fd = None
  105. self._forkserver_pid = None
  106. cmd = ('from multiprocessing.forkserver import main; ' +
  107. 'main(%d, %d, %r, **%r)')
  108. if self._preload_modules:
  109. desired_keys = {'main_path', 'sys_path'}
  110. data = spawn.get_preparation_data('ignore')
  111. data = {x: y for x, y in data.items() if x in desired_keys}
  112. else:
  113. data = {}
  114. with socket.socket(socket.AF_UNIX) as listener:
  115. address = connection.arbitrary_address('AF_UNIX')
  116. listener.bind(address)
  117. if not util.is_abstract_socket_namespace(address):
  118. os.chmod(address, 0o600)
  119. listener.listen()
  120. # all client processes own the write end of the "alive" pipe;
  121. # when they all terminate the read end becomes ready.
  122. alive_r, alive_w = os.pipe()
  123. try:
  124. fds_to_pass = [listener.fileno(), alive_r]
  125. cmd %= (listener.fileno(), alive_r, self._preload_modules,
  126. data)
  127. exe = spawn.get_executable()
  128. args = [exe] + util._args_from_interpreter_flags()
  129. args += ['-c', cmd]
  130. pid = util.spawnv_passfds(exe, args, fds_to_pass)
  131. except:
  132. os.close(alive_w)
  133. raise
  134. finally:
  135. os.close(alive_r)
  136. self._forkserver_address = address
  137. self._forkserver_alive_fd = alive_w
  138. self._forkserver_pid = pid
  139. #
  140. #
  141. #
  142. def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
  143. '''Run forkserver.'''
  144. if preload:
  145. if '__main__' in preload and main_path is not None:
  146. process.current_process()._inheriting = True
  147. try:
  148. spawn.import_main_path(main_path)
  149. finally:
  150. del process.current_process()._inheriting
  151. for modname in preload:
  152. try:
  153. __import__(modname)
  154. except ImportError:
  155. pass
  156. util._close_stdin()
  157. sig_r, sig_w = os.pipe()
  158. os.set_blocking(sig_r, False)
  159. os.set_blocking(sig_w, False)
  160. def sigchld_handler(*_unused):
  161. # Dummy signal handler, doesn't do anything
  162. pass
  163. handlers = {
  164. # unblocking SIGCHLD allows the wakeup fd to notify our event loop
  165. signal.SIGCHLD: sigchld_handler,
  166. # protect the process from ^C
  167. signal.SIGINT: signal.SIG_IGN,
  168. }
  169. old_handlers = {sig: signal.signal(sig, val)
  170. for (sig, val) in handlers.items()}
  171. # calling os.write() in the Python signal handler is racy
  172. signal.set_wakeup_fd(sig_w)
  173. # map child pids to client fds
  174. pid_to_fd = {}
  175. with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
  176. selectors.DefaultSelector() as selector:
  177. _forkserver._forkserver_address = listener.getsockname()
  178. selector.register(listener, selectors.EVENT_READ)
  179. selector.register(alive_r, selectors.EVENT_READ)
  180. selector.register(sig_r, selectors.EVENT_READ)
  181. while True:
  182. try:
  183. while True:
  184. rfds = [key.fileobj for (key, events) in selector.select()]
  185. if rfds:
  186. break
  187. if alive_r in rfds:
  188. # EOF because no more client processes left
  189. assert os.read(alive_r, 1) == b'', "Not at EOF?"
  190. raise SystemExit
  191. if sig_r in rfds:
  192. # Got SIGCHLD
  193. os.read(sig_r, 65536) # exhaust
  194. while True:
  195. # Scan for child processes
  196. try:
  197. pid, sts = os.waitpid(-1, os.WNOHANG)
  198. except ChildProcessError:
  199. break
  200. if pid == 0:
  201. break
  202. child_w = pid_to_fd.pop(pid, None)
  203. if child_w is not None:
  204. returncode = os.waitstatus_to_exitcode(sts)
  205. # Send exit code to client process
  206. try:
  207. write_signed(child_w, returncode)
  208. except BrokenPipeError:
  209. # client vanished
  210. pass
  211. os.close(child_w)
  212. else:
  213. # This shouldn't happen really
  214. warnings.warn('forkserver: waitpid returned '
  215. 'unexpected pid %d' % pid)
  216. if listener in rfds:
  217. # Incoming fork request
  218. with listener.accept()[0] as s:
  219. # Receive fds from client
  220. fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
  221. if len(fds) > MAXFDS_TO_SEND:
  222. raise RuntimeError(
  223. "Too many ({0:n}) fds to send".format(
  224. len(fds)))
  225. child_r, child_w, *fds = fds
  226. s.close()
  227. pid = os.fork()
  228. if pid == 0:
  229. # Child
  230. code = 1
  231. try:
  232. listener.close()
  233. selector.close()
  234. unused_fds = [alive_r, child_w, sig_r, sig_w]
  235. unused_fds.extend(pid_to_fd.values())
  236. code = _serve_one(child_r, fds,
  237. unused_fds,
  238. old_handlers)
  239. except Exception:
  240. sys.excepthook(*sys.exc_info())
  241. sys.stderr.flush()
  242. finally:
  243. os._exit(code)
  244. else:
  245. # Send pid to client process
  246. try:
  247. write_signed(child_w, pid)
  248. except BrokenPipeError:
  249. # client vanished
  250. pass
  251. pid_to_fd[pid] = child_w
  252. os.close(child_r)
  253. for fd in fds:
  254. os.close(fd)
  255. except OSError as e:
  256. if e.errno != errno.ECONNABORTED:
  257. raise
  258. def _serve_one(child_r, fds, unused_fds, handlers):
  259. # close unnecessary stuff and reset signal handlers
  260. signal.set_wakeup_fd(-1)
  261. for sig, val in handlers.items():
  262. signal.signal(sig, val)
  263. for fd in unused_fds:
  264. os.close(fd)
  265. (_forkserver._forkserver_alive_fd,
  266. resource_tracker._resource_tracker._fd,
  267. *_forkserver._inherited_fds) = fds
  268. # Run process object received over pipe
  269. parent_sentinel = os.dup(child_r)
  270. code = spawn._main(child_r, parent_sentinel)
  271. return code
  272. #
  273. # Read and write signed numbers
  274. #
  275. def read_signed(fd):
  276. data = b''
  277. length = SIGNED_STRUCT.size
  278. while len(data) < length:
  279. s = os.read(fd, length - len(data))
  280. if not s:
  281. raise EOFError('unexpected EOF')
  282. data += s
  283. return SIGNED_STRUCT.unpack(data)[0]
  284. def write_signed(fd, n):
  285. msg = SIGNED_STRUCT.pack(n)
  286. while msg:
  287. nbytes = os.write(fd, msg)
  288. if nbytes == 0:
  289. raise RuntimeError('should not get here')
  290. msg = msg[nbytes:]
  291. #
  292. #
  293. #
  294. _forkserver = ForkServer()
  295. ensure_running = _forkserver.ensure_running
  296. get_inherited_fds = _forkserver.get_inherited_fds
  297. connect_to_new_process = _forkserver.connect_to_new_process
  298. set_forkserver_preload = _forkserver.set_forkserver_preload