123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340 |
- from __future__ import (absolute_import, division, print_function,
- unicode_literals)
- import six
- import numpy as np
- from matplotlib.transforms import Bbox
- from . import clip_path
- clip_line_to_rect = clip_path.clip_line_to_rect
- import matplotlib.ticker as mticker
- from matplotlib.transforms import Transform
- # extremes finder
- class ExtremeFinderSimple(object):
- def __init__(self, nx, ny):
- self.nx, self.ny = nx, ny
- def __call__(self, transform_xy, x1, y1, x2, y2):
- """
- get extreme values.
- x1, y1, x2, y2 in image coordinates (0-based)
- nx, ny : number of division in each axis
- """
- x_, y_ = np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny)
- x, y = np.meshgrid(x_, y_)
- lon, lat = transform_xy(np.ravel(x), np.ravel(y))
- lon_min, lon_max = lon.min(), lon.max()
- lat_min, lat_max = lat.min(), lat.max()
- return self._add_pad(lon_min, lon_max, lat_min, lat_max)
- def _add_pad(self, lon_min, lon_max, lat_min, lat_max):
- """ a small amount of padding is added because the current
- clipping algorithms seems to fail when the gridline ends at
- the bbox boundary.
- """
- dlon = (lon_max - lon_min) / self.nx
- dlat = (lat_max - lat_min) / self.ny
- lon_min, lon_max = lon_min - dlon, lon_max + dlon
- lat_min, lat_max = lat_min - dlat, lat_max + dlat
- return lon_min, lon_max, lat_min, lat_max
- class GridFinderBase(object):
- def __init__(self,
- extreme_finder,
- grid_locator1,
- grid_locator2,
- tick_formatter1=None,
- tick_formatter2=None):
- """
- the transData of the axes to the world coordinate.
- locator1, locator2 : grid locator for 1st and 2nd axis.
- Derived must define "transform_xy, inv_transform_xy"
- (may use update_transform)
- """
- super(GridFinderBase, self).__init__()
- self.extreme_finder = extreme_finder
- self.grid_locator1 = grid_locator1
- self.grid_locator2 = grid_locator2
- self.tick_formatter1 = tick_formatter1
- self.tick_formatter2 = tick_formatter2
- def get_grid_info(self,
- x1, y1, x2, y2):
- """
- lon_values, lat_values : list of grid values. if integer is given,
- rough number of grids in each direction.
- """
- extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
- # min & max rage of lat (or lon) for each grid line will be drawn.
- # i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
- lon_min, lon_max, lat_min, lat_max = extremes
- lon_levs, lon_n, lon_factor = \
- self.grid_locator1(lon_min, lon_max)
- lat_levs, lat_n, lat_factor = \
- self.grid_locator2(lat_min, lat_max)
- if lon_factor is None:
- lon_values = np.asarray(lon_levs[:lon_n])
- else:
- lon_values = np.asarray(lon_levs[:lon_n]/lon_factor)
- if lat_factor is None:
- lat_values = np.asarray(lat_levs[:lat_n])
- else:
- lat_values = np.asarray(lat_levs[:lat_n]/lat_factor)
- lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
- lat_values,
- lon_min, lon_max,
- lat_min, lat_max)
- ddx = (x2-x1)*1.e-10
- ddy = (y2-y1)*1.e-10
- bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy)
- grid_info = {}
- grid_info["extremes"] = extremes
- grid_info["lon_lines"] = lon_lines
- grid_info["lat_lines"] = lat_lines
- grid_info["lon"] = self._clip_grid_lines_and_find_ticks(lon_lines,
- lon_values,
- lon_levs,
- bb)
- grid_info["lat"] = self._clip_grid_lines_and_find_ticks(lat_lines,
- lat_values,
- lat_levs,
- bb)
- tck_labels = grid_info["lon"]["tick_labels"] = dict()
- for direction in ["left", "bottom", "right", "top"]:
- levs = grid_info["lon"]["tick_levels"][direction]
- tck_labels[direction] = self.tick_formatter1(direction,
- lon_factor, levs)
- tck_labels = grid_info["lat"]["tick_labels"] = dict()
- for direction in ["left", "bottom", "right", "top"]:
- levs = grid_info["lat"]["tick_levels"][direction]
- tck_labels[direction] = self.tick_formatter2(direction,
- lat_factor, levs)
- return grid_info
- def _get_raw_grid_lines(self,
- lon_values, lat_values,
- lon_min, lon_max, lat_min, lat_max):
- lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
- lats_i = np.linspace(lat_min, lat_max, 100)
- lon_lines = [self.transform_xy(np.zeros_like(lats_i) + lon, lats_i)
- for lon in lon_values]
- lat_lines = [self.transform_xy(lons_i, np.zeros_like(lons_i) + lat)
- for lat in lat_values]
- return lon_lines, lat_lines
- def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
- gi = dict()
- gi["values"] = []
- gi["levels"] = []
- gi["tick_levels"] = dict(left=[], bottom=[], right=[], top=[])
- gi["tick_locs"] = dict(left=[], bottom=[], right=[], top=[])
- gi["lines"] = []
- tck_levels = gi["tick_levels"]
- tck_locs = gi["tick_locs"]
- for (lx, ly), v, lev in zip(lines, values, levs):
- xy, tcks = clip_line_to_rect(lx, ly, bb)
- if not xy:
- continue
- gi["levels"].append(v)
- gi["lines"].append(xy)
- for tck, direction in zip(tcks,
- ["left", "bottom", "right", "top"]):
- for t in tck:
- tck_levels[direction].append(lev)
- tck_locs[direction].append(t)
- return gi
- def update_transform(self, aux_trans):
- if isinstance(aux_trans, Transform):
- def transform_xy(x, y):
- x, y = np.asarray(x), np.asarray(y)
- ll1 = np.concatenate((x[:,np.newaxis], y[:,np.newaxis]), 1)
- ll2 = aux_trans.transform(ll1)
- lon, lat = ll2[:,0], ll2[:,1]
- return lon, lat
- def inv_transform_xy(x, y):
- x, y = np.asarray(x), np.asarray(y)
- ll1 = np.concatenate((x[:,np.newaxis], y[:,np.newaxis]), 1)
- ll2 = aux_trans.inverted().transform(ll1)
- lon, lat = ll2[:,0], ll2[:,1]
- return lon, lat
- else:
- transform_xy, inv_transform_xy = aux_trans
- self.transform_xy = transform_xy
- self.inv_transform_xy = inv_transform_xy
- def update(self, **kw):
- for k in kw:
- if k in ["extreme_finder",
- "grid_locator1",
- "grid_locator2",
- "tick_formatter1",
- "tick_formatter2"]:
- setattr(self, k, kw[k])
- else:
- raise ValueError("unknown update property '%s'" % k)
- class GridFinder(GridFinderBase):
- def __init__(self,
- transform,
- extreme_finder=None,
- grid_locator1=None,
- grid_locator2=None,
- tick_formatter1=None,
- tick_formatter2=None):
- """
- transform : transform from the image coordinate (which will be
- the transData of the axes to the world coordinate.
- or transform = (transform_xy, inv_transform_xy)
- locator1, locator2 : grid locator for 1st and 2nd axis.
- """
- if extreme_finder is None:
- extreme_finder = ExtremeFinderSimple(20, 20)
- if grid_locator1 is None:
- grid_locator1 = MaxNLocator()
- if grid_locator2 is None:
- grid_locator2 = MaxNLocator()
- if tick_formatter1 is None:
- tick_formatter1 = FormatterPrettyPrint()
- if tick_formatter2 is None:
- tick_formatter2 = FormatterPrettyPrint()
- super(GridFinder, self).__init__(
- extreme_finder,
- grid_locator1,
- grid_locator2,
- tick_formatter1,
- tick_formatter2)
- self.update_transform(transform)
- class MaxNLocator(mticker.MaxNLocator):
- def __init__(self, nbins=10, steps=None,
- trim=True,
- integer=False,
- symmetric=False,
- prune=None):
- # trim argument has no effect. It has been left for API compatibility
- mticker.MaxNLocator.__init__(self, nbins, steps=steps,
- integer=integer,
- symmetric=symmetric, prune=prune)
- self.create_dummy_axis()
- self._factor = None
- def __call__(self, v1, v2):
- if self._factor is not None:
- self.set_bounds(v1*self._factor, v2*self._factor)
- locs = mticker.MaxNLocator.__call__(self)
- return np.array(locs), len(locs), self._factor
- else:
- self.set_bounds(v1, v2)
- locs = mticker.MaxNLocator.__call__(self)
- return np.array(locs), len(locs), None
- def set_factor(self, f):
- self._factor = f
- class FixedLocator(object):
- def __init__(self, locs):
- self._locs = locs
- self._factor = None
- def __call__(self, v1, v2):
- if self._factor is None:
- v1, v2 = sorted([v1, v2])
- else:
- v1, v2 = sorted([v1*self._factor, v2*self._factor])
- locs = np.array([l for l in self._locs if ((v1 <= l) and (l <= v2))])
- return locs, len(locs), self._factor
- def set_factor(self, f):
- self._factor = f
- # Tick Formatter
- class FormatterPrettyPrint(object):
- def __init__(self, useMathText=True):
- self._fmt = mticker.ScalarFormatter(
- useMathText=useMathText, useOffset=False)
- self._fmt.create_dummy_axis()
- self._ignore_factor = True
- def __call__(self, direction, factor, values):
- if not self._ignore_factor:
- if factor is None:
- factor = 1.
- values = [v/factor for v in values]
- #values = [v for v in values]
- self._fmt.set_locs(values)
- return [self._fmt(v) for v in values]
- class DictFormatter(object):
- def __init__(self, format_dict, formatter=None):
- """
- format_dict : dictionary for format strings to be used.
- formatter : fall-back formatter
- """
- super(DictFormatter, self).__init__()
- self._format_dict = format_dict
- self._fallback_formatter = formatter
- def __call__(self, direction, factor, values):
- """
- factor is ignored if value is found in the dictionary
- """
- if self._fallback_formatter:
- fallback_strings = self._fallback_formatter(
- direction, factor, values)
- else:
- fallback_strings = [""]*len(values)
- r = [self._format_dict.get(k, v) for k, v in zip(values,
- fallback_strings)]
- return r
|