forkserver.py 12 KB


  1. import errno
  2. import os
  3. import selectors
  4. import signal
  5. import socket
  6. import struct
  7. import sys
  8. import threading
  9. import warnings
  10. from . import connection
  11. from . import process
  12. from .context import reduction
  13. from . import resource_tracker
  14. from . import spawn
  15. from . import util
  16. __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
  17. 'set_forkserver_preload']
  18. #
  19. #
  20. #
  21. MAXFDS_TO_SEND = 256
  22. SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
  23. #
  24. # Forkserver class
  25. #
  26. class ForkServer(object):
  27. def __init__(self):
  28. self._forkserver_address = None
  29. self._forkserver_alive_fd = None
  30. self._forkserver_pid = None
  31. self._inherited_fds = None
  32. self._lock = threading.Lock()
  33. self._preload_modules = ['__main__']
  34. def _stop(self):
  35. # Method used by unit tests to stop the server
  36. with self._lock:
  37. self._stop_unlocked()
  38. def _stop_unlocked(self):
  39. if self._forkserver_pid is None:
  40. return
  41. # close the "alive" file descriptor asks the server to stop
  42. os.close(self._forkserver_alive_fd)
  43. self._forkserver_alive_fd = None
  44. os.waitpid(self._forkserver_pid, 0)
  45. self._forkserver_pid = None
  46. if not util.is_abstract_socket_namespace(self._forkserver_address):
  47. os.unlink(self._forkserver_address)
  48. self._forkserver_address = None
  49. def set_forkserver_preload(self, modules_names):
  50. '''Set list of module names to try to load in forkserver process.'''
  51. if not all(type(mod) is str for mod in modules_names):
  52. raise TypeError('module_names must be a list of strings')
  53. self._preload_modules = modules_names
  54. def get_inherited_fds(self):
  55. '''Return list of fds inherited from parent process.
  56. This returns None if the current process was not started by fork
  57. server.
  58. '''
  59. return self._inherited_fds
  60. def connect_to_new_process(self, fds):
  61. '''Request forkserver to create a child process.
  62. Returns a pair of fds (status_r, data_w). The calling process can read
  63. the child process's pid and (eventually) its returncode from status_r.
  64. The calling process should write to data_w the pickled preparation and
  65. process data.
  66. '''
  67. self.ensure_running()
  68. if len(fds) + 4 >= MAXFDS_TO_SEND:
  69. raise ValueError('too many fds')
  70. with socket.socket(socket.AF_UNIX) as client:
  71. client.connect(self._forkserver_address)
  72. parent_r, child_w = os.pipe()
  73. child_r, parent_w = os.pipe()
  74. allfds = [child_r, child_w, self._forkserver_alive_fd,
  75. resource_tracker.getfd()]
  76. allfds += fds
  77. try:
  78. reduction.sendfds(client, allfds)
  79. return parent_r, parent_w
  80. except:
  81. os.close(parent_r)
  82. os.close(parent_w)
  83. raise
  84. finally:
  85. os.close(child_r)
  86. os.close(child_w)
  87. def ensure_running(self):
  88. '''Make sure that a fork server is running.
  89. This can be called from any process. Note that usually a child
  90. process will just reuse the forkserver started by its parent, so
  91. ensure_running() will do nothing.
  92. '''
  93. with self._lock:
  94. resource_tracker.ensure_running()
  95. if self._forkserver_pid is not None:
  96. # forkserver was launched before, is it still running?
  97. pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
  98. if not pid:
  99. # still alive
  100. return
  101. # dead, launch it again
  102. os.close(self._forkserver_alive_fd)
  103. self._forkserver_address = None
  104. self._forkserver_alive_fd = None
  105. self._forkserver_pid = None
  106. cmd = ('from multiprocessing.forkserver import main; ' +
  107. 'main(%d, %d, %r, **%r)')
  108. if self._preload_modules:
  109. desired_keys = {'main_path', 'sys_path'}
  110. data = spawn.get_preparation_data('ignore')
  111. data = {x: y for x, y in data.items() if x in desired_keys}
  112. else:
  113. data = {}
  114. with socket.socket(socket.AF_UNIX) as listener:
  115. address = connection.arbitrary_address('AF_UNIX')
  116. listener.bind(address)
  117. if not util.is_abstract_socket_namespace(address):
  118. os.chmod(address, 0o600)
  119. listener.listen()
  120. # all client processes own the write end of the "alive" pipe;
  121. # when they all terminate the read end becomes ready.
  122. alive_r, alive_w = os.pipe()
  123. try:
  124. fds_to_pass = [listener.fileno(), alive_r]
  125. cmd %= (listener.fileno(), alive_r, self._preload_modules,
  126. data)
  127. exe = spawn.get_executable()
  128. args = [exe] + util._args_from_interpreter_flags()
  129. args += ['-c', cmd]
  130. pid = util.spawnv_passfds(exe, args, fds_to_pass)
  131. except:
  132. os.close(alive_w)
  133. raise
  134. finally:
  135. os.close(alive_r)
  136. self._forkserver_address = address
  137. self._forkserver_alive_fd = alive_w
  138. self._forkserver_pid = pid
  139. #
  140. #
  141. #
  142. def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
  143. '''Run forkserver.'''
  144. if preload:
  145. if sys_path is not None:
  146. sys.path[:] = sys_path
  147. if '__main__' in preload and main_path is not None:
  148. process.current_process()._inheriting = True
  149. try:
  150. spawn.import_main_path(main_path)
  151. finally:
  152. del process.current_process()._inheriting
  153. for modname in preload:
  154. try:
  155. __import__(modname)
  156. except ImportError:
  157. pass
  158. util._close_stdin()
  159. sig_r, sig_w = os.pipe()
  160. os.set_blocking(sig_r, False)
  161. os.set_blocking(sig_w, False)
  162. def sigchld_handler(*_unused):
  163. # Dummy signal handler, doesn't do anything
  164. pass
  165. handlers = {
  166. # unblocking SIGCHLD allows the wakeup fd to notify our event loop
  167. signal.SIGCHLD: sigchld_handler,
  168. # protect the process from ^C
  169. signal.SIGINT: signal.SIG_IGN,
  170. }
  171. old_handlers = {sig: signal.signal(sig, val)
  172. for (sig, val) in handlers.items()}
  173. # calling os.write() in the Python signal handler is racy
  174. signal.set_wakeup_fd(sig_w)
  175. # map child pids to client fds
  176. pid_to_fd = {}
  177. with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
  178. selectors.DefaultSelector() as selector:
  179. _forkserver._forkserver_address = listener.getsockname()
  180. selector.register(listener, selectors.EVENT_READ)
  181. selector.register(alive_r, selectors.EVENT_READ)
  182. selector.register(sig_r, selectors.EVENT_READ)
  183. while True:
  184. try:
  185. while True:
  186. rfds = [key.fileobj for (key, events) in selector.select()]
  187. if rfds:
  188. break
  189. if alive_r in rfds:
  190. # EOF because no more client processes left
  191. assert os.read(alive_r, 1) == b'', "Not at EOF?"
  192. raise SystemExit
  193. if sig_r in rfds:
  194. # Got SIGCHLD
  195. os.read(sig_r, 65536) # exhaust
  196. while True:
  197. # Scan for child processes
  198. try:
  199. pid, sts = os.waitpid(-1, os.WNOHANG)
  200. except ChildProcessError:
  201. break
  202. if pid == 0:
  203. break
  204. child_w = pid_to_fd.pop(pid, None)
  205. if child_w is not None:
  206. returncode = os.waitstatus_to_exitcode(sts)
  207. # Send exit code to client process
  208. try:
  209. write_signed(child_w, returncode)
  210. except BrokenPipeError:
  211. # client vanished
  212. pass
  213. os.close(child_w)
  214. else:
  215. # This shouldn't happen really
  216. warnings.warn('forkserver: waitpid returned '
  217. 'unexpected pid %d' % pid)
  218. if listener in rfds:
  219. # Incoming fork request
  220. with listener.accept()[0] as s:
  221. # Receive fds from client
  222. fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
  223. if len(fds) > MAXFDS_TO_SEND:
  224. raise RuntimeError(
  225. "Too many ({0:n}) fds to send".format(
  226. len(fds)))
  227. child_r, child_w, *fds = fds
  228. s.close()
  229. pid = os.fork()
  230. if pid == 0:
  231. # Child
  232. code = 1
  233. try:
  234. listener.close()
  235. selector.close()
  236. unused_fds = [alive_r, child_w, sig_r, sig_w]
  237. unused_fds.extend(pid_to_fd.values())
  238. code = _serve_one(child_r, fds,
  239. unused_fds,
  240. old_handlers)
  241. except Exception:
  242. sys.excepthook(*sys.exc_info())
  243. sys.stderr.flush()
  244. finally:
  245. os._exit(code)
  246. else:
  247. # Send pid to client process
  248. try:
  249. write_signed(child_w, pid)
  250. except BrokenPipeError:
  251. # client vanished
  252. pass
  253. pid_to_fd[pid] = child_w
  254. os.close(child_r)
  255. for fd in fds:
  256. os.close(fd)
  257. except OSError as e:
  258. if e.errno != errno.ECONNABORTED:
  259. raise
  260. def _serve_one(child_r, fds, unused_fds, handlers):
  261. # close unnecessary stuff and reset signal handlers
  262. signal.set_wakeup_fd(-1)
  263. for sig, val in handlers.items():
  264. signal.signal(sig, val)
  265. for fd in unused_fds:
  266. os.close(fd)
  267. (_forkserver._forkserver_alive_fd,
  268. resource_tracker._resource_tracker._fd,
  269. *_forkserver._inherited_fds) = fds
  270. # Run process object received over pipe
  271. parent_sentinel = os.dup(child_r)
  272. code = spawn._main(child_r, parent_sentinel)
  273. return code
  274. #
  275. # Read and write signed numbers
  276. #
  277. def read_signed(fd):
  278. data = b''
  279. length = SIGNED_STRUCT.size
  280. while len(data) < length:
  281. s = os.read(fd, length - len(data))
  282. if not s:
  283. raise EOFError('unexpected EOF')
  284. data += s
  285. return SIGNED_STRUCT.unpack(data)[0]
  286. def write_signed(fd, n):
  287. msg = SIGNED_STRUCT.pack(n)
  288. while msg:
  289. nbytes = os.write(fd, msg)
  290. if nbytes == 0:
  291. raise RuntimeError('should not get here')
  292. msg = msg[nbytes:]
  293. #
  294. #
  295. #
  296. _forkserver = ForkServer()
  297. ensure_running = _forkserver.ensure_running
  298. get_inherited_fds = _forkserver.get_inherited_fds
  299. connect_to_new_process = _forkserver.connect_to_new_process
  300. set_forkserver_preload = _forkserver.set_forkserver_preload