class_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. from pyrsistent._compat import Hashable
  2. import math
  3. import pickle
  4. import pytest
  5. import sys
  6. import uuid
  7. from pyrsistent import (
  8. field, InvariantException, PClass, optional, CheckedPVector,
  9. pmap_field, pset_field, pvector_field)
  10. class Point(PClass):
  11. x = field(type=int, mandatory=True, invariant=lambda x: (x >= 0, 'X negative'))
  12. y = field(type=int, serializer=lambda formatter, y: formatter(y))
  13. z = field(type=int, initial=0)
  14. class Hierarchy(PClass):
  15. point = field(type=Point)
  16. class TypedContainerObj(PClass):
  17. map = pmap_field(str, str)
  18. set = pset_field(str)
  19. vec = pvector_field(str)
  20. class UniqueThing(PClass):
  21. id = field(type=uuid.UUID, factory=uuid.UUID)
  22. x = field(type=int)
  23. def test_create_ignore_extra():
  24. p = Point.create({'x': 5, 'y': 10, 'z': 15, 'a': 0}, ignore_extra=True)
  25. assert p.x == 5
  26. assert p.y == 10
  27. assert p.z == 15
  28. assert isinstance(p, Point)
  29. def test_create_ignore_extra_false():
  30. with pytest.raises(AttributeError):
  31. _ = Point.create({'x': 5, 'y': 10, 'z': 15, 'a': 0})
  32. def test_create_ignore_extra_true():
  33. h = Hierarchy.create(
  34. {'point': {'x': 5, 'y': 10, 'z': 15, 'extra_field_0': 'extra_data_0'}, 'extra_field_1': 'extra_data_1'},
  35. ignore_extra=True)
  36. assert isinstance(h, Hierarchy)
  37. def test_evolve_pclass_instance():
  38. p = Point(x=1, y=2)
  39. p2 = p.set(x=p.x+2)
  40. # Original remains
  41. assert p.x == 1
  42. assert p.y == 2
  43. # Evolved object updated
  44. assert p2.x == 3
  45. assert p2.y == 2
  46. p3 = p2.set('x', 4)
  47. assert p3.x == 4
  48. assert p3.y == 2
  49. def test_direct_assignment_not_possible():
  50. p = Point(x=1, y=2)
  51. with pytest.raises(AttributeError):
  52. p.x = 1
  53. with pytest.raises(AttributeError):
  54. setattr(p, 'x', 1)
  55. def test_direct_delete_not_possible():
  56. p = Point(x=1, y=2)
  57. with pytest.raises(AttributeError):
  58. del p.x
  59. with pytest.raises(AttributeError):
  60. delattr(p, 'x')
  61. def test_cannot_construct_with_undeclared_fields():
  62. with pytest.raises(AttributeError):
  63. Point(x=1, p=5)
  64. def test_cannot_construct_with_wrong_type():
  65. with pytest.raises(TypeError):
  66. Point(x='a')
  67. def test_cannot_construct_without_mandatory_fields():
  68. try:
  69. Point(y=1)
  70. assert False
  71. except InvariantException as e:
  72. assert "[Point.x]" in str(e)
  73. def test_field_invariant_must_hold():
  74. try:
  75. Point(x=-1)
  76. assert False
  77. except InvariantException as e:
  78. assert "X negative" in str(e)
  79. def test_initial_value_set_when_not_present_in_arguments():
  80. p = Point(x=1, y=2)
  81. assert p.z == 0
  82. class Line(PClass):
  83. p1 = field(type=Point)
  84. p2 = field(type=Point)
  85. def test_can_create_nested_structures_from_dict_and_serialize_back_to_dict():
  86. source = dict(p1=dict(x=1, y=2, z=3), p2=dict(x=10, y=20, z=30))
  87. l = Line.create(source)
  88. assert l.p1.x == 1
  89. assert l.p1.y == 2
  90. assert l.p1.z == 3
  91. assert l.p2.x == 10
  92. assert l.p2.y == 20
  93. assert l.p2.z == 30
  94. assert l.serialize(format=lambda val: val) == source
  95. def test_can_serialize_with_custom_serializer():
  96. p = Point(x=1, y=1, z=1)
  97. assert p.serialize(format=lambda v: v + 17) == {'x': 1, 'y': 18, 'z': 1}
  98. def test_implements_proper_equality_based_on_equality_of_fields():
  99. p1 = Point(x=1, y=2)
  100. p2 = Point(x=3)
  101. p3 = Point(x=1, y=2)
  102. assert p1 == p3
  103. assert not p1 != p3
  104. assert p1 != p2
  105. assert not p1 == p2
  106. def test_is_hashable():
  107. p1 = Point(x=1, y=2)
  108. p2 = Point(x=3, y=2)
  109. d = {p1: 'A point', p2: 'Another point'}
  110. p1_like = Point(x=1, y=2)
  111. p2_like = Point(x=3, y=2)
  112. assert isinstance(p1, Hashable)
  113. assert d[p1_like] == 'A point'
  114. assert d[p2_like] == 'Another point'
  115. assert Point(x=10) not in d
  116. def test_supports_nested_transformation():
  117. l1 = Line(p1=Point(x=2, y=1), p2=Point(x=20, y=10))
  118. l2 = l1.transform(['p1', 'x'], 3)
  119. assert l1.p1.x == 2
  120. assert l2.p1.x == 3
  121. assert l2.p1.y == 1
  122. assert l2.p2.x == 20
  123. assert l2.p2.y == 10
  124. def test_repr():
  125. class ARecord(PClass):
  126. a = field()
  127. b = field()
  128. assert repr(ARecord(a=1, b=2)) in ('ARecord(a=1, b=2)', 'ARecord(b=2, a=1)')
  129. def test_global_invariant_check():
  130. class UnitCirclePoint(PClass):
  131. __invariant__ = lambda cp: (0.99 < math.sqrt(cp.x*cp.x + cp.y*cp.y) < 1.01,
  132. "Point not on unit circle")
  133. x = field(type=float)
  134. y = field(type=float)
  135. UnitCirclePoint(x=1.0, y=0.0)
  136. with pytest.raises(InvariantException):
  137. UnitCirclePoint(x=1.0, y=1.0)
  138. def test_supports_pickling():
  139. p1 = Point(x=2, y=1)
  140. p2 = pickle.loads(pickle.dumps(p1, -1))
  141. assert p1 == p2
  142. assert isinstance(p2, Point)
  143. def test_supports_pickling_with_typed_container_fields():
  144. obj = TypedContainerObj(map={'foo': 'bar'}, set=['hello', 'there'], vec=['a', 'b'])
  145. obj2 = pickle.loads(pickle.dumps(obj))
  146. assert obj == obj2
  147. def test_can_remove_optional_member():
  148. p1 = Point(x=1, y=2)
  149. p2 = p1.remove('y')
  150. assert p2 == Point(x=1)
  151. def test_cannot_remove_mandatory_member():
  152. p1 = Point(x=1, y=2)
  153. with pytest.raises(InvariantException):
  154. p1.remove('x')
  155. def test_cannot_remove_non_existing_member():
  156. p1 = Point(x=1)
  157. with pytest.raises(AttributeError):
  158. p1.remove('y')
  159. def test_evolver_without_evolution_returns_original_instance():
  160. p1 = Point(x=1)
  161. e = p1.evolver()
  162. assert e.persistent() is p1
  163. def test_evolver_with_evolution_to_same_element_returns_original_instance():
  164. p1 = Point(x=1)
  165. e = p1.evolver()
  166. e.set('x', p1.x)
  167. assert e.persistent() is p1
  168. def test_evolver_supports_chained_set_and_remove():
  169. p1 = Point(x=1, y=2)
  170. assert p1.evolver().set('x', 3).remove('y').persistent() == Point(x=3)
  171. def test_evolver_supports_dot_notation_for_setting_and_getting_elements():
  172. e = Point(x=1, y=2).evolver()
  173. e.x = 3
  174. assert e.x == 3
  175. assert e.persistent() == Point(x=3, y=2)
  176. class Numbers(CheckedPVector):
  177. __type__ = int
  178. class LinkedList(PClass):
  179. value = field(type='__tests__.class_test.Numbers')
  180. next = field(type=optional('__tests__.class_test.LinkedList'))
  181. def test_string_as_type_specifier():
  182. l = LinkedList(value=[1, 2], next=LinkedList(value=[3, 4], next=None))
  183. assert isinstance(l.value, Numbers)
  184. assert list(l.value) == [1, 2]
  185. assert l.next.next is None
  186. def test_multiple_invariants_on_field():
  187. # If the invariant returns a list of tests the results of running those tests will be
  188. # a tuple containing result data of all failing tests.
  189. class MultiInvariantField(PClass):
  190. one = field(type=int, invariant=lambda x: ((False, 'one_one'),
  191. (False, 'one_two'),
  192. (True, 'one_three')))
  193. two = field(invariant=lambda x: (False, 'two_one'))
  194. try:
  195. MultiInvariantField(one=1, two=2)
  196. assert False
  197. except InvariantException as e:
  198. assert set(e.invariant_errors) == set([('one_one', 'one_two'), 'two_one'])
  199. def test_multiple_global_invariants():
  200. class MultiInvariantGlobal(PClass):
  201. __invariant__ = lambda self: ((False, 'x'), (False, 'y'))
  202. one = field()
  203. try:
  204. MultiInvariantGlobal(one=1)
  205. assert False
  206. except InvariantException as e:
  207. assert e.invariant_errors == (('x', 'y'),)
  208. def test_inherited_global_invariants():
  209. class Distant(object):
  210. def __invariant__(self):
  211. return [(self.distant, "distant")]
  212. class Nearby(Distant):
  213. def __invariant__(self):
  214. return [(self.nearby, "nearby")]
  215. class MultipleInvariantGlobal(Nearby, PClass):
  216. distant = field()
  217. nearby = field()
  218. try:
  219. MultipleInvariantGlobal(distant=False, nearby=False)
  220. assert False
  221. except InvariantException as e:
  222. assert e.invariant_errors == (("nearby",), ("distant",),)
  223. def test_diamond_inherited_global_invariants():
  224. counter = []
  225. class Base(object):
  226. def __invariant__(self):
  227. counter.append(None)
  228. return [(False, "base")]
  229. class Left(Base):
  230. pass
  231. class Right(Base):
  232. pass
  233. class SingleInvariantGlobal(Left, Right, PClass):
  234. pass
  235. try:
  236. SingleInvariantGlobal()
  237. assert False
  238. except InvariantException as e:
  239. assert e.invariant_errors == (("base",),)
  240. assert counter == [None]
  241. def test_supports_weakref():
  242. import weakref
  243. weakref.ref(Point(x=1, y=2))
  244. def test_supports_weakref_with_multi_level_inheritance():
  245. import weakref
  246. class PPoint(Point):
  247. a = field()
  248. weakref.ref(PPoint(x=1, y=2))
  249. def test_supports_lazy_initial_value_for_field():
  250. class MyClass(PClass):
  251. a = field(int, initial=lambda: 2)
  252. assert MyClass() == MyClass(a=2)
  253. def test_type_checks_lazy_initial_value_for_field():
  254. class MyClass(PClass):
  255. a = field(int, initial=lambda: "a")
  256. with pytest.raises(TypeError):
  257. MyClass()
  258. def test_invariant_checks_lazy_initial_value_for_field():
  259. class MyClass(PClass):
  260. a = field(int, invariant=lambda x: (x < 5, "Too large"), initial=lambda: 10)
  261. with pytest.raises(InvariantException):
  262. MyClass()
  263. def test_invariant_checks_static_initial_value():
  264. class MyClass(PClass):
  265. a = field(int, invariant=lambda x: (x < 5, "Too large"), initial=10)
  266. with pytest.raises(InvariantException):
  267. MyClass()
  268. def test_lazy_invariant_message():
  269. class MyClass(PClass):
  270. a = field(int, invariant=lambda x: (x < 5, lambda: "{x} is too large".format(x=x)))
  271. try:
  272. MyClass(a=5)
  273. assert False
  274. except InvariantException as e:
  275. assert '5 is too large' in e.invariant_errors
  276. # Skipping this test for now but it describes a corner case with using Enums in
  277. # python 3 as types and a workaround to make it work.
  278. @pytest.mark.skipif(sys.version_info < (3, 4) or True, reason="requires python3.4")
  279. def test_enum_key_type():
  280. import enum
  281. class Foo(enum.Enum):
  282. Bar = 1
  283. Baz = 2
  284. # This currently fails because the enum is iterable
  285. class MyClass1(PClass):
  286. f = pmap_field(key_type=Foo, value_type=int)
  287. MyClass1()
  288. # This is OK since it's wrapped in a tuple
  289. class MyClass2(PClass):
  290. f = pmap_field(key_type=(Foo,), value_type=int)
  291. MyClass2()
  292. def test_pickle_with_one_way_factory():
  293. thing = UniqueThing(id='25544626-86da-4bce-b6b6-9186c0804d64')
  294. assert pickle.loads(pickle.dumps(thing)) == thing
  295. def test_evolver_with_one_way_factory():
  296. thing = UniqueThing(id='cc65249a-56fe-4995-8719-ea02e124b234')
  297. ev = thing.evolver()
  298. ev.x = 5 # necessary to prevent persistent() returning the original
  299. assert ev.persistent() == UniqueThing(id=str(thing.id), x=5)
  300. def test_set_doesnt_trigger_other_factories():
  301. thing = UniqueThing(id='b413b280-de76-4e28-a8e3-5470ca83ea2c')
  302. thing.set(x=5)
  303. def test_set_does_trigger_factories():
  304. class SquaredPoint(PClass):
  305. x = field(factory=lambda x: x ** 2)
  306. y = field()
  307. sp = SquaredPoint(x=3, y=10)
  308. assert (sp.x, sp.y) == (9, 10)
  309. sp2 = sp.set(x=4)
  310. assert (sp2.x, sp2.y) == (16, 10)
  311. def test_value_can_be_overridden_in_subclass_new():
  312. class X(PClass):
  313. y = pvector_field(int)
  314. def __new__(cls, **kwargs):
  315. items = kwargs.get('y', None)
  316. if items is None:
  317. kwargs['y'] = ()
  318. return super(X, cls).__new__(cls, **kwargs)
  319. a = X(y=[])
  320. b = a.set(y=None)
  321. assert a == b