floating_axes.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. """
  2. An experimental support for curvilinear grid.
  3. """
  4. from __future__ import (absolute_import, division, print_function,
  5. unicode_literals)
  6. import six
  7. from six.moves import zip
  8. # TODO :
  9. # see if tick_iterator method can be simplified by reusing the parent method.
  10. import numpy as np
  11. from matplotlib.transforms import Affine2D, IdentityTransform
  12. from . import grid_helper_curvelinear
  13. from .axislines import AxisArtistHelper, GridHelperBase
  14. from .axis_artist import AxisArtist
  15. from .grid_finder import GridFinder
  16. class FloatingAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper):
  17. pass
  18. class FixedAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper):
  19. def __init__(self, grid_helper, side, nth_coord_ticks=None):
  20. """
  21. nth_coord = along which coordinate value varies.
  22. nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  23. """
  24. value, nth_coord = grid_helper.get_data_boundary(side) # return v= 0 , nth=1, extremes of the other coordinate.
  25. super(FixedAxisArtistHelper, self).__init__(grid_helper,
  26. nth_coord,
  27. value,
  28. axis_direction=side,
  29. )
  30. #self.grid_helper = grid_helper
  31. if nth_coord_ticks is None:
  32. nth_coord_ticks = nth_coord
  33. self.nth_coord_ticks = nth_coord_ticks
  34. self.value = value
  35. self.grid_helper = grid_helper
  36. self._side = side
  37. def update_lim(self, axes):
  38. self.grid_helper.update_lim(axes)
  39. self.grid_info = self.grid_helper.grid_info
  40. def get_axislabel_pos_angle(self, axes):
  41. extremes = self.grid_info["extremes"]
  42. if self.nth_coord == 0:
  43. xx0 = self.value
  44. yy0 = (extremes[2]+extremes[3])/2.
  45. dxx, dyy = 0., abs(extremes[2]-extremes[3])/1000.
  46. elif self.nth_coord == 1:
  47. xx0 = (extremes[0]+extremes[1])/2.
  48. yy0 = self.value
  49. dxx, dyy = abs(extremes[0]-extremes[1])/1000., 0.
  50. grid_finder = self.grid_helper.grid_finder
  51. xx1, yy1 = grid_finder.transform_xy([xx0], [yy0])
  52. trans_passingthrough_point = axes.transData + axes.transAxes.inverted()
  53. p = trans_passingthrough_point.transform_point([xx1[0], yy1[0]])
  54. if (0. <= p[0] <= 1.) and (0. <= p[1] <= 1.):
  55. xx1c, yy1c = axes.transData.transform_point([xx1[0], yy1[0]])
  56. xx2, yy2 = grid_finder.transform_xy([xx0+dxx], [yy0+dyy])
  57. xx2c, yy2c = axes.transData.transform_point([xx2[0], yy2[0]])
  58. return (xx1c, yy1c), np.arctan2(yy2c-yy1c, xx2c-xx1c)/np.pi*180.
  59. else:
  60. return None, None
  61. def get_tick_transform(self, axes):
  62. return IdentityTransform() #axes.transData
  63. def get_tick_iterators(self, axes):
  64. """tick_loc, tick_angle, tick_label, (optionally) tick_label"""
  65. grid_finder = self.grid_helper.grid_finder
  66. lat_levs, lat_n, lat_factor = self.grid_info["lat_info"]
  67. lon_levs, lon_n, lon_factor = self.grid_info["lon_info"]
  68. lon_levs, lat_levs = np.asarray(lon_levs), np.asarray(lat_levs)
  69. if lat_factor is not None:
  70. yy0 = lat_levs / lat_factor
  71. dy = 0.001 / lat_factor
  72. else:
  73. yy0 = lat_levs
  74. dy = 0.001
  75. if lon_factor is not None:
  76. xx0 = lon_levs / lon_factor
  77. dx = 0.001 / lon_factor
  78. else:
  79. xx0 = lon_levs
  80. dx = 0.001
  81. _extremes = self.grid_helper._extremes
  82. xmin, xmax = sorted(_extremes[:2])
  83. ymin, ymax = sorted(_extremes[2:])
  84. if self.nth_coord == 0:
  85. mask = (ymin <= yy0) & (yy0 <= ymax)
  86. yy0 = yy0[mask]
  87. elif self.nth_coord == 1:
  88. mask = (xmin <= xx0) & (xx0 <= xmax)
  89. xx0 = xx0[mask]
  90. def transform_xy(x, y):
  91. x1, y1 = grid_finder.transform_xy(x, y)
  92. x2y2 = axes.transData.transform(np.array([x1, y1]).transpose())
  93. x2, y2 = x2y2.transpose()
  94. return x2, y2
  95. # find angles
  96. if self.nth_coord == 0:
  97. xx0 = np.empty_like(yy0)
  98. xx0.fill(self.value)
  99. #yy0_ = yy0.copy()
  100. xx1, yy1 = transform_xy(xx0, yy0)
  101. xx00 = xx0.astype(float, copy=True)
  102. xx00[xx0+dx>xmax] -= dx
  103. xx1a, yy1a = transform_xy(xx00, yy0)
  104. xx1b, yy1b = transform_xy(xx00+dx, yy0)
  105. yy00 = yy0.astype(float, copy=True)
  106. yy00[yy0+dy>ymax] -= dy
  107. xx2a, yy2a = transform_xy(xx0, yy00)
  108. xx2b, yy2b = transform_xy(xx0, yy00+dy)
  109. labels = self.grid_info["lat_labels"]
  110. labels = [l for l, m in zip(labels, mask) if m]
  111. elif self.nth_coord == 1:
  112. yy0 = np.empty_like(xx0)
  113. yy0.fill(self.value)
  114. #xx0_ = xx0.copy()
  115. xx1, yy1 = transform_xy(xx0, yy0)
  116. yy00 = yy0.astype(float, copy=True)
  117. yy00[yy0+dy>ymax] -= dy
  118. xx1a, yy1a = transform_xy(xx0, yy00)
  119. xx1b, yy1b = transform_xy(xx0, yy00+dy)
  120. xx00 = xx0.astype(float, copy=True)
  121. xx00[xx0+dx>xmax] -= dx
  122. xx2a, yy2a = transform_xy(xx00, yy0)
  123. xx2b, yy2b = transform_xy(xx00+dx, yy0)
  124. labels = self.grid_info["lon_labels"]
  125. labels = [l for l, m in zip(labels, mask) if m]
  126. def f1():
  127. dd = np.arctan2(yy1b-yy1a, xx1b-xx1a) # angle normal
  128. dd2 = np.arctan2(yy2b-yy2a, xx2b-xx2a) # angle tangent
  129. mm = ((yy1b-yy1a)==0.) & ((xx1b-xx1a)==0.) # mask where dd1 is not defined
  130. dd[mm] = dd2[mm] + np.pi / 2
  131. #dd += np.pi
  132. #dd = np.arctan2(xx2-xx1, angle_tangent-yy1)
  133. trans_tick = self.get_tick_transform(axes)
  134. tr2ax = trans_tick + axes.transAxes.inverted()
  135. for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels):
  136. c2 = tr2ax.transform_point((x, y))
  137. delta=0.00001
  138. if (0. -delta<= c2[0] <= 1.+delta) and \
  139. (0. -delta<= c2[1] <= 1.+delta):
  140. d1 = d/3.14159*180.
  141. d2 = d2/3.14159*180.
  142. #_mod = (d2-d1+180)%360
  143. #if _mod < 180:
  144. # d1 += 180
  145. ##_div, _mod = divmod(d2-d1, 360)
  146. yield [x, y], d1, d2, lab
  147. #, d2/3.14159*180.+da)
  148. return f1(), iter([])
  149. def get_line_transform(self, axes):
  150. return axes.transData
  151. def get_line(self, axes):
  152. self.update_lim(axes)
  153. from matplotlib.path import Path
  154. k, v = dict(left=("lon_lines0", 0),
  155. right=("lon_lines0", 1),
  156. bottom=("lat_lines0", 0),
  157. top=("lat_lines0", 1))[self._side]
  158. xx, yy = self.grid_info[k][v]
  159. return Path(np.column_stack([xx, yy]))
  160. from .grid_finder import ExtremeFinderSimple
  161. class ExtremeFinderFixed(ExtremeFinderSimple):
  162. def __init__(self, extremes):
  163. self._extremes = extremes
  164. def __call__(self, transform_xy, x1, y1, x2, y2):
  165. """
  166. get extreme values.
  167. x1, y1, x2, y2 in image coordinates (0-based)
  168. nx, ny : number of division in each axis
  169. """
  170. #lon_min, lon_max, lat_min, lat_max = self._extremes
  171. return self._extremes
  172. class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
  173. def __init__(self, aux_trans, extremes,
  174. grid_locator1=None,
  175. grid_locator2=None,
  176. tick_formatter1=None,
  177. tick_formatter2=None):
  178. """
  179. aux_trans : a transform from the source (curved) coordinate to
  180. target (rectilinear) coordinate. An instance of MPL's Transform
  181. (inverse transform should be defined) or a tuple of two callable
  182. objects which defines the transform and its inverse. The callables
  183. need take two arguments of array of source coordinates and
  184. should return two target coordinates:
  185. e.g., *x2, y2 = trans(x1, y1)*
  186. """
  187. self._old_values = None
  188. self._extremes = extremes
  189. extreme_finder = ExtremeFinderFixed(extremes)
  190. super(GridHelperCurveLinear, self).__init__(aux_trans,
  191. extreme_finder,
  192. grid_locator1=grid_locator1,
  193. grid_locator2=grid_locator2,
  194. tick_formatter1=tick_formatter1,
  195. tick_formatter2=tick_formatter2)
  196. # def update_grid_finder(self, aux_trans=None, **kw):
  197. # if aux_trans is not None:
  198. # self.grid_finder.update_transform(aux_trans)
  199. # self.grid_finder.update(**kw)
  200. # self.invalidate()
  201. # def _update(self, x1, x2, y1, y2):
  202. # "bbox in 0-based image coordinates"
  203. # # update wcsgrid
  204. # if self.valid() and self._old_values == (x1, x2, y1, y2):
  205. # return
  206. # self._update_grid(x1, y1, x2, y2)
  207. # self._old_values = (x1, x2, y1, y2)
  208. # self._force_update = False
  209. def get_data_boundary(self, side):
  210. """
  211. return v= 0 , nth=1
  212. """
  213. lon1, lon2, lat1, lat2 = self._extremes
  214. return dict(left=(lon1, 0),
  215. right=(lon2, 0),
  216. bottom=(lat1, 1),
  217. top=(lat2, 1))[side]
  218. def new_fixed_axis(self, loc,
  219. nth_coord=None,
  220. axis_direction=None,
  221. offset=None,
  222. axes=None):
  223. if axes is None:
  224. axes = self.axes
  225. if axis_direction is None:
  226. axis_direction = loc
  227. _helper = FixedAxisArtistHelper(self, loc,
  228. nth_coord_ticks=nth_coord)
  229. axisline = AxisArtist(axes, _helper, axis_direction=axis_direction)
  230. axisline.line.set_clip_on(True)
  231. axisline.line.set_clip_box(axisline.axes.bbox)
  232. return axisline
  233. # new_floating_axis will inherit the grid_helper's extremes.
  234. # def new_floating_axis(self, nth_coord,
  235. # value,
  236. # axes=None,
  237. # axis_direction="bottom"
  238. # ):
  239. # axis = super(GridHelperCurveLinear,
  240. # self).new_floating_axis(nth_coord,
  241. # value, axes=axes,
  242. # axis_direction=axis_direction)
  243. # # set extreme values of the axis helper
  244. # if nth_coord == 1:
  245. # axis.get_helper().set_extremes(*self._extremes[:2])
  246. # elif nth_coord == 0:
  247. # axis.get_helper().set_extremes(*self._extremes[2:])
  248. # return axis
  249. def _update_grid(self, x1, y1, x2, y2):
  250. #self.grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)
  251. if self.grid_info is None:
  252. self.grid_info = dict()
  253. grid_info = self.grid_info
  254. grid_finder = self.grid_finder
  255. extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
  256. x1, y1, x2, y2)
  257. lon_min, lon_max = sorted(extremes[:2])
  258. lat_min, lat_max = sorted(extremes[2:])
  259. lon_levs, lon_n, lon_factor = \
  260. grid_finder.grid_locator1(lon_min, lon_max)
  261. lat_levs, lat_n, lat_factor = \
  262. grid_finder.grid_locator2(lat_min, lat_max)
  263. grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max #extremes
  264. grid_info["lon_info"] = lon_levs, lon_n, lon_factor
  265. grid_info["lat_info"] = lat_levs, lat_n, lat_factor
  266. grid_info["lon_labels"] = grid_finder.tick_formatter1("bottom",
  267. lon_factor,
  268. lon_levs)
  269. grid_info["lat_labels"] = grid_finder.tick_formatter2("bottom",
  270. lat_factor,
  271. lat_levs)
  272. if lon_factor is None:
  273. lon_values = np.asarray(lon_levs[:lon_n])
  274. else:
  275. lon_values = np.asarray(lon_levs[:lon_n]/lon_factor)
  276. if lat_factor is None:
  277. lat_values = np.asarray(lat_levs[:lat_n])
  278. else:
  279. lat_values = np.asarray(lat_levs[:lat_n]/lat_factor)
  280. lon_values0 = lon_values[(lon_min<lon_values) & (lon_values<lon_max)]
  281. lat_values0 = lat_values[(lat_min<lat_values) & (lat_values<lat_max)]
  282. lon_lines, lat_lines = grid_finder._get_raw_grid_lines(lon_values0,
  283. lat_values0,
  284. lon_min, lon_max,
  285. lat_min, lat_max)
  286. grid_info["lon_lines"] = lon_lines
  287. grid_info["lat_lines"] = lat_lines
  288. lon_lines, lat_lines = grid_finder._get_raw_grid_lines(extremes[:2],
  289. extremes[2:],
  290. *extremes)
  291. #lon_min, lon_max,
  292. # lat_min, lat_max)
  293. grid_info["lon_lines0"] = lon_lines
  294. grid_info["lat_lines0"] = lat_lines
  295. def get_gridlines(self, which="major", axis="both"):
  296. grid_lines = []
  297. if axis in ["both", "x"]:
  298. for gl in self.grid_info["lon_lines"]:
  299. grid_lines.extend([gl])
  300. if axis in ["both", "y"]:
  301. for gl in self.grid_info["lat_lines"]:
  302. grid_lines.extend([gl])
  303. return grid_lines
  304. def get_boundary(self):
  305. """
  306. return Nx2 array of x,y coordinate of the boundary
  307. """
  308. x0, x1, y0, y1 = self._extremes
  309. tr = self._aux_trans
  310. xx = np.linspace(x0, x1, 100)
  311. yy0, yy1 = np.empty_like(xx), np.empty_like(xx)
  312. yy0.fill(y0)
  313. yy1.fill(y1)
  314. yy = np.linspace(y0, y1, 100)
  315. xx0, xx1 = np.empty_like(yy), np.empty_like(yy)
  316. xx0.fill(x0)
  317. xx1.fill(x1)
  318. xxx = np.concatenate([xx[:-1], xx1[:-1], xx[-1:0:-1], xx0])
  319. yyy = np.concatenate([yy0[:-1], yy[:-1], yy1[:-1], yy[::-1]])
  320. t = tr.transform(np.array([xxx, yyy]).transpose())
  321. return t
  322. class FloatingAxesBase(object):
  323. def __init__(self, *kl, **kwargs):
  324. grid_helper = kwargs.get("grid_helper", None)
  325. if grid_helper is None:
  326. raise ValueError("FloatingAxes requires grid_helper argument")
  327. if not hasattr(grid_helper, "get_boundary"):
  328. raise ValueError("grid_helper must implement get_boundary method")
  329. self._axes_class_floating.__init__(self, *kl, **kwargs)
  330. self.set_aspect(1.)
  331. self.adjust_axes_lim()
  332. def _gen_axes_patch(self):
  333. """
  334. Returns the patch used to draw the background of the axes. It
  335. is also used as the clipping path for any data elements on the
  336. axes.
  337. In the standard axes, this is a rectangle, but in other
  338. projections it may not be.
  339. .. note::
  340. Intended to be overridden by new projection types.
  341. """
  342. import matplotlib.patches as mpatches
  343. grid_helper = self.get_grid_helper()
  344. t = grid_helper.get_boundary()
  345. return mpatches.Polygon(t)
  346. def cla(self):
  347. self._axes_class_floating.cla(self)
  348. #HostAxes.cla(self)
  349. self.patch.set_transform(self.transData)
  350. patch = self._axes_class_floating._gen_axes_patch(self)
  351. patch.set_figure(self.figure)
  352. patch.set_visible(False)
  353. patch.set_transform(self.transAxes)
  354. self.patch.set_clip_path(patch)
  355. self.gridlines.set_clip_path(patch)
  356. self._original_patch = patch
  357. def adjust_axes_lim(self):
  358. #t = self.get_boundary()
  359. grid_helper = self.get_grid_helper()
  360. t = grid_helper.get_boundary()
  361. x, y = t[:,0], t[:,1]
  362. xmin, xmax = min(x), max(x)
  363. ymin, ymax = min(y), max(y)
  364. dx = (xmax-xmin)/100.
  365. dy = (ymax-ymin)/100.
  366. self.set_xlim(xmin-dx, xmax+dx)
  367. self.set_ylim(ymin-dy, ymax+dy)
  368. _floatingaxes_classes = {}
  369. def floatingaxes_class_factory(axes_class):
  370. new_class = _floatingaxes_classes.get(axes_class)
  371. if new_class is None:
  372. new_class = type(str("Floating %s" % (axes_class.__name__)),
  373. (FloatingAxesBase, axes_class),
  374. {'_axes_class_floating': axes_class})
  375. _floatingaxes_classes[axes_class] = new_class
  376. return new_class
  377. from .axislines import Axes
  378. from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
  379. FloatingAxes = floatingaxes_class_factory(host_axes_class_factory(Axes))
  380. import matplotlib.axes as maxes
  381. FloatingSubplot = maxes.subplot_class_factory(FloatingAxes)