parasite_axes.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from matplotlib import _api, cbook
  2. import matplotlib.artist as martist
  3. import matplotlib.transforms as mtransforms
  4. from matplotlib.transforms import Bbox
  5. from .mpl_axes import Axes
  6. class ParasiteAxesBase:
  7. def __init__(self, parent_axes, aux_transform=None,
  8. *, viewlim_mode=None, **kwargs):
  9. self._parent_axes = parent_axes
  10. self.transAux = aux_transform
  11. self.set_viewlim_mode(viewlim_mode)
  12. kwargs["frameon"] = False
  13. super().__init__(parent_axes.figure, parent_axes._position, **kwargs)
  14. def clear(self):
  15. super().clear()
  16. martist.setp(self.get_children(), visible=False)
  17. self._get_lines = self._parent_axes._get_lines
  18. self._parent_axes.callbacks._connect_picklable(
  19. "xlim_changed", self._sync_lims)
  20. self._parent_axes.callbacks._connect_picklable(
  21. "ylim_changed", self._sync_lims)
  22. def pick(self, mouseevent):
  23. # This most likely goes to Artist.pick (depending on axes_class given
  24. # to the factory), which only handles pick events registered on the
  25. # axes associated with each child:
  26. super().pick(mouseevent)
  27. # But parasite axes are additionally given pick events from their host
  28. # axes (cf. HostAxesBase.pick), which we handle here:
  29. for a in self.get_children():
  30. if (hasattr(mouseevent.inaxes, "parasites")
  31. and self in mouseevent.inaxes.parasites):
  32. a.pick(mouseevent)
  33. # aux_transform support
  34. def _set_lim_and_transforms(self):
  35. if self.transAux is not None:
  36. self.transAxes = self._parent_axes.transAxes
  37. self.transData = self.transAux + self._parent_axes.transData
  38. self._xaxis_transform = mtransforms.blended_transform_factory(
  39. self.transData, self.transAxes)
  40. self._yaxis_transform = mtransforms.blended_transform_factory(
  41. self.transAxes, self.transData)
  42. else:
  43. super()._set_lim_and_transforms()
  44. def set_viewlim_mode(self, mode):
  45. _api.check_in_list([None, "equal", "transform"], mode=mode)
  46. self._viewlim_mode = mode
  47. def get_viewlim_mode(self):
  48. return self._viewlim_mode
  49. def _sync_lims(self, parent):
  50. viewlim = parent.viewLim.frozen()
  51. mode = self.get_viewlim_mode()
  52. if mode is None:
  53. pass
  54. elif mode == "equal":
  55. self.viewLim.set(viewlim)
  56. elif mode == "transform":
  57. self.viewLim.set(viewlim.transformed(self.transAux.inverted()))
  58. else:
  59. _api.check_in_list([None, "equal", "transform"], mode=mode)
  60. # end of aux_transform support
  61. parasite_axes_class_factory = cbook._make_class_factory(
  62. ParasiteAxesBase, "{}Parasite")
  63. ParasiteAxes = parasite_axes_class_factory(Axes)
  64. class HostAxesBase:
  65. def __init__(self, *args, **kwargs):
  66. self.parasites = []
  67. super().__init__(*args, **kwargs)
  68. def get_aux_axes(
  69. self, tr=None, viewlim_mode="equal", axes_class=None, **kwargs):
  70. """
  71. Add a parasite axes to this host.
  72. Despite this method's name, this should actually be thought of as an
  73. ``add_parasite_axes`` method.
  74. .. versionchanged:: 3.7
  75. Defaults to same base axes class as host axes.
  76. Parameters
  77. ----------
  78. tr : `~matplotlib.transforms.Transform` or None, default: None
  79. If a `.Transform`, the following relation will hold:
  80. ``parasite.transData = tr + host.transData``.
  81. If None, the parasite's and the host's ``transData`` are unrelated.
  82. viewlim_mode : {"equal", "transform", None}, default: "equal"
  83. How the parasite's view limits are set: directly equal to the
  84. parent axes ("equal"), equal after application of *tr*
  85. ("transform"), or independently (None).
  86. axes_class : subclass type of `~matplotlib.axes.Axes`, optional
  87. The `~.axes.Axes` subclass that is instantiated. If None, the base
  88. class of the host axes is used.
  89. **kwargs
  90. Other parameters are forwarded to the parasite axes constructor.
  91. """
  92. if axes_class is None:
  93. axes_class = self._base_axes_class
  94. parasite_axes_class = parasite_axes_class_factory(axes_class)
  95. ax2 = parasite_axes_class(
  96. self, tr, viewlim_mode=viewlim_mode, **kwargs)
  97. # note that ax2.transData == tr + ax1.transData
  98. # Anything you draw in ax2 will match the ticks and grids of ax1.
  99. self.parasites.append(ax2)
  100. ax2._remove_method = self.parasites.remove
  101. return ax2
  102. def draw(self, renderer):
  103. orig_children_len = len(self._children)
  104. locator = self.get_axes_locator()
  105. if locator:
  106. pos = locator(self, renderer)
  107. self.set_position(pos, which="active")
  108. self.apply_aspect(pos)
  109. else:
  110. self.apply_aspect()
  111. rect = self.get_position()
  112. for ax in self.parasites:
  113. ax.apply_aspect(rect)
  114. self._children.extend(ax.get_children())
  115. super().draw(renderer)
  116. del self._children[orig_children_len:]
  117. def clear(self):
  118. super().clear()
  119. for ax in self.parasites:
  120. ax.clear()
  121. def pick(self, mouseevent):
  122. super().pick(mouseevent)
  123. # Also pass pick events on to parasite axes and, in turn, their
  124. # children (cf. ParasiteAxesBase.pick)
  125. for a in self.parasites:
  126. a.pick(mouseevent)
  127. def twinx(self, axes_class=None):
  128. """
  129. Create a twin of Axes with a shared x-axis but independent y-axis.
  130. The y-axis of self will have ticks on the left and the returned axes
  131. will have ticks on the right.
  132. """
  133. ax = self._add_twin_axes(axes_class, sharex=self)
  134. self.axis["right"].set_visible(False)
  135. ax.axis["right"].set_visible(True)
  136. ax.axis["left", "top", "bottom"].set_visible(False)
  137. return ax
  138. def twiny(self, axes_class=None):
  139. """
  140. Create a twin of Axes with a shared y-axis but independent x-axis.
  141. The x-axis of self will have ticks on the bottom and the returned axes
  142. will have ticks on the top.
  143. """
  144. ax = self._add_twin_axes(axes_class, sharey=self)
  145. self.axis["top"].set_visible(False)
  146. ax.axis["top"].set_visible(True)
  147. ax.axis["left", "right", "bottom"].set_visible(False)
  148. return ax
  149. def twin(self, aux_trans=None, axes_class=None):
  150. """
  151. Create a twin of Axes with no shared axis.
  152. While self will have ticks on the left and bottom axis, the returned
  153. axes will have ticks on the top and right axis.
  154. """
  155. if aux_trans is None:
  156. aux_trans = mtransforms.IdentityTransform()
  157. ax = self._add_twin_axes(
  158. axes_class, aux_transform=aux_trans, viewlim_mode="transform")
  159. self.axis["top", "right"].set_visible(False)
  160. ax.axis["top", "right"].set_visible(True)
  161. ax.axis["left", "bottom"].set_visible(False)
  162. return ax
  163. def _add_twin_axes(self, axes_class, **kwargs):
  164. """
  165. Helper for `.twinx`/`.twiny`/`.twin`.
  166. *kwargs* are forwarded to the parasite axes constructor.
  167. """
  168. if axes_class is None:
  169. axes_class = self._base_axes_class
  170. ax = parasite_axes_class_factory(axes_class)(self, **kwargs)
  171. self.parasites.append(ax)
  172. ax._remove_method = self._remove_any_twin
  173. return ax
  174. def _remove_any_twin(self, ax):
  175. self.parasites.remove(ax)
  176. restore = ["top", "right"]
  177. if ax._sharex:
  178. restore.remove("top")
  179. if ax._sharey:
  180. restore.remove("right")
  181. self.axis[tuple(restore)].set_visible(True)
  182. self.axis[tuple(restore)].toggle(ticklabels=False, label=False)
  183. @_api.make_keyword_only("3.8", "call_axes_locator")
  184. def get_tightbbox(self, renderer=None, call_axes_locator=True,
  185. bbox_extra_artists=None):
  186. bbs = [
  187. *[ax.get_tightbbox(renderer, call_axes_locator=call_axes_locator)
  188. for ax in self.parasites],
  189. super().get_tightbbox(renderer,
  190. call_axes_locator=call_axes_locator,
  191. bbox_extra_artists=bbox_extra_artists)]
  192. return Bbox.union([b for b in bbs if b.width != 0 or b.height != 0])
  193. host_axes_class_factory = host_subplot_class_factory = \
  194. cbook._make_class_factory(HostAxesBase, "{}HostAxes", "_base_axes_class")
  195. HostAxes = SubplotHost = host_axes_class_factory(Axes)
  196. def host_axes(*args, axes_class=Axes, figure=None, **kwargs):
  197. """
  198. Create axes that can act as a hosts to parasitic axes.
  199. Parameters
  200. ----------
  201. figure : `~matplotlib.figure.Figure`
  202. Figure to which the axes will be added. Defaults to the current figure
  203. `.pyplot.gcf()`.
  204. *args, **kwargs
  205. Will be passed on to the underlying `~.axes.Axes` object creation.
  206. """
  207. import matplotlib.pyplot as plt
  208. host_axes_class = host_axes_class_factory(axes_class)
  209. if figure is None:
  210. figure = plt.gcf()
  211. ax = host_axes_class(figure, *args, **kwargs)
  212. figure.add_axes(ax)
  213. return ax
  214. host_subplot = host_axes