123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- import numpy as np
- from matplotlib import ticker as mticker
- from matplotlib.transforms import Bbox, Transform
- def _find_line_box_crossings(xys, bbox):
- """
- Find the points where a polyline crosses a bbox, and the crossing angles.
- Parameters
- ----------
- xys : (N, 2) array
- The polyline coordinates.
- bbox : `.Bbox`
- The bounding box.
- Returns
- -------
- list of ((float, float), float)
- Four separate lists of crossings, for the left, right, bottom, and top
- sides of the bbox, respectively. For each list, the entries are the
- ``((x, y), ccw_angle_in_degrees)`` of the crossing, where an angle of 0
- means that the polyline is moving to the right at the crossing point.
- The entries are computed by linearly interpolating at each crossing
- between the nearest points on either side of the bbox edges.
- """
- crossings = []
- dxys = xys[1:] - xys[:-1]
- for sl in [slice(None), slice(None, None, -1)]:
- us, vs = xys.T[sl] # "this" coord, "other" coord
- dus, dvs = dxys.T[sl]
- umin, vmin = bbox.min[sl]
- umax, vmax = bbox.max[sl]
- for u0, inside in [(umin, us > umin), (umax, us < umax)]:
- crossings.append([])
- idxs, = (inside[:-1] ^ inside[1:]).nonzero()
- for idx in idxs:
- v = vs[idx] + (u0 - us[idx]) * dvs[idx] / dus[idx]
- if not vmin <= v <= vmax:
- continue
- crossing = (u0, v)[sl]
- theta = np.degrees(np.arctan2(*dxys[idx][::-1]))
- crossings[-1].append((crossing, theta))
- return crossings
- class ExtremeFinderSimple:
- """
- A helper class to figure out the range of grid lines that need to be drawn.
- """
- def __init__(self, nx, ny):
- """
- Parameters
- ----------
- nx, ny : int
- The number of samples in each direction.
- """
- self.nx = nx
- self.ny = ny
- def __call__(self, transform_xy, x1, y1, x2, y2):
- """
- Compute an approximation of the bounding box obtained by applying
- *transform_xy* to the box delimited by ``(x1, y1, x2, y2)``.
- The intended use is to have ``(x1, y1, x2, y2)`` in axes coordinates,
- and have *transform_xy* be the transform from axes coordinates to data
- coordinates; this method then returns the range of data coordinates
- that span the actual axes.
- The computation is done by sampling ``nx * ny`` equispaced points in
- the ``(x1, y1, x2, y2)`` box and finding the resulting points with
- extremal coordinates; then adding some padding to take into account the
- finite sampling.
- As each sampling step covers a relative range of *1/nx* or *1/ny*,
- the padding is computed by expanding the span covered by the extremal
- coordinates by these fractions.
- """
- x, y = np.meshgrid(
- np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
- xt, yt = transform_xy(np.ravel(x), np.ravel(y))
- return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
- def _add_pad(self, x_min, x_max, y_min, y_max):
- """Perform the padding mentioned in `__call__`."""
- dx = (x_max - x_min) / self.nx
- dy = (y_max - y_min) / self.ny
- return x_min - dx, x_max + dx, y_min - dy, y_max + dy
- class _User2DTransform(Transform):
- """A transform defined by two user-set functions."""
- input_dims = output_dims = 2
- def __init__(self, forward, backward):
- """
- Parameters
- ----------
- forward, backward : callable
- The forward and backward transforms, taking ``x`` and ``y`` as
- separate arguments and returning ``(tr_x, tr_y)``.
- """
- # The normal Matplotlib convention would be to take and return an
- # (N, 2) array but axisartist uses the transposed version.
- super().__init__()
- self._forward = forward
- self._backward = backward
- def transform_non_affine(self, values):
- # docstring inherited
- return np.transpose(self._forward(*np.transpose(values)))
- def inverted(self):
- # docstring inherited
- return type(self)(self._backward, self._forward)
- class GridFinder:
- """
- Internal helper for `~.grid_helper_curvelinear.GridHelperCurveLinear`, with
- the same constructor parameters; should not be directly instantiated.
- """
- def __init__(self,
- transform,
- extreme_finder=None,
- grid_locator1=None,
- grid_locator2=None,
- tick_formatter1=None,
- tick_formatter2=None):
- if extreme_finder is None:
- extreme_finder = ExtremeFinderSimple(20, 20)
- if grid_locator1 is None:
- grid_locator1 = MaxNLocator()
- if grid_locator2 is None:
- grid_locator2 = MaxNLocator()
- if tick_formatter1 is None:
- tick_formatter1 = FormatterPrettyPrint()
- if tick_formatter2 is None:
- tick_formatter2 = FormatterPrettyPrint()
- self.extreme_finder = extreme_finder
- self.grid_locator1 = grid_locator1
- self.grid_locator2 = grid_locator2
- self.tick_formatter1 = tick_formatter1
- self.tick_formatter2 = tick_formatter2
- self.set_transform(transform)
- def get_grid_info(self, x1, y1, x2, y2):
- """
- lon_values, lat_values : list of grid values. if integer is given,
- rough number of grids in each direction.
- """
- extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
- # min & max rage of lat (or lon) for each grid line will be drawn.
- # i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
- lon_min, lon_max, lat_min, lat_max = extremes
- lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
- lon_levs = np.asarray(lon_levs)
- lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
- lat_levs = np.asarray(lat_levs)
- lon_values = lon_levs[:lon_n] / lon_factor
- lat_values = lat_levs[:lat_n] / lat_factor
- lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
- lat_values,
- lon_min, lon_max,
- lat_min, lat_max)
- ddx = (x2-x1)*1.e-10
- ddy = (y2-y1)*1.e-10
- bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy)
- grid_info = {
- "extremes": extremes,
- "lon_lines": lon_lines,
- "lat_lines": lat_lines,
- "lon": self._clip_grid_lines_and_find_ticks(
- lon_lines, lon_values, lon_levs, bb),
- "lat": self._clip_grid_lines_and_find_ticks(
- lat_lines, lat_values, lat_levs, bb),
- }
- tck_labels = grid_info["lon"]["tick_labels"] = {}
- for direction in ["left", "bottom", "right", "top"]:
- levs = grid_info["lon"]["tick_levels"][direction]
- tck_labels[direction] = self.tick_formatter1(
- direction, lon_factor, levs)
- tck_labels = grid_info["lat"]["tick_labels"] = {}
- for direction in ["left", "bottom", "right", "top"]:
- levs = grid_info["lat"]["tick_levels"][direction]
- tck_labels[direction] = self.tick_formatter2(
- direction, lat_factor, levs)
- return grid_info
- def _get_raw_grid_lines(self,
- lon_values, lat_values,
- lon_min, lon_max, lat_min, lat_max):
- lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
- lats_i = np.linspace(lat_min, lat_max, 100)
- lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
- for lon in lon_values]
- lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
- for lat in lat_values]
- return lon_lines, lat_lines
- def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
- gi = {
- "values": [],
- "levels": [],
- "tick_levels": dict(left=[], bottom=[], right=[], top=[]),
- "tick_locs": dict(left=[], bottom=[], right=[], top=[]),
- "lines": [],
- }
- tck_levels = gi["tick_levels"]
- tck_locs = gi["tick_locs"]
- for (lx, ly), v, lev in zip(lines, values, levs):
- tcks = _find_line_box_crossings(np.column_stack([lx, ly]), bb)
- gi["levels"].append(v)
- gi["lines"].append([(lx, ly)])
- for tck, direction in zip(tcks,
- ["left", "right", "bottom", "top"]):
- for t in tck:
- tck_levels[direction].append(lev)
- tck_locs[direction].append(t)
- return gi
- def set_transform(self, aux_trans):
- if isinstance(aux_trans, Transform):
- self._aux_transform = aux_trans
- elif len(aux_trans) == 2 and all(map(callable, aux_trans)):
- self._aux_transform = _User2DTransform(*aux_trans)
- else:
- raise TypeError("'aux_trans' must be either a Transform "
- "instance or a pair of callables")
- def get_transform(self):
- return self._aux_transform
- update_transform = set_transform # backcompat alias.
- def transform_xy(self, x, y):
- return self._aux_transform.transform(np.column_stack([x, y])).T
- def inv_transform_xy(self, x, y):
- return self._aux_transform.inverted().transform(
- np.column_stack([x, y])).T
- def update(self, **kwargs):
- for k, v in kwargs.items():
- if k in ["extreme_finder",
- "grid_locator1",
- "grid_locator2",
- "tick_formatter1",
- "tick_formatter2"]:
- setattr(self, k, v)
- else:
- raise ValueError(f"Unknown update property {k!r}")
- class MaxNLocator(mticker.MaxNLocator):
- def __init__(self, nbins=10, steps=None,
- trim=True,
- integer=False,
- symmetric=False,
- prune=None):
- # trim argument has no effect. It has been left for API compatibility
- super().__init__(nbins, steps=steps, integer=integer,
- symmetric=symmetric, prune=prune)
- self.create_dummy_axis()
- def __call__(self, v1, v2):
- locs = super().tick_values(v1, v2)
- return np.array(locs), len(locs), 1 # 1: factor (see angle_helper)
- class FixedLocator:
- def __init__(self, locs):
- self._locs = locs
- def __call__(self, v1, v2):
- v1, v2 = sorted([v1, v2])
- locs = np.array([l for l in self._locs if v1 <= l <= v2])
- return locs, len(locs), 1 # 1: factor (see angle_helper)
- # Tick Formatter
- class FormatterPrettyPrint:
- def __init__(self, useMathText=True):
- self._fmt = mticker.ScalarFormatter(
- useMathText=useMathText, useOffset=False)
- self._fmt.create_dummy_axis()
- def __call__(self, direction, factor, values):
- return self._fmt.format_ticks(values)
- class DictFormatter:
- def __init__(self, format_dict, formatter=None):
- """
- format_dict : dictionary for format strings to be used.
- formatter : fall-back formatter
- """
- super().__init__()
- self._format_dict = format_dict
- self._fallback_formatter = formatter
- def __call__(self, direction, factor, values):
- """
- factor is ignored if value is found in the dictionary
- """
- if self._fallback_formatter:
- fallback_strings = self._fallback_formatter(
- direction, factor, values)
- else:
- fallback_strings = [""] * len(values)
- return [self._format_dict.get(k, v)
- for k, v in zip(values, fallback_strings)]
|