123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531 |
- """
- Axislines includes modified implementation of the Axes class. The
- biggest difference is that the artists responsible for drawing the axis spine,
- ticks, ticklabels and axis labels are separated out from Matplotlib's Axis
- class. Originally, this change was motivated to support curvilinear
- grid. Here are a few reasons that I came up with a new axes class:
- * "top" and "bottom" x-axis (or "left" and "right" y-axis) can have
- different ticks (tick locations and labels). This is not possible
- with the current Matplotlib, although some twin axes trick can help.
- * Curvilinear grid.
- * angled ticks.
- In the new axes class, xaxis and yaxis is set to not visible by
- default, and new set of artist (AxisArtist) are defined to draw axis
- line, ticks, ticklabels and axis label. Axes.axis attribute serves as
- a dictionary of these artists, i.e., ax.axis["left"] is a AxisArtist
- instance responsible to draw left y-axis. The default Axes.axis contains
- "bottom", "left", "top" and "right".
- AxisArtist can be considered as a container artist and has the following
- children artists which will draw ticks, labels, etc.
- * line
- * major_ticks, major_ticklabels
- * minor_ticks, minor_ticklabels
- * offsetText
- * label
- Note that these are separate artists from `matplotlib.axis.Axis`, thus most
- tick-related functions in Matplotlib won't work. For example, color and
- markerwidth of the ``ax.axis["bottom"].major_ticks`` will follow those of
- Axes.xaxis unless explicitly specified.
- In addition to AxisArtist, the Axes will have *gridlines* attribute,
- which obviously draws grid lines. The gridlines needs to be separated
- from the axis as some gridlines can never pass any axis.
- """
- import numpy as np
- import matplotlib as mpl
- from matplotlib import _api
- import matplotlib.axes as maxes
- from matplotlib.path import Path
- from mpl_toolkits.axes_grid1 import mpl_axes
- from .axisline_style import AxislineStyle # noqa
- from .axis_artist import AxisArtist, GridlinesCollection
- class _AxisArtistHelperBase:
- """
- Base class for axis helper.
- Subclasses should define the methods listed below. The *axes*
- argument will be the ``.axes`` attribute of the caller artist. ::
- # Construct the spine.
- def get_line_transform(self, axes):
- return transform
- def get_line(self, axes):
- return path
- # Construct the label.
- def get_axislabel_transform(self, axes):
- return transform
- def get_axislabel_pos_angle(self, axes):
- return (x, y), angle
- # Construct the ticks.
- def get_tick_transform(self, axes):
- return transform
- def get_tick_iterators(self, axes):
- # A pair of iterables (one for major ticks, one for minor ticks)
- # that yield (tick_position, tick_angle, tick_label).
- return iter_major, iter_minor
- """
- def update_lim(self, axes):
- pass
- def _to_xy(self, values, const):
- """
- Create a (*values.shape, 2)-shape array representing (x, y) pairs.
- The other coordinate is filled with the constant *const*.
- Example::
- >>> self.nth_coord = 0
- >>> self._to_xy([1, 2, 3], const=0)
- array([[1, 0],
- [2, 0],
- [3, 0]])
- """
- if self.nth_coord == 0:
- return np.stack(np.broadcast_arrays(values, const), axis=-1)
- elif self.nth_coord == 1:
- return np.stack(np.broadcast_arrays(const, values), axis=-1)
- else:
- raise ValueError("Unexpected nth_coord")
- class _FixedAxisArtistHelperBase(_AxisArtistHelperBase):
- """Helper class for a fixed (in the axes coordinate) axis."""
- passthru_pt = _api.deprecated("3.7")(property(
- lambda self: {"left": (0, 0), "right": (1, 0),
- "bottom": (0, 0), "top": (0, 1)}[self._loc]))
- def __init__(self, loc, nth_coord=None):
- """``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis."""
- self.nth_coord = (
- nth_coord if nth_coord is not None else
- _api.check_getitem(
- {"bottom": 0, "top": 0, "left": 1, "right": 1}, loc=loc))
- if (nth_coord == 0 and loc not in ["left", "right"]
- or nth_coord == 1 and loc not in ["bottom", "top"]):
- _api.warn_deprecated(
- "3.7", message=f"{loc=!r} is incompatible with "
- "{nth_coord=}; support is deprecated since %(since)s")
- self._loc = loc
- self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc]
- super().__init__()
- # axis line in transAxes
- self._path = Path(self._to_xy((0, 1), const=self._pos))
- def get_nth_coord(self):
- return self.nth_coord
- # LINE
- def get_line(self, axes):
- return self._path
- def get_line_transform(self, axes):
- return axes.transAxes
- # LABEL
- def get_axislabel_transform(self, axes):
- return axes.transAxes
- def get_axislabel_pos_angle(self, axes):
- """
- Return the label reference position in transAxes.
- get_label_transform() returns a transform of (transAxes+offset)
- """
- return dict(left=((0., 0.5), 90), # (position, angle_tangent)
- right=((1., 0.5), 90),
- bottom=((0.5, 0.), 0),
- top=((0.5, 1.), 0))[self._loc]
- # TICK
- def get_tick_transform(self, axes):
- return [axes.get_xaxis_transform(),
- axes.get_yaxis_transform()][self.nth_coord]
- class _FloatingAxisArtistHelperBase(_AxisArtistHelperBase):
- def __init__(self, nth_coord, value):
- self.nth_coord = nth_coord
- self._value = value
- super().__init__()
- def get_nth_coord(self):
- return self.nth_coord
- def get_line(self, axes):
- raise RuntimeError(
- "get_line method should be defined by the derived class")
- class FixedAxisArtistHelperRectilinear(_FixedAxisArtistHelperBase):
- def __init__(self, axes, loc, nth_coord=None):
- """
- nth_coord = along which coordinate value varies
- in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
- """
- super().__init__(loc, nth_coord)
- self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
- # TICK
- def get_tick_iterators(self, axes):
- """tick_loc, tick_angle, tick_label"""
- if self._loc in ["bottom", "top"]:
- angle_normal, angle_tangent = 90, 0
- else: # "left", "right"
- angle_normal, angle_tangent = 0, 90
- major = self.axis.major
- major_locs = major.locator()
- major_labels = major.formatter.format_ticks(major_locs)
- minor = self.axis.minor
- minor_locs = minor.locator()
- minor_labels = minor.formatter.format_ticks(minor_locs)
- tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
- def _f(locs, labels):
- for loc, label in zip(locs, labels):
- c = self._to_xy(loc, const=self._pos)
- # check if the tick point is inside axes
- c2 = tick_to_axes.transform(c)
- if mpl.transforms._interval_contains_close(
- (0, 1), c2[self.nth_coord]):
- yield c, angle_normal, angle_tangent, label
- return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
- class FloatingAxisArtistHelperRectilinear(_FloatingAxisArtistHelperBase):
- def __init__(self, axes, nth_coord,
- passingthrough_point, axis_direction="bottom"):
- super().__init__(nth_coord, passingthrough_point)
- self._axis_direction = axis_direction
- self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
- def get_line(self, axes):
- fixed_coord = 1 - self.nth_coord
- data_to_axes = axes.transData - axes.transAxes
- p = data_to_axes.transform([self._value, self._value])
- return Path(self._to_xy((0, 1), const=p[fixed_coord]))
- def get_line_transform(self, axes):
- return axes.transAxes
- def get_axislabel_transform(self, axes):
- return axes.transAxes
- def get_axislabel_pos_angle(self, axes):
- """
- Return the label reference position in transAxes.
- get_label_transform() returns a transform of (transAxes+offset)
- """
- angle = [0, 90][self.nth_coord]
- fixed_coord = 1 - self.nth_coord
- data_to_axes = axes.transData - axes.transAxes
- p = data_to_axes.transform([self._value, self._value])
- verts = self._to_xy(0.5, const=p[fixed_coord])
- if 0 <= verts[fixed_coord] <= 1:
- return verts, angle
- else:
- return None, None
- def get_tick_transform(self, axes):
- return axes.transData
- def get_tick_iterators(self, axes):
- """tick_loc, tick_angle, tick_label"""
- if self.nth_coord == 0:
- angle_normal, angle_tangent = 90, 0
- else:
- angle_normal, angle_tangent = 0, 90
- major = self.axis.major
- major_locs = major.locator()
- major_labels = major.formatter.format_ticks(major_locs)
- minor = self.axis.minor
- minor_locs = minor.locator()
- minor_labels = minor.formatter.format_ticks(minor_locs)
- data_to_axes = axes.transData - axes.transAxes
- def _f(locs, labels):
- for loc, label in zip(locs, labels):
- c = self._to_xy(loc, const=self._value)
- c1, c2 = data_to_axes.transform(c)
- if 0 <= c1 <= 1 and 0 <= c2 <= 1:
- yield c, angle_normal, angle_tangent, label
- return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
- class AxisArtistHelper: # Backcompat.
- Fixed = _FixedAxisArtistHelperBase
- Floating = _FloatingAxisArtistHelperBase
- class AxisArtistHelperRectlinear: # Backcompat.
- Fixed = FixedAxisArtistHelperRectilinear
- Floating = FloatingAxisArtistHelperRectilinear
- class GridHelperBase:
- def __init__(self):
- self._old_limits = None
- super().__init__()
- def update_lim(self, axes):
- x1, x2 = axes.get_xlim()
- y1, y2 = axes.get_ylim()
- if self._old_limits != (x1, x2, y1, y2):
- self._update_grid(x1, y1, x2, y2)
- self._old_limits = (x1, x2, y1, y2)
- def _update_grid(self, x1, y1, x2, y2):
- """Cache relevant computations when the axes limits have changed."""
- def get_gridlines(self, which, axis):
- """
- Return list of grid lines as a list of paths (list of points).
- Parameters
- ----------
- which : {"both", "major", "minor"}
- axis : {"both", "x", "y"}
- """
- return []
- class GridHelperRectlinear(GridHelperBase):
- def __init__(self, axes):
- super().__init__()
- self.axes = axes
- def new_fixed_axis(self, loc,
- nth_coord=None,
- axis_direction=None,
- offset=None,
- axes=None,
- ):
- if axes is None:
- _api.warn_external(
- "'new_fixed_axis' explicitly requires the axes keyword.")
- axes = self.axes
- if axis_direction is None:
- axis_direction = loc
- helper = FixedAxisArtistHelperRectilinear(axes, loc, nth_coord)
- axisline = AxisArtist(axes, helper, offset=offset,
- axis_direction=axis_direction)
- return axisline
- def new_floating_axis(self, nth_coord, value,
- axis_direction="bottom",
- axes=None,
- ):
- if axes is None:
- _api.warn_external(
- "'new_floating_axis' explicitly requires the axes keyword.")
- axes = self.axes
- helper = FloatingAxisArtistHelperRectilinear(
- axes, nth_coord, value, axis_direction)
- axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
- axisline.line.set_clip_on(True)
- axisline.line.set_clip_box(axisline.axes.bbox)
- return axisline
- def get_gridlines(self, which="major", axis="both"):
- """
- Return list of gridline coordinates in data coordinates.
- Parameters
- ----------
- which : {"both", "major", "minor"}
- axis : {"both", "x", "y"}
- """
- _api.check_in_list(["both", "major", "minor"], which=which)
- _api.check_in_list(["both", "x", "y"], axis=axis)
- gridlines = []
- if axis in ("both", "x"):
- locs = []
- y1, y2 = self.axes.get_ylim()
- if which in ("both", "major"):
- locs.extend(self.axes.xaxis.major.locator())
- if which in ("both", "minor"):
- locs.extend(self.axes.xaxis.minor.locator())
- for x in locs:
- gridlines.append([[x, x], [y1, y2]])
- if axis in ("both", "y"):
- x1, x2 = self.axes.get_xlim()
- locs = []
- if self.axes.yaxis._major_tick_kw["gridOn"]:
- locs.extend(self.axes.yaxis.major.locator())
- if self.axes.yaxis._minor_tick_kw["gridOn"]:
- locs.extend(self.axes.yaxis.minor.locator())
- for y in locs:
- gridlines.append([[x1, x2], [y, y]])
- return gridlines
- class Axes(maxes.Axes):
- @_api.deprecated("3.8", alternative="ax.axis")
- def __call__(self, *args, **kwargs):
- return maxes.Axes.axis(self.axes, *args, **kwargs)
- def __init__(self, *args, grid_helper=None, **kwargs):
- self._axisline_on = True
- self._grid_helper = (grid_helper if grid_helper
- else GridHelperRectlinear(self))
- super().__init__(*args, **kwargs)
- self.toggle_axisline(True)
- def toggle_axisline(self, b=None):
- if b is None:
- b = not self._axisline_on
- if b:
- self._axisline_on = True
- self.spines[:].set_visible(False)
- self.xaxis.set_visible(False)
- self.yaxis.set_visible(False)
- else:
- self._axisline_on = False
- self.spines[:].set_visible(True)
- self.xaxis.set_visible(True)
- self.yaxis.set_visible(True)
- @property
- def axis(self):
- return self._axislines
- def clear(self):
- # docstring inherited
- # Init gridlines before clear() as clear() calls grid().
- self.gridlines = gridlines = GridlinesCollection(
- [],
- colors=mpl.rcParams['grid.color'],
- linestyles=mpl.rcParams['grid.linestyle'],
- linewidths=mpl.rcParams['grid.linewidth'])
- self._set_artist_props(gridlines)
- gridlines.set_grid_helper(self.get_grid_helper())
- super().clear()
- # clip_path is set after Axes.clear(): that's when a patch is created.
- gridlines.set_clip_path(self.axes.patch)
- # Init axis artists.
- self._axislines = mpl_axes.Axes.AxisDict(self)
- new_fixed_axis = self.get_grid_helper().new_fixed_axis
- self._axislines.update({
- loc: new_fixed_axis(loc=loc, axes=self, axis_direction=loc)
- for loc in ["bottom", "top", "left", "right"]})
- for axisline in [self._axislines["top"], self._axislines["right"]]:
- axisline.label.set_visible(False)
- axisline.major_ticklabels.set_visible(False)
- axisline.minor_ticklabels.set_visible(False)
- def get_grid_helper(self):
- return self._grid_helper
- def grid(self, visible=None, which='major', axis="both", **kwargs):
- """
- Toggle the gridlines, and optionally set the properties of the lines.
- """
- # There are some discrepancies in the behavior of grid() between
- # axes_grid and Matplotlib, because axes_grid explicitly sets the
- # visibility of the gridlines.
- super().grid(visible, which=which, axis=axis, **kwargs)
- if not self._axisline_on:
- return
- if visible is None:
- visible = (self.axes.xaxis._minor_tick_kw["gridOn"]
- or self.axes.xaxis._major_tick_kw["gridOn"]
- or self.axes.yaxis._minor_tick_kw["gridOn"]
- or self.axes.yaxis._major_tick_kw["gridOn"])
- self.gridlines.set(which=which, axis=axis, visible=visible)
- self.gridlines.set(**kwargs)
- def get_children(self):
- if self._axisline_on:
- children = [*self._axislines.values(), self.gridlines]
- else:
- children = []
- children.extend(super().get_children())
- return children
- def new_fixed_axis(self, loc, offset=None):
- gh = self.get_grid_helper()
- axis = gh.new_fixed_axis(loc,
- nth_coord=None,
- axis_direction=None,
- offset=offset,
- axes=self,
- )
- return axis
- def new_floating_axis(self, nth_coord, value, axis_direction="bottom"):
- gh = self.get_grid_helper()
- axis = gh.new_floating_axis(nth_coord, value,
- axis_direction=axis_direction,
- axes=self)
- return axis
- class AxesZero(Axes):
- def clear(self):
- super().clear()
- new_floating_axis = self.get_grid_helper().new_floating_axis
- self._axislines.update(
- xzero=new_floating_axis(
- nth_coord=0, value=0., axis_direction="bottom", axes=self),
- yzero=new_floating_axis(
- nth_coord=1, value=0., axis_direction="left", axes=self),
- )
- for k in ["xzero", "yzero"]:
- self._axislines[k].line.set_clip_path(self.patch)
- self._axislines[k].set_visible(False)
- Subplot = Axes
- SubplotZero = AxesZero
|