styles.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # -*- test-case-name: twisted.test.test_persisted -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Different styles of persisted objects.
  6. """
  7. import copy
  8. import copyreg as copy_reg
  9. import inspect
  10. import pickle
  11. import types
  12. from io import StringIO as _cStringIO
  13. from typing import Dict
  14. from twisted.python import log, reflect
  15. from twisted.python.compat import _PYPY
  16. oldModules: Dict[str, types.ModuleType] = {}
  17. _UniversalPicklingError = pickle.PicklingError
  18. def pickleMethod(method):
  19. "support function for copy_reg to pickle method refs"
  20. return (
  21. unpickleMethod,
  22. (method.__name__, method.__self__, method.__self__.__class__),
  23. )
  24. def _methodFunction(classObject, methodName):
  25. """
  26. Retrieve the function object implementing a method name given the class
  27. it's on and a method name.
  28. @param classObject: A class to retrieve the method's function from.
  29. @type classObject: L{type}
  30. @param methodName: The name of the method whose function to retrieve.
  31. @type methodName: native L{str}
  32. @return: the function object corresponding to the given method name.
  33. @rtype: L{types.FunctionType}
  34. """
  35. methodObject = getattr(classObject, methodName)
  36. return methodObject
  37. def unpickleMethod(im_name, im_self, im_class):
  38. """
  39. Support function for copy_reg to unpickle method refs.
  40. @param im_name: The name of the method.
  41. @type im_name: native L{str}
  42. @param im_self: The instance that the method was present on.
  43. @type im_self: L{object}
  44. @param im_class: The class where the method was declared.
  45. @type im_class: L{type} or L{None}
  46. """
  47. if im_self is None:
  48. return getattr(im_class, im_name)
  49. try:
  50. methodFunction = _methodFunction(im_class, im_name)
  51. except AttributeError:
  52. log.msg("Method", im_name, "not on class", im_class)
  53. assert im_self is not None, "No recourse: no instance to guess from."
  54. # Attempt a last-ditch fix before giving up. If classes have changed
  55. # around since we pickled this method, we may still be able to get it
  56. # by looking on the instance's current class.
  57. if im_self.__class__ is im_class:
  58. raise
  59. return unpickleMethod(im_name, im_self, im_self.__class__)
  60. else:
  61. maybeClass = ()
  62. bound = types.MethodType(methodFunction, im_self, *maybeClass)
  63. return bound
  64. copy_reg.pickle(types.MethodType, pickleMethod)
  65. def _pickleFunction(f):
  66. """
  67. Reduce, in the sense of L{pickle}'s C{object.__reduce__} special method, a
  68. function object into its constituent parts.
  69. @param f: The function to reduce.
  70. @type f: L{types.FunctionType}
  71. @return: a 2-tuple of a reference to L{_unpickleFunction} and a tuple of
  72. its arguments, a 1-tuple of the function's fully qualified name.
  73. @rtype: 2-tuple of C{callable, native string}
  74. """
  75. if f.__name__ == "<lambda>":
  76. raise _UniversalPicklingError(f"Cannot pickle lambda function: {f}")
  77. return (_unpickleFunction, tuple([".".join([f.__module__, f.__qualname__])]))
  78. def _unpickleFunction(fullyQualifiedName):
  79. """
  80. Convert a function name into a function by importing it.
  81. This is a synonym for L{twisted.python.reflect.namedAny}, but imported
  82. locally to avoid circular imports, and also to provide a persistent name
  83. that can be stored (and deprecated) independently of C{namedAny}.
  84. @param fullyQualifiedName: The fully qualified name of a function.
  85. @type fullyQualifiedName: native C{str}
  86. @return: A function object imported from the given location.
  87. @rtype: L{types.FunctionType}
  88. """
  89. from twisted.python.reflect import namedAny
  90. return namedAny(fullyQualifiedName)
  91. copy_reg.pickle(types.FunctionType, _pickleFunction)
  92. def pickleModule(module):
  93. "support function for copy_reg to pickle module refs"
  94. return unpickleModule, (module.__name__,)
  95. def unpickleModule(name):
  96. "support function for copy_reg to unpickle module refs"
  97. if name in oldModules:
  98. log.msg("Module has moved: %s" % name)
  99. name = oldModules[name]
  100. log.msg(name)
  101. return __import__(name, {}, {}, "x")
  102. copy_reg.pickle(types.ModuleType, pickleModule)
  103. def pickleStringO(stringo):
  104. """
  105. Reduce the given cStringO.
  106. This is only called on Python 2, because the cStringIO module only exists
  107. on Python 2.
  108. @param stringo: The string output to pickle.
  109. @type stringo: C{cStringIO.OutputType}
  110. """
  111. "support function for copy_reg to pickle StringIO.OutputTypes"
  112. return unpickleStringO, (stringo.getvalue(), stringo.tell())
  113. def unpickleStringO(val, sek):
  114. """
  115. Convert the output of L{pickleStringO} into an appropriate type for the
  116. current python version. This may be called on Python 3 and will convert a
  117. cStringIO into an L{io.StringIO}.
  118. @param val: The content of the file.
  119. @type val: L{bytes}
  120. @param sek: The seek position of the file.
  121. @type sek: L{int}
  122. @return: a file-like object which you can write bytes to.
  123. @rtype: C{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
  124. """
  125. x = _cStringIO()
  126. x.write(val)
  127. x.seek(sek)
  128. return x
  129. def pickleStringI(stringi):
  130. """
  131. Reduce the given cStringI.
  132. This is only called on Python 2, because the cStringIO module only exists
  133. on Python 2.
  134. @param stringi: The string input to pickle.
  135. @type stringi: C{cStringIO.InputType}
  136. @return: a 2-tuple of (C{unpickleStringI}, (bytes, pointer))
  137. @rtype: 2-tuple of (function, (bytes, int))
  138. """
  139. return unpickleStringI, (stringi.getvalue(), stringi.tell())
  140. def unpickleStringI(val, sek):
  141. """
  142. Convert the output of L{pickleStringI} into an appropriate type for the
  143. current Python version.
  144. This may be called on Python 3 and will convert a cStringIO into an
  145. L{io.StringIO}.
  146. @param val: The content of the file.
  147. @type val: L{bytes}
  148. @param sek: The seek position of the file.
  149. @type sek: L{int}
  150. @return: a file-like object which you can read bytes from.
  151. @rtype: C{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
  152. """
  153. x = _cStringIO(val)
  154. x.seek(sek)
  155. return x
  156. class Ephemeral:
  157. """
  158. This type of object is never persisted; if possible, even references to it
  159. are eliminated.
  160. """
  161. def __reduce__(self):
  162. """
  163. Serialize any subclass of L{Ephemeral} in a way which replaces it with
  164. L{Ephemeral} itself.
  165. """
  166. return (Ephemeral, ())
  167. def __getstate__(self):
  168. log.msg("WARNING: serializing ephemeral %s" % self)
  169. if not _PYPY:
  170. import gc
  171. if getattr(gc, "get_referrers", None):
  172. for r in gc.get_referrers(self):
  173. log.msg(f" referred to by {r}")
  174. return None
  175. def __setstate__(self, state):
  176. log.msg("WARNING: unserializing ephemeral %s" % self.__class__)
  177. self.__class__ = Ephemeral
  178. versionedsToUpgrade: Dict[int, "Versioned"] = {}
  179. upgraded = {}
  180. def doUpgrade():
  181. global versionedsToUpgrade, upgraded
  182. for versioned in list(versionedsToUpgrade.values()):
  183. requireUpgrade(versioned)
  184. versionedsToUpgrade = {}
  185. upgraded = {}
  186. def requireUpgrade(obj):
  187. """Require that a Versioned instance be upgraded completely first."""
  188. objID = id(obj)
  189. if objID in versionedsToUpgrade and objID not in upgraded:
  190. upgraded[objID] = 1
  191. obj.versionUpgrade()
  192. return obj
  193. def _aybabtu(c):
  194. """
  195. Get all of the parent classes of C{c}, not including C{c} itself, which are
  196. strict subclasses of L{Versioned}.
  197. @param c: a class
  198. @returns: list of classes
  199. """
  200. # begin with two classes that should *not* be included in the
  201. # final result
  202. l = [c, Versioned]
  203. for b in inspect.getmro(c):
  204. if b not in l and issubclass(b, Versioned):
  205. l.append(b)
  206. # return all except the unwanted classes
  207. return l[2:]
  208. class Versioned:
  209. """
  210. This type of object is persisted with versioning information.
  211. I have a single class attribute, the int persistenceVersion. After I am
  212. unserialized (and styles.doUpgrade() is called), self.upgradeToVersionX()
  213. will be called for each version upgrade I must undergo.
  214. For example, if I serialize an instance of a Foo(Versioned) at version 4
  215. and then unserialize it when the code is at version 9, the calls::
  216. self.upgradeToVersion5()
  217. self.upgradeToVersion6()
  218. self.upgradeToVersion7()
  219. self.upgradeToVersion8()
  220. self.upgradeToVersion9()
  221. will be made. If any of these methods are undefined, a warning message
  222. will be printed.
  223. """
  224. persistenceVersion = 0
  225. persistenceForgets = ()
  226. def __setstate__(self, state):
  227. versionedsToUpgrade[id(self)] = self
  228. self.__dict__ = state
  229. def __getstate__(self, dict=None):
  230. """Get state, adding a version number to it on its way out."""
  231. dct = copy.copy(dict or self.__dict__)
  232. bases = _aybabtu(self.__class__)
  233. bases.reverse()
  234. bases.append(self.__class__) # don't forget me!!
  235. for base in bases:
  236. if "persistenceForgets" in base.__dict__:
  237. for slot in base.persistenceForgets:
  238. if slot in dct:
  239. del dct[slot]
  240. if "persistenceVersion" in base.__dict__:
  241. dct[
  242. f"{reflect.qual(base)}.persistenceVersion"
  243. ] = base.persistenceVersion
  244. return dct
  245. def versionUpgrade(self):
  246. """(internal) Do a version upgrade."""
  247. bases = _aybabtu(self.__class__)
  248. # put the bases in order so superclasses' persistenceVersion methods
  249. # will be called first.
  250. bases.reverse()
  251. bases.append(self.__class__) # don't forget me!!
  252. # first let's look for old-skool versioned's
  253. if "persistenceVersion" in self.__dict__:
  254. # Hacky heuristic: if more than one class subclasses Versioned,
  255. # we'll assume that the higher version number wins for the older
  256. # class, so we'll consider the attribute the version of the older
  257. # class. There are obviously possibly times when this will
  258. # eventually be an incorrect assumption, but hopefully old-school
  259. # persistenceVersion stuff won't make it that far into multiple
  260. # classes inheriting from Versioned.
  261. pver = self.__dict__["persistenceVersion"]
  262. del self.__dict__["persistenceVersion"]
  263. highestVersion = 0
  264. highestBase = None
  265. for base in bases:
  266. if "persistenceVersion" not in base.__dict__:
  267. continue
  268. if base.persistenceVersion > highestVersion:
  269. highestBase = base
  270. highestVersion = base.persistenceVersion
  271. if highestBase:
  272. self.__dict__[
  273. "%s.persistenceVersion" % reflect.qual(highestBase)
  274. ] = pver
  275. for base in bases:
  276. # ugly hack, but it's what the user expects, really
  277. if (
  278. Versioned not in base.__bases__
  279. and "persistenceVersion" not in base.__dict__
  280. ):
  281. continue
  282. currentVers = base.persistenceVersion
  283. pverName = "%s.persistenceVersion" % reflect.qual(base)
  284. persistVers = self.__dict__.get(pverName) or 0
  285. if persistVers:
  286. del self.__dict__[pverName]
  287. assert persistVers <= currentVers, "Sorry, can't go backwards in time."
  288. while persistVers < currentVers:
  289. persistVers = persistVers + 1
  290. method = base.__dict__.get("upgradeToVersion%s" % persistVers, None)
  291. if method:
  292. log.msg(
  293. "Upgrading %s (of %s @ %s) to version %s"
  294. % (
  295. reflect.qual(base),
  296. reflect.qual(self.__class__),
  297. id(self),
  298. persistVers,
  299. )
  300. )
  301. method(self)
  302. else:
  303. log.msg(
  304. "Warning: cannot upgrade {} to version {}".format(
  305. base, persistVers
  306. )
  307. )