class_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. from collections.abc import Hashable
  2. import math
  3. import pickle
  4. import pytest
  5. import uuid
  6. from pyrsistent import (
  7. field, InvariantException, PClass, optional, CheckedPVector,
  8. pmap_field, pset_field, pvector_field)
  9. class Point(PClass):
  10. x = field(type=int, mandatory=True, invariant=lambda x: (x >= 0, 'X negative'))
  11. y = field(type=int, serializer=lambda formatter, y: formatter(y))
  12. z = field(type=int, initial=0)
  13. class Hierarchy(PClass):
  14. point = field(type=Point)
  15. class TypedContainerObj(PClass):
  16. map = pmap_field(str, str)
  17. set = pset_field(str)
  18. vec = pvector_field(str)
  19. class UniqueThing(PClass):
  20. id = field(type=uuid.UUID, factory=uuid.UUID)
  21. x = field(type=int)
  22. def test_create_ignore_extra():
  23. p = Point.create({'x': 5, 'y': 10, 'z': 15, 'a': 0}, ignore_extra=True)
  24. assert p.x == 5
  25. assert p.y == 10
  26. assert p.z == 15
  27. assert isinstance(p, Point)
  28. def test_create_ignore_extra_false():
  29. with pytest.raises(AttributeError):
  30. _ = Point.create({'x': 5, 'y': 10, 'z': 15, 'a': 0})
  31. def test_create_ignore_extra_true():
  32. h = Hierarchy.create(
  33. {'point': {'x': 5, 'y': 10, 'z': 15, 'extra_field_0': 'extra_data_0'}, 'extra_field_1': 'extra_data_1'},
  34. ignore_extra=True)
  35. assert isinstance(h, Hierarchy)
  36. def test_evolve_pclass_instance():
  37. p = Point(x=1, y=2)
  38. p2 = p.set(x=p.x+2)
  39. # Original remains
  40. assert p.x == 1
  41. assert p.y == 2
  42. # Evolved object updated
  43. assert p2.x == 3
  44. assert p2.y == 2
  45. p3 = p2.set('x', 4)
  46. assert p3.x == 4
  47. assert p3.y == 2
  48. def test_direct_assignment_not_possible():
  49. p = Point(x=1, y=2)
  50. with pytest.raises(AttributeError):
  51. p.x = 1
  52. with pytest.raises(AttributeError):
  53. setattr(p, 'x', 1)
  54. def test_direct_delete_not_possible():
  55. p = Point(x=1, y=2)
  56. with pytest.raises(AttributeError):
  57. del p.x
  58. with pytest.raises(AttributeError):
  59. delattr(p, 'x')
  60. def test_cannot_construct_with_undeclared_fields():
  61. with pytest.raises(AttributeError):
  62. Point(x=1, p=5)
  63. def test_cannot_construct_with_wrong_type():
  64. with pytest.raises(TypeError):
  65. Point(x='a')
  66. def test_cannot_construct_without_mandatory_fields():
  67. try:
  68. Point(y=1)
  69. assert False
  70. except InvariantException as e:
  71. assert "[Point.x]" in str(e)
  72. def test_field_invariant_must_hold():
  73. try:
  74. Point(x=-1)
  75. assert False
  76. except InvariantException as e:
  77. assert "X negative" in str(e)
  78. def test_initial_value_set_when_not_present_in_arguments():
  79. p = Point(x=1, y=2)
  80. assert p.z == 0
  81. class Line(PClass):
  82. p1 = field(type=Point)
  83. p2 = field(type=Point)
  84. def test_can_create_nested_structures_from_dict_and_serialize_back_to_dict():
  85. source = dict(p1=dict(x=1, y=2, z=3), p2=dict(x=10, y=20, z=30))
  86. l = Line.create(source)
  87. assert l.p1.x == 1
  88. assert l.p1.y == 2
  89. assert l.p1.z == 3
  90. assert l.p2.x == 10
  91. assert l.p2.y == 20
  92. assert l.p2.z == 30
  93. assert l.serialize(format=lambda val: val) == source
  94. def test_can_serialize_with_custom_serializer():
  95. p = Point(x=1, y=1, z=1)
  96. assert p.serialize(format=lambda v: v + 17) == {'x': 1, 'y': 18, 'z': 1}
  97. def test_implements_proper_equality_based_on_equality_of_fields():
  98. p1 = Point(x=1, y=2)
  99. p2 = Point(x=3)
  100. p3 = Point(x=1, y=2)
  101. assert p1 == p3
  102. assert not p1 != p3
  103. assert p1 != p2
  104. assert not p1 == p2
  105. def test_is_hashable():
  106. p1 = Point(x=1, y=2)
  107. p2 = Point(x=3, y=2)
  108. d = {p1: 'A point', p2: 'Another point'}
  109. p1_like = Point(x=1, y=2)
  110. p2_like = Point(x=3, y=2)
  111. assert isinstance(p1, Hashable)
  112. assert d[p1_like] == 'A point'
  113. assert d[p2_like] == 'Another point'
  114. assert Point(x=10) not in d
  115. def test_supports_nested_transformation():
  116. l1 = Line(p1=Point(x=2, y=1), p2=Point(x=20, y=10))
  117. l2 = l1.transform(['p1', 'x'], 3)
  118. assert l1.p1.x == 2
  119. assert l2.p1.x == 3
  120. assert l2.p1.y == 1
  121. assert l2.p2.x == 20
  122. assert l2.p2.y == 10
  123. def test_repr():
  124. class ARecord(PClass):
  125. a = field()
  126. b = field()
  127. assert repr(ARecord(a=1, b=2)) in ('ARecord(a=1, b=2)', 'ARecord(b=2, a=1)')
  128. def test_global_invariant_check():
  129. class UnitCirclePoint(PClass):
  130. __invariant__ = lambda cp: (0.99 < math.sqrt(cp.x*cp.x + cp.y*cp.y) < 1.01,
  131. "Point not on unit circle")
  132. x = field(type=float)
  133. y = field(type=float)
  134. UnitCirclePoint(x=1.0, y=0.0)
  135. with pytest.raises(InvariantException):
  136. UnitCirclePoint(x=1.0, y=1.0)
  137. def test_supports_pickling():
  138. p1 = Point(x=2, y=1)
  139. p2 = pickle.loads(pickle.dumps(p1, -1))
  140. assert p1 == p2
  141. assert isinstance(p2, Point)
  142. def test_supports_pickling_with_typed_container_fields():
  143. obj = TypedContainerObj(map={'foo': 'bar'}, set=['hello', 'there'], vec=['a', 'b'])
  144. obj2 = pickle.loads(pickle.dumps(obj))
  145. assert obj == obj2
  146. def test_can_remove_optional_member():
  147. p1 = Point(x=1, y=2)
  148. p2 = p1.remove('y')
  149. assert p2 == Point(x=1)
  150. def test_cannot_remove_mandatory_member():
  151. p1 = Point(x=1, y=2)
  152. with pytest.raises(InvariantException):
  153. p1.remove('x')
  154. def test_cannot_remove_non_existing_member():
  155. p1 = Point(x=1)
  156. with pytest.raises(AttributeError):
  157. p1.remove('y')
  158. def test_evolver_without_evolution_returns_original_instance():
  159. p1 = Point(x=1)
  160. e = p1.evolver()
  161. assert e.persistent() is p1
  162. def test_evolver_with_evolution_to_same_element_returns_original_instance():
  163. p1 = Point(x=1)
  164. e = p1.evolver()
  165. e.set('x', p1.x)
  166. assert e.persistent() is p1
  167. def test_evolver_supports_chained_set_and_remove():
  168. p1 = Point(x=1, y=2)
  169. assert p1.evolver().set('x', 3).remove('y').persistent() == Point(x=3)
  170. def test_evolver_supports_dot_notation_for_setting_and_getting_elements():
  171. e = Point(x=1, y=2).evolver()
  172. e.x = 3
  173. assert e.x == 3
  174. assert e.persistent() == Point(x=3, y=2)
  175. class Numbers(CheckedPVector):
  176. __type__ = int
  177. class LinkedList(PClass):
  178. value = field(type='__tests__.class_test.Numbers')
  179. next = field(type=optional('__tests__.class_test.LinkedList'))
  180. def test_string_as_type_specifier():
  181. l = LinkedList(value=[1, 2], next=LinkedList(value=[3, 4], next=None))
  182. assert isinstance(l.value, Numbers)
  183. assert list(l.value) == [1, 2]
  184. assert l.next.next is None
  185. def test_multiple_invariants_on_field():
  186. # If the invariant returns a list of tests the results of running those tests will be
  187. # a tuple containing result data of all failing tests.
  188. class MultiInvariantField(PClass):
  189. one = field(type=int, invariant=lambda x: ((False, 'one_one'),
  190. (False, 'one_two'),
  191. (True, 'one_three')))
  192. two = field(invariant=lambda x: (False, 'two_one'))
  193. try:
  194. MultiInvariantField(one=1, two=2)
  195. assert False
  196. except InvariantException as e:
  197. assert set(e.invariant_errors) == set([('one_one', 'one_two'), 'two_one'])
  198. def test_multiple_global_invariants():
  199. class MultiInvariantGlobal(PClass):
  200. __invariant__ = lambda self: ((False, 'x'), (False, 'y'))
  201. one = field()
  202. try:
  203. MultiInvariantGlobal(one=1)
  204. assert False
  205. except InvariantException as e:
  206. assert e.invariant_errors == (('x', 'y'),)
  207. def test_inherited_global_invariants():
  208. class Distant(object):
  209. def __invariant__(self):
  210. return [(self.distant, "distant")]
  211. class Nearby(Distant):
  212. def __invariant__(self):
  213. return [(self.nearby, "nearby")]
  214. class MultipleInvariantGlobal(Nearby, PClass):
  215. distant = field()
  216. nearby = field()
  217. try:
  218. MultipleInvariantGlobal(distant=False, nearby=False)
  219. assert False
  220. except InvariantException as e:
  221. assert e.invariant_errors == (("nearby",), ("distant",),)
  222. def test_diamond_inherited_global_invariants():
  223. counter = []
  224. class Base(object):
  225. def __invariant__(self):
  226. counter.append(None)
  227. return [(False, "base")]
  228. class Left(Base):
  229. pass
  230. class Right(Base):
  231. pass
  232. class SingleInvariantGlobal(Left, Right, PClass):
  233. pass
  234. try:
  235. SingleInvariantGlobal()
  236. assert False
  237. except InvariantException as e:
  238. assert e.invariant_errors == (("base",),)
  239. assert counter == [None]
  240. def test_supports_weakref():
  241. import weakref
  242. weakref.ref(Point(x=1, y=2))
  243. def test_supports_weakref_with_multi_level_inheritance():
  244. import weakref
  245. class PPoint(Point):
  246. a = field()
  247. weakref.ref(PPoint(x=1, y=2))
  248. def test_supports_lazy_initial_value_for_field():
  249. class MyClass(PClass):
  250. a = field(int, initial=lambda: 2)
  251. assert MyClass() == MyClass(a=2)
  252. def test_type_checks_lazy_initial_value_for_field():
  253. class MyClass(PClass):
  254. a = field(int, initial=lambda: "a")
  255. with pytest.raises(TypeError):
  256. MyClass()
  257. def test_invariant_checks_lazy_initial_value_for_field():
  258. class MyClass(PClass):
  259. a = field(int, invariant=lambda x: (x < 5, "Too large"), initial=lambda: 10)
  260. with pytest.raises(InvariantException):
  261. MyClass()
  262. def test_invariant_checks_static_initial_value():
  263. class MyClass(PClass):
  264. a = field(int, invariant=lambda x: (x < 5, "Too large"), initial=10)
  265. with pytest.raises(InvariantException):
  266. MyClass()
  267. def test_lazy_invariant_message():
  268. class MyClass(PClass):
  269. a = field(int, invariant=lambda x: (x < 5, lambda: "{x} is too large".format(x=x)))
  270. try:
  271. MyClass(a=5)
  272. assert False
  273. except InvariantException as e:
  274. assert '5 is too large' in e.invariant_errors
  275. def test_enum_key_type():
  276. import enum
  277. class Foo(enum.Enum):
  278. Bar = 1
  279. Baz = 2
  280. # This currently fails because the enum is iterable
  281. class MyClass1(PClass):
  282. f = pmap_field(key_type=Foo, value_type=int)
  283. MyClass1()
  284. # This is OK since it's wrapped in a tuple
  285. class MyClass2(PClass):
  286. f = pmap_field(key_type=(Foo,), value_type=int)
  287. MyClass2()
  288. def test_pickle_with_one_way_factory():
  289. thing = UniqueThing(id='25544626-86da-4bce-b6b6-9186c0804d64')
  290. assert pickle.loads(pickle.dumps(thing)) == thing
  291. def test_evolver_with_one_way_factory():
  292. thing = UniqueThing(id='cc65249a-56fe-4995-8719-ea02e124b234')
  293. ev = thing.evolver()
  294. ev.x = 5 # necessary to prevent persistent() returning the original
  295. assert ev.persistent() == UniqueThing(id=str(thing.id), x=5)
  296. def test_set_doesnt_trigger_other_factories():
  297. thing = UniqueThing(id='b413b280-de76-4e28-a8e3-5470ca83ea2c')
  298. thing.set(x=5)
  299. def test_set_does_trigger_factories():
  300. class SquaredPoint(PClass):
  301. x = field(factory=lambda x: x ** 2)
  302. y = field()
  303. sp = SquaredPoint(x=3, y=10)
  304. assert (sp.x, sp.y) == (9, 10)
  305. sp2 = sp.set(x=4)
  306. assert (sp2.x, sp2.y) == (16, 10)
  307. def test_value_can_be_overridden_in_subclass_new():
  308. class X(PClass):
  309. y = pvector_field(int)
  310. def __new__(cls, **kwargs):
  311. items = kwargs.get('y', None)
  312. if items is None:
  313. kwargs['y'] = ()
  314. return super(X, cls).__new__(cls, **kwargs)
  315. a = X(y=[])
  316. b = a.set(y=None)
  317. assert a == b