mpl_axes.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from __future__ import (absolute_import, division, print_function,
  2. unicode_literals)
  3. import six
  4. import matplotlib.axes as maxes
  5. from matplotlib.artist import Artist
  6. from matplotlib.axis import XAxis, YAxis
  7. class SimpleChainedObjects(object):
  8. def __init__(self, objects):
  9. self._objects = objects
  10. def __getattr__(self, k):
  11. _a = SimpleChainedObjects([getattr(a, k) for a in self._objects])
  12. return _a
  13. def __call__(self, *kl, **kwargs):
  14. for m in self._objects:
  15. m(*kl, **kwargs)
  16. class Axes(maxes.Axes):
  17. class AxisDict(dict):
  18. def __init__(self, axes):
  19. self.axes = axes
  20. super(Axes.AxisDict, self).__init__()
  21. def __getitem__(self, k):
  22. if isinstance(k, tuple):
  23. r = SimpleChainedObjects(
  24. [super(Axes.AxisDict, self).__getitem__(k1) for k1 in k])
  25. return r
  26. elif isinstance(k, slice):
  27. if k.start is None and k.stop is None and k.step is None:
  28. r = SimpleChainedObjects(list(six.itervalues(self)))
  29. return r
  30. else:
  31. raise ValueError("Unsupported slice")
  32. else:
  33. return dict.__getitem__(self, k)
  34. def __call__(self, *v, **kwargs):
  35. return maxes.Axes.axis(self.axes, *v, **kwargs)
  36. def __init__(self, *kl, **kw):
  37. super(Axes, self).__init__(*kl, **kw)
  38. def _init_axis_artists(self, axes=None):
  39. if axes is None:
  40. axes = self
  41. self._axislines = self.AxisDict(self)
  42. self._axislines["bottom"] = SimpleAxisArtist(self.xaxis, 1, self.spines["bottom"])
  43. self._axislines["top"] = SimpleAxisArtist(self.xaxis, 2, self.spines["top"])
  44. self._axislines["left"] = SimpleAxisArtist(self.yaxis, 1, self.spines["left"])
  45. self._axislines["right"] = SimpleAxisArtist(self.yaxis, 2, self.spines["right"])
  46. def _get_axislines(self):
  47. return self._axislines
  48. axis = property(_get_axislines)
  49. def cla(self):
  50. super(Axes, self).cla()
  51. self._init_axis_artists()
  52. class SimpleAxisArtist(Artist):
  53. def __init__(self, axis, axisnum, spine):
  54. self._axis = axis
  55. self._axisnum = axisnum
  56. self.line = spine
  57. if isinstance(axis, XAxis):
  58. self._axis_direction = ["bottom", "top"][axisnum-1]
  59. elif isinstance(axis, YAxis):
  60. self._axis_direction = ["left", "right"][axisnum-1]
  61. else:
  62. raise ValueError("axis must be instance of XAxis or YAxis : %s is provided" % (axis,))
  63. Artist.__init__(self)
  64. def _get_major_ticks(self):
  65. tickline = "tick%dline" % self._axisnum
  66. return SimpleChainedObjects([getattr(tick, tickline)
  67. for tick in self._axis.get_major_ticks()])
  68. def _get_major_ticklabels(self):
  69. label = "label%d" % self._axisnum
  70. return SimpleChainedObjects([getattr(tick, label)
  71. for tick in self._axis.get_major_ticks()])
  72. def _get_label(self):
  73. return self._axis.label
  74. major_ticks = property(_get_major_ticks)
  75. major_ticklabels = property(_get_major_ticklabels)
  76. label = property(_get_label)
  77. def set_visible(self, b):
  78. self.toggle(all=b)
  79. self.line.set_visible(b)
  80. self._axis.set_visible(True)
  81. Artist.set_visible(self, b)
  82. def set_label(self, txt):
  83. self._axis.set_label_text(txt)
  84. def toggle(self, all=None, ticks=None, ticklabels=None, label=None):
  85. if all:
  86. _ticks, _ticklabels, _label = True, True, True
  87. elif all is not None:
  88. _ticks, _ticklabels, _label = False, False, False
  89. else:
  90. _ticks, _ticklabels, _label = None, None, None
  91. if ticks is not None:
  92. _ticks = ticks
  93. if ticklabels is not None:
  94. _ticklabels = ticklabels
  95. if label is not None:
  96. _label = label
  97. tickOn = "tick%dOn" % self._axisnum
  98. labelOn = "label%dOn" % self._axisnum
  99. if _ticks is not None:
  100. tickparam = {tickOn: _ticks}
  101. self._axis.set_tick_params(**tickparam)
  102. if _ticklabels is not None:
  103. tickparam = {labelOn: _ticklabels}
  104. self._axis.set_tick_params(**tickparam)
  105. if _label is not None:
  106. pos = self._axis.get_label_position()
  107. if (pos == self._axis_direction) and not _label:
  108. self._axis.label.set_visible(False)
  109. elif _label:
  110. self._axis.label.set_visible(True)
  111. self._axis.set_label_position(self._axis_direction)
  112. if __name__ == '__main__':
  113. import matplotlib.pyplot as plt
  114. fig = plt.figure()
  115. ax = Axes(fig, [0.1, 0.1, 0.8, 0.8])
  116. fig.add_axes(ax)
  117. ax.cla()