shared_memory.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. """Provides shared memory for direct access across processes.
  2. The API of this package is currently provisional. Refer to the
  3. documentation for details.
  4. """
  5. __all__ = [ 'SharedMemory', 'ShareableList' ]
  6. from functools import partial
  7. import mmap
  8. import os
  9. import errno
  10. import struct
  11. import secrets
  12. import types
  13. if os.name == "nt":
  14. import _winapi
  15. _USE_POSIX = False
  16. else:
  17. import _posixshmem
  18. _USE_POSIX = True
  19. from . import resource_tracker
  20. _O_CREX = os.O_CREAT | os.O_EXCL
  21. # FreeBSD (and perhaps other BSDs) limit names to 14 characters.
  22. _SHM_SAFE_NAME_LENGTH = 14
  23. # Shared memory block name prefix
  24. if _USE_POSIX:
  25. _SHM_NAME_PREFIX = '/psm_'
  26. else:
  27. _SHM_NAME_PREFIX = 'wnsm_'
  28. def _make_filename():
  29. "Create a random filename for the shared memory object."
  30. # number of random bytes to use for name
  31. nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
  32. assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
  33. name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
  34. assert len(name) <= _SHM_SAFE_NAME_LENGTH
  35. return name
  36. class SharedMemory:
  37. """Creates a new shared memory block or attaches to an existing
  38. shared memory block.
  39. Every shared memory block is assigned a unique name. This enables
  40. one process to create a shared memory block with a particular name
  41. so that a different process can attach to that same shared memory
  42. block using that same name.
  43. As a resource for sharing data across processes, shared memory blocks
  44. may outlive the original process that created them. When one process
  45. no longer needs access to a shared memory block that might still be
  46. needed by other processes, the close() method should be called.
  47. When a shared memory block is no longer needed by any process, the
  48. unlink() method should be called to ensure proper cleanup."""
  49. # Defaults; enables close() and unlink() to run without errors.
  50. _name = None
  51. _fd = -1
  52. _mmap = None
  53. _buf = None
  54. _flags = os.O_RDWR
  55. _mode = 0o600
  56. _prepend_leading_slash = True if _USE_POSIX else False
  57. def __init__(self, name=None, create=False, size=0):
  58. if not size >= 0:
  59. raise ValueError("'size' must be a positive integer")
  60. if create:
  61. self._flags = _O_CREX | os.O_RDWR
  62. if size == 0:
  63. raise ValueError("'size' must be a positive number different from zero")
  64. if name is None and not self._flags & os.O_EXCL:
  65. raise ValueError("'name' can only be None if create=True")
  66. if _USE_POSIX:
  67. # POSIX Shared Memory
  68. if name is None:
  69. while True:
  70. name = _make_filename()
  71. try:
  72. self._fd = _posixshmem.shm_open(
  73. name,
  74. self._flags,
  75. mode=self._mode
  76. )
  77. except FileExistsError:
  78. continue
  79. self._name = name
  80. break
  81. else:
  82. name = "/" + name if self._prepend_leading_slash else name
  83. self._fd = _posixshmem.shm_open(
  84. name,
  85. self._flags,
  86. mode=self._mode
  87. )
  88. self._name = name
  89. try:
  90. if create and size:
  91. os.ftruncate(self._fd, size)
  92. stats = os.fstat(self._fd)
  93. size = stats.st_size
  94. self._mmap = mmap.mmap(self._fd, size)
  95. except OSError:
  96. self.unlink()
  97. raise
  98. resource_tracker.register(self._name, "shared_memory")
  99. else:
  100. # Windows Named Shared Memory
  101. if create:
  102. while True:
  103. temp_name = _make_filename() if name is None else name
  104. # Create and reserve shared memory block with this name
  105. # until it can be attached to by mmap.
  106. h_map = _winapi.CreateFileMapping(
  107. _winapi.INVALID_HANDLE_VALUE,
  108. _winapi.NULL,
  109. _winapi.PAGE_READWRITE,
  110. (size >> 32) & 0xFFFFFFFF,
  111. size & 0xFFFFFFFF,
  112. temp_name
  113. )
  114. try:
  115. last_error_code = _winapi.GetLastError()
  116. if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
  117. if name is not None:
  118. raise FileExistsError(
  119. errno.EEXIST,
  120. os.strerror(errno.EEXIST),
  121. name,
  122. _winapi.ERROR_ALREADY_EXISTS
  123. )
  124. else:
  125. continue
  126. self._mmap = mmap.mmap(-1, size, tagname=temp_name)
  127. finally:
  128. _winapi.CloseHandle(h_map)
  129. self._name = temp_name
  130. break
  131. else:
  132. self._name = name
  133. # Dynamically determine the existing named shared memory
  134. # block's size which is likely a multiple of mmap.PAGESIZE.
  135. h_map = _winapi.OpenFileMapping(
  136. _winapi.FILE_MAP_READ,
  137. False,
  138. name
  139. )
  140. try:
  141. p_buf = _winapi.MapViewOfFile(
  142. h_map,
  143. _winapi.FILE_MAP_READ,
  144. 0,
  145. 0,
  146. 0
  147. )
  148. finally:
  149. _winapi.CloseHandle(h_map)
  150. try:
  151. size = _winapi.VirtualQuerySize(p_buf)
  152. finally:
  153. _winapi.UnmapViewOfFile(p_buf)
  154. self._mmap = mmap.mmap(-1, size, tagname=name)
  155. self._size = size
  156. self._buf = memoryview(self._mmap)
  157. def __del__(self):
  158. try:
  159. self.close()
  160. except OSError:
  161. pass
  162. def __reduce__(self):
  163. return (
  164. self.__class__,
  165. (
  166. self.name,
  167. False,
  168. self.size,
  169. ),
  170. )
  171. def __repr__(self):
  172. return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
  173. @property
  174. def buf(self):
  175. "A memoryview of contents of the shared memory block."
  176. return self._buf
  177. @property
  178. def name(self):
  179. "Unique name that identifies the shared memory block."
  180. reported_name = self._name
  181. if _USE_POSIX and self._prepend_leading_slash:
  182. if self._name.startswith("/"):
  183. reported_name = self._name[1:]
  184. return reported_name
  185. @property
  186. def size(self):
  187. "Size in bytes."
  188. return self._size
  189. def close(self):
  190. """Closes access to the shared memory from this instance but does
  191. not destroy the shared memory block."""
  192. if self._buf is not None:
  193. self._buf.release()
  194. self._buf = None
  195. if self._mmap is not None:
  196. self._mmap.close()
  197. self._mmap = None
  198. if _USE_POSIX and self._fd >= 0:
  199. os.close(self._fd)
  200. self._fd = -1
  201. def unlink(self):
  202. """Requests that the underlying shared memory block be destroyed.
  203. In order to ensure proper cleanup of resources, unlink should be
  204. called once (and only once) across all processes which have access
  205. to the shared memory block."""
  206. if _USE_POSIX and self._name:
  207. _posixshmem.shm_unlink(self._name)
  208. resource_tracker.unregister(self._name, "shared_memory")
  209. _encoding = "utf8"
  210. class ShareableList:
  211. """Pattern for a mutable list-like object shareable via a shared
  212. memory block. It differs from the built-in list type in that these
  213. lists can not change their overall length (i.e. no append, insert,
  214. etc.)
  215. Because values are packed into a memoryview as bytes, the struct
  216. packing format for any storable value must require no more than 8
  217. characters to describe its format."""
  218. # The shared memory area is organized as follows:
  219. # - 8 bytes: number of items (N) as a 64-bit integer
  220. # - (N + 1) * 8 bytes: offsets of each element from the start of the
  221. # data area
  222. # - K bytes: the data area storing item values (with encoding and size
  223. # depending on their respective types)
  224. # - N * 8 bytes: `struct` format string for each element
  225. # - N bytes: index into _back_transforms_mapping for each element
  226. # (for reconstructing the corresponding Python value)
  227. _types_mapping = {
  228. int: "q",
  229. float: "d",
  230. bool: "xxxxxxx?",
  231. str: "%ds",
  232. bytes: "%ds",
  233. None.__class__: "xxxxxx?x",
  234. }
  235. _alignment = 8
  236. _back_transforms_mapping = {
  237. 0: lambda value: value, # int, float, bool
  238. 1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str
  239. 2: lambda value: value.rstrip(b'\x00'), # bytes
  240. 3: lambda _value: None, # None
  241. }
  242. @staticmethod
  243. def _extract_recreation_code(value):
  244. """Used in concert with _back_transforms_mapping to convert values
  245. into the appropriate Python objects when retrieving them from
  246. the list as well as when storing them."""
  247. if not isinstance(value, (str, bytes, None.__class__)):
  248. return 0
  249. elif isinstance(value, str):
  250. return 1
  251. elif isinstance(value, bytes):
  252. return 2
  253. else:
  254. return 3 # NoneType
  255. def __init__(self, sequence=None, *, name=None):
  256. if name is None or sequence is not None:
  257. sequence = sequence or ()
  258. _formats = [
  259. self._types_mapping[type(item)]
  260. if not isinstance(item, (str, bytes))
  261. else self._types_mapping[type(item)] % (
  262. self._alignment * (len(item) // self._alignment + 1),
  263. )
  264. for item in sequence
  265. ]
  266. self._list_len = len(_formats)
  267. assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
  268. offset = 0
  269. # The offsets of each list element into the shared memory's
  270. # data area (0 meaning the start of the data area, not the start
  271. # of the shared memory area).
  272. self._allocated_offsets = [0]
  273. for fmt in _formats:
  274. offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1])
  275. self._allocated_offsets.append(offset)
  276. _recreation_codes = [
  277. self._extract_recreation_code(item) for item in sequence
  278. ]
  279. requested_size = struct.calcsize(
  280. "q" + self._format_size_metainfo +
  281. "".join(_formats) +
  282. self._format_packing_metainfo +
  283. self._format_back_transform_codes
  284. )
  285. self.shm = SharedMemory(name, create=True, size=requested_size)
  286. else:
  287. self.shm = SharedMemory(name)
  288. if sequence is not None:
  289. _enc = _encoding
  290. struct.pack_into(
  291. "q" + self._format_size_metainfo,
  292. self.shm.buf,
  293. 0,
  294. self._list_len,
  295. *(self._allocated_offsets)
  296. )
  297. struct.pack_into(
  298. "".join(_formats),
  299. self.shm.buf,
  300. self._offset_data_start,
  301. *(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
  302. )
  303. struct.pack_into(
  304. self._format_packing_metainfo,
  305. self.shm.buf,
  306. self._offset_packing_formats,
  307. *(v.encode(_enc) for v in _formats)
  308. )
  309. struct.pack_into(
  310. self._format_back_transform_codes,
  311. self.shm.buf,
  312. self._offset_back_transform_codes,
  313. *(_recreation_codes)
  314. )
  315. else:
  316. self._list_len = len(self) # Obtains size from offset 0 in buffer.
  317. self._allocated_offsets = list(
  318. struct.unpack_from(
  319. self._format_size_metainfo,
  320. self.shm.buf,
  321. 1 * 8
  322. )
  323. )
  324. def _get_packing_format(self, position):
  325. "Gets the packing format for a single value stored in the list."
  326. position = position if position >= 0 else position + self._list_len
  327. if (position >= self._list_len) or (self._list_len < 0):
  328. raise IndexError("Requested position out of range.")
  329. v = struct.unpack_from(
  330. "8s",
  331. self.shm.buf,
  332. self._offset_packing_formats + position * 8
  333. )[0]
  334. fmt = v.rstrip(b'\x00')
  335. fmt_as_str = fmt.decode(_encoding)
  336. return fmt_as_str
  337. def _get_back_transform(self, position):
  338. "Gets the back transformation function for a single value."
  339. if (position >= self._list_len) or (self._list_len < 0):
  340. raise IndexError("Requested position out of range.")
  341. transform_code = struct.unpack_from(
  342. "b",
  343. self.shm.buf,
  344. self._offset_back_transform_codes + position
  345. )[0]
  346. transform_function = self._back_transforms_mapping[transform_code]
  347. return transform_function
  348. def _set_packing_format_and_transform(self, position, fmt_as_str, value):
  349. """Sets the packing format and back transformation code for a
  350. single value in the list at the specified position."""
  351. if (position >= self._list_len) or (self._list_len < 0):
  352. raise IndexError("Requested position out of range.")
  353. struct.pack_into(
  354. "8s",
  355. self.shm.buf,
  356. self._offset_packing_formats + position * 8,
  357. fmt_as_str.encode(_encoding)
  358. )
  359. transform_code = self._extract_recreation_code(value)
  360. struct.pack_into(
  361. "b",
  362. self.shm.buf,
  363. self._offset_back_transform_codes + position,
  364. transform_code
  365. )
  366. def __getitem__(self, position):
  367. position = position if position >= 0 else position + self._list_len
  368. try:
  369. offset = self._offset_data_start + self._allocated_offsets[position]
  370. (v,) = struct.unpack_from(
  371. self._get_packing_format(position),
  372. self.shm.buf,
  373. offset
  374. )
  375. except IndexError:
  376. raise IndexError("index out of range")
  377. back_transform = self._get_back_transform(position)
  378. v = back_transform(v)
  379. return v
  380. def __setitem__(self, position, value):
  381. position = position if position >= 0 else position + self._list_len
  382. try:
  383. item_offset = self._allocated_offsets[position]
  384. offset = self._offset_data_start + item_offset
  385. current_format = self._get_packing_format(position)
  386. except IndexError:
  387. raise IndexError("assignment index out of range")
  388. if not isinstance(value, (str, bytes)):
  389. new_format = self._types_mapping[type(value)]
  390. encoded_value = value
  391. else:
  392. allocated_length = self._allocated_offsets[position + 1] - item_offset
  393. encoded_value = (value.encode(_encoding)
  394. if isinstance(value, str) else value)
  395. if len(encoded_value) > allocated_length:
  396. raise ValueError("bytes/str item exceeds available storage")
  397. if current_format[-1] == "s":
  398. new_format = current_format
  399. else:
  400. new_format = self._types_mapping[str] % (
  401. allocated_length,
  402. )
  403. self._set_packing_format_and_transform(
  404. position,
  405. new_format,
  406. value
  407. )
  408. struct.pack_into(new_format, self.shm.buf, offset, encoded_value)
  409. def __reduce__(self):
  410. return partial(self.__class__, name=self.shm.name), ()
  411. def __len__(self):
  412. return struct.unpack_from("q", self.shm.buf, 0)[0]
  413. def __repr__(self):
  414. return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
  415. @property
  416. def format(self):
  417. "The struct packing format used by all currently stored items."
  418. return "".join(
  419. self._get_packing_format(i) for i in range(self._list_len)
  420. )
  421. @property
  422. def _format_size_metainfo(self):
  423. "The struct packing format used for the items' storage offsets."
  424. return "q" * (self._list_len + 1)
  425. @property
  426. def _format_packing_metainfo(self):
  427. "The struct packing format used for the items' packing formats."
  428. return "8s" * self._list_len
  429. @property
  430. def _format_back_transform_codes(self):
  431. "The struct packing format used for the items' back transforms."
  432. return "b" * self._list_len
  433. @property
  434. def _offset_data_start(self):
  435. # - 8 bytes for the list length
  436. # - (N + 1) * 8 bytes for the element offsets
  437. return (self._list_len + 2) * 8
  438. @property
  439. def _offset_packing_formats(self):
  440. return self._offset_data_start + self._allocated_offsets[-1]
  441. @property
  442. def _offset_back_transform_codes(self):
  443. return self._offset_packing_formats + self._list_len * 8
  444. def count(self, value):
  445. "L.count(value) -> integer -- return number of occurrences of value."
  446. return sum(value == entry for entry in self)
  447. def index(self, value):
  448. """L.index(value) -> integer -- return first index of value.
  449. Raises ValueError if the value is not present."""
  450. for position, entry in enumerate(self):
  451. if value == entry:
  452. return position
  453. else:
  454. raise ValueError(f"{value!r} not in this container")
  455. __class_getitem__ = classmethod(types.GenericAlias)