grid_finder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. from __future__ import (absolute_import, division, print_function,
  2. unicode_literals)
  3. import six
  4. import numpy as np
  5. from matplotlib.transforms import Bbox
  6. from . import clip_path
  7. clip_line_to_rect = clip_path.clip_line_to_rect
  8. import matplotlib.ticker as mticker
  9. from matplotlib.transforms import Transform
  10. # extremes finder
  11. class ExtremeFinderSimple(object):
  12. def __init__(self, nx, ny):
  13. self.nx, self.ny = nx, ny
  14. def __call__(self, transform_xy, x1, y1, x2, y2):
  15. """
  16. get extreme values.
  17. x1, y1, x2, y2 in image coordinates (0-based)
  18. nx, ny : number of division in each axis
  19. """
  20. x_, y_ = np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny)
  21. x, y = np.meshgrid(x_, y_)
  22. lon, lat = transform_xy(np.ravel(x), np.ravel(y))
  23. lon_min, lon_max = lon.min(), lon.max()
  24. lat_min, lat_max = lat.min(), lat.max()
  25. return self._add_pad(lon_min, lon_max, lat_min, lat_max)
  26. def _add_pad(self, lon_min, lon_max, lat_min, lat_max):
  27. """ a small amount of padding is added because the current
  28. clipping algorithms seems to fail when the gridline ends at
  29. the bbox boundary.
  30. """
  31. dlon = (lon_max - lon_min) / self.nx
  32. dlat = (lat_max - lat_min) / self.ny
  33. lon_min, lon_max = lon_min - dlon, lon_max + dlon
  34. lat_min, lat_max = lat_min - dlat, lat_max + dlat
  35. return lon_min, lon_max, lat_min, lat_max
  36. class GridFinderBase(object):
  37. def __init__(self,
  38. extreme_finder,
  39. grid_locator1,
  40. grid_locator2,
  41. tick_formatter1=None,
  42. tick_formatter2=None):
  43. """
  44. the transData of the axes to the world coordinate.
  45. locator1, locator2 : grid locator for 1st and 2nd axis.
  46. Derived must define "transform_xy, inv_transform_xy"
  47. (may use update_transform)
  48. """
  49. super(GridFinderBase, self).__init__()
  50. self.extreme_finder = extreme_finder
  51. self.grid_locator1 = grid_locator1
  52. self.grid_locator2 = grid_locator2
  53. self.tick_formatter1 = tick_formatter1
  54. self.tick_formatter2 = tick_formatter2
  55. def get_grid_info(self,
  56. x1, y1, x2, y2):
  57. """
  58. lon_values, lat_values : list of grid values. if integer is given,
  59. rough number of grids in each direction.
  60. """
  61. extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
  62. # min & max rage of lat (or lon) for each grid line will be drawn.
  63. # i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
  64. lon_min, lon_max, lat_min, lat_max = extremes
  65. lon_levs, lon_n, lon_factor = \
  66. self.grid_locator1(lon_min, lon_max)
  67. lat_levs, lat_n, lat_factor = \
  68. self.grid_locator2(lat_min, lat_max)
  69. if lon_factor is None:
  70. lon_values = np.asarray(lon_levs[:lon_n])
  71. else:
  72. lon_values = np.asarray(lon_levs[:lon_n]/lon_factor)
  73. if lat_factor is None:
  74. lat_values = np.asarray(lat_levs[:lat_n])
  75. else:
  76. lat_values = np.asarray(lat_levs[:lat_n]/lat_factor)
  77. lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
  78. lat_values,
  79. lon_min, lon_max,
  80. lat_min, lat_max)
  81. ddx = (x2-x1)*1.e-10
  82. ddy = (y2-y1)*1.e-10
  83. bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy)
  84. grid_info = {}
  85. grid_info["extremes"] = extremes
  86. grid_info["lon_lines"] = lon_lines
  87. grid_info["lat_lines"] = lat_lines
  88. grid_info["lon"] = self._clip_grid_lines_and_find_ticks(lon_lines,
  89. lon_values,
  90. lon_levs,
  91. bb)
  92. grid_info["lat"] = self._clip_grid_lines_and_find_ticks(lat_lines,
  93. lat_values,
  94. lat_levs,
  95. bb)
  96. tck_labels = grid_info["lon"]["tick_labels"] = dict()
  97. for direction in ["left", "bottom", "right", "top"]:
  98. levs = grid_info["lon"]["tick_levels"][direction]
  99. tck_labels[direction] = self.tick_formatter1(direction,
  100. lon_factor, levs)
  101. tck_labels = grid_info["lat"]["tick_labels"] = dict()
  102. for direction in ["left", "bottom", "right", "top"]:
  103. levs = grid_info["lat"]["tick_levels"][direction]
  104. tck_labels[direction] = self.tick_formatter2(direction,
  105. lat_factor, levs)
  106. return grid_info
  107. def _get_raw_grid_lines(self,
  108. lon_values, lat_values,
  109. lon_min, lon_max, lat_min, lat_max):
  110. lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
  111. lats_i = np.linspace(lat_min, lat_max, 100)
  112. lon_lines = [self.transform_xy(np.zeros_like(lats_i) + lon, lats_i)
  113. for lon in lon_values]
  114. lat_lines = [self.transform_xy(lons_i, np.zeros_like(lons_i) + lat)
  115. for lat in lat_values]
  116. return lon_lines, lat_lines
  117. def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
  118. gi = dict()
  119. gi["values"] = []
  120. gi["levels"] = []
  121. gi["tick_levels"] = dict(left=[], bottom=[], right=[], top=[])
  122. gi["tick_locs"] = dict(left=[], bottom=[], right=[], top=[])
  123. gi["lines"] = []
  124. tck_levels = gi["tick_levels"]
  125. tck_locs = gi["tick_locs"]
  126. for (lx, ly), v, lev in zip(lines, values, levs):
  127. xy, tcks = clip_line_to_rect(lx, ly, bb)
  128. if not xy:
  129. continue
  130. gi["levels"].append(v)
  131. gi["lines"].append(xy)
  132. for tck, direction in zip(tcks,
  133. ["left", "bottom", "right", "top"]):
  134. for t in tck:
  135. tck_levels[direction].append(lev)
  136. tck_locs[direction].append(t)
  137. return gi
  138. def update_transform(self, aux_trans):
  139. if isinstance(aux_trans, Transform):
  140. def transform_xy(x, y):
  141. x, y = np.asarray(x), np.asarray(y)
  142. ll1 = np.concatenate((x[:,np.newaxis], y[:,np.newaxis]), 1)
  143. ll2 = aux_trans.transform(ll1)
  144. lon, lat = ll2[:,0], ll2[:,1]
  145. return lon, lat
  146. def inv_transform_xy(x, y):
  147. x, y = np.asarray(x), np.asarray(y)
  148. ll1 = np.concatenate((x[:,np.newaxis], y[:,np.newaxis]), 1)
  149. ll2 = aux_trans.inverted().transform(ll1)
  150. lon, lat = ll2[:,0], ll2[:,1]
  151. return lon, lat
  152. else:
  153. transform_xy, inv_transform_xy = aux_trans
  154. self.transform_xy = transform_xy
  155. self.inv_transform_xy = inv_transform_xy
  156. def update(self, **kw):
  157. for k in kw:
  158. if k in ["extreme_finder",
  159. "grid_locator1",
  160. "grid_locator2",
  161. "tick_formatter1",
  162. "tick_formatter2"]:
  163. setattr(self, k, kw[k])
  164. else:
  165. raise ValueError("unknown update property '%s'" % k)
  166. class GridFinder(GridFinderBase):
  167. def __init__(self,
  168. transform,
  169. extreme_finder=None,
  170. grid_locator1=None,
  171. grid_locator2=None,
  172. tick_formatter1=None,
  173. tick_formatter2=None):
  174. """
  175. transform : transform from the image coordinate (which will be
  176. the transData of the axes to the world coordinate.
  177. or transform = (transform_xy, inv_transform_xy)
  178. locator1, locator2 : grid locator for 1st and 2nd axis.
  179. """
  180. if extreme_finder is None:
  181. extreme_finder = ExtremeFinderSimple(20, 20)
  182. if grid_locator1 is None:
  183. grid_locator1 = MaxNLocator()
  184. if grid_locator2 is None:
  185. grid_locator2 = MaxNLocator()
  186. if tick_formatter1 is None:
  187. tick_formatter1 = FormatterPrettyPrint()
  188. if tick_formatter2 is None:
  189. tick_formatter2 = FormatterPrettyPrint()
  190. super(GridFinder, self).__init__(
  191. extreme_finder,
  192. grid_locator1,
  193. grid_locator2,
  194. tick_formatter1,
  195. tick_formatter2)
  196. self.update_transform(transform)
  197. class MaxNLocator(mticker.MaxNLocator):
  198. def __init__(self, nbins=10, steps=None,
  199. trim=True,
  200. integer=False,
  201. symmetric=False,
  202. prune=None):
  203. # trim argument has no effect. It has been left for API compatibility
  204. mticker.MaxNLocator.__init__(self, nbins, steps=steps,
  205. integer=integer,
  206. symmetric=symmetric, prune=prune)
  207. self.create_dummy_axis()
  208. self._factor = None
  209. def __call__(self, v1, v2):
  210. if self._factor is not None:
  211. self.set_bounds(v1*self._factor, v2*self._factor)
  212. locs = mticker.MaxNLocator.__call__(self)
  213. return np.array(locs), len(locs), self._factor
  214. else:
  215. self.set_bounds(v1, v2)
  216. locs = mticker.MaxNLocator.__call__(self)
  217. return np.array(locs), len(locs), None
  218. def set_factor(self, f):
  219. self._factor = f
  220. class FixedLocator(object):
  221. def __init__(self, locs):
  222. self._locs = locs
  223. self._factor = None
  224. def __call__(self, v1, v2):
  225. if self._factor is None:
  226. v1, v2 = sorted([v1, v2])
  227. else:
  228. v1, v2 = sorted([v1*self._factor, v2*self._factor])
  229. locs = np.array([l for l in self._locs if ((v1 <= l) and (l <= v2))])
  230. return locs, len(locs), self._factor
  231. def set_factor(self, f):
  232. self._factor = f
  233. # Tick Formatter
  234. class FormatterPrettyPrint(object):
  235. def __init__(self, useMathText=True):
  236. self._fmt = mticker.ScalarFormatter(
  237. useMathText=useMathText, useOffset=False)
  238. self._fmt.create_dummy_axis()
  239. self._ignore_factor = True
  240. def __call__(self, direction, factor, values):
  241. if not self._ignore_factor:
  242. if factor is None:
  243. factor = 1.
  244. values = [v/factor for v in values]
  245. #values = [v for v in values]
  246. self._fmt.set_locs(values)
  247. return [self._fmt(v) for v in values]
  248. class DictFormatter(object):
  249. def __init__(self, format_dict, formatter=None):
  250. """
  251. format_dict : dictionary for format strings to be used.
  252. formatter : fall-back formatter
  253. """
  254. super(DictFormatter, self).__init__()
  255. self._format_dict = format_dict
  256. self._fallback_formatter = formatter
  257. def __call__(self, direction, factor, values):
  258. """
  259. factor is ignored if value is found in the dictionary
  260. """
  261. if self._fallback_formatter:
  262. fallback_strings = self._fallback_formatter(
  263. direction, factor, values)
  264. else:
  265. fallback_strings = [""]*len(values)
  266. r = [self._format_dict.get(k, v) for k, v in zip(values,
  267. fallback_strings)]
  268. return r