123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486 |
- from __future__ import (absolute_import, division, print_function,
- unicode_literals)
- import six
- from matplotlib import (
- artist as martist, collections as mcoll, transforms as mtransforms,
- rcParams)
- from matplotlib.axes import subplot_class_factory
- from matplotlib.transforms import Bbox
- from .mpl_axes import Axes
- import numpy as np
- class ParasiteAxesBase(object):
- def get_images_artists(self):
- artists = {a for a in self.get_children() if a.get_visible()}
- images = {a for a in self.images if a.get_visible()}
- return list(images), list(artists - images)
- def __init__(self, parent_axes, **kargs):
- self._parent_axes = parent_axes
- kargs.update(dict(frameon=False))
- self._get_base_axes_attr("__init__")(self, parent_axes.figure,
- parent_axes._position, **kargs)
- def cla(self):
- self._get_base_axes_attr("cla")(self)
- martist.setp(self.get_children(), visible=False)
- self._get_lines = self._parent_axes._get_lines
- # In mpl's Axes, zorders of x- and y-axis are originally set
- # within Axes.draw().
- if self._axisbelow:
- self.xaxis.set_zorder(0.5)
- self.yaxis.set_zorder(0.5)
- else:
- self.xaxis.set_zorder(2.5)
- self.yaxis.set_zorder(2.5)
- _parasite_axes_classes = {}
- def parasite_axes_class_factory(axes_class=None):
- if axes_class is None:
- axes_class = Axes
- new_class = _parasite_axes_classes.get(axes_class)
- if new_class is None:
- def _get_base_axes_attr(self, attrname):
- return getattr(axes_class, attrname)
- new_class = type(str("%sParasite" % (axes_class.__name__)),
- (ParasiteAxesBase, axes_class),
- {'_get_base_axes_attr': _get_base_axes_attr})
- _parasite_axes_classes[axes_class] = new_class
- return new_class
- ParasiteAxes = parasite_axes_class_factory()
- # #class ParasiteAxes(ParasiteAxesBase, Axes):
- # @classmethod
- # def _get_base_axes_attr(cls, attrname):
- # return getattr(Axes, attrname)
- class ParasiteAxesAuxTransBase(object):
- def __init__(self, parent_axes, aux_transform, viewlim_mode=None,
- **kwargs):
- self.transAux = aux_transform
- self.set_viewlim_mode(viewlim_mode)
- self._parasite_axes_class.__init__(self, parent_axes, **kwargs)
- def _set_lim_and_transforms(self):
- self.transAxes = self._parent_axes.transAxes
- self.transData = \
- self.transAux + \
- self._parent_axes.transData
- self._xaxis_transform = mtransforms.blended_transform_factory(
- self.transData, self.transAxes)
- self._yaxis_transform = mtransforms.blended_transform_factory(
- self.transAxes, self.transData)
- def set_viewlim_mode(self, mode):
- if mode not in [None, "equal", "transform"]:
- raise ValueError("Unknown mode : %s" % (mode,))
- else:
- self._viewlim_mode = mode
- def get_viewlim_mode(self):
- return self._viewlim_mode
- def update_viewlim(self):
- viewlim = self._parent_axes.viewLim.frozen()
- mode = self.get_viewlim_mode()
- if mode is None:
- pass
- elif mode == "equal":
- self.axes.viewLim.set(viewlim)
- elif mode == "transform":
- self.axes.viewLim.set(viewlim.transformed(self.transAux.inverted()))
- else:
- raise ValueError("Unknown mode : %s" % (self._viewlim_mode,))
- def _pcolor(self, method_name, *XYC, **kwargs):
- if len(XYC) == 1:
- C = XYC[0]
- ny, nx = C.shape
- gx = np.arange(-0.5, nx, 1.)
- gy = np.arange(-0.5, ny, 1.)
- X, Y = np.meshgrid(gx, gy)
- else:
- X, Y, C = XYC
- pcolor_routine = self._get_base_axes_attr(method_name)
- if "transform" in kwargs:
- mesh = pcolor_routine(self, X, Y, C, **kwargs)
- else:
- orig_shape = X.shape
- xy = np.vstack([X.flat, Y.flat])
- xyt=xy.transpose()
- wxy = self.transAux.transform(xyt)
- gx, gy = wxy[:,0].reshape(orig_shape), wxy[:,1].reshape(orig_shape)
- mesh = pcolor_routine(self, gx, gy, C, **kwargs)
- mesh.set_transform(self._parent_axes.transData)
- return mesh
- def pcolormesh(self, *XYC, **kwargs):
- return self._pcolor("pcolormesh", *XYC, **kwargs)
- def pcolor(self, *XYC, **kwargs):
- return self._pcolor("pcolor", *XYC, **kwargs)
- def _contour(self, method_name, *XYCL, **kwargs):
- if len(XYCL) <= 2:
- C = XYCL[0]
- ny, nx = C.shape
- gx = np.arange(0., nx, 1.)
- gy = np.arange(0., ny, 1.)
- X,Y = np.meshgrid(gx, gy)
- CL = XYCL
- else:
- X, Y = XYCL[:2]
- CL = XYCL[2:]
- contour_routine = self._get_base_axes_attr(method_name)
- if "transform" in kwargs:
- cont = contour_routine(self, X, Y, *CL, **kwargs)
- else:
- orig_shape = X.shape
- xy = np.vstack([X.flat, Y.flat])
- xyt=xy.transpose()
- wxy = self.transAux.transform(xyt)
- gx, gy = wxy[:,0].reshape(orig_shape), wxy[:,1].reshape(orig_shape)
- cont = contour_routine(self, gx, gy, *CL, **kwargs)
- for c in cont.collections:
- c.set_transform(self._parent_axes.transData)
- return cont
- def contour(self, *XYCL, **kwargs):
- return self._contour("contour", *XYCL, **kwargs)
- def contourf(self, *XYCL, **kwargs):
- return self._contour("contourf", *XYCL, **kwargs)
- def apply_aspect(self, position=None):
- self.update_viewlim()
- self._get_base_axes_attr("apply_aspect")(self)
- #ParasiteAxes.apply_aspect()
- _parasite_axes_auxtrans_classes = {}
- def parasite_axes_auxtrans_class_factory(axes_class=None):
- if axes_class is None:
- parasite_axes_class = ParasiteAxes
- elif not issubclass(axes_class, ParasiteAxesBase):
- parasite_axes_class = parasite_axes_class_factory(axes_class)
- else:
- parasite_axes_class = axes_class
- new_class = _parasite_axes_auxtrans_classes.get(parasite_axes_class)
- if new_class is None:
- new_class = type(str("%sParasiteAuxTrans" % (parasite_axes_class.__name__)),
- (ParasiteAxesAuxTransBase, parasite_axes_class),
- {'_parasite_axes_class': parasite_axes_class,
- 'name': 'parasite_axes'})
- _parasite_axes_auxtrans_classes[parasite_axes_class] = new_class
- return new_class
- ParasiteAxesAuxTrans = parasite_axes_auxtrans_class_factory(axes_class=ParasiteAxes)
- def _get_handles(ax):
- handles = ax.lines[:]
- handles.extend(ax.patches)
- handles.extend([c for c in ax.collections
- if isinstance(c, mcoll.LineCollection)])
- handles.extend([c for c in ax.collections
- if isinstance(c, mcoll.RegularPolyCollection)])
- handles.extend([c for c in ax.collections
- if isinstance(c, mcoll.CircleCollection)])
- return handles
- class HostAxesBase(object):
- def __init__(self, *args, **kwargs):
- self.parasites = []
- self._get_base_axes_attr("__init__")(self, *args, **kwargs)
- def get_aux_axes(self, tr, viewlim_mode="equal", axes_class=None):
- parasite_axes_class = parasite_axes_auxtrans_class_factory(axes_class)
- ax2 = parasite_axes_class(self, tr, viewlim_mode)
- # note that ax2.transData == tr + ax1.transData
- # Anthing you draw in ax2 will match the ticks and grids of ax1.
- self.parasites.append(ax2)
- ax2._remove_method = lambda h: self.parasites.remove(h)
- return ax2
- def _get_legend_handles(self, legend_handler_map=None):
- # don't use this!
- Axes_get_legend_handles = self._get_base_axes_attr("_get_legend_handles")
- all_handles = list(Axes_get_legend_handles(self, legend_handler_map))
- for ax in self.parasites:
- all_handles.extend(ax._get_legend_handles(legend_handler_map))
- return all_handles
- def draw(self, renderer):
- orig_artists = list(self.artists)
- orig_images = list(self.images)
- if hasattr(self, "get_axes_locator"):
- locator = self.get_axes_locator()
- if locator:
- pos = locator(self, renderer)
- self.set_position(pos, which="active")
- self.apply_aspect(pos)
- else:
- self.apply_aspect()
- else:
- self.apply_aspect()
- rect = self.get_position()
- for ax in self.parasites:
- ax.apply_aspect(rect)
- images, artists = ax.get_images_artists()
- self.images.extend(images)
- self.artists.extend(artists)
- self._get_base_axes_attr("draw")(self, renderer)
- self.artists = orig_artists
- self.images = orig_images
- def cla(self):
- for ax in self.parasites:
- ax.cla()
- self._get_base_axes_attr("cla")(self)
- #super(HostAxes, self).cla()
- def twinx(self, axes_class=None):
- """
- create a twin of Axes for generating a plot with a sharex
- x-axis but independent y axis. The y-axis of self will have
- ticks on left and the returned axes will have ticks on the
- right
- """
- if axes_class is None:
- axes_class = self._get_base_axes()
- parasite_axes_class = parasite_axes_class_factory(axes_class)
- ax2 = parasite_axes_class(self, sharex=self, frameon=False)
- self.parasites.append(ax2)
- self.axis["right"].set_visible(False)
- ax2.axis["right"].set_visible(True)
- ax2.axis["left", "top", "bottom"].set_visible(False)
- def _remove_method(h):
- self.parasites.remove(h)
- self.axis["right"].set_visible(True)
- self.axis["right"].toggle(ticklabels=False, label=False)
- ax2._remove_method = _remove_method
- return ax2
- def twiny(self, axes_class=None):
- """
- create a twin of Axes for generating a plot with a shared
- y-axis but independent x axis. The x-axis of self will have
- ticks on bottom and the returned axes will have ticks on the
- top
- """
- if axes_class is None:
- axes_class = self._get_base_axes()
- parasite_axes_class = parasite_axes_class_factory(axes_class)
- ax2 = parasite_axes_class(self, sharey=self, frameon=False)
- self.parasites.append(ax2)
- self.axis["top"].set_visible(False)
- ax2.axis["top"].set_visible(True)
- ax2.axis["left", "right", "bottom"].set_visible(False)
- def _remove_method(h):
- self.parasites.remove(h)
- self.axis["top"].set_visible(True)
- self.axis["top"].toggle(ticklabels=False, label=False)
- ax2._remove_method = _remove_method
- return ax2
- def twin(self, aux_trans=None, axes_class=None):
- """
- create a twin of Axes for generating a plot with a sharex
- x-axis but independent y axis. The y-axis of self will have
- ticks on left and the returned axes will have ticks on the
- right
- """
- if axes_class is None:
- axes_class = self._get_base_axes()
- parasite_axes_auxtrans_class = parasite_axes_auxtrans_class_factory(axes_class)
- if aux_trans is None:
- ax2 = parasite_axes_auxtrans_class(self, mtransforms.IdentityTransform(),
- viewlim_mode="equal",
- )
- else:
- ax2 = parasite_axes_auxtrans_class(self, aux_trans,
- viewlim_mode="transform",
- )
- self.parasites.append(ax2)
- ax2._remove_method = lambda h: self.parasites.remove(h)
- self.axis["top", "right"].set_visible(False)
- ax2.axis["top", "right"].set_visible(True)
- ax2.axis["left", "bottom"].set_visible(False)
- def _remove_method(h):
- self.parasites.remove(h)
- self.axis["top", "right"].set_visible(True)
- self.axis["top", "right"].toggle(ticklabels=False, label=False)
- ax2._remove_method = _remove_method
- return ax2
- def get_tightbbox(self, renderer, call_axes_locator=True):
- bbs = [ax.get_tightbbox(renderer, call_axes_locator)
- for ax in self.parasites]
- get_tightbbox = self._get_base_axes_attr("get_tightbbox")
- bbs.append(get_tightbbox(self, renderer, call_axes_locator))
- _bbox = Bbox.union([b for b in bbs if b.width!=0 or b.height!=0])
- return _bbox
- _host_axes_classes = {}
- def host_axes_class_factory(axes_class=None):
- if axes_class is None:
- axes_class = Axes
- new_class = _host_axes_classes.get(axes_class)
- if new_class is None:
- def _get_base_axes(self):
- return axes_class
- def _get_base_axes_attr(self, attrname):
- return getattr(axes_class, attrname)
- new_class = type(str("%sHostAxes" % (axes_class.__name__)),
- (HostAxesBase, axes_class),
- {'_get_base_axes_attr': _get_base_axes_attr,
- '_get_base_axes': _get_base_axes})
- _host_axes_classes[axes_class] = new_class
- return new_class
- def host_subplot_class_factory(axes_class):
- host_axes_class = host_axes_class_factory(axes_class=axes_class)
- subplot_host_class = subplot_class_factory(host_axes_class)
- return subplot_host_class
- HostAxes = host_axes_class_factory(axes_class=Axes)
- SubplotHost = subplot_class_factory(HostAxes)
- def host_axes(*args, **kwargs):
- """
- Create axes that can act as a hosts to parasitic axes.
- Parameters
- ----------
- figure : `matplotlib.figure.Figure`
- Figure to which the axes will be added. Defaults to the current figure
- `pyplot.gcf()`.
- *args, **kwargs :
- Will be passed on to the underlying ``Axes`` object creation.
- """
- import matplotlib.pyplot as plt
- axes_class = kwargs.pop("axes_class", None)
- host_axes_class = host_axes_class_factory(axes_class)
- fig = kwargs.get("figure", None)
- if fig is None:
- fig = plt.gcf()
- ax = host_axes_class(fig, *args, **kwargs)
- fig.add_axes(ax)
- plt.draw_if_interactive()
- return ax
- def host_subplot(*args, **kwargs):
- """
- Create a subplot that can act as a host to parasitic axes.
- Parameters
- ----------
- figure : `matplotlib.figure.Figure`
- Figure to which the subplot will be added. Defaults to the current
- figure `pyplot.gcf()`.
- *args, **kwargs :
- Will be passed on to the underlying ``Axes`` object creation.
- """
- import matplotlib.pyplot as plt
- axes_class = kwargs.pop("axes_class", None)
- host_subplot_class = host_subplot_class_factory(axes_class)
- fig = kwargs.get("figure", None)
- if fig is None:
- fig = plt.gcf()
- ax = host_subplot_class(fig, *args, **kwargs)
- fig.add_subplot(ax)
- plt.draw_if_interactive()
- return ax
|