test_traitlets.py 79 KB


  1. """Tests for traitlets.traitlets."""
  2. # Copyright (c) IPython Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. #
  5. # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
  6. # also under the terms of the Modified BSD License.
  7. from __future__ import annotations
  8. import pickle
  9. import re
  10. import typing as t
  11. from unittest import TestCase
  12. import pytest
  13. from traitlets import (
  14. All,
  15. Any,
  16. BaseDescriptor,
  17. Bool,
  18. Bytes,
  19. Callable,
  20. CBytes,
  21. CFloat,
  22. CInt,
  23. CLong,
  24. Complex,
  25. CRegExp,
  26. CUnicode,
  27. Dict,
  28. DottedObjectName,
  29. Enum,
  30. Float,
  31. ForwardDeclaredInstance,
  32. ForwardDeclaredType,
  33. HasDescriptors,
  34. HasTraits,
  35. Instance,
  36. Int,
  37. Integer,
  38. List,
  39. Long,
  40. MetaHasTraits,
  41. ObjectName,
  42. Set,
  43. TCPAddress,
  44. This,
  45. TraitError,
  46. TraitType,
  47. Tuple,
  48. Type,
  49. Undefined,
  50. Unicode,
  51. Union,
  52. default,
  53. directional_link,
  54. link,
  55. observe,
  56. observe_compat,
  57. traitlets,
  58. validate,
  59. )
  60. from traitlets.utils import cast_unicode
  61. from ._warnings import expected_warnings
  62. def change_dict(*ordered_values):
  63. change_names = ("name", "old", "new", "owner", "type")
  64. return dict(zip(change_names, ordered_values))
  65. # -----------------------------------------------------------------------------
  66. # Helper classes for testing
  67. # -----------------------------------------------------------------------------
  68. class HasTraitsStub(HasTraits):
  69. def notify_change(self, change):
  70. self._notify_name = change["name"]
  71. self._notify_old = change["old"]
  72. self._notify_new = change["new"]
  73. self._notify_type = change["type"]
  74. class CrossValidationStub(HasTraits):
  75. _cross_validation_lock = False
  76. # -----------------------------------------------------------------------------
  77. # Test classes
  78. # -----------------------------------------------------------------------------
  79. class TestTraitType(TestCase):
  80. def test_get_undefined(self):
  81. class A(HasTraits):
  82. a = TraitType
  83. a = A()
  84. assert a.a is Undefined # type:ignore
  85. def test_set(self):
  86. class A(HasTraitsStub):
  87. a = TraitType
  88. a = A()
  89. a.a = 10 # type:ignore
  90. self.assertEqual(a.a, 10)
  91. self.assertEqual(a._notify_name, "a")
  92. self.assertEqual(a._notify_old, Undefined)
  93. self.assertEqual(a._notify_new, 10)
  94. def test_validate(self):
  95. class MyTT(TraitType[int, int]):
  96. def validate(self, inst, value):
  97. return -1
  98. class A(HasTraitsStub):
  99. tt = MyTT
  100. a = A()
  101. a.tt = 10 # type:ignore
  102. self.assertEqual(a.tt, -1)
  103. a = A(tt=11)
  104. self.assertEqual(a.tt, -1)
  105. def test_default_validate(self):
  106. class MyIntTT(TraitType[int, int]):
  107. def validate(self, obj, value):
  108. if isinstance(value, int):
  109. return value
  110. self.error(obj, value)
  111. class A(HasTraits):
  112. tt = MyIntTT(10)
  113. a = A()
  114. self.assertEqual(a.tt, 10)
  115. # Defaults are validated when the HasTraits is instantiated
  116. class B(HasTraits):
  117. tt = MyIntTT("bad default")
  118. self.assertRaises(TraitError, getattr, B(), "tt")
  119. def test_info(self):
  120. class A(HasTraits):
  121. tt = TraitType
  122. a = A()
  123. self.assertEqual(A.tt.info(), "any value") # type:ignore
  124. def test_error(self):
  125. class A(HasTraits):
  126. tt = TraitType[int, int]()
  127. a = A()
  128. self.assertRaises(TraitError, A.tt.error, a, 10)
  129. def test_deprecated_dynamic_initializer(self):
  130. class A(HasTraits):
  131. x = Int(10)
  132. def _x_default(self):
  133. return 11
  134. class B(A):
  135. x = Int(20)
  136. class C(A):
  137. def _x_default(self):
  138. return 21
  139. a = A()
  140. self.assertEqual(a._trait_values, {})
  141. self.assertEqual(a.x, 11)
  142. self.assertEqual(a._trait_values, {"x": 11})
  143. b = B()
  144. self.assertEqual(b.x, 20)
  145. self.assertEqual(b._trait_values, {"x": 20})
  146. c = C()
  147. self.assertEqual(c._trait_values, {})
  148. self.assertEqual(c.x, 21)
  149. self.assertEqual(c._trait_values, {"x": 21})
  150. # Ensure that the base class remains unmolested when the _default
  151. # initializer gets overridden in a subclass.
  152. a = A()
  153. c = C()
  154. self.assertEqual(a._trait_values, {})
  155. self.assertEqual(a.x, 11)
  156. self.assertEqual(a._trait_values, {"x": 11})
  157. def test_deprecated_method_warnings(self):
  158. with expected_warnings([]):
  159. class ShouldntWarn(HasTraits):
  160. x = Integer()
  161. @default("x")
  162. def _x_default(self):
  163. return 10
  164. @validate("x")
  165. def _x_validate(self, proposal):
  166. return proposal.value
  167. @observe("x")
  168. def _x_changed(self, change):
  169. pass
  170. obj = ShouldntWarn()
  171. obj.x = 5
  172. assert obj.x == 5
  173. with expected_warnings(["@validate", "@observe"]) as w:
  174. class ShouldWarn(HasTraits):
  175. x = Integer()
  176. def _x_default(self):
  177. return 10
  178. def _x_validate(self, value, _):
  179. return value
  180. def _x_changed(self):
  181. pass
  182. obj = ShouldWarn() # type:ignore
  183. obj.x = 5
  184. assert obj.x == 5
  185. def test_dynamic_initializer(self):
  186. class A(HasTraits):
  187. x = Int(10)
  188. @default("x")
  189. def _default_x(self):
  190. return 11
  191. class B(A):
  192. x = Int(20)
  193. class C(A):
  194. @default("x")
  195. def _default_x(self):
  196. return 21
  197. a = A()
  198. self.assertEqual(a._trait_values, {})
  199. self.assertEqual(a.x, 11)
  200. self.assertEqual(a._trait_values, {"x": 11})
  201. b = B()
  202. self.assertEqual(b.x, 20)
  203. self.assertEqual(b._trait_values, {"x": 20})
  204. c = C()
  205. self.assertEqual(c._trait_values, {})
  206. self.assertEqual(c.x, 21)
  207. self.assertEqual(c._trait_values, {"x": 21})
  208. # Ensure that the base class remains unmolested when the _default
  209. # initializer gets overridden in a subclass.
  210. a = A()
  211. c = C()
  212. self.assertEqual(a._trait_values, {})
  213. self.assertEqual(a.x, 11)
  214. self.assertEqual(a._trait_values, {"x": 11})
  215. def test_tag_metadata(self):
  216. class MyIntTT(TraitType[int, int]):
  217. metadata = {"a": 1, "b": 2}
  218. a = MyIntTT(10).tag(b=3, c=4)
  219. self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4})
  220. def test_metadata_localized_instance(self):
  221. class MyIntTT(TraitType[int, int]):
  222. metadata = {"a": 1, "b": 2}
  223. a = MyIntTT(10)
  224. b = MyIntTT(10)
  225. a.metadata["c"] = 3
  226. # make sure that changing a's metadata didn't change b's metadata
  227. self.assertNotIn("c", b.metadata)
  228. def test_union_metadata(self):
  229. class Foo(HasTraits):
  230. bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a")
  231. foo = Foo()
  232. # At this point, no value has been set for bar, so value-specific
  233. # is not set.
  234. self.assertEqual(foo.trait_metadata("bar", "ta"), None)
  235. self.assertEqual(foo.trait_metadata("bar", "ti"), "a")
  236. foo.bar = {}
  237. self.assertEqual(foo.trait_metadata("bar", "ta"), 2)
  238. self.assertEqual(foo.trait_metadata("bar", "ti"), "b")
  239. foo.bar = 1
  240. self.assertEqual(foo.trait_metadata("bar", "ta"), 1)
  241. self.assertEqual(foo.trait_metadata("bar", "ti"), "a")
  242. def test_union_default_value(self):
  243. class Foo(HasTraits):
  244. bar = Union([Dict(), Int()], default_value=1)
  245. foo = Foo()
  246. self.assertEqual(foo.bar, 1)
  247. def test_union_validation_priority(self):
  248. class Foo(HasTraits):
  249. bar = Union([CInt(), Unicode()])
  250. foo = Foo()
  251. foo.bar = "1"
  252. # validation in order of the TraitTypes given
  253. self.assertEqual(foo.bar, 1)
  254. def test_union_trait_default_value(self):
  255. class Foo(HasTraits):
  256. bar = Union([Dict(), Int()])
  257. self.assertEqual(Foo().bar, {})
  258. def test_deprecated_metadata_access(self):
  259. class MyIntTT(TraitType[int, int]):
  260. metadata = {"a": 1, "b": 2}
  261. a = MyIntTT(10)
  262. with expected_warnings(["use the instance .metadata dictionary directly"] * 2):
  263. a.set_metadata("key", "value")
  264. v = a.get_metadata("key")
  265. self.assertEqual(v, "value")
  266. with expected_warnings(["use the instance .help string directly"] * 2):
  267. a.set_metadata("help", "some help")
  268. v = a.get_metadata("help")
  269. self.assertEqual(v, "some help")
  270. def test_trait_types_deprecated(self):
  271. with expected_warnings(["Traits should be given as instances"]):
  272. class C(HasTraits):
  273. t = Int
  274. def test_trait_types_list_deprecated(self):
  275. with expected_warnings(["Traits should be given as instances"]):
  276. class C(HasTraits):
  277. t = List(Int)
  278. def test_trait_types_tuple_deprecated(self):
  279. with expected_warnings(["Traits should be given as instances"]):
  280. class C(HasTraits):
  281. t = Tuple(Int)
  282. def test_trait_types_dict_deprecated(self):
  283. with expected_warnings(["Traits should be given as instances"]):
  284. class C(HasTraits):
  285. t = Dict(Int)
  286. class TestHasDescriptorsMeta(TestCase):
  287. def test_metaclass(self):
  288. self.assertEqual(type(HasTraits), MetaHasTraits)
  289. class A(HasTraits):
  290. a = Int()
  291. a = A()
  292. self.assertEqual(type(a.__class__), MetaHasTraits)
  293. self.assertEqual(a.a, 0)
  294. a.a = 10
  295. self.assertEqual(a.a, 10)
  296. class B(HasTraits):
  297. b = Int()
  298. b = B()
  299. self.assertEqual(b.b, 0)
  300. b.b = 10
  301. self.assertEqual(b.b, 10)
  302. class C(HasTraits):
  303. c = Int(30)
  304. c = C()
  305. self.assertEqual(c.c, 30)
  306. c.c = 10
  307. self.assertEqual(c.c, 10)
  308. def test_this_class(self):
  309. class A(HasTraits):
  310. t = This["A"]()
  311. tt = This["A"]()
  312. class B(A):
  313. tt = This["A"]()
  314. ttt = This["A"]()
  315. self.assertEqual(A.t.this_class, A)
  316. self.assertEqual(B.t.this_class, A)
  317. self.assertEqual(B.tt.this_class, B)
  318. self.assertEqual(B.ttt.this_class, B)
  319. class TestHasDescriptors(TestCase):
  320. def test_setup_instance(self):
  321. class FooDescriptor(BaseDescriptor):
  322. def instance_init(self, inst):
  323. foo = inst.foo # instance should have the attr
  324. class HasFooDescriptors(HasDescriptors):
  325. fd = FooDescriptor()
  326. def setup_instance(self, *args, **kwargs):
  327. self.foo = kwargs.get("foo", None)
  328. super().setup_instance(*args, **kwargs)
  329. hfd = HasFooDescriptors(foo="bar")
  330. class TestHasTraitsNotify(TestCase):
  331. def setUp(self):
  332. self._notify1 = []
  333. self._notify2 = []
  334. def notify1(self, name, old, new):
  335. self._notify1.append((name, old, new))
  336. def notify2(self, name, old, new):
  337. self._notify2.append((name, old, new))
  338. def test_notify_all(self):
  339. class A(HasTraits):
  340. a = Int()
  341. b = Float()
  342. a = A()
  343. a.on_trait_change(self.notify1)
  344. a.a = 0
  345. self.assertEqual(len(self._notify1), 0)
  346. a.b = 0.0
  347. self.assertEqual(len(self._notify1), 0)
  348. a.a = 10
  349. self.assertTrue(("a", 0, 10) in self._notify1)
  350. a.b = 10.0
  351. self.assertTrue(("b", 0.0, 10.0) in self._notify1)
  352. self.assertRaises(TraitError, setattr, a, "a", "bad string")
  353. self.assertRaises(TraitError, setattr, a, "b", "bad string")
  354. self._notify1 = []
  355. a.on_trait_change(self.notify1, remove=True)
  356. a.a = 20
  357. a.b = 20.0
  358. self.assertEqual(len(self._notify1), 0)
  359. def test_notify_one(self):
  360. class A(HasTraits):
  361. a = Int()
  362. b = Float()
  363. a = A()
  364. a.on_trait_change(self.notify1, "a")
  365. a.a = 0
  366. self.assertEqual(len(self._notify1), 0)
  367. a.a = 10
  368. self.assertTrue(("a", 0, 10) in self._notify1)
  369. self.assertRaises(TraitError, setattr, a, "a", "bad string")
  370. def test_subclass(self):
  371. class A(HasTraits):
  372. a = Int()
  373. class B(A):
  374. b = Float()
  375. b = B()
  376. self.assertEqual(b.a, 0)
  377. self.assertEqual(b.b, 0.0)
  378. b.a = 100
  379. b.b = 100.0
  380. self.assertEqual(b.a, 100)
  381. self.assertEqual(b.b, 100.0)
  382. def test_notify_subclass(self):
  383. class A(HasTraits):
  384. a = Int()
  385. class B(A):
  386. b = Float()
  387. b = B()
  388. b.on_trait_change(self.notify1, "a")
  389. b.on_trait_change(self.notify2, "b")
  390. b.a = 0
  391. b.b = 0.0
  392. self.assertEqual(len(self._notify1), 0)
  393. self.assertEqual(len(self._notify2), 0)
  394. b.a = 10
  395. b.b = 10.0
  396. self.assertTrue(("a", 0, 10) in self._notify1)
  397. self.assertTrue(("b", 0.0, 10.0) in self._notify2)
  398. def test_static_notify(self):
  399. class A(HasTraits):
  400. a = Int()
  401. _notify1 = []
  402. def _a_changed(self, name, old, new):
  403. self._notify1.append((name, old, new))
  404. a = A()
  405. a.a = 0
  406. # This is broken!!!
  407. self.assertEqual(len(a._notify1), 0)
  408. a.a = 10
  409. self.assertTrue(("a", 0, 10) in a._notify1)
  410. class B(A):
  411. b = Float()
  412. _notify2 = []
  413. def _b_changed(self, name, old, new):
  414. self._notify2.append((name, old, new))
  415. b = B()
  416. b.a = 10
  417. b.b = 10.0
  418. self.assertTrue(("a", 0, 10) in b._notify1)
  419. self.assertTrue(("b", 0.0, 10.0) in b._notify2)
  420. def test_notify_args(self):
  421. def callback0():
  422. self.cb = ()
  423. def callback1(name):
  424. self.cb = (name,) # type:ignore
  425. def callback2(name, new):
  426. self.cb = (name, new) # type:ignore
  427. def callback3(name, old, new):
  428. self.cb = (name, old, new) # type:ignore
  429. def callback4(name, old, new, obj):
  430. self.cb = (name, old, new, obj) # type:ignore
  431. class A(HasTraits):
  432. a = Int()
  433. a = A()
  434. a.on_trait_change(callback0, "a")
  435. a.a = 10
  436. self.assertEqual(self.cb, ())
  437. a.on_trait_change(callback0, "a", remove=True)
  438. a.on_trait_change(callback1, "a")
  439. a.a = 100
  440. self.assertEqual(self.cb, ("a",))
  441. a.on_trait_change(callback1, "a", remove=True)
  442. a.on_trait_change(callback2, "a")
  443. a.a = 1000
  444. self.assertEqual(self.cb, ("a", 1000))
  445. a.on_trait_change(callback2, "a", remove=True)
  446. a.on_trait_change(callback3, "a")
  447. a.a = 10000
  448. self.assertEqual(self.cb, ("a", 1000, 10000))
  449. a.on_trait_change(callback3, "a", remove=True)
  450. a.on_trait_change(callback4, "a")
  451. a.a = 100000
  452. self.assertEqual(self.cb, ("a", 10000, 100000, a))
  453. self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1)
  454. a.on_trait_change(callback4, "a", remove=True)
  455. self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0)
  456. def test_notify_only_once(self):
  457. class A(HasTraits):
  458. listen_to = ["a"]
  459. a = Int(0)
  460. b = 0
  461. def __init__(self, **kwargs):
  462. super().__init__(**kwargs)
  463. self.on_trait_change(self.listener1, ["a"])
  464. def listener1(self, name, old, new):
  465. self.b += 1
  466. class B(A):
  467. c = 0
  468. d = 0
  469. def __init__(self, **kwargs):
  470. super().__init__(**kwargs)
  471. self.on_trait_change(self.listener2)
  472. def listener2(self, name, old, new):
  473. self.c += 1
  474. def _a_changed(self, name, old, new):
  475. self.d += 1
  476. b = B()
  477. b.a += 1
  478. self.assertEqual(b.b, b.c)
  479. self.assertEqual(b.b, b.d)
  480. b.a += 1
  481. self.assertEqual(b.b, b.c)
  482. self.assertEqual(b.b, b.d)
  483. class TestObserveDecorator(TestCase):
  484. def setUp(self):
  485. self._notify1 = []
  486. self._notify2 = []
  487. def notify1(self, change):
  488. self._notify1.append(change)
  489. def notify2(self, change):
  490. self._notify2.append(change)
  491. def test_notify_all(self):
  492. class A(HasTraits):
  493. a = Int()
  494. b = Float()
  495. a = A()
  496. a.observe(self.notify1)
  497. a.a = 0
  498. self.assertEqual(len(self._notify1), 0)
  499. a.b = 0.0
  500. self.assertEqual(len(self._notify1), 0)
  501. a.a = 10
  502. change = change_dict("a", 0, 10, a, "change")
  503. self.assertTrue(change in self._notify1)
  504. a.b = 10.0
  505. change = change_dict("b", 0.0, 10.0, a, "change")
  506. self.assertTrue(change in self._notify1)
  507. self.assertRaises(TraitError, setattr, a, "a", "bad string")
  508. self.assertRaises(TraitError, setattr, a, "b", "bad string")
  509. self._notify1 = []
  510. a.unobserve(self.notify1)
  511. a.a = 20
  512. a.b = 20.0
  513. self.assertEqual(len(self._notify1), 0)
  514. def test_notify_one(self):
  515. class A(HasTraits):
  516. a = Int()
  517. b = Float()
  518. a = A()
  519. a.observe(self.notify1, "a")
  520. a.a = 0
  521. self.assertEqual(len(self._notify1), 0)
  522. a.a = 10
  523. change = change_dict("a", 0, 10, a, "change")
  524. self.assertTrue(change in self._notify1)
  525. self.assertRaises(TraitError, setattr, a, "a", "bad string")
  526. def test_subclass(self):
  527. class A(HasTraits):
  528. a = Int()
  529. class B(A):
  530. b = Float()
  531. b = B()
  532. self.assertEqual(b.a, 0)
  533. self.assertEqual(b.b, 0.0)
  534. b.a = 100
  535. b.b = 100.0
  536. self.assertEqual(b.a, 100)
  537. self.assertEqual(b.b, 100.0)
  538. def test_notify_subclass(self):
  539. class A(HasTraits):
  540. a = Int()
  541. class B(A):
  542. b = Float()
  543. b = B()
  544. b.observe(self.notify1, "a")
  545. b.observe(self.notify2, "b")
  546. b.a = 0
  547. b.b = 0.0
  548. self.assertEqual(len(self._notify1), 0)
  549. self.assertEqual(len(self._notify2), 0)
  550. b.a = 10
  551. b.b = 10.0
  552. change = change_dict("a", 0, 10, b, "change")
  553. self.assertTrue(change in self._notify1)
  554. change = change_dict("b", 0.0, 10.0, b, "change")
  555. self.assertTrue(change in self._notify2)
  556. def test_static_notify(self):
  557. class A(HasTraits):
  558. a = Int()
  559. b = Int()
  560. _notify1 = []
  561. _notify_any = []
  562. @observe("a")
  563. def _a_changed(self, change):
  564. self._notify1.append(change)
  565. @observe(All)
  566. def _any_changed(self, change):
  567. self._notify_any.append(change)
  568. a = A()
  569. a.a = 0
  570. self.assertEqual(len(a._notify1), 0)
  571. a.a = 10
  572. change = change_dict("a", 0, 10, a, "change")
  573. self.assertTrue(change in a._notify1)
  574. a.b = 1
  575. self.assertEqual(len(a._notify_any), 2)
  576. change = change_dict("b", 0, 1, a, "change")
  577. self.assertTrue(change in a._notify_any)
  578. class B(A):
  579. b = Float() # type:ignore
  580. _notify2 = []
  581. @observe("b")
  582. def _b_changed(self, change):
  583. self._notify2.append(change)
  584. b = B()
  585. b.a = 10
  586. b.b = 10.0 # type:ignore
  587. change = change_dict("a", 0, 10, b, "change")
  588. self.assertTrue(change in b._notify1)
  589. change = change_dict("b", 0.0, 10.0, b, "change")
  590. self.assertTrue(change in b._notify2)
  591. def test_notify_args(self):
  592. def callback0():
  593. self.cb = ()
  594. def callback1(change):
  595. self.cb = change
  596. class A(HasTraits):
  597. a = Int()
  598. a = A()
  599. a.on_trait_change(callback0, "a")
  600. a.a = 10
  601. self.assertEqual(self.cb, ())
  602. a.unobserve(callback0, "a")
  603. a.observe(callback1, "a")
  604. a.a = 100
  605. change = change_dict("a", 10, 100, a, "change")
  606. self.assertEqual(self.cb, change)
  607. self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1)
  608. a.unobserve(callback1, "a")
  609. self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0)
  610. def test_notify_only_once(self):
  611. class A(HasTraits):
  612. listen_to = ["a"]
  613. a = Int(0)
  614. b = 0
  615. def __init__(self, **kwargs):
  616. super().__init__(**kwargs)
  617. self.observe(self.listener1, ["a"])
  618. def listener1(self, change):
  619. self.b += 1
  620. class B(A):
  621. c = 0
  622. d = 0
  623. def __init__(self, **kwargs):
  624. super().__init__(**kwargs)
  625. self.observe(self.listener2)
  626. def listener2(self, change):
  627. self.c += 1
  628. @observe("a")
  629. def _a_changed(self, change):
  630. self.d += 1
  631. b = B()
  632. b.a += 1
  633. self.assertEqual(b.b, b.c)
  634. self.assertEqual(b.b, b.d)
  635. b.a += 1
  636. self.assertEqual(b.b, b.c)
  637. self.assertEqual(b.b, b.d)
  638. class TestHasTraits(TestCase):
  639. def test_trait_names(self):
  640. class A(HasTraits):
  641. i = Int()
  642. f = Float()
  643. a = A()
  644. self.assertEqual(sorted(a.trait_names()), ["f", "i"])
  645. self.assertEqual(sorted(A.class_trait_names()), ["f", "i"])
  646. self.assertTrue(a.has_trait("f"))
  647. self.assertFalse(a.has_trait("g"))
  648. def test_trait_has_value(self):
  649. class A(HasTraits):
  650. i = Int()
  651. f = Float()
  652. a = A()
  653. self.assertFalse(a.trait_has_value("f"))
  654. self.assertFalse(a.trait_has_value("g"))
  655. a.i = 1
  656. a.f
  657. self.assertTrue(a.trait_has_value("i"))
  658. self.assertTrue(a.trait_has_value("f"))
  659. def test_trait_metadata_deprecated(self):
  660. with expected_warnings([r"metadata should be set using the \.tag\(\) method"]):
  661. class A(HasTraits):
  662. i = Int(config_key="MY_VALUE")
  663. a = A()
  664. self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE")
  665. def test_trait_metadata(self):
  666. class A(HasTraits):
  667. i = Int().tag(config_key="MY_VALUE")
  668. a = A()
  669. self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE")
  670. def test_trait_metadata_default(self):
  671. class A(HasTraits):
  672. i = Int()
  673. a = A()
  674. self.assertEqual(a.trait_metadata("i", "config_key"), None)
  675. self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default")
  676. def test_traits(self):
  677. class A(HasTraits):
  678. i = Int()
  679. f = Float()
  680. a = A()
  681. self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
  682. self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
  683. def test_traits_metadata(self):
  684. class A(HasTraits):
  685. i = Int().tag(config_key="VALUE1", other_thing="VALUE2")
  686. f = Float().tag(config_key="VALUE3", other_thing="VALUE2")
  687. j = Int(0)
  688. a = A()
  689. self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
  690. traits = a.traits(config_key="VALUE1", other_thing="VALUE2")
  691. self.assertEqual(traits, dict(i=A.i))
  692. # This passes, but it shouldn't because I am replicating a bug in
  693. # traits.
  694. traits = a.traits(config_key=lambda v: True)
  695. self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
  696. def test_traits_metadata_deprecated(self):
  697. with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2):
  698. class A(HasTraits):
  699. i = Int(config_key="VALUE1", other_thing="VALUE2")
  700. f = Float(config_key="VALUE3", other_thing="VALUE2")
  701. j = Int(0)
  702. a = A()
  703. self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
  704. traits = a.traits(config_key="VALUE1", other_thing="VALUE2")
  705. self.assertEqual(traits, dict(i=A.i))
  706. # This passes, but it shouldn't because I am replicating a bug in
  707. # traits.
  708. traits = a.traits(config_key=lambda v: True)
  709. self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
  710. def test_init(self):
  711. class A(HasTraits):
  712. i = Int()
  713. x = Float()
  714. a = A(i=1, x=10.0)
  715. self.assertEqual(a.i, 1)
  716. self.assertEqual(a.x, 10.0)
  717. def test_positional_args(self):
  718. class A(HasTraits):
  719. i = Int(0)
  720. def __init__(self, i):
  721. super().__init__()
  722. self.i = i
  723. a = A(5)
  724. self.assertEqual(a.i, 5)
  725. # should raise TypeError if no positional arg given
  726. self.assertRaises(TypeError, A)
  727. # -----------------------------------------------------------------------------
  728. # Tests for specific trait types
  729. # -----------------------------------------------------------------------------
  730. class TestType(TestCase):
  731. def test_default(self):
  732. class B:
  733. pass
  734. class A(HasTraits):
  735. klass = Type(allow_none=True)
  736. a = A()
  737. self.assertEqual(a.klass, object)
  738. a.klass = B
  739. self.assertEqual(a.klass, B)
  740. self.assertRaises(TraitError, setattr, a, "klass", 10)
  741. def test_default_options(self):
  742. class B:
  743. pass
  744. class C(B):
  745. pass
  746. class A(HasTraits):
  747. # Different possible combinations of options for default_value
  748. # and klass. default_value=None is only valid with allow_none=True.
  749. k1 = Type()
  750. k2 = Type(None, allow_none=True)
  751. k3 = Type(B)
  752. k4 = Type(klass=B)
  753. k5 = Type(default_value=None, klass=B, allow_none=True)
  754. k6 = Type(default_value=C, klass=B)
  755. self.assertIs(A.k1.default_value, object)
  756. self.assertIs(A.k1.klass, object)
  757. self.assertIs(A.k2.default_value, None)
  758. self.assertIs(A.k2.klass, object)
  759. self.assertIs(A.k3.default_value, B)
  760. self.assertIs(A.k3.klass, B)
  761. self.assertIs(A.k4.default_value, B)
  762. self.assertIs(A.k4.klass, B)
  763. self.assertIs(A.k5.default_value, None)
  764. self.assertIs(A.k5.klass, B)
  765. self.assertIs(A.k6.default_value, C)
  766. self.assertIs(A.k6.klass, B)
  767. a = A()
  768. self.assertIs(a.k1, object)
  769. self.assertIs(a.k2, None)
  770. self.assertIs(a.k3, B)
  771. self.assertIs(a.k4, B)
  772. self.assertIs(a.k5, None)
  773. self.assertIs(a.k6, C)
  774. def test_value(self):
  775. class B:
  776. pass
  777. class C:
  778. pass
  779. class A(HasTraits):
  780. klass = Type(B)
  781. a = A()
  782. self.assertEqual(a.klass, B)
  783. self.assertRaises(TraitError, setattr, a, "klass", C)
  784. self.assertRaises(TraitError, setattr, a, "klass", object)
  785. a.klass = B
  786. def test_allow_none(self):
  787. class B:
  788. pass
  789. class C(B):
  790. pass
  791. class A(HasTraits):
  792. klass = Type(B)
  793. a = A()
  794. self.assertEqual(a.klass, B)
  795. self.assertRaises(TraitError, setattr, a, "klass", None)
  796. a.klass = C
  797. self.assertEqual(a.klass, C)
  798. def test_validate_klass(self):
  799. class A(HasTraits):
  800. klass = Type("no strings allowed")
  801. self.assertRaises(ImportError, A)
  802. class A(HasTraits): # type:ignore
  803. klass = Type("rub.adub.Duck")
  804. self.assertRaises(ImportError, A)
  805. def test_validate_default(self):
  806. class B:
  807. pass
  808. class A(HasTraits):
  809. klass = Type("bad default", B)
  810. self.assertRaises(ImportError, A)
  811. class C(HasTraits):
  812. klass = Type(None, B)
  813. self.assertRaises(TraitError, getattr, C(), "klass")
  814. def test_str_klass(self):
  815. class A(HasTraits):
  816. klass = Type("traitlets.config.Config")
  817. from traitlets.config import Config
  818. a = A()
  819. a.klass = Config
  820. self.assertEqual(a.klass, Config)
  821. self.assertRaises(TraitError, setattr, a, "klass", 10)
  822. def test_set_str_klass(self):
  823. class A(HasTraits):
  824. klass = Type()
  825. a = A(klass="traitlets.config.Config")
  826. from traitlets.config import Config
  827. self.assertEqual(a.klass, Config)
  828. class TestInstance(TestCase):
  829. def test_basic(self):
  830. class Foo:
  831. pass
  832. class Bar(Foo):
  833. pass
  834. class Bah:
  835. pass
  836. class A(HasTraits):
  837. inst = Instance(Foo, allow_none=True)
  838. a = A()
  839. self.assertTrue(a.inst is None)
  840. a.inst = Foo()
  841. self.assertTrue(isinstance(a.inst, Foo))
  842. a.inst = Bar()
  843. self.assertTrue(isinstance(a.inst, Foo))
  844. self.assertRaises(TraitError, setattr, a, "inst", Foo)
  845. self.assertRaises(TraitError, setattr, a, "inst", Bar)
  846. self.assertRaises(TraitError, setattr, a, "inst", Bah())
  847. def test_default_klass(self):
  848. class Foo:
  849. pass
  850. class Bar(Foo):
  851. pass
  852. class Bah:
  853. pass
  854. class FooInstance(Instance[Foo]):
  855. klass = Foo
  856. class A(HasTraits):
  857. inst = FooInstance(allow_none=True)
  858. a = A()
  859. self.assertTrue(a.inst is None)
  860. a.inst = Foo()
  861. self.assertTrue(isinstance(a.inst, Foo))
  862. a.inst = Bar()
  863. self.assertTrue(isinstance(a.inst, Foo))
  864. self.assertRaises(TraitError, setattr, a, "inst", Foo)
  865. self.assertRaises(TraitError, setattr, a, "inst", Bar)
  866. self.assertRaises(TraitError, setattr, a, "inst", Bah())
  867. def test_unique_default_value(self):
  868. class Foo:
  869. pass
  870. class A(HasTraits):
  871. inst = Instance(Foo, (), {})
  872. a = A()
  873. b = A()
  874. self.assertTrue(a.inst is not b.inst)
  875. def test_args_kw(self):
  876. class Foo:
  877. def __init__(self, c):
  878. self.c = c
  879. class Bar:
  880. pass
  881. class Bah:
  882. def __init__(self, c, d):
  883. self.c = c
  884. self.d = d
  885. class A(HasTraits):
  886. inst = Instance(Foo, (10,))
  887. a = A()
  888. self.assertEqual(a.inst.c, 10)
  889. class B(HasTraits):
  890. inst = Instance(Bah, args=(10,), kw=dict(d=20))
  891. b = B()
  892. self.assertEqual(b.inst.c, 10)
  893. self.assertEqual(b.inst.d, 20)
  894. class C(HasTraits):
  895. inst = Instance(Foo, allow_none=True)
  896. c = C()
  897. self.assertTrue(c.inst is None)
  898. def test_bad_default(self):
  899. class Foo:
  900. pass
  901. class A(HasTraits):
  902. inst = Instance(Foo)
  903. a = A()
  904. with self.assertRaises(TraitError):
  905. a.inst
  906. def test_instance(self):
  907. class Foo:
  908. pass
  909. def inner():
  910. class A(HasTraits):
  911. inst = Instance(Foo()) # type:ignore
  912. self.assertRaises(TraitError, inner)
  913. class TestThis(TestCase):
  914. def test_this_class(self):
  915. class Foo(HasTraits):
  916. this = This["Foo"]()
  917. f = Foo()
  918. self.assertEqual(f.this, None)
  919. g = Foo()
  920. f.this = g
  921. self.assertEqual(f.this, g)
  922. self.assertRaises(TraitError, setattr, f, "this", 10)
  923. def test_this_inst(self):
  924. class Foo(HasTraits):
  925. this = This["Foo"]()
  926. f = Foo()
  927. f.this = Foo()
  928. self.assertTrue(isinstance(f.this, Foo))
  929. def test_subclass(self):
  930. class Foo(HasTraits):
  931. t = This["Foo"]()
  932. class Bar(Foo):
  933. pass
  934. f = Foo()
  935. b = Bar()
  936. f.t = b
  937. b.t = f
  938. self.assertEqual(f.t, b)
  939. self.assertEqual(b.t, f)
  940. def test_subclass_override(self):
  941. class Foo(HasTraits):
  942. t = This["Foo"]()
  943. class Bar(Foo):
  944. t = This()
  945. f = Foo()
  946. b = Bar()
  947. f.t = b
  948. self.assertEqual(f.t, b)
  949. self.assertRaises(TraitError, setattr, b, "t", f)
  950. def test_this_in_container(self):
  951. class Tree(HasTraits):
  952. value = Unicode()
  953. leaves = List(This())
  954. tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")])
  955. with self.assertRaises(TraitError):
  956. tree.leaves = [1, 2]
  957. class TraitTestBase(TestCase):
  958. """A best testing class for basic trait types."""
  959. def assign(self, value):
  960. self.obj.value = value # type:ignore
  961. def coerce(self, value):
  962. return value
  963. def test_good_values(self):
  964. if hasattr(self, "_good_values"):
  965. for value in self._good_values:
  966. self.assign(value)
  967. self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore
  968. def test_bad_values(self):
  969. if hasattr(self, "_bad_values"):
  970. for value in self._bad_values:
  971. try:
  972. self.assertRaises(TraitError, self.assign, value)
  973. except AssertionError:
  974. assert False, value # noqa: PT015
  975. def test_default_value(self):
  976. if hasattr(self, "_default_value"):
  977. self.assertEqual(self._default_value, self.obj.value) # type:ignore
  978. def test_allow_none(self):
  979. if (
  980. hasattr(self, "_bad_values")
  981. and hasattr(self, "_good_values")
  982. and None in self._bad_values
  983. ):
  984. trait = self.obj.traits()["value"] # type:ignore
  985. try:
  986. trait.allow_none = True
  987. self._bad_values.remove(None)
  988. # skip coerce. Allow None casts None to None.
  989. self.assign(None)
  990. self.assertEqual(self.obj.value, None) # type:ignore
  991. self.test_good_values()
  992. self.test_bad_values()
  993. finally:
  994. # tear down
  995. trait.allow_none = False
  996. self._bad_values.append(None)
  997. def tearDown(self):
  998. # restore default value after tests, if set
  999. if hasattr(self, "_default_value"):
  1000. self.obj.value = self._default_value # type:ignore
  1001. class AnyTrait(HasTraits):
  1002. value = Any()
  1003. class AnyTraitTest(TraitTestBase):
  1004. obj = AnyTrait()
  1005. _default_value = None
  1006. _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j]
  1007. _bad_values: t.Any = []
  1008. class UnionTrait(HasTraits):
  1009. value = Union([Type(), Bool()])
  1010. class UnionTraitTest(TraitTestBase):
  1011. obj = UnionTrait(value="traitlets.config.Config")
  1012. _good_values = [int, float, True]
  1013. _bad_values = [[], (0,), 1j]
  1014. class CallableTrait(HasTraits):
  1015. value = Callable()
  1016. class CallableTraitTest(TraitTestBase):
  1017. obj = CallableTrait(value=lambda x: type(x))
  1018. _good_values = [int, sorted, lambda x: print(x)]
  1019. _bad_values = [[], 1, ""]
  1020. class OrTrait(HasTraits):
  1021. value = Bool() | Unicode()
  1022. class OrTraitTest(TraitTestBase):
  1023. obj = OrTrait()
  1024. _good_values = [True, False, "ten"]
  1025. _bad_values = [[], (0,), 1j]
  1026. class IntTrait(HasTraits):
  1027. value = Int(99, min=-100)
  1028. class TestInt(TraitTestBase):
  1029. obj = IntTrait()
  1030. _default_value = 99
  1031. _good_values = [10, -10]
  1032. _bad_values = [
  1033. "ten",
  1034. [10],
  1035. {"ten": 10},
  1036. (10,),
  1037. None,
  1038. 1j,
  1039. 10.1,
  1040. -10.1,
  1041. "10L",
  1042. "-10L",
  1043. "10.1",
  1044. "-10.1",
  1045. "10",
  1046. "-10",
  1047. -200,
  1048. ]
  1049. class CIntTrait(HasTraits):
  1050. value = CInt("5")
  1051. class TestCInt(TraitTestBase):
  1052. obj = CIntTrait()
  1053. _default_value = 5
  1054. _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1]
  1055. _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"]
  1056. def coerce(self, n):
  1057. return int(n)
  1058. class MinBoundCIntTrait(HasTraits):
  1059. value = CInt("5", min=3)
  1060. class TestMinBoundCInt(TestCInt):
  1061. obj = MinBoundCIntTrait() # type:ignore
  1062. _default_value = 5
  1063. _good_values = [3, 3.0, "3"]
  1064. _bad_values = [2.6, 2, -3, -3.0]
  1065. class LongTrait(HasTraits):
  1066. value = Long(99)
  1067. class TestLong(TraitTestBase):
  1068. obj = LongTrait()
  1069. _default_value = 99
  1070. _good_values = [10, -10]
  1071. _bad_values = [
  1072. "ten",
  1073. [10],
  1074. {"ten": 10},
  1075. (10,),
  1076. None,
  1077. 1j,
  1078. 10.1,
  1079. -10.1,
  1080. "10",
  1081. "-10",
  1082. "10L",
  1083. "-10L",
  1084. "10.1",
  1085. "-10.1",
  1086. ]
  1087. class MinBoundLongTrait(HasTraits):
  1088. value = Long(99, min=5)
  1089. class TestMinBoundLong(TraitTestBase):
  1090. obj = MinBoundLongTrait()
  1091. _default_value = 99
  1092. _good_values = [5, 10]
  1093. _bad_values = [4, -10]
  1094. class MaxBoundLongTrait(HasTraits):
  1095. value = Long(5, max=10)
  1096. class TestMaxBoundLong(TraitTestBase):
  1097. obj = MaxBoundLongTrait()
  1098. _default_value = 5
  1099. _good_values = [10, -2]
  1100. _bad_values = [11, 20]
  1101. class CLongTrait(HasTraits):
  1102. value = CLong("5")
  1103. class TestCLong(TraitTestBase):
  1104. obj = CLongTrait()
  1105. _default_value = 5
  1106. _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1]
  1107. _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"]
  1108. def coerce(self, n):
  1109. return int(n)
  1110. class MaxBoundCLongTrait(HasTraits):
  1111. value = CLong("5", max=10)
  1112. class TestMaxBoundCLong(TestCLong):
  1113. obj = MaxBoundCLongTrait() # type:ignore
  1114. _default_value = 5
  1115. _good_values = [10, "10", 10.3]
  1116. _bad_values = [11.0, "11"]
  1117. class IntegerTrait(HasTraits):
  1118. value = Integer(1)
  1119. class TestInteger(TestLong):
  1120. obj = IntegerTrait() # type:ignore
  1121. _default_value = 1
  1122. def coerce(self, n):
  1123. return int(n)
  1124. class MinBoundIntegerTrait(HasTraits):
  1125. value = Integer(5, min=3)
  1126. class TestMinBoundInteger(TraitTestBase):
  1127. obj = MinBoundIntegerTrait()
  1128. _default_value = 5
  1129. _good_values = 3, 20
  1130. _bad_values = [2, -10]
  1131. class MaxBoundIntegerTrait(HasTraits):
  1132. value = Integer(1, max=3)
  1133. class TestMaxBoundInteger(TraitTestBase):
  1134. obj = MaxBoundIntegerTrait()
  1135. _default_value = 1
  1136. _good_values = 3, -2
  1137. _bad_values = [4, 10]
  1138. class FloatTrait(HasTraits):
  1139. value = Float(99.0, max=200.0)
  1140. class TestFloat(TraitTestBase):
  1141. obj = FloatTrait()
  1142. _default_value = 99.0
  1143. _good_values = [10, -10, 10.1, -10.1]
  1144. _bad_values = [
  1145. "ten",
  1146. [10],
  1147. {"ten": 10},
  1148. (10,),
  1149. None,
  1150. 1j,
  1151. "10",
  1152. "-10",
  1153. "10L",
  1154. "-10L",
  1155. "10.1",
  1156. "-10.1",
  1157. 201.0,
  1158. ]
  1159. class CFloatTrait(HasTraits):
  1160. value = CFloat("99.0", max=200.0)
  1161. class TestCFloat(TraitTestBase):
  1162. obj = CFloatTrait()
  1163. _default_value = 99.0
  1164. _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"]
  1165. _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"]
  1166. def coerce(self, v):
  1167. return float(v)
  1168. class ComplexTrait(HasTraits):
  1169. value = Complex(99.0 - 99.0j)
  1170. class TestComplex(TraitTestBase):
  1171. obj = ComplexTrait()
  1172. _default_value = 99.0 - 99.0j
  1173. _good_values = [
  1174. 10,
  1175. -10,
  1176. 10.1,
  1177. -10.1,
  1178. 10j,
  1179. 10 + 10j,
  1180. 10 - 10j,
  1181. 10.1j,
  1182. 10.1 + 10.1j,
  1183. 10.1 - 10.1j,
  1184. ]
  1185. _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None]
  1186. class BytesTrait(HasTraits):
  1187. value = Bytes(b"string")
  1188. class TestBytes(TraitTestBase):
  1189. obj = BytesTrait()
  1190. _default_value = b"string"
  1191. _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"]
  1192. _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"]
  1193. class UnicodeTrait(HasTraits):
  1194. value = Unicode("unicode")
  1195. class TestUnicode(TraitTestBase):
  1196. obj = UnicodeTrait()
  1197. _default_value = "unicode"
  1198. _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"]
  1199. _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None]
  1200. def coerce(self, v):
  1201. return cast_unicode(v)
  1202. class ObjectNameTrait(HasTraits):
  1203. value = ObjectName("abc")
  1204. class TestObjectName(TraitTestBase):
  1205. obj = ObjectNameTrait()
  1206. _default_value = "abc"
  1207. _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"]
  1208. _bad_values = [
  1209. 1,
  1210. "",
  1211. "€",
  1212. "9g",
  1213. "!",
  1214. "#abc",
  1215. "aj@",
  1216. "a.b",
  1217. "a()",
  1218. "a[0]",
  1219. None,
  1220. object(),
  1221. object,
  1222. ]
  1223. _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131).
  1224. class DottedObjectNameTrait(HasTraits):
  1225. value = DottedObjectName("a.b")
  1226. class TestDottedObjectName(TraitTestBase):
  1227. obj = DottedObjectNameTrait()
  1228. _default_value = "a.b"
  1229. _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"]
  1230. _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
  1231. _good_values.append("t.þ")
  1232. class TCPAddressTrait(HasTraits):
  1233. value = TCPAddress()
  1234. class TestTCPAddress(TraitTestBase):
  1235. obj = TCPAddressTrait()
  1236. _default_value = ("127.0.0.1", 0)
  1237. _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)]
  1238. _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None]
  1239. class ListTrait(HasTraits):
  1240. value = List(Int())
  1241. class TestList(TraitTestBase):
  1242. obj = ListTrait()
  1243. _default_value: t.List[t.Any] = []
  1244. _good_values = [[], [1], list(range(10)), (1, 2)]
  1245. _bad_values = [10, [1, "a"], "a"]
  1246. def coerce(self, value):
  1247. if value is not None:
  1248. value = list(value)
  1249. return value
  1250. class SetTrait(HasTraits):
  1251. value = Set(Unicode())
  1252. class TestSet(TraitTestBase):
  1253. obj = SetTrait()
  1254. _default_value: t.Set[str] = set()
  1255. _good_values = [{"a", "b"}, "ab"]
  1256. _bad_values = [1]
  1257. def coerce(self, value):
  1258. if isinstance(value, str):
  1259. # compatibility handling: convert string to set containing string
  1260. value = {value}
  1261. return value
  1262. class Foo:
  1263. pass
  1264. class NoneInstanceListTrait(HasTraits):
  1265. value = List(Instance(Foo))
  1266. class TestNoneInstanceList(TraitTestBase):
  1267. obj = NoneInstanceListTrait()
  1268. _default_value: t.List[t.Any] = []
  1269. _good_values = [[Foo(), Foo()], []]
  1270. _bad_values = [[None], [Foo(), None]]
  1271. class InstanceListTrait(HasTraits):
  1272. value = List(Instance(__name__ + ".Foo"))
  1273. class TestInstanceList(TraitTestBase):
  1274. obj = InstanceListTrait()
  1275. def test_klass(self):
  1276. """Test that the instance klass is properly assigned."""
  1277. self.assertIs(self.obj.traits()["value"]._trait.klass, Foo)
  1278. _default_value: t.List[t.Any] = []
  1279. _good_values = [[Foo(), Foo()], []]
  1280. _bad_values = [
  1281. [
  1282. "1",
  1283. 2,
  1284. ],
  1285. "1",
  1286. [Foo],
  1287. None,
  1288. ]
  1289. class UnionListTrait(HasTraits):
  1290. value = List(Int() | Bool())
  1291. class TestUnionListTrait(TraitTestBase):
  1292. obj = UnionListTrait()
  1293. _default_value: t.List[t.Any] = []
  1294. _good_values = [[True, 1], [False, True]]
  1295. _bad_values = [[1, "True"], False]
  1296. class LenListTrait(HasTraits):
  1297. value = List(Int(), [0], minlen=1, maxlen=2)
  1298. class TestLenList(TraitTestBase):
  1299. obj = LenListTrait()
  1300. _default_value = [0]
  1301. _good_values = [[1], [1, 2], (1, 2)]
  1302. _bad_values = [10, [1, "a"], "a", [], list(range(3))]
  1303. def coerce(self, value):
  1304. if value is not None:
  1305. value = list(value)
  1306. return value
  1307. class TupleTrait(HasTraits):
  1308. value = Tuple(Int(allow_none=True), default_value=(1,))
  1309. class TestTupleTrait(TraitTestBase):
  1310. obj = TupleTrait()
  1311. _default_value = (1,)
  1312. _good_values = [(1,), (0,), [1]]
  1313. _bad_values = [10, (1, 2), ("a"), (), None]
  1314. def coerce(self, value):
  1315. if value is not None:
  1316. value = tuple(value)
  1317. return value
  1318. def test_invalid_args(self):
  1319. self.assertRaises(TypeError, Tuple, 5)
  1320. self.assertRaises(TypeError, Tuple, default_value="hello")
  1321. t = Tuple(Int(), CBytes(), default_value=(1, 5))
  1322. class LooseTupleTrait(HasTraits):
  1323. value = Tuple((1, 2, 3))
  1324. class TestLooseTupleTrait(TraitTestBase):
  1325. obj = LooseTupleTrait()
  1326. _default_value = (1, 2, 3)
  1327. _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()]
  1328. _bad_values = [10, "hello", {}, None]
  1329. def coerce(self, value):
  1330. if value is not None:
  1331. value = tuple(value)
  1332. return value
  1333. def test_invalid_args(self):
  1334. self.assertRaises(TypeError, Tuple, 5)
  1335. self.assertRaises(TypeError, Tuple, default_value="hello")
  1336. t = Tuple(Int(), CBytes(), default_value=(1, 5))
  1337. class MultiTupleTrait(HasTraits):
  1338. value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"])
  1339. class TestMultiTuple(TraitTestBase):
  1340. obj = MultiTupleTrait()
  1341. _default_value = (99, b"bottles")
  1342. _good_values = [(1, b"a"), (2, b"b")]
  1343. _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a"))
  1344. @pytest.mark.parametrize(
  1345. "Trait",
  1346. ( # noqa: PT007
  1347. List,
  1348. Tuple,
  1349. Set,
  1350. Dict,
  1351. Integer,
  1352. Unicode,
  1353. ),
  1354. )
  1355. def test_allow_none_default_value(Trait):
  1356. class C(HasTraits):
  1357. t = Trait(default_value=None, allow_none=True)
  1358. # test default value
  1359. c = C()
  1360. assert c.t is None
  1361. # and in constructor
  1362. c = C(t=None)
  1363. assert c.t is None
  1364. @pytest.mark.parametrize(
  1365. "Trait, default_value",
  1366. ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")), # noqa: PT007
  1367. )
  1368. def test_default_value(Trait, default_value):
  1369. class C(HasTraits):
  1370. t = Trait()
  1371. # test default value
  1372. c = C()
  1373. assert type(c.t) is type(default_value)
  1374. assert c.t == default_value
  1375. @pytest.mark.parametrize(
  1376. "Trait, default_value",
  1377. ((List, []), (Tuple, ()), (Set, set())), # noqa: PT007
  1378. )
  1379. def test_subclass_default_value(Trait, default_value):
  1380. """Test deprecated default_value=None behavior for Container subclass traits"""
  1381. class SubclassTrait(Trait): # type:ignore
  1382. def __init__(self, default_value=None):
  1383. super().__init__(default_value=default_value)
  1384. class C(HasTraits):
  1385. t = SubclassTrait()
  1386. # test default value
  1387. c = C()
  1388. assert type(c.t) is type(default_value)
  1389. assert c.t == default_value
  1390. class CRegExpTrait(HasTraits):
  1391. value = CRegExp(r"")
  1392. class TestCRegExp(TraitTestBase):
  1393. def coerce(self, value):
  1394. return re.compile(value)
  1395. obj = CRegExpTrait()
  1396. _default_value = re.compile(r"")
  1397. _good_values = [r"\d+", re.compile(r"\d+")]
  1398. _bad_values = ["(", None, ()]
  1399. class DictTrait(HasTraits):
  1400. value = Dict()
  1401. def test_dict_assignment():
  1402. d: t.Dict[str, int] = {}
  1403. c = DictTrait()
  1404. c.value = d
  1405. d["a"] = 5
  1406. assert d == c.value
  1407. assert c.value is d
  1408. class UniformlyValueValidatedDictTrait(HasTraits):
  1409. value = Dict(value_trait=Unicode(), default_value={"foo": "1"})
  1410. class TestInstanceUniformlyValueValidatedDict(TraitTestBase):
  1411. obj = UniformlyValueValidatedDictTrait()
  1412. _default_value = {"foo": "1"}
  1413. _good_values = [{"foo": "0", "bar": "1"}]
  1414. _bad_values = [{"foo": 0, "bar": "1"}]
  1415. class NonuniformlyValueValidatedDictTrait(HasTraits):
  1416. value = Dict(per_key_traits={"foo": Int()}, default_value={"foo": 1})
  1417. class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase):
  1418. obj = NonuniformlyValueValidatedDictTrait()
  1419. _default_value = {"foo": 1}
  1420. _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}]
  1421. _bad_values = [{"foo": "0", "bar": "1"}]
  1422. class KeyValidatedDictTrait(HasTraits):
  1423. value = Dict(key_trait=Unicode(), default_value={"foo": "1"})
  1424. class TestInstanceKeyValidatedDict(TraitTestBase):
  1425. obj = KeyValidatedDictTrait()
  1426. _default_value = {"foo": "1"}
  1427. _good_values = [{"foo": "0", "bar": "1"}]
  1428. _bad_values = [{"foo": "0", 0: "1"}]
  1429. class FullyValidatedDictTrait(HasTraits):
  1430. value = Dict(
  1431. value_trait=Unicode(),
  1432. key_trait=Unicode(),
  1433. per_key_traits={"foo": Int()},
  1434. default_value={"foo": 1},
  1435. )
  1436. class TestInstanceFullyValidatedDict(TraitTestBase):
  1437. obj = FullyValidatedDictTrait()
  1438. _default_value = {"foo": 1}
  1439. _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}]
  1440. _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}]
  1441. def test_dict_default_value():
  1442. """Check that the `{}` default value of the Dict traitlet constructor is
  1443. actually copied."""
  1444. class Foo(HasTraits):
  1445. d1 = Dict()
  1446. d2 = Dict()
  1447. foo = Foo()
  1448. assert foo.d1 == {}
  1449. assert foo.d2 == {}
  1450. assert foo.d1 is not foo.d2
  1451. class TestValidationHook(TestCase):
  1452. def test_parity_trait(self):
  1453. """Verify that the early validation hook is effective"""
  1454. class Parity(HasTraits):
  1455. value = Int(0)
  1456. parity = Enum(["odd", "even"], default_value="even")
  1457. @validate("value")
  1458. def _value_validate(self, proposal):
  1459. value = proposal["value"]
  1460. if self.parity == "even" and value % 2:
  1461. raise TraitError("Expected an even number")
  1462. if self.parity == "odd" and (value % 2 == 0):
  1463. raise TraitError("Expected an odd number")
  1464. return value
  1465. u = Parity()
  1466. u.parity = "odd"
  1467. u.value = 1 # OK
  1468. with self.assertRaises(TraitError):
  1469. u.value = 2 # Trait Error
  1470. u.parity = "even"
  1471. u.value = 2 # OK
  1472. def test_multiple_validate(self):
  1473. """Verify that we can register the same validator to multiple names"""
  1474. class OddEven(HasTraits):
  1475. odd = Int(1)
  1476. even = Int(0)
  1477. @validate("odd", "even")
  1478. def check_valid(self, proposal):
  1479. if proposal["trait"].name == "odd" and not proposal["value"] % 2:
  1480. raise TraitError("odd should be odd")
  1481. if proposal["trait"].name == "even" and proposal["value"] % 2:
  1482. raise TraitError("even should be even")
  1483. u = OddEven()
  1484. u.odd = 3 # OK
  1485. with self.assertRaises(TraitError):
  1486. u.odd = 2 # Trait Error
  1487. u.even = 2 # OK
  1488. with self.assertRaises(TraitError):
  1489. u.even = 3 # Trait Error
  1490. def test_validate_used(self):
  1491. """Verify that the validate value is being used"""
  1492. class FixedValue(HasTraits):
  1493. value = Int(0)
  1494. @validate("value")
  1495. def _value_validate(self, proposal):
  1496. return -1
  1497. u = FixedValue(value=2)
  1498. assert u.value == -1
  1499. u = FixedValue()
  1500. u.value = 3
  1501. assert u.value == -1
  1502. class TestLink(TestCase):
  1503. def test_connect_same(self):
  1504. """Verify two traitlets of the same type can be linked together using link."""
  1505. # Create two simple classes with Int traitlets.
  1506. class A(HasTraits):
  1507. value = Int()
  1508. a = A(value=9)
  1509. b = A(value=8)
  1510. # Connect the two classes.
  1511. c = link((a, "value"), (b, "value"))
  1512. # Make sure the values are the same at the point of linking.
  1513. self.assertEqual(a.value, b.value)
  1514. # Change one of the values to make sure they stay in sync.
  1515. a.value = 5
  1516. self.assertEqual(a.value, b.value)
  1517. b.value = 6
  1518. self.assertEqual(a.value, b.value)
  1519. def test_link_different(self):
  1520. """Verify two traitlets of different types can be linked together using link."""
  1521. # Create two simple classes with Int traitlets.
  1522. class A(HasTraits):
  1523. value = Int()
  1524. class B(HasTraits):
  1525. count = Int()
  1526. a = A(value=9)
  1527. b = B(count=8)
  1528. # Connect the two classes.
  1529. c = link((a, "value"), (b, "count"))
  1530. # Make sure the values are the same at the point of linking.
  1531. self.assertEqual(a.value, b.count)
  1532. # Change one of the values to make sure they stay in sync.
  1533. a.value = 5
  1534. self.assertEqual(a.value, b.count)
  1535. b.count = 4
  1536. self.assertEqual(a.value, b.count)
  1537. def test_unlink_link(self):
  1538. """Verify two linked traitlets can be unlinked and relinked."""
  1539. # Create two simple classes with Int traitlets.
  1540. class A(HasTraits):
  1541. value = Int()
  1542. a = A(value=9)
  1543. b = A(value=8)
  1544. # Connect the two classes.
  1545. c = link((a, "value"), (b, "value"))
  1546. a.value = 4
  1547. c.unlink()
  1548. # Change one of the values to make sure they don't stay in sync.
  1549. a.value = 5
  1550. self.assertNotEqual(a.value, b.value)
  1551. c.link()
  1552. self.assertEqual(a.value, b.value)
  1553. a.value += 1
  1554. self.assertEqual(a.value, b.value)
  1555. def test_callbacks(self):
  1556. """Verify two linked traitlets have their callbacks called once."""
  1557. # Create two simple classes with Int traitlets.
  1558. class A(HasTraits):
  1559. value = Int()
  1560. class B(HasTraits):
  1561. count = Int()
  1562. a = A(value=9)
  1563. b = B(count=8)
  1564. # Register callbacks that count.
  1565. callback_count = []
  1566. def a_callback(name, old, new):
  1567. callback_count.append("a")
  1568. a.on_trait_change(a_callback, "value")
  1569. def b_callback(name, old, new):
  1570. callback_count.append("b")
  1571. b.on_trait_change(b_callback, "count")
  1572. # Connect the two classes.
  1573. c = link((a, "value"), (b, "count"))
  1574. # Make sure b's count was set to a's value once.
  1575. self.assertEqual("".join(callback_count), "b")
  1576. del callback_count[:]
  1577. # Make sure a's value was set to b's count once.
  1578. b.count = 5
  1579. self.assertEqual("".join(callback_count), "ba")
  1580. del callback_count[:]
  1581. # Make sure b's count was set to a's value once.
  1582. a.value = 4
  1583. self.assertEqual("".join(callback_count), "ab")
  1584. del callback_count[:]
  1585. def test_tranform(self):
  1586. """Test transform link."""
  1587. # Create two simple classes with Int traitlets.
  1588. class A(HasTraits):
  1589. value = Int()
  1590. a = A(value=9)
  1591. b = A(value=8)
  1592. # Connect the two classes.
  1593. c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0)))
  1594. # Make sure the values are correct at the point of linking.
  1595. self.assertEqual(b.value, 2 * a.value)
  1596. # Change one the value of the source and check that it modifies the target.
  1597. a.value = 5
  1598. self.assertEqual(b.value, 10)
  1599. # Change one the value of the target and check that it modifies the
  1600. # source.
  1601. b.value = 6
  1602. self.assertEqual(a.value, 3)
  1603. def test_link_broken_at_source(self):
  1604. class MyClass(HasTraits):
  1605. i = Int()
  1606. j = Int()
  1607. @observe("j")
  1608. def another_update(self, change):
  1609. self.i = change.new * 2
  1610. mc = MyClass()
  1611. l = link((mc, "i"), (mc, "j")) # noqa: E741
  1612. self.assertRaises(TraitError, setattr, mc, "i", 2)
  1613. def test_link_broken_at_target(self):
  1614. class MyClass(HasTraits):
  1615. i = Int()
  1616. j = Int()
  1617. @observe("i")
  1618. def another_update(self, change):
  1619. self.j = change.new * 2
  1620. mc = MyClass()
  1621. l = link((mc, "i"), (mc, "j")) # noqa: E741
  1622. self.assertRaises(TraitError, setattr, mc, "j", 2)
  1623. class TestDirectionalLink(TestCase):
  1624. def test_connect_same(self):
  1625. """Verify two traitlets of the same type can be linked together using directional_link."""
  1626. # Create two simple classes with Int traitlets.
  1627. class A(HasTraits):
  1628. value = Int()
  1629. a = A(value=9)
  1630. b = A(value=8)
  1631. # Connect the two classes.
  1632. c = directional_link((a, "value"), (b, "value"))
  1633. # Make sure the values are the same at the point of linking.
  1634. self.assertEqual(a.value, b.value)
  1635. # Change one the value of the source and check that it synchronizes the target.
  1636. a.value = 5
  1637. self.assertEqual(b.value, 5)
  1638. # Change one the value of the target and check that it has no impact on the source
  1639. b.value = 6
  1640. self.assertEqual(a.value, 5)
  1641. def test_tranform(self):
  1642. """Test transform link."""
  1643. # Create two simple classes with Int traitlets.
  1644. class A(HasTraits):
  1645. value = Int()
  1646. a = A(value=9)
  1647. b = A(value=8)
  1648. # Connect the two classes.
  1649. c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x)
  1650. # Make sure the values are correct at the point of linking.
  1651. self.assertEqual(b.value, 2 * a.value)
  1652. # Change one the value of the source and check that it modifies the target.
  1653. a.value = 5
  1654. self.assertEqual(b.value, 10)
  1655. # Change one the value of the target and check that it has no impact on the source
  1656. b.value = 6
  1657. self.assertEqual(a.value, 5)
  1658. def test_link_different(self):
  1659. """Verify two traitlets of different types can be linked together using link."""
  1660. # Create two simple classes with Int traitlets.
  1661. class A(HasTraits):
  1662. value = Int()
  1663. class B(HasTraits):
  1664. count = Int()
  1665. a = A(value=9)
  1666. b = B(count=8)
  1667. # Connect the two classes.
  1668. c = directional_link((a, "value"), (b, "count"))
  1669. # Make sure the values are the same at the point of linking.
  1670. self.assertEqual(a.value, b.count)
  1671. # Change one the value of the source and check that it synchronizes the target.
  1672. a.value = 5
  1673. self.assertEqual(b.count, 5)
  1674. # Change one the value of the target and check that it has no impact on the source
  1675. b.value = 6 # type:ignore
  1676. self.assertEqual(a.value, 5)
  1677. def test_unlink_link(self):
  1678. """Verify two linked traitlets can be unlinked and relinked."""
  1679. # Create two simple classes with Int traitlets.
  1680. class A(HasTraits):
  1681. value = Int()
  1682. a = A(value=9)
  1683. b = A(value=8)
  1684. # Connect the two classes.
  1685. c = directional_link((a, "value"), (b, "value"))
  1686. a.value = 4
  1687. c.unlink()
  1688. # Change one of the values to make sure they don't stay in sync.
  1689. a.value = 5
  1690. self.assertNotEqual(a.value, b.value)
  1691. c.link()
  1692. self.assertEqual(a.value, b.value)
  1693. a.value += 1
  1694. self.assertEqual(a.value, b.value)
  1695. class Pickleable(HasTraits):
  1696. i = Int()
  1697. @observe("i")
  1698. def _i_changed(self, change):
  1699. pass
  1700. @validate("i")
  1701. def _i_validate(self, commit):
  1702. return commit["value"]
  1703. j = Int()
  1704. def __init__(self):
  1705. with self.hold_trait_notifications():
  1706. self.i = 1
  1707. self.on_trait_change(self._i_changed, "i")
  1708. def test_pickle_hastraits():
  1709. c = Pickleable()
  1710. for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
  1711. p = pickle.dumps(c, protocol)
  1712. c2 = pickle.loads(p)
  1713. assert c2.i == c.i
  1714. assert c2.j == c.j
  1715. c.i = 5
  1716. for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
  1717. p = pickle.dumps(c, protocol)
  1718. c2 = pickle.loads(p)
  1719. assert c2.i == c.i
  1720. assert c2.j == c.j
  1721. def test_hold_trait_notifications():
  1722. changes = []
  1723. class Test(HasTraits):
  1724. a = Integer(0)
  1725. b = Integer(0)
  1726. def _a_changed(self, name, old, new):
  1727. changes.append((old, new))
  1728. def _b_validate(self, value, trait):
  1729. if value != 0:
  1730. raise TraitError("Only 0 is a valid value")
  1731. return value
  1732. # Test context manager and nesting
  1733. t = Test()
  1734. with t.hold_trait_notifications():
  1735. with t.hold_trait_notifications():
  1736. t.a = 1
  1737. assert t.a == 1
  1738. assert changes == []
  1739. t.a = 2
  1740. assert t.a == 2
  1741. with t.hold_trait_notifications():
  1742. t.a = 3
  1743. assert t.a == 3
  1744. assert changes == []
  1745. t.a = 4
  1746. assert t.a == 4
  1747. assert changes == []
  1748. t.a = 4
  1749. assert t.a == 4
  1750. assert changes == []
  1751. assert changes == [(0, 4)]
  1752. # Test roll-back
  1753. try:
  1754. with t.hold_trait_notifications():
  1755. t.b = 1 # raises a Trait error
  1756. except Exception:
  1757. pass
  1758. assert t.b == 0
  1759. class RollBack(HasTraits):
  1760. bar = Int()
  1761. def _bar_validate(self, value, trait):
  1762. if value:
  1763. raise TraitError("foobar")
  1764. return value
  1765. class TestRollback(TestCase):
  1766. def test_roll_back(self):
  1767. def assign_rollback():
  1768. RollBack(bar=1)
  1769. self.assertRaises(TraitError, assign_rollback)
  1770. class CacheModification(HasTraits):
  1771. foo = Int()
  1772. bar = Int()
  1773. def _bar_validate(self, value, trait):
  1774. self.foo = value
  1775. return value
  1776. def _foo_validate(self, value, trait):
  1777. self.bar = value
  1778. return value
  1779. def test_cache_modification():
  1780. CacheModification(foo=1)
  1781. CacheModification(bar=1)
  1782. class OrderTraits(HasTraits):
  1783. notified = Dict()
  1784. a = Unicode()
  1785. b = Unicode()
  1786. c = Unicode()
  1787. d = Unicode()
  1788. e = Unicode()
  1789. f = Unicode()
  1790. g = Unicode()
  1791. h = Unicode()
  1792. i = Unicode()
  1793. j = Unicode()
  1794. k = Unicode()
  1795. l = Unicode() # noqa: E741
  1796. def _notify(self, name, old, new):
  1797. """check the value of all traits when each trait change is triggered
  1798. This verifies that the values are not sensitive
  1799. to dict ordering when loaded from kwargs
  1800. """
  1801. # check the value of the other traits
  1802. # when a given trait change notification fires
  1803. self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"}
  1804. def __init__(self, **kwargs):
  1805. self.on_trait_change(self._notify)
  1806. super().__init__(**kwargs)
  1807. def test_notification_order():
  1808. d = {c: c for c in "abcdefghijkl"}
  1809. obj = OrderTraits()
  1810. assert obj.notified == {}
  1811. obj = OrderTraits(**d)
  1812. notifications = {c: d for c in "abcdefghijkl"}
  1813. assert obj.notified == notifications
  1814. ###
  1815. # Traits for Forward Declaration Tests
  1816. ###
  1817. class ForwardDeclaredInstanceTrait(HasTraits):
  1818. value = ForwardDeclaredInstance["ForwardDeclaredBar"]("ForwardDeclaredBar", allow_none=True)
  1819. class ForwardDeclaredTypeTrait(HasTraits):
  1820. value = ForwardDeclaredType[t.Any, t.Any]("ForwardDeclaredBar", allow_none=True)
  1821. class ForwardDeclaredInstanceListTrait(HasTraits):
  1822. value = List(ForwardDeclaredInstance("ForwardDeclaredBar"))
  1823. class ForwardDeclaredTypeListTrait(HasTraits):
  1824. value = List(ForwardDeclaredType("ForwardDeclaredBar"))
  1825. ###
  1826. # End Traits for Forward Declaration Tests
  1827. ###
  1828. ###
  1829. # Classes for Forward Declaration Tests
  1830. ###
  1831. class ForwardDeclaredBar:
  1832. pass
  1833. class ForwardDeclaredBarSub(ForwardDeclaredBar):
  1834. pass
  1835. ###
  1836. # End Classes for Forward Declaration Tests
  1837. ###
  1838. ###
  1839. # Forward Declaration Tests
  1840. ###
  1841. class TestForwardDeclaredInstanceTrait(TraitTestBase):
  1842. obj = ForwardDeclaredInstanceTrait()
  1843. _default_value = None
  1844. _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
  1845. _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
  1846. class TestForwardDeclaredTypeTrait(TraitTestBase):
  1847. obj = ForwardDeclaredTypeTrait()
  1848. _default_value = None
  1849. _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
  1850. _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
  1851. class TestForwardDeclaredInstanceList(TraitTestBase):
  1852. obj = ForwardDeclaredInstanceListTrait()
  1853. def test_klass(self):
  1854. """Test that the instance klass is properly assigned."""
  1855. self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar)
  1856. _default_value: t.List[t.Any] = []
  1857. _good_values = [
  1858. [ForwardDeclaredBar(), ForwardDeclaredBarSub()],
  1859. [],
  1860. ]
  1861. _bad_values = [
  1862. ForwardDeclaredBar(),
  1863. [ForwardDeclaredBar(), 3, None],
  1864. "1",
  1865. # Note that this is the type, not an instance.
  1866. [ForwardDeclaredBar],
  1867. [None],
  1868. None,
  1869. ]
  1870. class TestForwardDeclaredTypeList(TraitTestBase):
  1871. obj = ForwardDeclaredTypeListTrait()
  1872. def test_klass(self):
  1873. """Test that the instance klass is properly assigned."""
  1874. self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar)
  1875. _default_value: t.List[t.Any] = []
  1876. _good_values = [
  1877. [ForwardDeclaredBar, ForwardDeclaredBarSub],
  1878. [],
  1879. ]
  1880. _bad_values = [
  1881. ForwardDeclaredBar,
  1882. [ForwardDeclaredBar, 3],
  1883. "1",
  1884. # Note that this is an instance, not the type.
  1885. [ForwardDeclaredBar()],
  1886. [None],
  1887. None,
  1888. ]
  1889. ###
  1890. # End Forward Declaration Tests
  1891. ###
  1892. class TestDynamicTraits(TestCase):
  1893. def setUp(self):
  1894. self._notify1 = []
  1895. def notify1(self, name, old, new):
  1896. self._notify1.append((name, old, new))
  1897. @t.no_type_check
  1898. def test_notify_all(self):
  1899. class A(HasTraits):
  1900. pass
  1901. a = A()
  1902. self.assertTrue(not hasattr(a, "x"))
  1903. self.assertTrue(not hasattr(a, "y"))
  1904. # Dynamically add trait x.
  1905. a.add_traits(x=Int())
  1906. self.assertTrue(hasattr(a, "x"))
  1907. self.assertTrue(isinstance(a, (A,)))
  1908. # Dynamically add trait y.
  1909. a.add_traits(y=Float())
  1910. self.assertTrue(hasattr(a, "y"))
  1911. self.assertTrue(isinstance(a, (A,)))
  1912. self.assertEqual(a.__class__.__name__, A.__name__)
  1913. # Create a new instance and verify that x and y
  1914. # aren't defined.
  1915. b = A()
  1916. self.assertTrue(not hasattr(b, "x"))
  1917. self.assertTrue(not hasattr(b, "y"))
  1918. # Verify that notification works like normal.
  1919. a.on_trait_change(self.notify1)
  1920. a.x = 0
  1921. self.assertEqual(len(self._notify1), 0)
  1922. a.y = 0.0
  1923. self.assertEqual(len(self._notify1), 0)
  1924. a.x = 10
  1925. self.assertTrue(("x", 0, 10) in self._notify1)
  1926. a.y = 10.0
  1927. self.assertTrue(("y", 0.0, 10.0) in self._notify1)
  1928. self.assertRaises(TraitError, setattr, a, "x", "bad string")
  1929. self.assertRaises(TraitError, setattr, a, "y", "bad string")
  1930. self._notify1 = []
  1931. a.on_trait_change(self.notify1, remove=True)
  1932. a.x = 20
  1933. a.y = 20.0
  1934. self.assertEqual(len(self._notify1), 0)
  1935. def test_enum_no_default():
  1936. class C(HasTraits):
  1937. t = Enum(["a", "b"])
  1938. c = C()
  1939. c.t = "a"
  1940. assert c.t == "a"
  1941. c = C()
  1942. with pytest.raises(TraitError):
  1943. t = c.t
  1944. c = C(t="b")
  1945. assert c.t == "b"
  1946. def test_default_value_repr():
  1947. class C(HasTraits):
  1948. t = Type("traitlets.HasTraits")
  1949. t2 = Type(HasTraits)
  1950. n = Integer(0)
  1951. lis = List()
  1952. d = Dict()
  1953. assert C.t.default_value_repr() == "'traitlets.HasTraits'"
  1954. assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'"
  1955. assert C.n.default_value_repr() == "0"
  1956. assert C.lis.default_value_repr() == "[]"
  1957. assert C.d.default_value_repr() == "{}"
  1958. class TransitionalClass(HasTraits):
  1959. d = Any()
  1960. @default("d")
  1961. def _d_default(self):
  1962. return TransitionalClass
  1963. parent_super = False
  1964. calls_super = Integer(0)
  1965. @default("calls_super")
  1966. def _calls_super_default(self):
  1967. return -1
  1968. @observe("calls_super")
  1969. @observe_compat
  1970. def _calls_super_changed(self, change):
  1971. self.parent_super = change
  1972. parent_override = False
  1973. overrides = Integer(0)
  1974. @observe("overrides")
  1975. @observe_compat
  1976. def _overrides_changed(self, change):
  1977. self.parent_override = change
  1978. class SubClass(TransitionalClass):
  1979. def _d_default(self):
  1980. return SubClass
  1981. subclass_super = False
  1982. def _calls_super_changed(self, name, old, new):
  1983. self.subclass_super = True
  1984. super()._calls_super_changed(name, old, new)
  1985. subclass_override = False
  1986. def _overrides_changed(self, name, old, new):
  1987. self.subclass_override = True
  1988. def test_subclass_compat():
  1989. obj = SubClass()
  1990. obj.calls_super = 5
  1991. assert obj.parent_super
  1992. assert obj.subclass_super
  1993. obj.overrides = 5
  1994. assert obj.subclass_override
  1995. assert not obj.parent_override
  1996. assert obj.d is SubClass
  1997. class DefinesHandler(HasTraits):
  1998. parent_called = False
  1999. trait = Integer()
  2000. @observe("trait")
  2001. def handler(self, change):
  2002. self.parent_called = True
  2003. class OverridesHandler(DefinesHandler):
  2004. child_called = False
  2005. @observe("trait")
  2006. def handler(self, change):
  2007. self.child_called = True
  2008. def test_subclass_override_observer():
  2009. obj = OverridesHandler()
  2010. obj.trait = 5
  2011. assert obj.child_called
  2012. assert not obj.parent_called
  2013. class DoesntRegisterHandler(DefinesHandler):
  2014. child_called = False
  2015. def handler(self, change):
  2016. self.child_called = True
  2017. def test_subclass_override_not_registered():
  2018. """Subclass that overrides observer and doesn't re-register unregisters both"""
  2019. obj = DoesntRegisterHandler()
  2020. obj.trait = 5
  2021. assert not obj.child_called
  2022. assert not obj.parent_called
  2023. class AddsHandler(DefinesHandler):
  2024. child_called = False
  2025. @observe("trait")
  2026. def child_handler(self, change):
  2027. self.child_called = True
  2028. def test_subclass_add_observer():
  2029. obj = AddsHandler()
  2030. obj.trait = 5
  2031. assert obj.child_called
  2032. assert obj.parent_called
  2033. def test_observe_iterables():
  2034. class C(HasTraits):
  2035. i = Integer()
  2036. s = Unicode()
  2037. c = C()
  2038. recorded = {}
  2039. def record(change):
  2040. recorded["change"] = change
  2041. # observe with names=set
  2042. c.observe(record, names={"i", "s"})
  2043. c.i = 5
  2044. assert recorded["change"].name == "i"
  2045. assert recorded["change"].new == 5
  2046. c.s = "hi"
  2047. assert recorded["change"].name == "s"
  2048. assert recorded["change"].new == "hi"
  2049. # observe with names=custom container with iter, contains
  2050. class MyContainer:
  2051. def __init__(self, container):
  2052. self.container = container
  2053. def __iter__(self):
  2054. return iter(self.container)
  2055. def __contains__(self, key):
  2056. return key in self.container
  2057. c.observe(record, names=MyContainer({"i", "s"}))
  2058. c.i = 10
  2059. assert recorded["change"].name == "i"
  2060. assert recorded["change"].new == 10
  2061. c.s = "ok"
  2062. assert recorded["change"].name == "s"
  2063. assert recorded["change"].new == "ok"
  2064. def test_super_args():
  2065. class SuperRecorder:
  2066. def __init__(self, *args, **kwargs):
  2067. self.super_args = args
  2068. self.super_kwargs = kwargs
  2069. class SuperHasTraits(HasTraits, SuperRecorder):
  2070. i = Integer()
  2071. obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x")
  2072. assert obj.i == 5
  2073. assert not hasattr(obj, "b")
  2074. assert not hasattr(obj, "c")
  2075. assert obj.super_args == ("a1", "a2")
  2076. assert obj.super_kwargs == {"b": 10, "c": "x"}
  2077. def test_super_bad_args():
  2078. class SuperHasTraits(HasTraits):
  2079. a = Integer()
  2080. w = ["Passing unrecognized arguments"]
  2081. with expected_warnings(w):
  2082. obj = SuperHasTraits(a=1, b=2)
  2083. assert obj.a == 1
  2084. assert not hasattr(obj, "b")
  2085. def test_default_mro():
  2086. """Verify that default values follow mro"""
  2087. class Base(HasTraits):
  2088. trait = Unicode("base")
  2089. attr = "base"
  2090. class A(Base):
  2091. pass
  2092. class B(Base):
  2093. trait = Unicode("B")
  2094. attr = "B"
  2095. class AB(A, B):
  2096. pass
  2097. class BA(B, A):
  2098. pass
  2099. assert A().trait == "base"
  2100. assert A().attr == "base"
  2101. assert BA().trait == "B"
  2102. assert BA().attr == "B"
  2103. assert AB().trait == "B"
  2104. assert AB().attr == "B"
  2105. def test_cls_self_argument():
  2106. class X(HasTraits):
  2107. def __init__(__self, cls, self):
  2108. pass
  2109. x = X(cls=None, self=None)
  2110. def test_override_default():
  2111. class C(HasTraits):
  2112. a = Unicode("hard default")
  2113. def _a_default(self):
  2114. return "default method"
  2115. C._a_default = lambda self: "overridden" # type:ignore
  2116. c = C()
  2117. assert c.a == "overridden"
  2118. def test_override_default_decorator():
  2119. class C(HasTraits):
  2120. a = Unicode("hard default")
  2121. @default("a")
  2122. def _a_default(self):
  2123. return "default method"
  2124. C._a_default = lambda self: "overridden" # type:ignore
  2125. c = C()
  2126. assert c.a == "overridden"
  2127. def test_override_default_instance():
  2128. class C(HasTraits):
  2129. a = Unicode("hard default")
  2130. @default("a")
  2131. def _a_default(self):
  2132. return "default method"
  2133. c = C()
  2134. c._a_default = lambda self: "overridden"
  2135. assert c.a == "overridden"
  2136. def test_copy_HasTraits():
  2137. from copy import copy
  2138. class C(HasTraits):
  2139. a = Int()
  2140. c = C(a=1)
  2141. assert c.a == 1
  2142. cc = copy(c)
  2143. cc.a = 2
  2144. assert cc.a == 2
  2145. assert c.a == 1
  2146. def _from_string_test(traittype, s, expected):
  2147. """Run a test of trait.from_string"""
  2148. if isinstance(traittype, TraitType):
  2149. trait = traittype
  2150. else:
  2151. trait = traittype(allow_none=True)
  2152. if isinstance(s, list):
  2153. cast = trait.from_string_list # type:ignore
  2154. else:
  2155. cast = trait.from_string
  2156. if type(expected) is type and issubclass(expected, Exception):
  2157. with pytest.raises(expected): # noqa: PT012
  2158. value = cast(s)
  2159. trait.validate(CrossValidationStub(), value) # type:ignore
  2160. else:
  2161. value = cast(s)
  2162. assert value == expected
  2163. @pytest.mark.parametrize(
  2164. "s, expected",
  2165. [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)],
  2166. )
  2167. def test_unicode_from_string(s, expected):
  2168. _from_string_test(Unicode, s, expected)
  2169. @pytest.mark.parametrize(
  2170. "s, expected",
  2171. [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)],
  2172. )
  2173. def test_cunicode_from_string(s, expected):
  2174. _from_string_test(CUnicode, s, expected)
  2175. @pytest.mark.parametrize(
  2176. "s, expected",
  2177. [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)],
  2178. )
  2179. def test_bytes_from_string(s, expected):
  2180. _from_string_test(Bytes, s, expected)
  2181. @pytest.mark.parametrize(
  2182. "s, expected",
  2183. [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)],
  2184. )
  2185. def test_cbytes_from_string(s, expected):
  2186. _from_string_test(CBytes, s, expected)
  2187. @pytest.mark.parametrize(
  2188. "s, expected",
  2189. [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)],
  2190. )
  2191. def test_int_from_string(s, expected):
  2192. _from_string_test(Integer, s, expected)
  2193. @pytest.mark.parametrize(
  2194. "s, expected",
  2195. [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)],
  2196. )
  2197. def test_float_from_string(s, expected):
  2198. _from_string_test(Float, s, expected)
  2199. @pytest.mark.parametrize(
  2200. "s, expected",
  2201. [
  2202. ("x", ValueError),
  2203. ("1", 1.0),
  2204. ("123.5", 123.5),
  2205. ("2.5", 2.5),
  2206. ("1+2j", 1 + 2j),
  2207. ("None", None),
  2208. ],
  2209. )
  2210. def test_complex_from_string(s, expected):
  2211. _from_string_test(Complex, s, expected)
  2212. @pytest.mark.parametrize(
  2213. "s, expected",
  2214. [
  2215. ("true", True),
  2216. ("TRUE", True),
  2217. ("1", True),
  2218. ("0", False),
  2219. ("False", False),
  2220. ("false", False),
  2221. ("1.0", ValueError),
  2222. ("None", None),
  2223. ],
  2224. )
  2225. def test_bool_from_string(s, expected):
  2226. _from_string_test(Bool, s, expected)
  2227. @pytest.mark.parametrize(
  2228. "s, expected",
  2229. [
  2230. ("{}", {}),
  2231. ("1", TraitError),
  2232. ("{1: 2}", {1: 2}),
  2233. ('{"key": "value"}', {"key": "value"}),
  2234. ("x", TraitError),
  2235. ("None", None),
  2236. ],
  2237. )
  2238. def test_dict_from_string(s, expected):
  2239. _from_string_test(Dict, s, expected)
  2240. @pytest.mark.parametrize(
  2241. "s, expected",
  2242. [
  2243. ("[]", []),
  2244. ('[1, 2, "x"]', [1, 2, "x"]),
  2245. (["1", "x"], ["1", "x"]),
  2246. (["None"], None),
  2247. ],
  2248. )
  2249. def test_list_from_string(s, expected):
  2250. _from_string_test(List, s, expected)
  2251. @pytest.mark.parametrize(
  2252. "s, expected, value_trait",
  2253. [
  2254. (["1", "2", "3"], [1, 2, 3], Integer()),
  2255. (["x"], ValueError, Integer()),
  2256. (["1", "x"], ["1", "x"], Unicode()),
  2257. (["None"], [None], Unicode(allow_none=True)),
  2258. (["None"], ["None"], Unicode(allow_none=False)),
  2259. ],
  2260. )
  2261. def test_list_items_from_string(s, expected, value_trait):
  2262. _from_string_test(List(value_trait), s, expected)
  2263. @pytest.mark.parametrize(
  2264. "s, expected",
  2265. [
  2266. ("[]", set()),
  2267. ('[1, 2, "x"]', {1, 2, "x"}),
  2268. ('{1, 2, "x"}', {1, 2, "x"}),
  2269. (["1", "x"], {"1", "x"}),
  2270. (["None"], None),
  2271. ],
  2272. )
  2273. def test_set_from_string(s, expected):
  2274. _from_string_test(Set, s, expected)
  2275. @pytest.mark.parametrize(
  2276. "s, expected, value_trait",
  2277. [
  2278. (["1", "2", "3"], {1, 2, 3}, Integer()),
  2279. (["x"], ValueError, Integer()),
  2280. (["1", "x"], {"1", "x"}, Unicode()),
  2281. (["None"], {None}, Unicode(allow_none=True)),
  2282. ],
  2283. )
  2284. def test_set_items_from_string(s, expected, value_trait):
  2285. _from_string_test(Set(value_trait), s, expected)
  2286. @pytest.mark.parametrize(
  2287. "s, expected",
  2288. [
  2289. ("[]", ()),
  2290. ("()", ()),
  2291. ('[1, 2, "x"]', (1, 2, "x")),
  2292. ('(1, 2, "x")', (1, 2, "x")),
  2293. (["1", "x"], ("1", "x")),
  2294. (["None"], None),
  2295. ],
  2296. )
  2297. def test_tuple_from_string(s, expected):
  2298. _from_string_test(Tuple, s, expected)
  2299. @pytest.mark.parametrize(
  2300. "s, expected, value_traits",
  2301. [
  2302. (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]),
  2303. (["x"], ValueError, [Integer()]),
  2304. (["1", "x"], ("1", "x"), [Unicode()]),
  2305. (["None"], ("None",), [Unicode(allow_none=False)]),
  2306. (["None"], (None,), [Unicode(allow_none=True)]),
  2307. ],
  2308. )
  2309. def test_tuple_items_from_string(s, expected, value_traits):
  2310. _from_string_test(Tuple(*value_traits), s, expected)
  2311. @pytest.mark.parametrize(
  2312. "s, expected",
  2313. [
  2314. ("x", "x"),
  2315. ("mod.submod", "mod.submod"),
  2316. ("not an identifier", TraitError),
  2317. ("1", "1"),
  2318. ("None", None),
  2319. ],
  2320. )
  2321. def test_object_from_string(s, expected):
  2322. _from_string_test(DottedObjectName, s, expected)
  2323. @pytest.mark.parametrize(
  2324. "s, expected",
  2325. [
  2326. ("127.0.0.1:8000", ("127.0.0.1", 8000)),
  2327. ("host.tld:80", ("host.tld", 80)),
  2328. ("host:notaport", ValueError),
  2329. ("127.0.0.1", ValueError),
  2330. ("None", None),
  2331. ],
  2332. )
  2333. def test_tcp_from_string(s, expected):
  2334. _from_string_test(TCPAddress, s, expected)
  2335. @pytest.mark.parametrize(
  2336. "s, expected",
  2337. [("[]", []), ("{}", "{}")],
  2338. )
  2339. def test_union_of_list_and_unicode_from_string(s, expected):
  2340. _from_string_test(Union([List(), Unicode()]), s, expected)
  2341. @pytest.mark.parametrize(
  2342. "s, expected",
  2343. [("1", 1), ("1.5", 1.5)],
  2344. )
  2345. def test_union_of_int_and_float_from_string(s, expected):
  2346. _from_string_test(Union([Int(), Float()]), s, expected)
  2347. @pytest.mark.parametrize(
  2348. "s, expected, allow_none",
  2349. [("[]", [], False), ("{}", {}, False), ("None", TraitError, False), ("None", None, True)],
  2350. )
  2351. def test_union_of_list_and_dict_from_string(s, expected, allow_none):
  2352. _from_string_test(Union([List(), Dict()], allow_none=allow_none), s, expected)
  2353. def test_all_attribute():
  2354. """Verify all trait types are added to `traitlets.__all__`"""
  2355. names = dir(traitlets)
  2356. for name in names:
  2357. value = getattr(traitlets, name)
  2358. if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType):
  2359. if name not in traitlets.__all__:
  2360. raise ValueError(f"{name} not in __all__")
  2361. for name in traitlets.__all__:
  2362. if name not in names:
  2363. raise ValueError(f"{name} should be removed from __all__")