test_traitlets.py 78 KB


  1. """Tests for traitlets.traitlets."""
  2. # Copyright (c) IPython Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. #
  5. # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
  6. # also under the terms of the Modified BSD License.
  7. from __future__ import annotations
  8. import pickle
  9. import re
  10. import typing as t
  11. from unittest import TestCase
  12. import pytest
  13. from traitlets import (
  14. All,
  15. Any,
  16. BaseDescriptor,
  17. Bool,
  18. Bytes,
  19. Callable,
  20. CBytes,
  21. CFloat,
  22. CInt,
  23. CLong,
  24. Complex,
  25. CRegExp,
  26. CUnicode,
  27. Dict,
  28. DottedObjectName,
  29. Enum,
  30. Float,
  31. ForwardDeclaredInstance,
  32. ForwardDeclaredType,
  33. HasDescriptors,
  34. HasTraits,
  35. Instance,
  36. Int,
  37. Integer,
  38. List,
  39. Long,
  40. MetaHasTraits,
  41. ObjectName,
  42. Set,
  43. TCPAddress,
  44. This,
  45. TraitError,
  46. TraitType,
  47. Tuple,
  48. Type,
  49. Undefined,
  50. Unicode,
  51. Union,
  52. default,
  53. directional_link,
  54. link,
  55. observe,
  56. observe_compat,
  57. traitlets,
  58. validate,
  59. )
  60. from traitlets.utils import cast_unicode
  61. from ._warnings import expected_warnings
  62. def change_dict(*ordered_values):
  63. change_names = ("name", "old", "new", "owner", "type")
  64. return dict(zip(change_names, ordered_values))
  65. # -----------------------------------------------------------------------------
  66. # Helper classes for testing
  67. # -----------------------------------------------------------------------------
  68. class HasTraitsStub(HasTraits):
  69. def notify_change(self, change):
  70. self._notify_name = change["name"]
  71. self._notify_old = change["old"]
  72. self._notify_new = change["new"]
  73. self._notify_type = change["type"]
  74. class CrossValidationStub(HasTraits):
  75. _cross_validation_lock = False
  76. # -----------------------------------------------------------------------------
  77. # Test classes
  78. # -----------------------------------------------------------------------------
  79. class TestTraitType(TestCase):
  80. def test_get_undefined(self):
  81. class A(HasTraits):
  82. a = TraitType
  83. a = A()
  84. assert a.a is Undefined # type:ignore
  85. def test_set(self):
  86. class A(HasTraitsStub):
  87. a = TraitType
  88. a = A()
  89. a.a = 10 # type:ignore
  90. self.assertEqual(a.a, 10)
  91. self.assertEqual(a._notify_name, "a")
  92. self.assertEqual(a._notify_old, Undefined)
  93. self.assertEqual(a._notify_new, 10)
  94. def test_validate(self):
  95. class MyTT(TraitType[int, int]):
  96. def validate(self, inst, value):
  97. return -1
  98. class A(HasTraitsStub):
  99. tt = MyTT
  100. a = A()
  101. a.tt = 10 # type:ignore
  102. self.assertEqual(a.tt, -1)
  103. a = A(tt=11)
  104. self.assertEqual(a.tt, -1)
  105. def test_default_validate(self):
  106. class MyIntTT(TraitType[int, int]):
  107. def validate(self, obj, value):
  108. if isinstance(value, int):
  109. return value
  110. self.error(obj, value)
  111. class A(HasTraits):
  112. tt = MyIntTT(10)
  113. a = A()
  114. self.assertEqual(a.tt, 10)
  115. # Defaults are validated when the HasTraits is instantiated
  116. class B(HasTraits):
  117. tt = MyIntTT("bad default")
  118. self.assertRaises(TraitError, getattr, B(), "tt")
  119. def test_info(self):
  120. class A(HasTraits):
  121. tt = TraitType
  122. a = A()
  123. self.assertEqual(A.tt.info(), "any value") # type:ignore
  124. def test_error(self):
  125. class A(HasTraits):
  126. tt = TraitType[int, int]()
  127. a = A()
  128. self.assertRaises(TraitError, A.tt.error, a, 10)
  129. def test_deprecated_dynamic_initializer(self):
  130. class A(HasTraits):
  131. x = Int(10)
  132. def _x_default(self):
  133. return 11
  134. class B(A):
  135. x = Int(20)
  136. class C(A):
  137. def _x_default(self):
  138. return 21
  139. a = A()
  140. self.assertEqual(a._trait_values, {})
  141. self.assertEqual(a.x, 11)
  142. self.assertEqual(a._trait_values, {"x": 11})
  143. b = B()
  144. self.assertEqual(b.x, 20)
  145. self.assertEqual(b._trait_values, {"x": 20})
  146. c = C()
  147. self.assertEqual(c._trait_values, {})
  148. self.assertEqual(c.x, 21)
  149. self.assertEqual(c._trait_values, {"x": 21})
  150. # Ensure that the base class remains unmolested when the _default
  151. # initializer gets overridden in a subclass.
  152. a = A()
  153. c = C()
  154. self.assertEqual(a._trait_values, {})
  155. self.assertEqual(a.x, 11)
  156. self.assertEqual(a._trait_values, {"x": 11})
  157. def test_deprecated_method_warnings(self):
  158. with expected_warnings([]):
  159. class ShouldntWarn(HasTraits):
  160. x = Integer()
  161. @default("x")
  162. def _x_default(self):
  163. return 10
  164. @validate("x")
  165. def _x_validate(self, proposal):
  166. return proposal.value
  167. @observe("x")
  168. def _x_changed(self, change):
  169. pass
  170. obj = ShouldntWarn()
  171. obj.x = 5
  172. assert obj.x == 5
  173. with expected_warnings(["@validate", "@observe"]) as w:
  174. class ShouldWarn(HasTraits):
  175. x = Integer()
  176. def _x_default(self):
  177. return 10
  178. def _x_validate(self, value, _):
  179. return value
  180. def _x_changed(self):
  181. pass
  182. obj = ShouldWarn() # type:ignore
  183. obj.x = 5
  184. assert obj.x == 5
  185. def test_dynamic_initializer(self):
  186. class A(HasTraits):
  187. x = Int(10)
  188. @default("x")
  189. def _default_x(self):
  190. return 11
  191. class B(A):
  192. x = Int(20)
  193. class C(A):
  194. @default("x")
  195. def _default_x(self):
  196. return 21
  197. a = A()
  198. self.assertEqual(a._trait_values, {})
  199. self.assertEqual(a.x, 11)
  200. self.assertEqual(a._trait_values, {"x": 11})
  201. b = B()
  202. self.assertEqual(b.x, 20)
  203. self.assertEqual(b._trait_values, {"x": 20})
  204. c = C()
  205. self.assertEqual(c._trait_values, {})
  206. self.assertEqual(c.x, 21)
  207. self.assertEqual(c._trait_values, {"x": 21})
  208. # Ensure that the base class remains unmolested when the _default
  209. # initializer gets overridden in a subclass.
  210. a = A()
  211. c = C()
  212. self.assertEqual(a._trait_values, {})
  213. self.assertEqual(a.x, 11)
  214. self.assertEqual(a._trait_values, {"x": 11})
  215. def test_tag_metadata(self):
  216. class MyIntTT(TraitType[int, int]):
  217. metadata = {"a": 1, "b": 2}
  218. a = MyIntTT(10).tag(b=3, c=4)
  219. self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4})
  220. def test_metadata_localized_instance(self):
  221. class MyIntTT(TraitType[int, int]):
  222. metadata = {"a": 1, "b": 2}
  223. a = MyIntTT(10)
  224. b = MyIntTT(10)
  225. a.metadata["c"] = 3
  226. # make sure that changing a's metadata didn't change b's metadata
  227. self.assertNotIn("c", b.metadata)
  228. def test_union_metadata(self):
  229. class Foo(HasTraits):
  230. bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a")
  231. foo = Foo()
  232. # At this point, no value has been set for bar, so value-specific
  233. # is not set.
  234. self.assertEqual(foo.trait_metadata("bar", "ta"), None)
  235. self.assertEqual(foo.trait_metadata("bar", "ti"), "a")
  236. foo.bar = {}
  237. self.assertEqual(foo.trait_metadata("bar", "ta"), 2)
  238. self.assertEqual(foo.trait_metadata("bar", "ti"), "b")
  239. foo.bar = 1
  240. self.assertEqual(foo.trait_metadata("bar", "ta"), 1)
  241. self.assertEqual(foo.trait_metadata("bar", "ti"), "a")
  242. def test_union_default_value(self):
  243. class Foo(HasTraits):
  244. bar = Union([Dict(), Int()], default_value=1)
  245. foo = Foo()
  246. self.assertEqual(foo.bar, 1)
  247. def test_union_validation_priority(self):
  248. class Foo(HasTraits):
  249. bar = Union([CInt(), Unicode()])
  250. foo = Foo()
  251. foo.bar = "1"
  252. # validation in order of the TraitTypes given
  253. self.assertEqual(foo.bar, 1)
  254. def test_union_trait_default_value(self):
  255. class Foo(HasTraits):
  256. bar = Union([Dict(), Int()])
  257. self.assertEqual(Foo().bar, {})
  258. def test_deprecated_metadata_access(self):
  259. class MyIntTT(TraitType[int, int]):
  260. metadata = {"a": 1, "b": 2}
  261. a = MyIntTT(10)
  262. with expected_warnings(["use the instance .metadata dictionary directly"] * 2):
  263. a.set_metadata("key", "value")
  264. v = a.get_metadata("key")
  265. self.assertEqual(v, "value")
  266. with expected_warnings(["use the instance .help string directly"] * 2):
  267. a.set_metadata("help", "some help")
  268. v = a.get_metadata("help")
  269. self.assertEqual(v, "some help")
  270. def test_trait_types_deprecated(self):
  271. with expected_warnings(["Traits should be given as instances"]):
  272. class C(HasTraits):
  273. t = Int
  274. def test_trait_types_list_deprecated(self):
  275. with expected_warnings(["Traits should be given as instances"]):
  276. class C(HasTraits):
  277. t = List(Int)
  278. def test_trait_types_tuple_deprecated(self):
  279. with expected_warnings(["Traits should be given as instances"]):
  280. class C(HasTraits):
  281. t = Tuple(Int)
  282. def test_trait_types_dict_deprecated(self):
  283. with expected_warnings(["Traits should be given as instances"]):
  284. class C(HasTraits):
  285. t = Dict(Int)
  286. class TestHasDescriptorsMeta(TestCase):
  287. def test_metaclass(self):
  288. self.assertEqual(type(HasTraits), MetaHasTraits)
  289. class A(HasTraits):
  290. a = Int()
  291. a = A()
  292. self.assertEqual(type(a.__class__), MetaHasTraits)
  293. self.assertEqual(a.a, 0)
  294. a.a = 10
  295. self.assertEqual(a.a, 10)
  296. class B(HasTraits):
  297. b = Int()
  298. b = B()
  299. self.assertEqual(b.b, 0)
  300. b.b = 10
  301. self.assertEqual(b.b, 10)
  302. class C(HasTraits):
  303. c = Int(30)
  304. c = C()
  305. self.assertEqual(c.c, 30)
  306. c.c = 10
  307. self.assertEqual(c.c, 10)
  308. def test_this_class(self):
  309. class A(HasTraits):
  310. t = This["A"]()
  311. tt = This["A"]()
  312. class B(A):
  313. tt = This["A"]()
  314. ttt = This["A"]()
  315. self.assertEqual(A.t.this_class, A)
  316. self.assertEqual(B.t.this_class, A)
  317. self.assertEqual(B.tt.this_class, B)
  318. self.assertEqual(B.ttt.this_class, B)
  319. class TestHasDescriptors(TestCase):
  320. def test_setup_instance(self):
  321. class FooDescriptor(BaseDescriptor):
  322. def instance_init(self, inst):
  323. foo = inst.foo # instance should have the attr
  324. class HasFooDescriptors(HasDescriptors):
  325. fd = FooDescriptor()
  326. def setup_instance(self, *args, **kwargs):
  327. self.foo = kwargs.get("foo", None)
  328. super().setup_instance(*args, **kwargs)
  329. hfd = HasFooDescriptors(foo="bar")
  330. class TestHasTraitsNotify(TestCase):
  331. def setUp(self):
  332. self._notify1 = []
  333. self._notify2 = []
  334. def notify1(self, name, old, new):
  335. self._notify1.append((name, old, new))
  336. def notify2(self, name, old, new):
  337. self._notify2.append((name, old, new))
  338. def test_notify_all(self):
  339. class A(HasTraits):
  340. a = Int()
  341. b = Float()
  342. a = A()
  343. a.on_trait_change(self.notify1)
  344. a.a = 0
  345. self.assertEqual(len(self._notify1), 0)
  346. a.b = 0.0
  347. self.assertEqual(len(self._notify1), 0)
  348. a.a = 10
  349. self.assertTrue(("a", 0, 10) in self._notify1)
  350. a.b = 10.0
  351. self.assertTrue(("b", 0.0, 10.0) in self._notify1)
  352. self.assertRaises(TraitError, setattr, a, "a", "bad string")
  353. self.assertRaises(TraitError, setattr, a, "b", "bad string")
  354. self._notify1 = []
  355. a.on_trait_change(self.notify1, remove=True)
  356. a.a = 20
  357. a.b = 20.0
  358. self.assertEqual(len(self._notify1), 0)
  359. def test_notify_one(self):
  360. class A(HasTraits):
  361. a = Int()
  362. b = Float()
  363. a = A()
  364. a.on_trait_change(self.notify1, "a")
  365. a.a = 0
  366. self.assertEqual(len(self._notify1), 0)
  367. a.a = 10
  368. self.assertTrue(("a", 0, 10) in self._notify1)
  369. self.assertRaises(TraitError, setattr, a, "a", "bad string")
  370. def test_subclass(self):
  371. class A(HasTraits):
  372. a = Int()
  373. class B(A):
  374. b = Float()
  375. b = B()
  376. self.assertEqual(b.a, 0)
  377. self.assertEqual(b.b, 0.0)
  378. b.a = 100
  379. b.b = 100.0
  380. self.assertEqual(b.a, 100)
  381. self.assertEqual(b.b, 100.0)
  382. def test_notify_subclass(self):
  383. class A(HasTraits):
  384. a = Int()
  385. class B(A):
  386. b = Float()
  387. b = B()
  388. b.on_trait_change(self.notify1, "a")
  389. b.on_trait_change(self.notify2, "b")
  390. b.a = 0
  391. b.b = 0.0
  392. self.assertEqual(len(self._notify1), 0)
  393. self.assertEqual(len(self._notify2), 0)
  394. b.a = 10
  395. b.b = 10.0
  396. self.assertTrue(("a", 0, 10) in self._notify1)
  397. self.assertTrue(("b", 0.0, 10.0) in self._notify2)
  398. def test_static_notify(self):
  399. class A(HasTraits):
  400. a = Int()
  401. _notify1 = []
  402. def _a_changed(self, name, old, new):
  403. self._notify1.append((name, old, new))
  404. a = A()
  405. a.a = 0
  406. # This is broken!!!
  407. self.assertEqual(len(a._notify1), 0)
  408. a.a = 10
  409. self.assertTrue(("a", 0, 10) in a._notify1)
  410. class B(A):
  411. b = Float()
  412. _notify2 = []
  413. def _b_changed(self, name, old, new):
  414. self._notify2.append((name, old, new))
  415. b = B()
  416. b.a = 10
  417. b.b = 10.0
  418. self.assertTrue(("a", 0, 10) in b._notify1)
  419. self.assertTrue(("b", 0.0, 10.0) in b._notify2)
  420. def test_notify_args(self):
  421. def callback0():
  422. self.cb = ()
  423. def callback1(name):
  424. self.cb = (name,) # type:ignore
  425. def callback2(name, new):
  426. self.cb = (name, new) # type:ignore
  427. def callback3(name, old, new):
  428. self.cb = (name, old, new) # type:ignore
  429. def callback4(name, old, new, obj):
  430. self.cb = (name, old, new, obj) # type:ignore
  431. class A(HasTraits):
  432. a = Int()
  433. a = A()
  434. a.on_trait_change(callback0, "a")
  435. a.a = 10
  436. self.assertEqual(self.cb, ())
  437. a.on_trait_change(callback0, "a", remove=True)
  438. a.on_trait_change(callback1, "a")
  439. a.a = 100
  440. self.assertEqual(self.cb, ("a",))
  441. a.on_trait_change(callback1, "a", remove=True)
  442. a.on_trait_change(callback2, "a")
  443. a.a = 1000
  444. self.assertEqual(self.cb, ("a", 1000))
  445. a.on_trait_change(callback2, "a", remove=True)
  446. a.on_trait_change(callback3, "a")
  447. a.a = 10000
  448. self.assertEqual(self.cb, ("a", 1000, 10000))
  449. a.on_trait_change(callback3, "a", remove=True)
  450. a.on_trait_change(callback4, "a")
  451. a.a = 100000
  452. self.assertEqual(self.cb, ("a", 10000, 100000, a))
  453. self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1)
  454. a.on_trait_change(callback4, "a", remove=True)
  455. self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0)
  456. def test_notify_only_once(self):
  457. class A(HasTraits):
  458. listen_to = ["a"]
  459. a = Int(0)
  460. b = 0
  461. def __init__(self, **kwargs):
  462. super().__init__(**kwargs)
  463. self.on_trait_change(self.listener1, ["a"])
  464. def listener1(self, name, old, new):
  465. self.b += 1
  466. class B(A):
  467. c = 0
  468. d = 0
  469. def __init__(self, **kwargs):
  470. super().__init__(**kwargs)
  471. self.on_trait_change(self.listener2)
  472. def listener2(self, name, old, new):
  473. self.c += 1
  474. def _a_changed(self, name, old, new):
  475. self.d += 1
  476. b = B()
  477. b.a += 1
  478. self.assertEqual(b.b, b.c)
  479. self.assertEqual(b.b, b.d)
  480. b.a += 1
  481. self.assertEqual(b.b, b.c)
  482. self.assertEqual(b.b, b.d)
  483. class TestObserveDecorator(TestCase):
  484. def setUp(self):
  485. self._notify1 = []
  486. self._notify2 = []
  487. def notify1(self, change):
  488. self._notify1.append(change)
  489. def notify2(self, change):
  490. self._notify2.append(change)
  491. def test_notify_all(self):
  492. class A(HasTraits):
  493. a = Int()
  494. b = Float()
  495. a = A()
  496. a.observe(self.notify1)
  497. a.a = 0
  498. self.assertEqual(len(self._notify1), 0)
  499. a.b = 0.0
  500. self.assertEqual(len(self._notify1), 0)
  501. a.a = 10
  502. change = change_dict("a", 0, 10, a, "change")
  503. self.assertTrue(change in self._notify1)
  504. a.b = 10.0
  505. change = change_dict("b", 0.0, 10.0, a, "change")
  506. self.assertTrue(change in self._notify1)
  507. self.assertRaises(TraitError, setattr, a, "a", "bad string")
  508. self.assertRaises(TraitError, setattr, a, "b", "bad string")
  509. self._notify1 = []
  510. a.unobserve(self.notify1)
  511. a.a = 20
  512. a.b = 20.0
  513. self.assertEqual(len(self._notify1), 0)
  514. def test_notify_one(self):
  515. class A(HasTraits):
  516. a = Int()
  517. b = Float()
  518. a = A()
  519. a.observe(self.notify1, "a")
  520. a.a = 0
  521. self.assertEqual(len(self._notify1), 0)
  522. a.a = 10
  523. change = change_dict("a", 0, 10, a, "change")
  524. self.assertTrue(change in self._notify1)
  525. self.assertRaises(TraitError, setattr, a, "a", "bad string")
  526. def test_subclass(self):
  527. class A(HasTraits):
  528. a = Int()
  529. class B(A):
  530. b = Float()
  531. b = B()
  532. self.assertEqual(b.a, 0)
  533. self.assertEqual(b.b, 0.0)
  534. b.a = 100
  535. b.b = 100.0
  536. self.assertEqual(b.a, 100)
  537. self.assertEqual(b.b, 100.0)
  538. def test_notify_subclass(self):
  539. class A(HasTraits):
  540. a = Int()
  541. class B(A):
  542. b = Float()
  543. b = B()
  544. b.observe(self.notify1, "a")
  545. b.observe(self.notify2, "b")
  546. b.a = 0
  547. b.b = 0.0
  548. self.assertEqual(len(self._notify1), 0)
  549. self.assertEqual(len(self._notify2), 0)
  550. b.a = 10
  551. b.b = 10.0
  552. change = change_dict("a", 0, 10, b, "change")
  553. self.assertTrue(change in self._notify1)
  554. change = change_dict("b", 0.0, 10.0, b, "change")
  555. self.assertTrue(change in self._notify2)
  556. def test_static_notify(self):
  557. class A(HasTraits):
  558. a = Int()
  559. b = Int()
  560. _notify1 = []
  561. _notify_any = []
  562. @observe("a")
  563. def _a_changed(self, change):
  564. self._notify1.append(change)
  565. @observe(All)
  566. def _any_changed(self, change):
  567. self._notify_any.append(change)
  568. a = A()
  569. a.a = 0
  570. self.assertEqual(len(a._notify1), 0)
  571. a.a = 10
  572. change = change_dict("a", 0, 10, a, "change")
  573. self.assertTrue(change in a._notify1)
  574. a.b = 1
  575. self.assertEqual(len(a._notify_any), 2)
  576. change = change_dict("b", 0, 1, a, "change")
  577. self.assertTrue(change in a._notify_any)
  578. class B(A):
  579. b = Float() # type:ignore
  580. _notify2 = []
  581. @observe("b")
  582. def _b_changed(self, change):
  583. self._notify2.append(change)
  584. b = B()
  585. b.a = 10
  586. b.b = 10.0 # type:ignore
  587. change = change_dict("a", 0, 10, b, "change")
  588. self.assertTrue(change in b._notify1)
  589. change = change_dict("b", 0.0, 10.0, b, "change")
  590. self.assertTrue(change in b._notify2)
  591. def test_notify_args(self):
  592. def callback0():
  593. self.cb = ()
  594. def callback1(change):
  595. self.cb = change
  596. class A(HasTraits):
  597. a = Int()
  598. a = A()
  599. a.on_trait_change(callback0, "a")
  600. a.a = 10
  601. self.assertEqual(self.cb, ())
  602. a.unobserve(callback0, "a")
  603. a.observe(callback1, "a")
  604. a.a = 100
  605. change = change_dict("a", 10, 100, a, "change")
  606. self.assertEqual(self.cb, change)
  607. self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1)
  608. a.unobserve(callback1, "a")
  609. self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0)
  610. def test_notify_only_once(self):
  611. class A(HasTraits):
  612. listen_to = ["a"]
  613. a = Int(0)
  614. b = 0
  615. def __init__(self, **kwargs):
  616. super().__init__(**kwargs)
  617. self.observe(self.listener1, ["a"])
  618. def listener1(self, change):
  619. self.b += 1
  620. class B(A):
  621. c = 0
  622. d = 0
  623. def __init__(self, **kwargs):
  624. super().__init__(**kwargs)
  625. self.observe(self.listener2)
  626. def listener2(self, change):
  627. self.c += 1
  628. @observe("a")
  629. def _a_changed(self, change):
  630. self.d += 1
  631. b = B()
  632. b.a += 1
  633. self.assertEqual(b.b, b.c)
  634. self.assertEqual(b.b, b.d)
  635. b.a += 1
  636. self.assertEqual(b.b, b.c)
  637. self.assertEqual(b.b, b.d)
  638. class TestHasTraits(TestCase):
  639. def test_trait_names(self):
  640. class A(HasTraits):
  641. i = Int()
  642. f = Float()
  643. a = A()
  644. self.assertEqual(sorted(a.trait_names()), ["f", "i"])
  645. self.assertEqual(sorted(A.class_trait_names()), ["f", "i"])
  646. self.assertTrue(a.has_trait("f"))
  647. self.assertFalse(a.has_trait("g"))
  648. def test_trait_has_value(self):
  649. class A(HasTraits):
  650. i = Int()
  651. f = Float()
  652. a = A()
  653. self.assertFalse(a.trait_has_value("f"))
  654. self.assertFalse(a.trait_has_value("g"))
  655. a.i = 1
  656. a.f
  657. self.assertTrue(a.trait_has_value("i"))
  658. self.assertTrue(a.trait_has_value("f"))
  659. def test_trait_metadata_deprecated(self):
  660. with expected_warnings([r"metadata should be set using the \.tag\(\) method"]):
  661. class A(HasTraits):
  662. i = Int(config_key="MY_VALUE")
  663. a = A()
  664. self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE")
  665. def test_trait_metadata(self):
  666. class A(HasTraits):
  667. i = Int().tag(config_key="MY_VALUE")
  668. a = A()
  669. self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE")
  670. def test_trait_metadata_default(self):
  671. class A(HasTraits):
  672. i = Int()
  673. a = A()
  674. self.assertEqual(a.trait_metadata("i", "config_key"), None)
  675. self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default")
  676. def test_traits(self):
  677. class A(HasTraits):
  678. i = Int()
  679. f = Float()
  680. a = A()
  681. self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
  682. self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
  683. def test_traits_metadata(self):
  684. class A(HasTraits):
  685. i = Int().tag(config_key="VALUE1", other_thing="VALUE2")
  686. f = Float().tag(config_key="VALUE3", other_thing="VALUE2")
  687. j = Int(0)
  688. a = A()
  689. self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
  690. traits = a.traits(config_key="VALUE1", other_thing="VALUE2")
  691. self.assertEqual(traits, dict(i=A.i))
  692. # This passes, but it shouldn't because I am replicating a bug in
  693. # traits.
  694. traits = a.traits(config_key=lambda v: True)
  695. self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
  696. def test_traits_metadata_deprecated(self):
  697. with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2):
  698. class A(HasTraits):
  699. i = Int(config_key="VALUE1", other_thing="VALUE2")
  700. f = Float(config_key="VALUE3", other_thing="VALUE2")
  701. j = Int(0)
  702. a = A()
  703. self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
  704. traits = a.traits(config_key="VALUE1", other_thing="VALUE2")
  705. self.assertEqual(traits, dict(i=A.i))
  706. # This passes, but it shouldn't because I am replicating a bug in
  707. # traits.
  708. traits = a.traits(config_key=lambda v: True)
  709. self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
  710. def test_init(self):
  711. class A(HasTraits):
  712. i = Int()
  713. x = Float()
  714. a = A(i=1, x=10.0)
  715. self.assertEqual(a.i, 1)
  716. self.assertEqual(a.x, 10.0)
  717. def test_positional_args(self):
  718. class A(HasTraits):
  719. i = Int(0)
  720. def __init__(self, i):
  721. super().__init__()
  722. self.i = i
  723. a = A(5)
  724. self.assertEqual(a.i, 5)
  725. # should raise TypeError if no positional arg given
  726. self.assertRaises(TypeError, A)
  727. # -----------------------------------------------------------------------------
  728. # Tests for specific trait types
  729. # -----------------------------------------------------------------------------
  730. class TestType(TestCase):
  731. def test_default(self):
  732. class B:
  733. pass
  734. class A(HasTraits):
  735. klass = Type(allow_none=True)
  736. a = A()
  737. self.assertEqual(a.klass, object)
  738. a.klass = B
  739. self.assertEqual(a.klass, B)
  740. self.assertRaises(TraitError, setattr, a, "klass", 10)
  741. def test_default_options(self):
  742. class B:
  743. pass
  744. class C(B):
  745. pass
  746. class A(HasTraits):
  747. # Different possible combinations of options for default_value
  748. # and klass. default_value=None is only valid with allow_none=True.
  749. k1 = Type()
  750. k2 = Type(None, allow_none=True)
  751. k3 = Type(B)
  752. k4 = Type(klass=B)
  753. k5 = Type(default_value=None, klass=B, allow_none=True)
  754. k6 = Type(default_value=C, klass=B)
  755. self.assertIs(A.k1.default_value, object)
  756. self.assertIs(A.k1.klass, object)
  757. self.assertIs(A.k2.default_value, None)
  758. self.assertIs(A.k2.klass, object)
  759. self.assertIs(A.k3.default_value, B)
  760. self.assertIs(A.k3.klass, B)
  761. self.assertIs(A.k4.default_value, B)
  762. self.assertIs(A.k4.klass, B)
  763. self.assertIs(A.k5.default_value, None)
  764. self.assertIs(A.k5.klass, B)
  765. self.assertIs(A.k6.default_value, C)
  766. self.assertIs(A.k6.klass, B)
  767. a = A()
  768. self.assertIs(a.k1, object)
  769. self.assertIs(a.k2, None)
  770. self.assertIs(a.k3, B)
  771. self.assertIs(a.k4, B)
  772. self.assertIs(a.k5, None)
  773. self.assertIs(a.k6, C)
  774. def test_value(self):
  775. class B:
  776. pass
  777. class C:
  778. pass
  779. class A(HasTraits):
  780. klass = Type(B)
  781. a = A()
  782. self.assertEqual(a.klass, B)
  783. self.assertRaises(TraitError, setattr, a, "klass", C)
  784. self.assertRaises(TraitError, setattr, a, "klass", object)
  785. a.klass = B
  786. def test_allow_none(self):
  787. class B:
  788. pass
  789. class C(B):
  790. pass
  791. class A(HasTraits):
  792. klass = Type(B)
  793. a = A()
  794. self.assertEqual(a.klass, B)
  795. self.assertRaises(TraitError, setattr, a, "klass", None)
  796. a.klass = C
  797. self.assertEqual(a.klass, C)
  798. def test_validate_klass(self):
  799. class A(HasTraits):
  800. klass = Type("no strings allowed")
  801. self.assertRaises(ImportError, A)
  802. class A(HasTraits): # type:ignore
  803. klass = Type("rub.adub.Duck")
  804. self.assertRaises(ImportError, A)
  805. def test_validate_default(self):
  806. class B:
  807. pass
  808. class A(HasTraits):
  809. klass = Type("bad default", B)
  810. self.assertRaises(ImportError, A)
  811. class C(HasTraits):
  812. klass = Type(None, B)
  813. self.assertRaises(TraitError, getattr, C(), "klass")
  814. def test_str_klass(self):
  815. class A(HasTraits):
  816. klass = Type("traitlets.config.Config")
  817. from traitlets.config import Config
  818. a = A()
  819. a.klass = Config
  820. self.assertEqual(a.klass, Config)
  821. self.assertRaises(TraitError, setattr, a, "klass", 10)
  822. def test_set_str_klass(self):
  823. class A(HasTraits):
  824. klass = Type()
  825. a = A(klass="traitlets.config.Config")
  826. from traitlets.config import Config
  827. self.assertEqual(a.klass, Config)
  828. class TestInstance(TestCase):
  829. def test_basic(self):
  830. class Foo:
  831. pass
  832. class Bar(Foo):
  833. pass
  834. class Bah:
  835. pass
  836. class A(HasTraits):
  837. inst = Instance(Foo, allow_none=True)
  838. a = A()
  839. self.assertTrue(a.inst is None)
  840. a.inst = Foo()
  841. self.assertTrue(isinstance(a.inst, Foo))
  842. a.inst = Bar()
  843. self.assertTrue(isinstance(a.inst, Foo))
  844. self.assertRaises(TraitError, setattr, a, "inst", Foo)
  845. self.assertRaises(TraitError, setattr, a, "inst", Bar)
  846. self.assertRaises(TraitError, setattr, a, "inst", Bah())
  847. def test_default_klass(self):
  848. class Foo:
  849. pass
  850. class Bar(Foo):
  851. pass
  852. class Bah:
  853. pass
  854. class FooInstance(Instance[Foo]):
  855. klass = Foo
  856. class A(HasTraits):
  857. inst = FooInstance(allow_none=True)
  858. a = A()
  859. self.assertTrue(a.inst is None)
  860. a.inst = Foo()
  861. self.assertTrue(isinstance(a.inst, Foo))
  862. a.inst = Bar()
  863. self.assertTrue(isinstance(a.inst, Foo))
  864. self.assertRaises(TraitError, setattr, a, "inst", Foo)
  865. self.assertRaises(TraitError, setattr, a, "inst", Bar)
  866. self.assertRaises(TraitError, setattr, a, "inst", Bah())
  867. def test_unique_default_value(self):
  868. class Foo:
  869. pass
  870. class A(HasTraits):
  871. inst = Instance(Foo, (), {})
  872. a = A()
  873. b = A()
  874. self.assertTrue(a.inst is not b.inst)
  875. def test_args_kw(self):
  876. class Foo:
  877. def __init__(self, c):
  878. self.c = c
  879. class Bar:
  880. pass
  881. class Bah:
  882. def __init__(self, c, d):
  883. self.c = c
  884. self.d = d
  885. class A(HasTraits):
  886. inst = Instance(Foo, (10,))
  887. a = A()
  888. self.assertEqual(a.inst.c, 10)
  889. class B(HasTraits):
  890. inst = Instance(Bah, args=(10,), kw=dict(d=20))
  891. b = B()
  892. self.assertEqual(b.inst.c, 10)
  893. self.assertEqual(b.inst.d, 20)
  894. class C(HasTraits):
  895. inst = Instance(Foo, allow_none=True)
  896. c = C()
  897. self.assertTrue(c.inst is None)
  898. def test_bad_default(self):
  899. class Foo:
  900. pass
  901. class A(HasTraits):
  902. inst = Instance(Foo)
  903. a = A()
  904. with self.assertRaises(TraitError):
  905. a.inst
  906. def test_instance(self):
  907. class Foo:
  908. pass
  909. def inner():
  910. class A(HasTraits):
  911. inst = Instance(Foo()) # type:ignore
  912. self.assertRaises(TraitError, inner)
  913. class TestThis(TestCase):
  914. def test_this_class(self):
  915. class Foo(HasTraits):
  916. this = This["Foo"]()
  917. f = Foo()
  918. self.assertEqual(f.this, None)
  919. g = Foo()
  920. f.this = g
  921. self.assertEqual(f.this, g)
  922. self.assertRaises(TraitError, setattr, f, "this", 10)
  923. def test_this_inst(self):
  924. class Foo(HasTraits):
  925. this = This["Foo"]()
  926. f = Foo()
  927. f.this = Foo()
  928. self.assertTrue(isinstance(f.this, Foo))
  929. def test_subclass(self):
  930. class Foo(HasTraits):
  931. t = This["Foo"]()
  932. class Bar(Foo):
  933. pass
  934. f = Foo()
  935. b = Bar()
  936. f.t = b
  937. b.t = f
  938. self.assertEqual(f.t, b)
  939. self.assertEqual(b.t, f)
  940. def test_subclass_override(self):
  941. class Foo(HasTraits):
  942. t = This["Foo"]()
  943. class Bar(Foo):
  944. t = This()
  945. f = Foo()
  946. b = Bar()
  947. f.t = b
  948. self.assertEqual(f.t, b)
  949. self.assertRaises(TraitError, setattr, b, "t", f)
  950. def test_this_in_container(self):
  951. class Tree(HasTraits):
  952. value = Unicode()
  953. leaves = List(This())
  954. tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")])
  955. with self.assertRaises(TraitError):
  956. tree.leaves = [1, 2]
  957. class TraitTestBase(TestCase):
  958. """A best testing class for basic trait types."""
  959. def assign(self, value):
  960. self.obj.value = value # type:ignore
  961. def coerce(self, value):
  962. return value
  963. def test_good_values(self):
  964. if hasattr(self, "_good_values"):
  965. for value in self._good_values:
  966. self.assign(value)
  967. self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore
  968. def test_bad_values(self):
  969. if hasattr(self, "_bad_values"):
  970. for value in self._bad_values:
  971. try:
  972. self.assertRaises(TraitError, self.assign, value)
  973. except AssertionError:
  974. assert False, value # noqa: PT015
  975. def test_default_value(self):
  976. if hasattr(self, "_default_value"):
  977. self.assertEqual(self._default_value, self.obj.value) # type:ignore
  978. def test_allow_none(self):
  979. if (
  980. hasattr(self, "_bad_values")
  981. and hasattr(self, "_good_values")
  982. and None in self._bad_values
  983. ):
  984. trait = self.obj.traits()["value"] # type:ignore
  985. try:
  986. trait.allow_none = True
  987. self._bad_values.remove(None)
  988. # skip coerce. Allow None casts None to None.
  989. self.assign(None)
  990. self.assertEqual(self.obj.value, None) # type:ignore
  991. self.test_good_values()
  992. self.test_bad_values()
  993. finally:
  994. # tear down
  995. trait.allow_none = False
  996. self._bad_values.append(None)
  997. def tearDown(self):
  998. # restore default value after tests, if set
  999. if hasattr(self, "_default_value"):
  1000. self.obj.value = self._default_value # type:ignore
  1001. class AnyTrait(HasTraits):
  1002. value = Any()
  1003. class AnyTraitTest(TraitTestBase):
  1004. obj = AnyTrait()
  1005. _default_value = None
  1006. _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j]
  1007. _bad_values: t.Any = []
  1008. class UnionTrait(HasTraits):
  1009. value = Union([Type(), Bool()])
  1010. class UnionTraitTest(TraitTestBase):
  1011. obj = UnionTrait(value="traitlets.config.Config")
  1012. _good_values = [int, float, True]
  1013. _bad_values = [[], (0,), 1j]
  1014. class CallableTrait(HasTraits):
  1015. value = Callable()
  1016. class CallableTraitTest(TraitTestBase):
  1017. obj = CallableTrait(value=lambda x: type(x))
  1018. _good_values = [int, sorted, lambda x: print(x)]
  1019. _bad_values = [[], 1, ""]
  1020. class OrTrait(HasTraits):
  1021. value = Bool() | Unicode()
  1022. class OrTraitTest(TraitTestBase):
  1023. obj = OrTrait()
  1024. _good_values = [True, False, "ten"]
  1025. _bad_values = [[], (0,), 1j]
  1026. class IntTrait(HasTraits):
  1027. value = Int(99, min=-100)
  1028. class TestInt(TraitTestBase):
  1029. obj = IntTrait()
  1030. _default_value = 99
  1031. _good_values = [10, -10]
  1032. _bad_values = [
  1033. "ten",
  1034. [10],
  1035. {"ten": 10},
  1036. (10,),
  1037. None,
  1038. 1j,
  1039. 10.1,
  1040. -10.1,
  1041. "10L",
  1042. "-10L",
  1043. "10.1",
  1044. "-10.1",
  1045. "10",
  1046. "-10",
  1047. -200,
  1048. ]
  1049. class CIntTrait(HasTraits):
  1050. value = CInt("5")
  1051. class TestCInt(TraitTestBase):
  1052. obj = CIntTrait()
  1053. _default_value = 5
  1054. _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1]
  1055. _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"]
  1056. def coerce(self, n):
  1057. return int(n)
  1058. class MinBoundCIntTrait(HasTraits):
  1059. value = CInt("5", min=3)
  1060. class TestMinBoundCInt(TestCInt):
  1061. obj = MinBoundCIntTrait() # type:ignore
  1062. _default_value = 5
  1063. _good_values = [3, 3.0, "3"]
  1064. _bad_values = [2.6, 2, -3, -3.0]
  1065. class LongTrait(HasTraits):
  1066. value = Long(99)
  1067. class TestLong(TraitTestBase):
  1068. obj = LongTrait()
  1069. _default_value = 99
  1070. _good_values = [10, -10]
  1071. _bad_values = [
  1072. "ten",
  1073. [10],
  1074. {"ten": 10},
  1075. (10,),
  1076. None,
  1077. 1j,
  1078. 10.1,
  1079. -10.1,
  1080. "10",
  1081. "-10",
  1082. "10L",
  1083. "-10L",
  1084. "10.1",
  1085. "-10.1",
  1086. ]
  1087. class MinBoundLongTrait(HasTraits):
  1088. value = Long(99, min=5)
  1089. class TestMinBoundLong(TraitTestBase):
  1090. obj = MinBoundLongTrait()
  1091. _default_value = 99
  1092. _good_values = [5, 10]
  1093. _bad_values = [4, -10]
  1094. class MaxBoundLongTrait(HasTraits):
  1095. value = Long(5, max=10)
  1096. class TestMaxBoundLong(TraitTestBase):
  1097. obj = MaxBoundLongTrait()
  1098. _default_value = 5
  1099. _good_values = [10, -2]
  1100. _bad_values = [11, 20]
  1101. class CLongTrait(HasTraits):
  1102. value = CLong("5")
  1103. class TestCLong(TraitTestBase):
  1104. obj = CLongTrait()
  1105. _default_value = 5
  1106. _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1]
  1107. _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"]
  1108. def coerce(self, n):
  1109. return int(n)
  1110. class MaxBoundCLongTrait(HasTraits):
  1111. value = CLong("5", max=10)
  1112. class TestMaxBoundCLong(TestCLong):
  1113. obj = MaxBoundCLongTrait() # type:ignore
  1114. _default_value = 5
  1115. _good_values = [10, "10", 10.3]
  1116. _bad_values = [11.0, "11"]
  1117. class IntegerTrait(HasTraits):
  1118. value = Integer(1)
  1119. class TestInteger(TestLong):
  1120. obj = IntegerTrait() # type:ignore
  1121. _default_value = 1
  1122. def coerce(self, n):
  1123. return int(n)
  1124. class MinBoundIntegerTrait(HasTraits):
  1125. value = Integer(5, min=3)
  1126. class TestMinBoundInteger(TraitTestBase):
  1127. obj = MinBoundIntegerTrait()
  1128. _default_value = 5
  1129. _good_values = 3, 20
  1130. _bad_values = [2, -10]
  1131. class MaxBoundIntegerTrait(HasTraits):
  1132. value = Integer(1, max=3)
  1133. class TestMaxBoundInteger(TraitTestBase):
  1134. obj = MaxBoundIntegerTrait()
  1135. _default_value = 1
  1136. _good_values = 3, -2
  1137. _bad_values = [4, 10]
  1138. class FloatTrait(HasTraits):
  1139. value = Float(99.0, max=200.0)
  1140. class TestFloat(TraitTestBase):
  1141. obj = FloatTrait()
  1142. _default_value = 99.0
  1143. _good_values = [10, -10, 10.1, -10.1]
  1144. _bad_values = [
  1145. "ten",
  1146. [10],
  1147. {"ten": 10},
  1148. (10,),
  1149. None,
  1150. 1j,
  1151. "10",
  1152. "-10",
  1153. "10L",
  1154. "-10L",
  1155. "10.1",
  1156. "-10.1",
  1157. 201.0,
  1158. ]
  1159. class CFloatTrait(HasTraits):
  1160. value = CFloat("99.0", max=200.0)
  1161. class TestCFloat(TraitTestBase):
  1162. obj = CFloatTrait()
  1163. _default_value = 99.0
  1164. _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"]
  1165. _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"]
  1166. def coerce(self, v):
  1167. return float(v)
  1168. class ComplexTrait(HasTraits):
  1169. value = Complex(99.0 - 99.0j)
  1170. class TestComplex(TraitTestBase):
  1171. obj = ComplexTrait()
  1172. _default_value = 99.0 - 99.0j
  1173. _good_values = [
  1174. 10,
  1175. -10,
  1176. 10.1,
  1177. -10.1,
  1178. 10j,
  1179. 10 + 10j,
  1180. 10 - 10j,
  1181. 10.1j,
  1182. 10.1 + 10.1j,
  1183. 10.1 - 10.1j,
  1184. ]
  1185. _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None]
  1186. class BytesTrait(HasTraits):
  1187. value = Bytes(b"string")
  1188. class TestBytes(TraitTestBase):
  1189. obj = BytesTrait()
  1190. _default_value = b"string"
  1191. _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"]
  1192. _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"]
  1193. class UnicodeTrait(HasTraits):
  1194. value = Unicode("unicode")
  1195. class TestUnicode(TraitTestBase):
  1196. obj = UnicodeTrait()
  1197. _default_value = "unicode"
  1198. _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"]
  1199. _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None]
  1200. def coerce(self, v):
  1201. return cast_unicode(v)
  1202. class ObjectNameTrait(HasTraits):
  1203. value = ObjectName("abc")
  1204. class TestObjectName(TraitTestBase):
  1205. obj = ObjectNameTrait()
  1206. _default_value = "abc"
  1207. _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"]
  1208. _bad_values = [
  1209. 1,
  1210. "",
  1211. "€",
  1212. "9g",
  1213. "!",
  1214. "#abc",
  1215. "aj@",
  1216. "a.b",
  1217. "a()",
  1218. "a[0]",
  1219. None,
  1220. object(),
  1221. object,
  1222. ]
  1223. _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131).
  1224. class DottedObjectNameTrait(HasTraits):
  1225. value = DottedObjectName("a.b")
  1226. class TestDottedObjectName(TraitTestBase):
  1227. obj = DottedObjectNameTrait()
  1228. _default_value = "a.b"
  1229. _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"]
  1230. _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
  1231. _good_values.append("t.þ")
  1232. class TCPAddressTrait(HasTraits):
  1233. value = TCPAddress()
  1234. class TestTCPAddress(TraitTestBase):
  1235. obj = TCPAddressTrait()
  1236. _default_value = ("127.0.0.1", 0)
  1237. _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)]
  1238. _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None]
  1239. class ListTrait(HasTraits):
  1240. value = List(Int())
  1241. class TestList(TraitTestBase):
  1242. obj = ListTrait()
  1243. _default_value: t.List[t.Any] = []
  1244. _good_values = [[], [1], list(range(10)), (1, 2)]
  1245. _bad_values = [10, [1, "a"], "a"]
  1246. def coerce(self, value):
  1247. if value is not None:
  1248. value = list(value)
  1249. return value
  1250. class Foo:
  1251. pass
  1252. class NoneInstanceListTrait(HasTraits):
  1253. value = List(Instance(Foo))
  1254. class TestNoneInstanceList(TraitTestBase):
  1255. obj = NoneInstanceListTrait()
  1256. _default_value: t.List[t.Any] = []
  1257. _good_values = [[Foo(), Foo()], []]
  1258. _bad_values = [[None], [Foo(), None]]
  1259. class InstanceListTrait(HasTraits):
  1260. value = List(Instance(__name__ + ".Foo"))
  1261. class TestInstanceList(TraitTestBase):
  1262. obj = InstanceListTrait()
  1263. def test_klass(self):
  1264. """Test that the instance klass is properly assigned."""
  1265. self.assertIs(self.obj.traits()["value"]._trait.klass, Foo)
  1266. _default_value: t.List[t.Any] = []
  1267. _good_values = [[Foo(), Foo()], []]
  1268. _bad_values = [
  1269. [
  1270. "1",
  1271. 2,
  1272. ],
  1273. "1",
  1274. [Foo],
  1275. None,
  1276. ]
  1277. class UnionListTrait(HasTraits):
  1278. value = List(Int() | Bool())
  1279. class TestUnionListTrait(TraitTestBase):
  1280. obj = UnionListTrait()
  1281. _default_value: t.List[t.Any] = []
  1282. _good_values = [[True, 1], [False, True]]
  1283. _bad_values = [[1, "True"], False]
  1284. class LenListTrait(HasTraits):
  1285. value = List(Int(), [0], minlen=1, maxlen=2)
  1286. class TestLenList(TraitTestBase):
  1287. obj = LenListTrait()
  1288. _default_value = [0]
  1289. _good_values = [[1], [1, 2], (1, 2)]
  1290. _bad_values = [10, [1, "a"], "a", [], list(range(3))]
  1291. def coerce(self, value):
  1292. if value is not None:
  1293. value = list(value)
  1294. return value
  1295. class TupleTrait(HasTraits):
  1296. value = Tuple(Int(allow_none=True), default_value=(1,))
  1297. class TestTupleTrait(TraitTestBase):
  1298. obj = TupleTrait()
  1299. _default_value = (1,)
  1300. _good_values = [(1,), (0,), [1]]
  1301. _bad_values = [10, (1, 2), ("a"), (), None]
  1302. def coerce(self, value):
  1303. if value is not None:
  1304. value = tuple(value)
  1305. return value
  1306. def test_invalid_args(self):
  1307. self.assertRaises(TypeError, Tuple, 5)
  1308. self.assertRaises(TypeError, Tuple, default_value="hello")
  1309. t = Tuple(Int(), CBytes(), default_value=(1, 5))
  1310. class LooseTupleTrait(HasTraits):
  1311. value = Tuple((1, 2, 3))
  1312. class TestLooseTupleTrait(TraitTestBase):
  1313. obj = LooseTupleTrait()
  1314. _default_value = (1, 2, 3)
  1315. _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()]
  1316. _bad_values = [10, "hello", {}, None]
  1317. def coerce(self, value):
  1318. if value is not None:
  1319. value = tuple(value)
  1320. return value
  1321. def test_invalid_args(self):
  1322. self.assertRaises(TypeError, Tuple, 5)
  1323. self.assertRaises(TypeError, Tuple, default_value="hello")
  1324. t = Tuple(Int(), CBytes(), default_value=(1, 5))
  1325. class MultiTupleTrait(HasTraits):
  1326. value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"])
  1327. class TestMultiTuple(TraitTestBase):
  1328. obj = MultiTupleTrait()
  1329. _default_value = (99, b"bottles")
  1330. _good_values = [(1, b"a"), (2, b"b")]
  1331. _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a"))
  1332. @pytest.mark.parametrize(
  1333. "Trait",
  1334. ( # noqa: PT007
  1335. List,
  1336. Tuple,
  1337. Set,
  1338. Dict,
  1339. Integer,
  1340. Unicode,
  1341. ),
  1342. )
  1343. def test_allow_none_default_value(Trait):
  1344. class C(HasTraits):
  1345. t = Trait(default_value=None, allow_none=True)
  1346. # test default value
  1347. c = C()
  1348. assert c.t is None
  1349. # and in constructor
  1350. c = C(t=None)
  1351. assert c.t is None
  1352. @pytest.mark.parametrize(
  1353. "Trait, default_value",
  1354. ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")), # noqa: PT007
  1355. )
  1356. def test_default_value(Trait, default_value):
  1357. class C(HasTraits):
  1358. t = Trait()
  1359. # test default value
  1360. c = C()
  1361. assert type(c.t) is type(default_value)
  1362. assert c.t == default_value
  1363. @pytest.mark.parametrize(
  1364. "Trait, default_value",
  1365. ((List, []), (Tuple, ()), (Set, set())), # noqa: PT007
  1366. )
  1367. def test_subclass_default_value(Trait, default_value):
  1368. """Test deprecated default_value=None behavior for Container subclass traits"""
  1369. class SubclassTrait(Trait): # type:ignore
  1370. def __init__(self, default_value=None):
  1371. super().__init__(default_value=default_value)
  1372. class C(HasTraits):
  1373. t = SubclassTrait()
  1374. # test default value
  1375. c = C()
  1376. assert type(c.t) is type(default_value)
  1377. assert c.t == default_value
  1378. class CRegExpTrait(HasTraits):
  1379. value = CRegExp(r"")
  1380. class TestCRegExp(TraitTestBase):
  1381. def coerce(self, value):
  1382. return re.compile(value)
  1383. obj = CRegExpTrait()
  1384. _default_value = re.compile(r"")
  1385. _good_values = [r"\d+", re.compile(r"\d+")]
  1386. _bad_values = ["(", None, ()]
  1387. class DictTrait(HasTraits):
  1388. value = Dict()
  1389. def test_dict_assignment():
  1390. d: t.Dict[str, int] = {}
  1391. c = DictTrait()
  1392. c.value = d
  1393. d["a"] = 5
  1394. assert d == c.value
  1395. assert c.value is d
  1396. class UniformlyValueValidatedDictTrait(HasTraits):
  1397. value = Dict(value_trait=Unicode(), default_value={"foo": "1"})
  1398. class TestInstanceUniformlyValueValidatedDict(TraitTestBase):
  1399. obj = UniformlyValueValidatedDictTrait()
  1400. _default_value = {"foo": "1"}
  1401. _good_values = [{"foo": "0", "bar": "1"}]
  1402. _bad_values = [{"foo": 0, "bar": "1"}]
  1403. class NonuniformlyValueValidatedDictTrait(HasTraits):
  1404. value = Dict(per_key_traits={"foo": Int()}, default_value={"foo": 1})
  1405. class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase):
  1406. obj = NonuniformlyValueValidatedDictTrait()
  1407. _default_value = {"foo": 1}
  1408. _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}]
  1409. _bad_values = [{"foo": "0", "bar": "1"}]
  1410. class KeyValidatedDictTrait(HasTraits):
  1411. value = Dict(key_trait=Unicode(), default_value={"foo": "1"})
  1412. class TestInstanceKeyValidatedDict(TraitTestBase):
  1413. obj = KeyValidatedDictTrait()
  1414. _default_value = {"foo": "1"}
  1415. _good_values = [{"foo": "0", "bar": "1"}]
  1416. _bad_values = [{"foo": "0", 0: "1"}]
  1417. class FullyValidatedDictTrait(HasTraits):
  1418. value = Dict(
  1419. value_trait=Unicode(),
  1420. key_trait=Unicode(),
  1421. per_key_traits={"foo": Int()},
  1422. default_value={"foo": 1},
  1423. )
  1424. class TestInstanceFullyValidatedDict(TraitTestBase):
  1425. obj = FullyValidatedDictTrait()
  1426. _default_value = {"foo": 1}
  1427. _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}]
  1428. _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}]
  1429. def test_dict_default_value():
  1430. """Check that the `{}` default value of the Dict traitlet constructor is
  1431. actually copied."""
  1432. class Foo(HasTraits):
  1433. d1 = Dict()
  1434. d2 = Dict()
  1435. foo = Foo()
  1436. assert foo.d1 == {}
  1437. assert foo.d2 == {}
  1438. assert foo.d1 is not foo.d2
  1439. class TestValidationHook(TestCase):
  1440. def test_parity_trait(self):
  1441. """Verify that the early validation hook is effective"""
  1442. class Parity(HasTraits):
  1443. value = Int(0)
  1444. parity = Enum(["odd", "even"], default_value="even")
  1445. @validate("value")
  1446. def _value_validate(self, proposal):
  1447. value = proposal["value"]
  1448. if self.parity == "even" and value % 2:
  1449. raise TraitError("Expected an even number")
  1450. if self.parity == "odd" and (value % 2 == 0):
  1451. raise TraitError("Expected an odd number")
  1452. return value
  1453. u = Parity()
  1454. u.parity = "odd"
  1455. u.value = 1 # OK
  1456. with self.assertRaises(TraitError):
  1457. u.value = 2 # Trait Error
  1458. u.parity = "even"
  1459. u.value = 2 # OK
  1460. def test_multiple_validate(self):
  1461. """Verify that we can register the same validator to multiple names"""
  1462. class OddEven(HasTraits):
  1463. odd = Int(1)
  1464. even = Int(0)
  1465. @validate("odd", "even")
  1466. def check_valid(self, proposal):
  1467. if proposal["trait"].name == "odd" and not proposal["value"] % 2:
  1468. raise TraitError("odd should be odd")
  1469. if proposal["trait"].name == "even" and proposal["value"] % 2:
  1470. raise TraitError("even should be even")
  1471. u = OddEven()
  1472. u.odd = 3 # OK
  1473. with self.assertRaises(TraitError):
  1474. u.odd = 2 # Trait Error
  1475. u.even = 2 # OK
  1476. with self.assertRaises(TraitError):
  1477. u.even = 3 # Trait Error
  1478. def test_validate_used(self):
  1479. """Verify that the validate value is being used"""
  1480. class FixedValue(HasTraits):
  1481. value = Int(0)
  1482. @validate("value")
  1483. def _value_validate(self, proposal):
  1484. return -1
  1485. u = FixedValue(value=2)
  1486. assert u.value == -1
  1487. u = FixedValue()
  1488. u.value = 3
  1489. assert u.value == -1
  1490. class TestLink(TestCase):
  1491. def test_connect_same(self):
  1492. """Verify two traitlets of the same type can be linked together using link."""
  1493. # Create two simple classes with Int traitlets.
  1494. class A(HasTraits):
  1495. value = Int()
  1496. a = A(value=9)
  1497. b = A(value=8)
  1498. # Connect the two classes.
  1499. c = link((a, "value"), (b, "value"))
  1500. # Make sure the values are the same at the point of linking.
  1501. self.assertEqual(a.value, b.value)
  1502. # Change one of the values to make sure they stay in sync.
  1503. a.value = 5
  1504. self.assertEqual(a.value, b.value)
  1505. b.value = 6
  1506. self.assertEqual(a.value, b.value)
  1507. def test_link_different(self):
  1508. """Verify two traitlets of different types can be linked together using link."""
  1509. # Create two simple classes with Int traitlets.
  1510. class A(HasTraits):
  1511. value = Int()
  1512. class B(HasTraits):
  1513. count = Int()
  1514. a = A(value=9)
  1515. b = B(count=8)
  1516. # Connect the two classes.
  1517. c = link((a, "value"), (b, "count"))
  1518. # Make sure the values are the same at the point of linking.
  1519. self.assertEqual(a.value, b.count)
  1520. # Change one of the values to make sure they stay in sync.
  1521. a.value = 5
  1522. self.assertEqual(a.value, b.count)
  1523. b.count = 4
  1524. self.assertEqual(a.value, b.count)
  1525. def test_unlink_link(self):
  1526. """Verify two linked traitlets can be unlinked and relinked."""
  1527. # Create two simple classes with Int traitlets.
  1528. class A(HasTraits):
  1529. value = Int()
  1530. a = A(value=9)
  1531. b = A(value=8)
  1532. # Connect the two classes.
  1533. c = link((a, "value"), (b, "value"))
  1534. a.value = 4
  1535. c.unlink()
  1536. # Change one of the values to make sure they don't stay in sync.
  1537. a.value = 5
  1538. self.assertNotEqual(a.value, b.value)
  1539. c.link()
  1540. self.assertEqual(a.value, b.value)
  1541. a.value += 1
  1542. self.assertEqual(a.value, b.value)
  1543. def test_callbacks(self):
  1544. """Verify two linked traitlets have their callbacks called once."""
  1545. # Create two simple classes with Int traitlets.
  1546. class A(HasTraits):
  1547. value = Int()
  1548. class B(HasTraits):
  1549. count = Int()
  1550. a = A(value=9)
  1551. b = B(count=8)
  1552. # Register callbacks that count.
  1553. callback_count = []
  1554. def a_callback(name, old, new):
  1555. callback_count.append("a")
  1556. a.on_trait_change(a_callback, "value")
  1557. def b_callback(name, old, new):
  1558. callback_count.append("b")
  1559. b.on_trait_change(b_callback, "count")
  1560. # Connect the two classes.
  1561. c = link((a, "value"), (b, "count"))
  1562. # Make sure b's count was set to a's value once.
  1563. self.assertEqual("".join(callback_count), "b")
  1564. del callback_count[:]
  1565. # Make sure a's value was set to b's count once.
  1566. b.count = 5
  1567. self.assertEqual("".join(callback_count), "ba")
  1568. del callback_count[:]
  1569. # Make sure b's count was set to a's value once.
  1570. a.value = 4
  1571. self.assertEqual("".join(callback_count), "ab")
  1572. del callback_count[:]
  1573. def test_tranform(self):
  1574. """Test transform link."""
  1575. # Create two simple classes with Int traitlets.
  1576. class A(HasTraits):
  1577. value = Int()
  1578. a = A(value=9)
  1579. b = A(value=8)
  1580. # Connect the two classes.
  1581. c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0)))
  1582. # Make sure the values are correct at the point of linking.
  1583. self.assertEqual(b.value, 2 * a.value)
  1584. # Change one the value of the source and check that it modifies the target.
  1585. a.value = 5
  1586. self.assertEqual(b.value, 10)
  1587. # Change one the value of the target and check that it modifies the
  1588. # source.
  1589. b.value = 6
  1590. self.assertEqual(a.value, 3)
  1591. def test_link_broken_at_source(self):
  1592. class MyClass(HasTraits):
  1593. i = Int()
  1594. j = Int()
  1595. @observe("j")
  1596. def another_update(self, change):
  1597. self.i = change.new * 2
  1598. mc = MyClass()
  1599. l = link((mc, "i"), (mc, "j")) # noqa: E741
  1600. self.assertRaises(TraitError, setattr, mc, "i", 2)
  1601. def test_link_broken_at_target(self):
  1602. class MyClass(HasTraits):
  1603. i = Int()
  1604. j = Int()
  1605. @observe("i")
  1606. def another_update(self, change):
  1607. self.j = change.new * 2
  1608. mc = MyClass()
  1609. l = link((mc, "i"), (mc, "j")) # noqa: E741
  1610. self.assertRaises(TraitError, setattr, mc, "j", 2)
  1611. class TestDirectionalLink(TestCase):
  1612. def test_connect_same(self):
  1613. """Verify two traitlets of the same type can be linked together using directional_link."""
  1614. # Create two simple classes with Int traitlets.
  1615. class A(HasTraits):
  1616. value = Int()
  1617. a = A(value=9)
  1618. b = A(value=8)
  1619. # Connect the two classes.
  1620. c = directional_link((a, "value"), (b, "value"))
  1621. # Make sure the values are the same at the point of linking.
  1622. self.assertEqual(a.value, b.value)
  1623. # Change one the value of the source and check that it synchronizes the target.
  1624. a.value = 5
  1625. self.assertEqual(b.value, 5)
  1626. # Change one the value of the target and check that it has no impact on the source
  1627. b.value = 6
  1628. self.assertEqual(a.value, 5)
  1629. def test_tranform(self):
  1630. """Test transform link."""
  1631. # Create two simple classes with Int traitlets.
  1632. class A(HasTraits):
  1633. value = Int()
  1634. a = A(value=9)
  1635. b = A(value=8)
  1636. # Connect the two classes.
  1637. c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x)
  1638. # Make sure the values are correct at the point of linking.
  1639. self.assertEqual(b.value, 2 * a.value)
  1640. # Change one the value of the source and check that it modifies the target.
  1641. a.value = 5
  1642. self.assertEqual(b.value, 10)
  1643. # Change one the value of the target and check that it has no impact on the source
  1644. b.value = 6
  1645. self.assertEqual(a.value, 5)
  1646. def test_link_different(self):
  1647. """Verify two traitlets of different types can be linked together using link."""
  1648. # Create two simple classes with Int traitlets.
  1649. class A(HasTraits):
  1650. value = Int()
  1651. class B(HasTraits):
  1652. count = Int()
  1653. a = A(value=9)
  1654. b = B(count=8)
  1655. # Connect the two classes.
  1656. c = directional_link((a, "value"), (b, "count"))
  1657. # Make sure the values are the same at the point of linking.
  1658. self.assertEqual(a.value, b.count)
  1659. # Change one the value of the source and check that it synchronizes the target.
  1660. a.value = 5
  1661. self.assertEqual(b.count, 5)
  1662. # Change one the value of the target and check that it has no impact on the source
  1663. b.value = 6 # type:ignore
  1664. self.assertEqual(a.value, 5)
  1665. def test_unlink_link(self):
  1666. """Verify two linked traitlets can be unlinked and relinked."""
  1667. # Create two simple classes with Int traitlets.
  1668. class A(HasTraits):
  1669. value = Int()
  1670. a = A(value=9)
  1671. b = A(value=8)
  1672. # Connect the two classes.
  1673. c = directional_link((a, "value"), (b, "value"))
  1674. a.value = 4
  1675. c.unlink()
  1676. # Change one of the values to make sure they don't stay in sync.
  1677. a.value = 5
  1678. self.assertNotEqual(a.value, b.value)
  1679. c.link()
  1680. self.assertEqual(a.value, b.value)
  1681. a.value += 1
  1682. self.assertEqual(a.value, b.value)
  1683. class Pickleable(HasTraits):
  1684. i = Int()
  1685. @observe("i")
  1686. def _i_changed(self, change):
  1687. pass
  1688. @validate("i")
  1689. def _i_validate(self, commit):
  1690. return commit["value"]
  1691. j = Int()
  1692. def __init__(self):
  1693. with self.hold_trait_notifications():
  1694. self.i = 1
  1695. self.on_trait_change(self._i_changed, "i")
  1696. def test_pickle_hastraits():
  1697. c = Pickleable()
  1698. for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
  1699. p = pickle.dumps(c, protocol)
  1700. c2 = pickle.loads(p)
  1701. assert c2.i == c.i
  1702. assert c2.j == c.j
  1703. c.i = 5
  1704. for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
  1705. p = pickle.dumps(c, protocol)
  1706. c2 = pickle.loads(p)
  1707. assert c2.i == c.i
  1708. assert c2.j == c.j
  1709. def test_hold_trait_notifications():
  1710. changes = []
  1711. class Test(HasTraits):
  1712. a = Integer(0)
  1713. b = Integer(0)
  1714. def _a_changed(self, name, old, new):
  1715. changes.append((old, new))
  1716. def _b_validate(self, value, trait):
  1717. if value != 0:
  1718. raise TraitError("Only 0 is a valid value")
  1719. return value
  1720. # Test context manager and nesting
  1721. t = Test()
  1722. with t.hold_trait_notifications():
  1723. with t.hold_trait_notifications():
  1724. t.a = 1
  1725. assert t.a == 1
  1726. assert changes == []
  1727. t.a = 2
  1728. assert t.a == 2
  1729. with t.hold_trait_notifications():
  1730. t.a = 3
  1731. assert t.a == 3
  1732. assert changes == []
  1733. t.a = 4
  1734. assert t.a == 4
  1735. assert changes == []
  1736. t.a = 4
  1737. assert t.a == 4
  1738. assert changes == []
  1739. assert changes == [(0, 4)]
  1740. # Test roll-back
  1741. try:
  1742. with t.hold_trait_notifications():
  1743. t.b = 1 # raises a Trait error
  1744. except Exception:
  1745. pass
  1746. assert t.b == 0
  1747. class RollBack(HasTraits):
  1748. bar = Int()
  1749. def _bar_validate(self, value, trait):
  1750. if value:
  1751. raise TraitError("foobar")
  1752. return value
  1753. class TestRollback(TestCase):
  1754. def test_roll_back(self):
  1755. def assign_rollback():
  1756. RollBack(bar=1)
  1757. self.assertRaises(TraitError, assign_rollback)
  1758. class CacheModification(HasTraits):
  1759. foo = Int()
  1760. bar = Int()
  1761. def _bar_validate(self, value, trait):
  1762. self.foo = value
  1763. return value
  1764. def _foo_validate(self, value, trait):
  1765. self.bar = value
  1766. return value
  1767. def test_cache_modification():
  1768. CacheModification(foo=1)
  1769. CacheModification(bar=1)
  1770. class OrderTraits(HasTraits):
  1771. notified = Dict()
  1772. a = Unicode()
  1773. b = Unicode()
  1774. c = Unicode()
  1775. d = Unicode()
  1776. e = Unicode()
  1777. f = Unicode()
  1778. g = Unicode()
  1779. h = Unicode()
  1780. i = Unicode()
  1781. j = Unicode()
  1782. k = Unicode()
  1783. l = Unicode() # noqa: E741
  1784. def _notify(self, name, old, new):
  1785. """check the value of all traits when each trait change is triggered
  1786. This verifies that the values are not sensitive
  1787. to dict ordering when loaded from kwargs
  1788. """
  1789. # check the value of the other traits
  1790. # when a given trait change notification fires
  1791. self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"}
  1792. def __init__(self, **kwargs):
  1793. self.on_trait_change(self._notify)
  1794. super().__init__(**kwargs)
  1795. def test_notification_order():
  1796. d = {c: c for c in "abcdefghijkl"}
  1797. obj = OrderTraits()
  1798. assert obj.notified == {}
  1799. obj = OrderTraits(**d)
  1800. notifications = {c: d for c in "abcdefghijkl"}
  1801. assert obj.notified == notifications
  1802. ###
  1803. # Traits for Forward Declaration Tests
  1804. ###
  1805. class ForwardDeclaredInstanceTrait(HasTraits):
  1806. value = ForwardDeclaredInstance["ForwardDeclaredBar"]("ForwardDeclaredBar", allow_none=True)
  1807. class ForwardDeclaredTypeTrait(HasTraits):
  1808. value = ForwardDeclaredType[t.Any, t.Any]("ForwardDeclaredBar", allow_none=True)
  1809. class ForwardDeclaredInstanceListTrait(HasTraits):
  1810. value = List(ForwardDeclaredInstance("ForwardDeclaredBar"))
  1811. class ForwardDeclaredTypeListTrait(HasTraits):
  1812. value = List(ForwardDeclaredType("ForwardDeclaredBar"))
  1813. ###
  1814. # End Traits for Forward Declaration Tests
  1815. ###
  1816. ###
  1817. # Classes for Forward Declaration Tests
  1818. ###
  1819. class ForwardDeclaredBar:
  1820. pass
  1821. class ForwardDeclaredBarSub(ForwardDeclaredBar):
  1822. pass
  1823. ###
  1824. # End Classes for Forward Declaration Tests
  1825. ###
  1826. ###
  1827. # Forward Declaration Tests
  1828. ###
  1829. class TestForwardDeclaredInstanceTrait(TraitTestBase):
  1830. obj = ForwardDeclaredInstanceTrait()
  1831. _default_value = None
  1832. _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
  1833. _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
  1834. class TestForwardDeclaredTypeTrait(TraitTestBase):
  1835. obj = ForwardDeclaredTypeTrait()
  1836. _default_value = None
  1837. _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
  1838. _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
  1839. class TestForwardDeclaredInstanceList(TraitTestBase):
  1840. obj = ForwardDeclaredInstanceListTrait()
  1841. def test_klass(self):
  1842. """Test that the instance klass is properly assigned."""
  1843. self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar)
  1844. _default_value: t.List[t.Any] = []
  1845. _good_values = [
  1846. [ForwardDeclaredBar(), ForwardDeclaredBarSub()],
  1847. [],
  1848. ]
  1849. _bad_values = [
  1850. ForwardDeclaredBar(),
  1851. [ForwardDeclaredBar(), 3, None],
  1852. "1",
  1853. # Note that this is the type, not an instance.
  1854. [ForwardDeclaredBar],
  1855. [None],
  1856. None,
  1857. ]
  1858. class TestForwardDeclaredTypeList(TraitTestBase):
  1859. obj = ForwardDeclaredTypeListTrait()
  1860. def test_klass(self):
  1861. """Test that the instance klass is properly assigned."""
  1862. self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar)
  1863. _default_value: t.List[t.Any] = []
  1864. _good_values = [
  1865. [ForwardDeclaredBar, ForwardDeclaredBarSub],
  1866. [],
  1867. ]
  1868. _bad_values = [
  1869. ForwardDeclaredBar,
  1870. [ForwardDeclaredBar, 3],
  1871. "1",
  1872. # Note that this is an instance, not the type.
  1873. [ForwardDeclaredBar()],
  1874. [None],
  1875. None,
  1876. ]
  1877. ###
  1878. # End Forward Declaration Tests
  1879. ###
  1880. class TestDynamicTraits(TestCase):
  1881. def setUp(self):
  1882. self._notify1 = []
  1883. def notify1(self, name, old, new):
  1884. self._notify1.append((name, old, new))
  1885. @t.no_type_check
  1886. def test_notify_all(self):
  1887. class A(HasTraits):
  1888. pass
  1889. a = A()
  1890. self.assertTrue(not hasattr(a, "x"))
  1891. self.assertTrue(not hasattr(a, "y"))
  1892. # Dynamically add trait x.
  1893. a.add_traits(x=Int())
  1894. self.assertTrue(hasattr(a, "x"))
  1895. self.assertTrue(isinstance(a, (A,)))
  1896. # Dynamically add trait y.
  1897. a.add_traits(y=Float())
  1898. self.assertTrue(hasattr(a, "y"))
  1899. self.assertTrue(isinstance(a, (A,)))
  1900. self.assertEqual(a.__class__.__name__, A.__name__)
  1901. # Create a new instance and verify that x and y
  1902. # aren't defined.
  1903. b = A()
  1904. self.assertTrue(not hasattr(b, "x"))
  1905. self.assertTrue(not hasattr(b, "y"))
  1906. # Verify that notification works like normal.
  1907. a.on_trait_change(self.notify1)
  1908. a.x = 0
  1909. self.assertEqual(len(self._notify1), 0)
  1910. a.y = 0.0
  1911. self.assertEqual(len(self._notify1), 0)
  1912. a.x = 10
  1913. self.assertTrue(("x", 0, 10) in self._notify1)
  1914. a.y = 10.0
  1915. self.assertTrue(("y", 0.0, 10.0) in self._notify1)
  1916. self.assertRaises(TraitError, setattr, a, "x", "bad string")
  1917. self.assertRaises(TraitError, setattr, a, "y", "bad string")
  1918. self._notify1 = []
  1919. a.on_trait_change(self.notify1, remove=True)
  1920. a.x = 20
  1921. a.y = 20.0
  1922. self.assertEqual(len(self._notify1), 0)
  1923. def test_enum_no_default():
  1924. class C(HasTraits):
  1925. t = Enum(["a", "b"])
  1926. c = C()
  1927. c.t = "a"
  1928. assert c.t == "a"
  1929. c = C()
  1930. with pytest.raises(TraitError):
  1931. t = c.t
  1932. c = C(t="b")
  1933. assert c.t == "b"
  1934. def test_default_value_repr():
  1935. class C(HasTraits):
  1936. t = Type("traitlets.HasTraits")
  1937. t2 = Type(HasTraits)
  1938. n = Integer(0)
  1939. lis = List()
  1940. d = Dict()
  1941. assert C.t.default_value_repr() == "'traitlets.HasTraits'"
  1942. assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'"
  1943. assert C.n.default_value_repr() == "0"
  1944. assert C.lis.default_value_repr() == "[]"
  1945. assert C.d.default_value_repr() == "{}"
  1946. class TransitionalClass(HasTraits):
  1947. d = Any()
  1948. @default("d")
  1949. def _d_default(self):
  1950. return TransitionalClass
  1951. parent_super = False
  1952. calls_super = Integer(0)
  1953. @default("calls_super")
  1954. def _calls_super_default(self):
  1955. return -1
  1956. @observe("calls_super")
  1957. @observe_compat
  1958. def _calls_super_changed(self, change):
  1959. self.parent_super = change
  1960. parent_override = False
  1961. overrides = Integer(0)
  1962. @observe("overrides")
  1963. @observe_compat
  1964. def _overrides_changed(self, change):
  1965. self.parent_override = change
  1966. class SubClass(TransitionalClass):
  1967. def _d_default(self):
  1968. return SubClass
  1969. subclass_super = False
  1970. def _calls_super_changed(self, name, old, new):
  1971. self.subclass_super = True
  1972. super()._calls_super_changed(name, old, new)
  1973. subclass_override = False
  1974. def _overrides_changed(self, name, old, new):
  1975. self.subclass_override = True
  1976. def test_subclass_compat():
  1977. obj = SubClass()
  1978. obj.calls_super = 5
  1979. assert obj.parent_super
  1980. assert obj.subclass_super
  1981. obj.overrides = 5
  1982. assert obj.subclass_override
  1983. assert not obj.parent_override
  1984. assert obj.d is SubClass
  1985. class DefinesHandler(HasTraits):
  1986. parent_called = False
  1987. trait = Integer()
  1988. @observe("trait")
  1989. def handler(self, change):
  1990. self.parent_called = True
  1991. class OverridesHandler(DefinesHandler):
  1992. child_called = False
  1993. @observe("trait")
  1994. def handler(self, change):
  1995. self.child_called = True
  1996. def test_subclass_override_observer():
  1997. obj = OverridesHandler()
  1998. obj.trait = 5
  1999. assert obj.child_called
  2000. assert not obj.parent_called
  2001. class DoesntRegisterHandler(DefinesHandler):
  2002. child_called = False
  2003. def handler(self, change):
  2004. self.child_called = True
  2005. def test_subclass_override_not_registered():
  2006. """Subclass that overrides observer and doesn't re-register unregisters both"""
  2007. obj = DoesntRegisterHandler()
  2008. obj.trait = 5
  2009. assert not obj.child_called
  2010. assert not obj.parent_called
  2011. class AddsHandler(DefinesHandler):
  2012. child_called = False
  2013. @observe("trait")
  2014. def child_handler(self, change):
  2015. self.child_called = True
  2016. def test_subclass_add_observer():
  2017. obj = AddsHandler()
  2018. obj.trait = 5
  2019. assert obj.child_called
  2020. assert obj.parent_called
  2021. def test_observe_iterables():
  2022. class C(HasTraits):
  2023. i = Integer()
  2024. s = Unicode()
  2025. c = C()
  2026. recorded = {}
  2027. def record(change):
  2028. recorded["change"] = change
  2029. # observe with names=set
  2030. c.observe(record, names={"i", "s"})
  2031. c.i = 5
  2032. assert recorded["change"].name == "i"
  2033. assert recorded["change"].new == 5
  2034. c.s = "hi"
  2035. assert recorded["change"].name == "s"
  2036. assert recorded["change"].new == "hi"
  2037. # observe with names=custom container with iter, contains
  2038. class MyContainer:
  2039. def __init__(self, container):
  2040. self.container = container
  2041. def __iter__(self):
  2042. return iter(self.container)
  2043. def __contains__(self, key):
  2044. return key in self.container
  2045. c.observe(record, names=MyContainer({"i", "s"}))
  2046. c.i = 10
  2047. assert recorded["change"].name == "i"
  2048. assert recorded["change"].new == 10
  2049. c.s = "ok"
  2050. assert recorded["change"].name == "s"
  2051. assert recorded["change"].new == "ok"
  2052. def test_super_args():
  2053. class SuperRecorder:
  2054. def __init__(self, *args, **kwargs):
  2055. self.super_args = args
  2056. self.super_kwargs = kwargs
  2057. class SuperHasTraits(HasTraits, SuperRecorder):
  2058. i = Integer()
  2059. obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x")
  2060. assert obj.i == 5
  2061. assert not hasattr(obj, "b")
  2062. assert not hasattr(obj, "c")
  2063. assert obj.super_args == ("a1", "a2")
  2064. assert obj.super_kwargs == {"b": 10, "c": "x"}
  2065. def test_super_bad_args():
  2066. class SuperHasTraits(HasTraits):
  2067. a = Integer()
  2068. w = ["Passing unrecognized arguments"]
  2069. with expected_warnings(w):
  2070. obj = SuperHasTraits(a=1, b=2)
  2071. assert obj.a == 1
  2072. assert not hasattr(obj, "b")
  2073. def test_default_mro():
  2074. """Verify that default values follow mro"""
  2075. class Base(HasTraits):
  2076. trait = Unicode("base")
  2077. attr = "base"
  2078. class A(Base):
  2079. pass
  2080. class B(Base):
  2081. trait = Unicode("B")
  2082. attr = "B"
  2083. class AB(A, B):
  2084. pass
  2085. class BA(B, A):
  2086. pass
  2087. assert A().trait == "base"
  2088. assert A().attr == "base"
  2089. assert BA().trait == "B"
  2090. assert BA().attr == "B"
  2091. assert AB().trait == "B"
  2092. assert AB().attr == "B"
  2093. def test_cls_self_argument():
  2094. class X(HasTraits):
  2095. def __init__(__self, cls, self):
  2096. pass
  2097. x = X(cls=None, self=None)
  2098. def test_override_default():
  2099. class C(HasTraits):
  2100. a = Unicode("hard default")
  2101. def _a_default(self):
  2102. return "default method"
  2103. C._a_default = lambda self: "overridden" # type:ignore
  2104. c = C()
  2105. assert c.a == "overridden"
  2106. def test_override_default_decorator():
  2107. class C(HasTraits):
  2108. a = Unicode("hard default")
  2109. @default("a")
  2110. def _a_default(self):
  2111. return "default method"
  2112. C._a_default = lambda self: "overridden" # type:ignore
  2113. c = C()
  2114. assert c.a == "overridden"
  2115. def test_override_default_instance():
  2116. class C(HasTraits):
  2117. a = Unicode("hard default")
  2118. @default("a")
  2119. def _a_default(self):
  2120. return "default method"
  2121. c = C()
  2122. c._a_default = lambda self: "overridden"
  2123. assert c.a == "overridden"
  2124. def test_copy_HasTraits():
  2125. from copy import copy
  2126. class C(HasTraits):
  2127. a = Int()
  2128. c = C(a=1)
  2129. assert c.a == 1
  2130. cc = copy(c)
  2131. cc.a = 2
  2132. assert cc.a == 2
  2133. assert c.a == 1
  2134. def _from_string_test(traittype, s, expected):
  2135. """Run a test of trait.from_string"""
  2136. if isinstance(traittype, TraitType):
  2137. trait = traittype
  2138. else:
  2139. trait = traittype(allow_none=True)
  2140. if isinstance(s, list):
  2141. cast = trait.from_string_list # type:ignore
  2142. else:
  2143. cast = trait.from_string
  2144. if type(expected) is type and issubclass(expected, Exception):
  2145. with pytest.raises(expected): # noqa: PT012
  2146. value = cast(s)
  2147. trait.validate(CrossValidationStub(), value) # type:ignore
  2148. else:
  2149. value = cast(s)
  2150. assert value == expected
  2151. @pytest.mark.parametrize(
  2152. "s, expected",
  2153. [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)],
  2154. )
  2155. def test_unicode_from_string(s, expected):
  2156. _from_string_test(Unicode, s, expected)
  2157. @pytest.mark.parametrize(
  2158. "s, expected",
  2159. [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)],
  2160. )
  2161. def test_cunicode_from_string(s, expected):
  2162. _from_string_test(CUnicode, s, expected)
  2163. @pytest.mark.parametrize(
  2164. "s, expected",
  2165. [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)],
  2166. )
  2167. def test_bytes_from_string(s, expected):
  2168. _from_string_test(Bytes, s, expected)
  2169. @pytest.mark.parametrize(
  2170. "s, expected",
  2171. [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)],
  2172. )
  2173. def test_cbytes_from_string(s, expected):
  2174. _from_string_test(CBytes, s, expected)
  2175. @pytest.mark.parametrize(
  2176. "s, expected",
  2177. [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)],
  2178. )
  2179. def test_int_from_string(s, expected):
  2180. _from_string_test(Integer, s, expected)
  2181. @pytest.mark.parametrize(
  2182. "s, expected",
  2183. [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)],
  2184. )
  2185. def test_float_from_string(s, expected):
  2186. _from_string_test(Float, s, expected)
  2187. @pytest.mark.parametrize(
  2188. "s, expected",
  2189. [
  2190. ("x", ValueError),
  2191. ("1", 1.0),
  2192. ("123.5", 123.5),
  2193. ("2.5", 2.5),
  2194. ("1+2j", 1 + 2j),
  2195. ("None", None),
  2196. ],
  2197. )
  2198. def test_complex_from_string(s, expected):
  2199. _from_string_test(Complex, s, expected)
  2200. @pytest.mark.parametrize(
  2201. "s, expected",
  2202. [
  2203. ("true", True),
  2204. ("TRUE", True),
  2205. ("1", True),
  2206. ("0", False),
  2207. ("False", False),
  2208. ("false", False),
  2209. ("1.0", ValueError),
  2210. ("None", None),
  2211. ],
  2212. )
  2213. def test_bool_from_string(s, expected):
  2214. _from_string_test(Bool, s, expected)
  2215. @pytest.mark.parametrize(
  2216. "s, expected",
  2217. [
  2218. ("{}", {}),
  2219. ("1", TraitError),
  2220. ("{1: 2}", {1: 2}),
  2221. ('{"key": "value"}', {"key": "value"}),
  2222. ("x", TraitError),
  2223. ("None", None),
  2224. ],
  2225. )
  2226. def test_dict_from_string(s, expected):
  2227. _from_string_test(Dict, s, expected)
  2228. @pytest.mark.parametrize(
  2229. "s, expected",
  2230. [
  2231. ("[]", []),
  2232. ('[1, 2, "x"]', [1, 2, "x"]),
  2233. (["1", "x"], ["1", "x"]),
  2234. (["None"], None),
  2235. ],
  2236. )
  2237. def test_list_from_string(s, expected):
  2238. _from_string_test(List, s, expected)
  2239. @pytest.mark.parametrize(
  2240. "s, expected, value_trait",
  2241. [
  2242. (["1", "2", "3"], [1, 2, 3], Integer()),
  2243. (["x"], ValueError, Integer()),
  2244. (["1", "x"], ["1", "x"], Unicode()),
  2245. (["None"], [None], Unicode(allow_none=True)),
  2246. (["None"], ["None"], Unicode(allow_none=False)),
  2247. ],
  2248. )
  2249. def test_list_items_from_string(s, expected, value_trait):
  2250. _from_string_test(List(value_trait), s, expected)
  2251. @pytest.mark.parametrize(
  2252. "s, expected",
  2253. [
  2254. ("[]", set()),
  2255. ('[1, 2, "x"]', {1, 2, "x"}),
  2256. ('{1, 2, "x"}', {1, 2, "x"}),
  2257. (["1", "x"], {"1", "x"}),
  2258. (["None"], None),
  2259. ],
  2260. )
  2261. def test_set_from_string(s, expected):
  2262. _from_string_test(Set, s, expected)
  2263. @pytest.mark.parametrize(
  2264. "s, expected, value_trait",
  2265. [
  2266. (["1", "2", "3"], {1, 2, 3}, Integer()),
  2267. (["x"], ValueError, Integer()),
  2268. (["1", "x"], {"1", "x"}, Unicode()),
  2269. (["None"], {None}, Unicode(allow_none=True)),
  2270. ],
  2271. )
  2272. def test_set_items_from_string(s, expected, value_trait):
  2273. _from_string_test(Set(value_trait), s, expected)
  2274. @pytest.mark.parametrize(
  2275. "s, expected",
  2276. [
  2277. ("[]", ()),
  2278. ("()", ()),
  2279. ('[1, 2, "x"]', (1, 2, "x")),
  2280. ('(1, 2, "x")', (1, 2, "x")),
  2281. (["1", "x"], ("1", "x")),
  2282. (["None"], None),
  2283. ],
  2284. )
  2285. def test_tuple_from_string(s, expected):
  2286. _from_string_test(Tuple, s, expected)
  2287. @pytest.mark.parametrize(
  2288. "s, expected, value_traits",
  2289. [
  2290. (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]),
  2291. (["x"], ValueError, [Integer()]),
  2292. (["1", "x"], ("1", "x"), [Unicode()]),
  2293. (["None"], ("None",), [Unicode(allow_none=False)]),
  2294. (["None"], (None,), [Unicode(allow_none=True)]),
  2295. ],
  2296. )
  2297. def test_tuple_items_from_string(s, expected, value_traits):
  2298. _from_string_test(Tuple(*value_traits), s, expected)
  2299. @pytest.mark.parametrize(
  2300. "s, expected",
  2301. [
  2302. ("x", "x"),
  2303. ("mod.submod", "mod.submod"),
  2304. ("not an identifier", TraitError),
  2305. ("1", "1"),
  2306. ("None", None),
  2307. ],
  2308. )
  2309. def test_object_from_string(s, expected):
  2310. _from_string_test(DottedObjectName, s, expected)
  2311. @pytest.mark.parametrize(
  2312. "s, expected",
  2313. [
  2314. ("127.0.0.1:8000", ("127.0.0.1", 8000)),
  2315. ("host.tld:80", ("host.tld", 80)),
  2316. ("host:notaport", ValueError),
  2317. ("127.0.0.1", ValueError),
  2318. ("None", None),
  2319. ],
  2320. )
  2321. def test_tcp_from_string(s, expected):
  2322. _from_string_test(TCPAddress, s, expected)
  2323. @pytest.mark.parametrize(
  2324. "s, expected",
  2325. [("[]", []), ("{}", "{}")],
  2326. )
  2327. def test_union_of_list_and_unicode_from_string(s, expected):
  2328. _from_string_test(Union([List(), Unicode()]), s, expected)
  2329. @pytest.mark.parametrize(
  2330. "s, expected",
  2331. [("1", 1), ("1.5", 1.5)],
  2332. )
  2333. def test_union_of_int_and_float_from_string(s, expected):
  2334. _from_string_test(Union([Int(), Float()]), s, expected)
  2335. @pytest.mark.parametrize(
  2336. "s, expected, allow_none",
  2337. [("[]", [], False), ("{}", {}, False), ("None", TraitError, False), ("None", None, True)],
  2338. )
  2339. def test_union_of_list_and_dict_from_string(s, expected, allow_none):
  2340. _from_string_test(Union([List(), Dict()], allow_none=allow_none), s, expected)
  2341. def test_all_attribute():
  2342. """Verify all trait types are added to `traitlets.__all__`"""
  2343. names = dir(traitlets)
  2344. for name in names:
  2345. value = getattr(traitlets, name)
  2346. if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType):
  2347. if name not in traitlets.__all__:
  2348. raise ValueError(f"{name} not in __all__")
  2349. for name in traitlets.__all__:
  2350. if name not in names:
  2351. raise ValueError(f"{name} should be removed from __all__")