123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336 |
- """
- An experimental support for curvilinear grid.
- """
- import functools
- from itertools import chain
- import numpy as np
- import matplotlib as mpl
- from matplotlib.path import Path
- from matplotlib.transforms import Affine2D, IdentityTransform
- from .axislines import (
- _FixedAxisArtistHelperBase, _FloatingAxisArtistHelperBase, GridHelperBase)
- from .axis_artist import AxisArtist
- from .grid_finder import GridFinder
- def _value_and_jacobian(func, xs, ys, xlims, ylims):
- """
- Compute *func* and its derivatives along x and y at positions *xs*, *ys*,
- while ensuring that finite difference calculations don't try to evaluate
- values outside of *xlims*, *ylims*.
- """
- eps = np.finfo(float).eps ** (1/2) # see e.g. scipy.optimize.approx_fprime
- val = func(xs, ys)
- # Take the finite difference step in the direction where the bound is the
- # furthest; the step size is min of epsilon and distance to that bound.
- xlo, xhi = sorted(xlims)
- dxlo = xs - xlo
- dxhi = xhi - xs
- xeps = (np.take([-1, 1], dxhi >= dxlo)
- * np.minimum(eps, np.maximum(dxlo, dxhi)))
- val_dx = func(xs + xeps, ys)
- ylo, yhi = sorted(ylims)
- dylo = ys - ylo
- dyhi = yhi - ys
- yeps = (np.take([-1, 1], dyhi >= dylo)
- * np.minimum(eps, np.maximum(dylo, dyhi)))
- val_dy = func(xs, ys + yeps)
- return (val, (val_dx - val) / xeps, (val_dy - val) / yeps)
- class FixedAxisArtistHelper(_FixedAxisArtistHelperBase):
- """
- Helper class for a fixed axis.
- """
- def __init__(self, grid_helper, side, nth_coord_ticks=None):
- """
- nth_coord = along which coordinate value varies.
- nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
- """
- super().__init__(loc=side)
- self.grid_helper = grid_helper
- if nth_coord_ticks is None:
- nth_coord_ticks = self.nth_coord
- self.nth_coord_ticks = nth_coord_ticks
- self.side = side
- def update_lim(self, axes):
- self.grid_helper.update_lim(axes)
- def get_tick_transform(self, axes):
- return axes.transData
- def get_tick_iterators(self, axes):
- """tick_loc, tick_angle, tick_label"""
- v1, v2 = axes.get_ylim() if self.nth_coord == 0 else axes.get_xlim()
- if v1 > v2: # Inverted limits.
- side = {"left": "right", "right": "left",
- "top": "bottom", "bottom": "top"}[self.side]
- else:
- side = self.side
- g = self.grid_helper
- ti1 = g.get_tick_iterator(self.nth_coord_ticks, side)
- ti2 = g.get_tick_iterator(1-self.nth_coord_ticks, side, minor=True)
- return chain(ti1, ti2), iter([])
- class FloatingAxisArtistHelper(_FloatingAxisArtistHelperBase):
- def __init__(self, grid_helper, nth_coord, value, axis_direction=None):
- """
- nth_coord = along which coordinate value varies.
- nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
- """
- super().__init__(nth_coord, value)
- self.value = value
- self.grid_helper = grid_helper
- self._extremes = -np.inf, np.inf
- self._line_num_points = 100 # number of points to create a line
- def set_extremes(self, e1, e2):
- if e1 is None:
- e1 = -np.inf
- if e2 is None:
- e2 = np.inf
- self._extremes = e1, e2
- def update_lim(self, axes):
- self.grid_helper.update_lim(axes)
- x1, x2 = axes.get_xlim()
- y1, y2 = axes.get_ylim()
- grid_finder = self.grid_helper.grid_finder
- extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
- x1, y1, x2, y2)
- lon_min, lon_max, lat_min, lat_max = extremes
- e_min, e_max = self._extremes # ranges of other coordinates
- if self.nth_coord == 0:
- lat_min = max(e_min, lat_min)
- lat_max = min(e_max, lat_max)
- elif self.nth_coord == 1:
- lon_min = max(e_min, lon_min)
- lon_max = min(e_max, lon_max)
- lon_levs, lon_n, lon_factor = \
- grid_finder.grid_locator1(lon_min, lon_max)
- lat_levs, lat_n, lat_factor = \
- grid_finder.grid_locator2(lat_min, lat_max)
- if self.nth_coord == 0:
- xx0 = np.full(self._line_num_points, self.value)
- yy0 = np.linspace(lat_min, lat_max, self._line_num_points)
- xx, yy = grid_finder.transform_xy(xx0, yy0)
- elif self.nth_coord == 1:
- xx0 = np.linspace(lon_min, lon_max, self._line_num_points)
- yy0 = np.full(self._line_num_points, self.value)
- xx, yy = grid_finder.transform_xy(xx0, yy0)
- self._grid_info = {
- "extremes": (lon_min, lon_max, lat_min, lat_max),
- "lon_info": (lon_levs, lon_n, np.asarray(lon_factor)),
- "lat_info": (lat_levs, lat_n, np.asarray(lat_factor)),
- "lon_labels": grid_finder.tick_formatter1(
- "bottom", lon_factor, lon_levs),
- "lat_labels": grid_finder.tick_formatter2(
- "bottom", lat_factor, lat_levs),
- "line_xy": (xx, yy),
- }
- def get_axislabel_transform(self, axes):
- return Affine2D() # axes.transData
- def get_axislabel_pos_angle(self, axes):
- def trf_xy(x, y):
- trf = self.grid_helper.grid_finder.get_transform() + axes.transData
- return trf.transform([x, y]).T
- xmin, xmax, ymin, ymax = self._grid_info["extremes"]
- if self.nth_coord == 0:
- xx0 = self.value
- yy0 = (ymin + ymax) / 2
- elif self.nth_coord == 1:
- xx0 = (xmin + xmax) / 2
- yy0 = self.value
- xy1, dxy1_dx, dxy1_dy = _value_and_jacobian(
- trf_xy, xx0, yy0, (xmin, xmax), (ymin, ymax))
- p = axes.transAxes.inverted().transform(xy1)
- if 0 <= p[0] <= 1 and 0 <= p[1] <= 1:
- d = [dxy1_dy, dxy1_dx][self.nth_coord]
- return xy1, np.rad2deg(np.arctan2(*d[::-1]))
- else:
- return None, None
- def get_tick_transform(self, axes):
- return IdentityTransform() # axes.transData
- def get_tick_iterators(self, axes):
- """tick_loc, tick_angle, tick_label, (optionally) tick_label"""
- lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
- yy0 = lat_levs / lat_factor
- lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
- xx0 = lon_levs / lon_factor
- e0, e1 = self._extremes
- def trf_xy(x, y):
- trf = self.grid_helper.grid_finder.get_transform() + axes.transData
- return trf.transform(np.column_stack(np.broadcast_arrays(x, y))).T
- # find angles
- if self.nth_coord == 0:
- mask = (e0 <= yy0) & (yy0 <= e1)
- (xx1, yy1), (dxx1, dyy1), (dxx2, dyy2) = _value_and_jacobian(
- trf_xy, self.value, yy0[mask], (-np.inf, np.inf), (e0, e1))
- labels = self._grid_info["lat_labels"]
- elif self.nth_coord == 1:
- mask = (e0 <= xx0) & (xx0 <= e1)
- (xx1, yy1), (dxx2, dyy2), (dxx1, dyy1) = _value_and_jacobian(
- trf_xy, xx0[mask], self.value, (-np.inf, np.inf), (e0, e1))
- labels = self._grid_info["lon_labels"]
- labels = [l for l, m in zip(labels, mask) if m]
- angle_normal = np.arctan2(dyy1, dxx1)
- angle_tangent = np.arctan2(dyy2, dxx2)
- mm = (dyy1 == 0) & (dxx1 == 0) # points with degenerate normal
- angle_normal[mm] = angle_tangent[mm] + np.pi / 2
- tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
- in_01 = functools.partial(
- mpl.transforms._interval_contains_close, (0, 1))
- def f1():
- for x, y, normal, tangent, lab \
- in zip(xx1, yy1, angle_normal, angle_tangent, labels):
- c2 = tick_to_axes.transform((x, y))
- if in_01(c2[0]) and in_01(c2[1]):
- yield [x, y], *np.rad2deg([normal, tangent]), lab
- return f1(), iter([])
- def get_line_transform(self, axes):
- return axes.transData
- def get_line(self, axes):
- self.update_lim(axes)
- x, y = self._grid_info["line_xy"]
- return Path(np.column_stack([x, y]))
- class GridHelperCurveLinear(GridHelperBase):
- def __init__(self, aux_trans,
- extreme_finder=None,
- grid_locator1=None,
- grid_locator2=None,
- tick_formatter1=None,
- tick_formatter2=None):
- """
- Parameters
- ----------
- aux_trans : `.Transform` or tuple[Callable, Callable]
- The transform from curved coordinates to rectilinear coordinate:
- either a `.Transform` instance (which provides also its inverse),
- or a pair of callables ``(trans, inv_trans)`` that define the
- transform and its inverse. The callables should have signature::
- x_rect, y_rect = trans(x_curved, y_curved)
- x_curved, y_curved = inv_trans(x_rect, y_rect)
- extreme_finder
- grid_locator1, grid_locator2
- Grid locators for each axis.
- tick_formatter1, tick_formatter2
- Tick formatters for each axis.
- """
- super().__init__()
- self._grid_info = None
- self.grid_finder = GridFinder(aux_trans,
- extreme_finder,
- grid_locator1,
- grid_locator2,
- tick_formatter1,
- tick_formatter2)
- def update_grid_finder(self, aux_trans=None, **kwargs):
- if aux_trans is not None:
- self.grid_finder.update_transform(aux_trans)
- self.grid_finder.update(**kwargs)
- self._old_limits = None # Force revalidation.
- def new_fixed_axis(self, loc,
- nth_coord=None,
- axis_direction=None,
- offset=None,
- axes=None):
- if axes is None:
- axes = self.axes
- if axis_direction is None:
- axis_direction = loc
- helper = FixedAxisArtistHelper(self, loc, nth_coord_ticks=nth_coord)
- axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
- # Why is clip not set on axisline, unlike in new_floating_axis or in
- # the floating_axig.GridHelperCurveLinear subclass?
- return axisline
- def new_floating_axis(self, nth_coord,
- value,
- axes=None,
- axis_direction="bottom"
- ):
- if axes is None:
- axes = self.axes
- helper = FloatingAxisArtistHelper(
- self, nth_coord, value, axis_direction)
- axisline = AxisArtist(axes, helper)
- axisline.line.set_clip_on(True)
- axisline.line.set_clip_box(axisline.axes.bbox)
- # axisline.major_ticklabels.set_visible(True)
- # axisline.minor_ticklabels.set_visible(False)
- return axisline
- def _update_grid(self, x1, y1, x2, y2):
- self._grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)
- def get_gridlines(self, which="major", axis="both"):
- grid_lines = []
- if axis in ["both", "x"]:
- for gl in self._grid_info["lon"]["lines"]:
- grid_lines.extend(gl)
- if axis in ["both", "y"]:
- for gl in self._grid_info["lat"]["lines"]:
- grid_lines.extend(gl)
- return grid_lines
- def get_tick_iterator(self, nth_coord, axis_side, minor=False):
- # axisnr = dict(left=0, bottom=1, right=2, top=3)[axis_side]
- angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side]
- # angle = [0, 90, 180, 270][axisnr]
- lon_or_lat = ["lon", "lat"][nth_coord]
- if not minor: # major ticks
- for (xy, a), l in zip(
- self._grid_info[lon_or_lat]["tick_locs"][axis_side],
- self._grid_info[lon_or_lat]["tick_labels"][axis_side]):
- angle_normal = a
- yield xy, angle_normal, angle_tangent, l
- else:
- for (xy, a), l in zip(
- self._grid_info[lon_or_lat]["tick_locs"][axis_side],
- self._grid_info[lon_or_lat]["tick_labels"][axis_side]):
- angle_normal = a
- yield xy, angle_normal, angle_tangent, ""
- # for xy, a, l in self._grid_info[lon_or_lat]["ticks"][axis_side]:
- # yield xy, a, ""
|