parasite_axes.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. from __future__ import (absolute_import, division, print_function,
  2. unicode_literals)
  3. import six
  4. from matplotlib import (
  5. artist as martist, collections as mcoll, transforms as mtransforms,
  6. rcParams)
  7. from matplotlib.axes import subplot_class_factory
  8. from matplotlib.transforms import Bbox
  9. from .mpl_axes import Axes
  10. import numpy as np
  11. class ParasiteAxesBase(object):
  12. def get_images_artists(self):
  13. artists = {a for a in self.get_children() if a.get_visible()}
  14. images = {a for a in self.images if a.get_visible()}
  15. return list(images), list(artists - images)
  16. def __init__(self, parent_axes, **kargs):
  17. self._parent_axes = parent_axes
  18. kargs.update(dict(frameon=False))
  19. self._get_base_axes_attr("__init__")(self, parent_axes.figure,
  20. parent_axes._position, **kargs)
  21. def cla(self):
  22. self._get_base_axes_attr("cla")(self)
  23. martist.setp(self.get_children(), visible=False)
  24. self._get_lines = self._parent_axes._get_lines
  25. # In mpl's Axes, zorders of x- and y-axis are originally set
  26. # within Axes.draw().
  27. if self._axisbelow:
  28. self.xaxis.set_zorder(0.5)
  29. self.yaxis.set_zorder(0.5)
  30. else:
  31. self.xaxis.set_zorder(2.5)
  32. self.yaxis.set_zorder(2.5)
  33. _parasite_axes_classes = {}
  34. def parasite_axes_class_factory(axes_class=None):
  35. if axes_class is None:
  36. axes_class = Axes
  37. new_class = _parasite_axes_classes.get(axes_class)
  38. if new_class is None:
  39. def _get_base_axes_attr(self, attrname):
  40. return getattr(axes_class, attrname)
  41. new_class = type(str("%sParasite" % (axes_class.__name__)),
  42. (ParasiteAxesBase, axes_class),
  43. {'_get_base_axes_attr': _get_base_axes_attr})
  44. _parasite_axes_classes[axes_class] = new_class
  45. return new_class
  46. ParasiteAxes = parasite_axes_class_factory()
  47. # #class ParasiteAxes(ParasiteAxesBase, Axes):
  48. # @classmethod
  49. # def _get_base_axes_attr(cls, attrname):
  50. # return getattr(Axes, attrname)
  51. class ParasiteAxesAuxTransBase(object):
  52. def __init__(self, parent_axes, aux_transform, viewlim_mode=None,
  53. **kwargs):
  54. self.transAux = aux_transform
  55. self.set_viewlim_mode(viewlim_mode)
  56. self._parasite_axes_class.__init__(self, parent_axes, **kwargs)
  57. def _set_lim_and_transforms(self):
  58. self.transAxes = self._parent_axes.transAxes
  59. self.transData = \
  60. self.transAux + \
  61. self._parent_axes.transData
  62. self._xaxis_transform = mtransforms.blended_transform_factory(
  63. self.transData, self.transAxes)
  64. self._yaxis_transform = mtransforms.blended_transform_factory(
  65. self.transAxes, self.transData)
  66. def set_viewlim_mode(self, mode):
  67. if mode not in [None, "equal", "transform"]:
  68. raise ValueError("Unknown mode : %s" % (mode,))
  69. else:
  70. self._viewlim_mode = mode
  71. def get_viewlim_mode(self):
  72. return self._viewlim_mode
  73. def update_viewlim(self):
  74. viewlim = self._parent_axes.viewLim.frozen()
  75. mode = self.get_viewlim_mode()
  76. if mode is None:
  77. pass
  78. elif mode == "equal":
  79. self.axes.viewLim.set(viewlim)
  80. elif mode == "transform":
  81. self.axes.viewLim.set(viewlim.transformed(self.transAux.inverted()))
  82. else:
  83. raise ValueError("Unknown mode : %s" % (self._viewlim_mode,))
  84. def _pcolor(self, method_name, *XYC, **kwargs):
  85. if len(XYC) == 1:
  86. C = XYC[0]
  87. ny, nx = C.shape
  88. gx = np.arange(-0.5, nx, 1.)
  89. gy = np.arange(-0.5, ny, 1.)
  90. X, Y = np.meshgrid(gx, gy)
  91. else:
  92. X, Y, C = XYC
  93. pcolor_routine = self._get_base_axes_attr(method_name)
  94. if "transform" in kwargs:
  95. mesh = pcolor_routine(self, X, Y, C, **kwargs)
  96. else:
  97. orig_shape = X.shape
  98. xy = np.vstack([X.flat, Y.flat])
  99. xyt=xy.transpose()
  100. wxy = self.transAux.transform(xyt)
  101. gx, gy = wxy[:,0].reshape(orig_shape), wxy[:,1].reshape(orig_shape)
  102. mesh = pcolor_routine(self, gx, gy, C, **kwargs)
  103. mesh.set_transform(self._parent_axes.transData)
  104. return mesh
  105. def pcolormesh(self, *XYC, **kwargs):
  106. return self._pcolor("pcolormesh", *XYC, **kwargs)
  107. def pcolor(self, *XYC, **kwargs):
  108. return self._pcolor("pcolor", *XYC, **kwargs)
  109. def _contour(self, method_name, *XYCL, **kwargs):
  110. if len(XYCL) <= 2:
  111. C = XYCL[0]
  112. ny, nx = C.shape
  113. gx = np.arange(0., nx, 1.)
  114. gy = np.arange(0., ny, 1.)
  115. X,Y = np.meshgrid(gx, gy)
  116. CL = XYCL
  117. else:
  118. X, Y = XYCL[:2]
  119. CL = XYCL[2:]
  120. contour_routine = self._get_base_axes_attr(method_name)
  121. if "transform" in kwargs:
  122. cont = contour_routine(self, X, Y, *CL, **kwargs)
  123. else:
  124. orig_shape = X.shape
  125. xy = np.vstack([X.flat, Y.flat])
  126. xyt=xy.transpose()
  127. wxy = self.transAux.transform(xyt)
  128. gx, gy = wxy[:,0].reshape(orig_shape), wxy[:,1].reshape(orig_shape)
  129. cont = contour_routine(self, gx, gy, *CL, **kwargs)
  130. for c in cont.collections:
  131. c.set_transform(self._parent_axes.transData)
  132. return cont
  133. def contour(self, *XYCL, **kwargs):
  134. return self._contour("contour", *XYCL, **kwargs)
  135. def contourf(self, *XYCL, **kwargs):
  136. return self._contour("contourf", *XYCL, **kwargs)
  137. def apply_aspect(self, position=None):
  138. self.update_viewlim()
  139. self._get_base_axes_attr("apply_aspect")(self)
  140. #ParasiteAxes.apply_aspect()
  141. _parasite_axes_auxtrans_classes = {}
  142. def parasite_axes_auxtrans_class_factory(axes_class=None):
  143. if axes_class is None:
  144. parasite_axes_class = ParasiteAxes
  145. elif not issubclass(axes_class, ParasiteAxesBase):
  146. parasite_axes_class = parasite_axes_class_factory(axes_class)
  147. else:
  148. parasite_axes_class = axes_class
  149. new_class = _parasite_axes_auxtrans_classes.get(parasite_axes_class)
  150. if new_class is None:
  151. new_class = type(str("%sParasiteAuxTrans" % (parasite_axes_class.__name__)),
  152. (ParasiteAxesAuxTransBase, parasite_axes_class),
  153. {'_parasite_axes_class': parasite_axes_class,
  154. 'name': 'parasite_axes'})
  155. _parasite_axes_auxtrans_classes[parasite_axes_class] = new_class
  156. return new_class
  157. ParasiteAxesAuxTrans = parasite_axes_auxtrans_class_factory(axes_class=ParasiteAxes)
  158. def _get_handles(ax):
  159. handles = ax.lines[:]
  160. handles.extend(ax.patches)
  161. handles.extend([c for c in ax.collections
  162. if isinstance(c, mcoll.LineCollection)])
  163. handles.extend([c for c in ax.collections
  164. if isinstance(c, mcoll.RegularPolyCollection)])
  165. handles.extend([c for c in ax.collections
  166. if isinstance(c, mcoll.CircleCollection)])
  167. return handles
  168. class HostAxesBase(object):
  169. def __init__(self, *args, **kwargs):
  170. self.parasites = []
  171. self._get_base_axes_attr("__init__")(self, *args, **kwargs)
  172. def get_aux_axes(self, tr, viewlim_mode="equal", axes_class=None):
  173. parasite_axes_class = parasite_axes_auxtrans_class_factory(axes_class)
  174. ax2 = parasite_axes_class(self, tr, viewlim_mode)
  175. # note that ax2.transData == tr + ax1.transData
  176. # Anthing you draw in ax2 will match the ticks and grids of ax1.
  177. self.parasites.append(ax2)
  178. ax2._remove_method = lambda h: self.parasites.remove(h)
  179. return ax2
  180. def _get_legend_handles(self, legend_handler_map=None):
  181. # don't use this!
  182. Axes_get_legend_handles = self._get_base_axes_attr("_get_legend_handles")
  183. all_handles = list(Axes_get_legend_handles(self, legend_handler_map))
  184. for ax in self.parasites:
  185. all_handles.extend(ax._get_legend_handles(legend_handler_map))
  186. return all_handles
  187. def draw(self, renderer):
  188. orig_artists = list(self.artists)
  189. orig_images = list(self.images)
  190. if hasattr(self, "get_axes_locator"):
  191. locator = self.get_axes_locator()
  192. if locator:
  193. pos = locator(self, renderer)
  194. self.set_position(pos, which="active")
  195. self.apply_aspect(pos)
  196. else:
  197. self.apply_aspect()
  198. else:
  199. self.apply_aspect()
  200. rect = self.get_position()
  201. for ax in self.parasites:
  202. ax.apply_aspect(rect)
  203. images, artists = ax.get_images_artists()
  204. self.images.extend(images)
  205. self.artists.extend(artists)
  206. self._get_base_axes_attr("draw")(self, renderer)
  207. self.artists = orig_artists
  208. self.images = orig_images
  209. def cla(self):
  210. for ax in self.parasites:
  211. ax.cla()
  212. self._get_base_axes_attr("cla")(self)
  213. #super(HostAxes, self).cla()
  214. def twinx(self, axes_class=None):
  215. """
  216. create a twin of Axes for generating a plot with a sharex
  217. x-axis but independent y axis. The y-axis of self will have
  218. ticks on left and the returned axes will have ticks on the
  219. right
  220. """
  221. if axes_class is None:
  222. axes_class = self._get_base_axes()
  223. parasite_axes_class = parasite_axes_class_factory(axes_class)
  224. ax2 = parasite_axes_class(self, sharex=self, frameon=False)
  225. self.parasites.append(ax2)
  226. self.axis["right"].set_visible(False)
  227. ax2.axis["right"].set_visible(True)
  228. ax2.axis["left", "top", "bottom"].set_visible(False)
  229. def _remove_method(h):
  230. self.parasites.remove(h)
  231. self.axis["right"].set_visible(True)
  232. self.axis["right"].toggle(ticklabels=False, label=False)
  233. ax2._remove_method = _remove_method
  234. return ax2
  235. def twiny(self, axes_class=None):
  236. """
  237. create a twin of Axes for generating a plot with a shared
  238. y-axis but independent x axis. The x-axis of self will have
  239. ticks on bottom and the returned axes will have ticks on the
  240. top
  241. """
  242. if axes_class is None:
  243. axes_class = self._get_base_axes()
  244. parasite_axes_class = parasite_axes_class_factory(axes_class)
  245. ax2 = parasite_axes_class(self, sharey=self, frameon=False)
  246. self.parasites.append(ax2)
  247. self.axis["top"].set_visible(False)
  248. ax2.axis["top"].set_visible(True)
  249. ax2.axis["left", "right", "bottom"].set_visible(False)
  250. def _remove_method(h):
  251. self.parasites.remove(h)
  252. self.axis["top"].set_visible(True)
  253. self.axis["top"].toggle(ticklabels=False, label=False)
  254. ax2._remove_method = _remove_method
  255. return ax2
  256. def twin(self, aux_trans=None, axes_class=None):
  257. """
  258. create a twin of Axes for generating a plot with a sharex
  259. x-axis but independent y axis. The y-axis of self will have
  260. ticks on left and the returned axes will have ticks on the
  261. right
  262. """
  263. if axes_class is None:
  264. axes_class = self._get_base_axes()
  265. parasite_axes_auxtrans_class = parasite_axes_auxtrans_class_factory(axes_class)
  266. if aux_trans is None:
  267. ax2 = parasite_axes_auxtrans_class(self, mtransforms.IdentityTransform(),
  268. viewlim_mode="equal",
  269. )
  270. else:
  271. ax2 = parasite_axes_auxtrans_class(self, aux_trans,
  272. viewlim_mode="transform",
  273. )
  274. self.parasites.append(ax2)
  275. ax2._remove_method = lambda h: self.parasites.remove(h)
  276. self.axis["top", "right"].set_visible(False)
  277. ax2.axis["top", "right"].set_visible(True)
  278. ax2.axis["left", "bottom"].set_visible(False)
  279. def _remove_method(h):
  280. self.parasites.remove(h)
  281. self.axis["top", "right"].set_visible(True)
  282. self.axis["top", "right"].toggle(ticklabels=False, label=False)
  283. ax2._remove_method = _remove_method
  284. return ax2
  285. def get_tightbbox(self, renderer, call_axes_locator=True):
  286. bbs = [ax.get_tightbbox(renderer, call_axes_locator)
  287. for ax in self.parasites]
  288. get_tightbbox = self._get_base_axes_attr("get_tightbbox")
  289. bbs.append(get_tightbbox(self, renderer, call_axes_locator))
  290. _bbox = Bbox.union([b for b in bbs if b.width!=0 or b.height!=0])
  291. return _bbox
  292. _host_axes_classes = {}
  293. def host_axes_class_factory(axes_class=None):
  294. if axes_class is None:
  295. axes_class = Axes
  296. new_class = _host_axes_classes.get(axes_class)
  297. if new_class is None:
  298. def _get_base_axes(self):
  299. return axes_class
  300. def _get_base_axes_attr(self, attrname):
  301. return getattr(axes_class, attrname)
  302. new_class = type(str("%sHostAxes" % (axes_class.__name__)),
  303. (HostAxesBase, axes_class),
  304. {'_get_base_axes_attr': _get_base_axes_attr,
  305. '_get_base_axes': _get_base_axes})
  306. _host_axes_classes[axes_class] = new_class
  307. return new_class
  308. def host_subplot_class_factory(axes_class):
  309. host_axes_class = host_axes_class_factory(axes_class=axes_class)
  310. subplot_host_class = subplot_class_factory(host_axes_class)
  311. return subplot_host_class
  312. HostAxes = host_axes_class_factory(axes_class=Axes)
  313. SubplotHost = subplot_class_factory(HostAxes)
  314. def host_axes(*args, **kwargs):
  315. """
  316. Create axes that can act as a hosts to parasitic axes.
  317. Parameters
  318. ----------
  319. figure : `matplotlib.figure.Figure`
  320. Figure to which the axes will be added. Defaults to the current figure
  321. `pyplot.gcf()`.
  322. *args, **kwargs :
  323. Will be passed on to the underlying ``Axes`` object creation.
  324. """
  325. import matplotlib.pyplot as plt
  326. axes_class = kwargs.pop("axes_class", None)
  327. host_axes_class = host_axes_class_factory(axes_class)
  328. fig = kwargs.get("figure", None)
  329. if fig is None:
  330. fig = plt.gcf()
  331. ax = host_axes_class(fig, *args, **kwargs)
  332. fig.add_axes(ax)
  333. plt.draw_if_interactive()
  334. return ax
  335. def host_subplot(*args, **kwargs):
  336. """
  337. Create a subplot that can act as a host to parasitic axes.
  338. Parameters
  339. ----------
  340. figure : `matplotlib.figure.Figure`
  341. Figure to which the subplot will be added. Defaults to the current
  342. figure `pyplot.gcf()`.
  343. *args, **kwargs :
  344. Will be passed on to the underlying ``Axes`` object creation.
  345. """
  346. import matplotlib.pyplot as plt
  347. axes_class = kwargs.pop("axes_class", None)
  348. host_subplot_class = host_subplot_class_factory(axes_class)
  349. fig = kwargs.get("figure", None)
  350. if fig is None:
  351. fig = plt.gcf()
  352. ax = host_subplot_class(fig, *args, **kwargs)
  353. fig.add_subplot(ax)
  354. plt.draw_if_interactive()
  355. return ax