mpl_axes.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import matplotlib.axes as maxes
  2. from matplotlib.artist import Artist
  3. from matplotlib.axis import XAxis, YAxis
  4. class SimpleChainedObjects:
  5. def __init__(self, objects):
  6. self._objects = objects
  7. def __getattr__(self, k):
  8. _a = SimpleChainedObjects([getattr(a, k) for a in self._objects])
  9. return _a
  10. def __call__(self, *args, **kwargs):
  11. for m in self._objects:
  12. m(*args, **kwargs)
  13. class Axes(maxes.Axes):
  14. class AxisDict(dict):
  15. def __init__(self, axes):
  16. self.axes = axes
  17. super().__init__()
  18. def __getitem__(self, k):
  19. if isinstance(k, tuple):
  20. r = SimpleChainedObjects(
  21. # super() within a list comprehension needs explicit args.
  22. [super(Axes.AxisDict, self).__getitem__(k1) for k1 in k])
  23. return r
  24. elif isinstance(k, slice):
  25. if k.start is None and k.stop is None and k.step is None:
  26. return SimpleChainedObjects(list(self.values()))
  27. else:
  28. raise ValueError("Unsupported slice")
  29. else:
  30. return dict.__getitem__(self, k)
  31. def __call__(self, *v, **kwargs):
  32. return maxes.Axes.axis(self.axes, *v, **kwargs)
  33. @property
  34. def axis(self):
  35. return self._axislines
  36. def clear(self):
  37. # docstring inherited
  38. super().clear()
  39. # Init axis artists.
  40. self._axislines = self.AxisDict(self)
  41. self._axislines.update(
  42. bottom=SimpleAxisArtist(self.xaxis, 1, self.spines["bottom"]),
  43. top=SimpleAxisArtist(self.xaxis, 2, self.spines["top"]),
  44. left=SimpleAxisArtist(self.yaxis, 1, self.spines["left"]),
  45. right=SimpleAxisArtist(self.yaxis, 2, self.spines["right"]))
  46. class SimpleAxisArtist(Artist):
  47. def __init__(self, axis, axisnum, spine):
  48. self._axis = axis
  49. self._axisnum = axisnum
  50. self.line = spine
  51. if isinstance(axis, XAxis):
  52. self._axis_direction = ["bottom", "top"][axisnum-1]
  53. elif isinstance(axis, YAxis):
  54. self._axis_direction = ["left", "right"][axisnum-1]
  55. else:
  56. raise ValueError(
  57. f"axis must be instance of XAxis or YAxis, but got {axis}")
  58. super().__init__()
  59. @property
  60. def major_ticks(self):
  61. tickline = "tick%dline" % self._axisnum
  62. return SimpleChainedObjects([getattr(tick, tickline)
  63. for tick in self._axis.get_major_ticks()])
  64. @property
  65. def major_ticklabels(self):
  66. label = "label%d" % self._axisnum
  67. return SimpleChainedObjects([getattr(tick, label)
  68. for tick in self._axis.get_major_ticks()])
  69. @property
  70. def label(self):
  71. return self._axis.label
  72. def set_visible(self, b):
  73. self.toggle(all=b)
  74. self.line.set_visible(b)
  75. self._axis.set_visible(True)
  76. super().set_visible(b)
  77. def set_label(self, txt):
  78. self._axis.set_label_text(txt)
  79. def toggle(self, all=None, ticks=None, ticklabels=None, label=None):
  80. if all:
  81. _ticks, _ticklabels, _label = True, True, True
  82. elif all is not None:
  83. _ticks, _ticklabels, _label = False, False, False
  84. else:
  85. _ticks, _ticklabels, _label = None, None, None
  86. if ticks is not None:
  87. _ticks = ticks
  88. if ticklabels is not None:
  89. _ticklabels = ticklabels
  90. if label is not None:
  91. _label = label
  92. if _ticks is not None:
  93. tickparam = {f"tick{self._axisnum}On": _ticks}
  94. self._axis.set_tick_params(**tickparam)
  95. if _ticklabels is not None:
  96. tickparam = {f"label{self._axisnum}On": _ticklabels}
  97. self._axis.set_tick_params(**tickparam)
  98. if _label is not None:
  99. pos = self._axis.get_label_position()
  100. if (pos == self._axis_direction) and not _label:
  101. self._axis.label.set_visible(False)
  102. elif _label:
  103. self._axis.label.set_visible(True)
  104. self._axis.set_label_position(self._axis_direction)