grid_helper_curvelinear.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. """
  2. An experimental support for curvilinear grid.
  3. """
  4. from __future__ import (absolute_import, division, print_function,
  5. unicode_literals)
  6. import six
  7. from six.moves import zip
  8. from itertools import chain
  9. from .grid_finder import GridFinder
  10. from .axislines import AxisArtistHelper, GridHelperBase
  11. from .axis_artist import AxisArtist
  12. from matplotlib.transforms import Affine2D, IdentityTransform
  13. import numpy as np
  14. from matplotlib.path import Path
  15. class FixedAxisArtistHelper(AxisArtistHelper.Fixed):
  16. """
  17. Helper class for a fixed axis.
  18. """
  19. def __init__(self, grid_helper, side, nth_coord_ticks=None):
  20. """
  21. nth_coord = along which coordinate value varies.
  22. nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  23. """
  24. super(FixedAxisArtistHelper, self).__init__(loc=side)
  25. self.grid_helper = grid_helper
  26. if nth_coord_ticks is None:
  27. nth_coord_ticks = self.nth_coord
  28. self.nth_coord_ticks = nth_coord_ticks
  29. self.side = side
  30. self._limits_inverted = False
  31. def update_lim(self, axes):
  32. self.grid_helper.update_lim(axes)
  33. if self.nth_coord == 0:
  34. xy1, xy2 = axes.get_ylim()
  35. else:
  36. xy1, xy2 = axes.get_xlim()
  37. if xy1 > xy2:
  38. self._limits_inverted = True
  39. else:
  40. self._limits_inverted = False
  41. def change_tick_coord(self, coord_number=None):
  42. if coord_number is None:
  43. self.nth_coord_ticks = 1 - self.nth_coord_ticks
  44. elif coord_number in [0, 1]:
  45. self.nth_coord_ticks = coord_number
  46. else:
  47. raise Exception("wrong coord number")
  48. def get_tick_transform(self, axes):
  49. return axes.transData
  50. def get_tick_iterators(self, axes):
  51. """tick_loc, tick_angle, tick_label"""
  52. g = self.grid_helper
  53. if self._limits_inverted:
  54. side = {"left":"right","right":"left",
  55. "top":"bottom", "bottom":"top"}[self.side]
  56. else:
  57. side = self.side
  58. ti1 = g.get_tick_iterator(self.nth_coord_ticks, side)
  59. ti2 = g.get_tick_iterator(1-self.nth_coord_ticks, side, minor=True)
  60. #ti2 = g.get_tick_iterator(1-self.nth_coord_ticks, self.side, minor=True)
  61. return chain(ti1, ti2), iter([])
  62. class FloatingAxisArtistHelper(AxisArtistHelper.Floating):
  63. def __init__(self, grid_helper, nth_coord, value, axis_direction=None):
  64. """
  65. nth_coord = along which coordinate value varies.
  66. nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  67. """
  68. super(FloatingAxisArtistHelper, self).__init__(nth_coord,
  69. value,
  70. )
  71. self.value = value
  72. self.grid_helper = grid_helper
  73. self._extremes = None, None
  74. self._get_line_path = None # a method that returns a Path.
  75. self._line_num_points = 100 # number of points to create a line
  76. def set_extremes(self, e1, e2):
  77. self._extremes = e1, e2
  78. def update_lim(self, axes):
  79. self.grid_helper.update_lim(axes)
  80. x1, x2 = axes.get_xlim()
  81. y1, y2 = axes.get_ylim()
  82. grid_finder = self.grid_helper.grid_finder
  83. extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
  84. x1, y1, x2, y2)
  85. extremes = list(extremes)
  86. e1, e2 = self._extremes # ranges of other coordinates
  87. if self.nth_coord == 0:
  88. if e1 is not None:
  89. extremes[2] = max(e1, extremes[2])
  90. if e2 is not None:
  91. extremes[3] = min(e2, extremes[3])
  92. elif self.nth_coord == 1:
  93. if e1 is not None:
  94. extremes[0] = max(e1, extremes[0])
  95. if e2 is not None:
  96. extremes[1] = min(e2, extremes[1])
  97. grid_info = dict()
  98. lon_min, lon_max, lat_min, lat_max = extremes
  99. lon_levs, lon_n, lon_factor = \
  100. grid_finder.grid_locator1(lon_min, lon_max)
  101. lat_levs, lat_n, lat_factor = \
  102. grid_finder.grid_locator2(lat_min, lat_max)
  103. grid_info["extremes"] = extremes
  104. grid_info["lon_info"] = lon_levs, lon_n, lon_factor
  105. grid_info["lat_info"] = lat_levs, lat_n, lat_factor
  106. grid_info["lon_labels"] = grid_finder.tick_formatter1("bottom",
  107. lon_factor,
  108. lon_levs)
  109. grid_info["lat_labels"] = grid_finder.tick_formatter2("bottom",
  110. lat_factor,
  111. lat_levs)
  112. grid_finder = self.grid_helper.grid_finder
  113. #e1, e2 = self._extremes # ranges of other coordinates
  114. if self.nth_coord == 0:
  115. xx0 = np.linspace(self.value, self.value, self._line_num_points)
  116. yy0 = np.linspace(extremes[2], extremes[3], self._line_num_points)
  117. xx, yy = grid_finder.transform_xy(xx0, yy0)
  118. elif self.nth_coord == 1:
  119. xx0 = np.linspace(extremes[0], extremes[1], self._line_num_points)
  120. yy0 = np.linspace(self.value, self.value, self._line_num_points)
  121. xx, yy = grid_finder.transform_xy(xx0, yy0)
  122. grid_info["line_xy"] = xx, yy
  123. self.grid_info = grid_info
  124. def get_axislabel_transform(self, axes):
  125. return Affine2D() #axes.transData
  126. def get_axislabel_pos_angle(self, axes):
  127. extremes = self.grid_info["extremes"]
  128. if self.nth_coord == 0:
  129. xx0 = self.value
  130. yy0 = (extremes[2]+extremes[3])/2.
  131. dxx, dyy = 0., abs(extremes[2]-extremes[3])/1000.
  132. elif self.nth_coord == 1:
  133. xx0 = (extremes[0]+extremes[1])/2.
  134. yy0 = self.value
  135. dxx, dyy = abs(extremes[0]-extremes[1])/1000., 0.
  136. grid_finder = self.grid_helper.grid_finder
  137. xx1, yy1 = grid_finder.transform_xy([xx0], [yy0])
  138. trans_passingthrough_point = axes.transData + axes.transAxes.inverted()
  139. p = trans_passingthrough_point.transform_point([xx1[0], yy1[0]])
  140. if (0. <= p[0] <= 1.) and (0. <= p[1] <= 1.):
  141. xx1c, yy1c = axes.transData.transform_point([xx1[0], yy1[0]])
  142. xx2, yy2 = grid_finder.transform_xy([xx0+dxx], [yy0+dyy])
  143. xx2c, yy2c = axes.transData.transform_point([xx2[0], yy2[0]])
  144. return (xx1c, yy1c), np.arctan2(yy2c-yy1c, xx2c-xx1c)/np.pi*180.
  145. else:
  146. return None, None
  147. def get_tick_transform(self, axes):
  148. return IdentityTransform() #axes.transData
  149. def get_tick_iterators(self, axes):
  150. """tick_loc, tick_angle, tick_label, (optionally) tick_label"""
  151. grid_finder = self.grid_helper.grid_finder
  152. lat_levs, lat_n, lat_factor = self.grid_info["lat_info"]
  153. lat_levs = np.asarray(lat_levs)
  154. if lat_factor is not None:
  155. yy0 = lat_levs / lat_factor
  156. dy = 0.01 / lat_factor
  157. else:
  158. yy0 = lat_levs
  159. dy = 0.01
  160. lon_levs, lon_n, lon_factor = self.grid_info["lon_info"]
  161. lon_levs = np.asarray(lon_levs)
  162. if lon_factor is not None:
  163. xx0 = lon_levs / lon_factor
  164. dx = 0.01 / lon_factor
  165. else:
  166. xx0 = lon_levs
  167. dx = 0.01
  168. if None in self._extremes:
  169. e0, e1 = self._extremes
  170. else:
  171. e0, e1 = sorted(self._extremes)
  172. if e0 is None:
  173. e0 = -np.inf
  174. if e1 is None:
  175. e1 = np.inf
  176. if self.nth_coord == 0:
  177. mask = (e0 <= yy0) & (yy0 <= e1)
  178. #xx0, yy0 = xx0[mask], yy0[mask]
  179. yy0 = yy0[mask]
  180. elif self.nth_coord == 1:
  181. mask = (e0 <= xx0) & (xx0 <= e1)
  182. #xx0, yy0 = xx0[mask], yy0[mask]
  183. xx0 = xx0[mask]
  184. def transform_xy(x, y):
  185. x1, y1 = grid_finder.transform_xy(x, y)
  186. x2y2 = axes.transData.transform(np.array([x1, y1]).transpose())
  187. x2, y2 = x2y2.transpose()
  188. return x2, y2
  189. # find angles
  190. if self.nth_coord == 0:
  191. xx0 = np.empty_like(yy0)
  192. xx0.fill(self.value)
  193. xx1, yy1 = transform_xy(xx0, yy0)
  194. xx00 = xx0.copy()
  195. xx00[xx0+dx>e1] -= dx
  196. xx1a, yy1a = transform_xy(xx00, yy0)
  197. xx1b, yy1b = transform_xy(xx00+dx, yy0)
  198. xx2a, yy2a = transform_xy(xx0, yy0)
  199. xx2b, yy2b = transform_xy(xx0, yy0+dy)
  200. labels = self.grid_info["lat_labels"]
  201. labels = [l for l, m in zip(labels, mask) if m]
  202. elif self.nth_coord == 1:
  203. yy0 = np.empty_like(xx0)
  204. yy0.fill(self.value)
  205. xx1, yy1 = transform_xy(xx0, yy0)
  206. xx1a, yy1a = transform_xy(xx0, yy0)
  207. xx1b, yy1b = transform_xy(xx0, yy0+dy)
  208. xx00 = xx0.copy()
  209. xx00[xx0+dx>e1] -= dx
  210. xx2a, yy2a = transform_xy(xx00, yy0)
  211. xx2b, yy2b = transform_xy(xx00+dx, yy0)
  212. labels = self.grid_info["lon_labels"]
  213. labels = [l for l, m in zip(labels, mask) if m]
  214. def f1():
  215. dd = np.arctan2(yy1b-yy1a, xx1b-xx1a) # angle normal
  216. dd2 = np.arctan2(yy2b-yy2a, xx2b-xx2a) # angle tangent
  217. mm = ((yy1b-yy1a)==0.) & ((xx1b-xx1a)==0.) # mask where dd1 is not defined
  218. dd[mm] = dd2[mm] + np.pi / 2
  219. #dd = np.arctan2(yy2-yy1, xx2-xx1) # angle normal
  220. #dd2 = np.arctan2(yy3-yy1, xx3-xx1) # angle tangent
  221. #mm = ((yy2-yy1)==0.) & ((xx2-xx1)==0.) # mask where dd1 is not defined
  222. #dd[mm] = dd2[mm] + np.pi / 2
  223. #dd += np.pi
  224. #dd = np.arctan2(xx2-xx1, angle_tangent-yy1)
  225. trans_tick = self.get_tick_transform(axes)
  226. tr2ax = trans_tick + axes.transAxes.inverted()
  227. for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels):
  228. c2 = tr2ax.transform_point((x, y))
  229. delta=0.00001
  230. if (0. -delta<= c2[0] <= 1.+delta) and \
  231. (0. -delta<= c2[1] <= 1.+delta):
  232. d1 = d/3.14159*180.
  233. d2 = d2/3.14159*180.
  234. yield [x, y], d1, d2, lab
  235. return f1(), iter([])
  236. def get_line_transform(self, axes):
  237. return axes.transData
  238. def get_line(self, axes):
  239. self.update_lim(axes)
  240. x, y = self.grid_info["line_xy"]
  241. if self._get_line_path is None:
  242. return Path(np.column_stack([x, y]))
  243. else:
  244. return self._get_line_path(axes, x, y)
  245. class GridHelperCurveLinear(GridHelperBase):
  246. def __init__(self, aux_trans,
  247. extreme_finder=None,
  248. grid_locator1=None,
  249. grid_locator2=None,
  250. tick_formatter1=None,
  251. tick_formatter2=None):
  252. """
  253. aux_trans : a transform from the source (curved) coordinate to
  254. target (rectilinear) coordinate. An instance of MPL's Transform
  255. (inverse transform should be defined) or a tuple of two callable
  256. objects which defines the transform and its inverse. The callables
  257. need take two arguments of array of source coordinates and
  258. should return two target coordinates.
  259. e.g., ``x2, y2 = trans(x1, y1)``
  260. """
  261. super(GridHelperCurveLinear, self).__init__()
  262. self.grid_info = None
  263. self._old_values = None
  264. #self._grid_params = dict()
  265. self._aux_trans = aux_trans
  266. self.grid_finder = GridFinder(aux_trans,
  267. extreme_finder,
  268. grid_locator1,
  269. grid_locator2,
  270. tick_formatter1,
  271. tick_formatter2)
  272. def update_grid_finder(self, aux_trans=None, **kw):
  273. if aux_trans is not None:
  274. self.grid_finder.update_transform(aux_trans)
  275. self.grid_finder.update(**kw)
  276. self.invalidate()
  277. def _update(self, x1, x2, y1, y2):
  278. "bbox in 0-based image coordinates"
  279. # update wcsgrid
  280. if self.valid() and self._old_values == (x1, x2, y1, y2):
  281. return
  282. self._update_grid(x1, y1, x2, y2)
  283. self._old_values = (x1, x2, y1, y2)
  284. self._force_update = False
  285. def new_fixed_axis(self, loc,
  286. nth_coord=None,
  287. axis_direction=None,
  288. offset=None,
  289. axes=None):
  290. if axes is None:
  291. axes = self.axes
  292. if axis_direction is None:
  293. axis_direction = loc
  294. _helper = FixedAxisArtistHelper(self, loc,
  295. #nth_coord,
  296. nth_coord_ticks=nth_coord,
  297. )
  298. axisline = AxisArtist(axes, _helper, axis_direction=axis_direction)
  299. return axisline
  300. def new_floating_axis(self, nth_coord,
  301. value,
  302. axes=None,
  303. axis_direction="bottom"
  304. ):
  305. if axes is None:
  306. axes = self.axes
  307. _helper = FloatingAxisArtistHelper(
  308. self, nth_coord, value, axis_direction)
  309. axisline = AxisArtist(axes, _helper)
  310. #_helper = FloatingAxisArtistHelper(self, nth_coord,
  311. # value,
  312. # label_direction=label_direction,
  313. # )
  314. #axisline = AxisArtistFloating(axes, _helper,
  315. # axis_direction=axis_direction)
  316. axisline.line.set_clip_on(True)
  317. axisline.line.set_clip_box(axisline.axes.bbox)
  318. #axisline.major_ticklabels.set_visible(True)
  319. #axisline.minor_ticklabels.set_visible(False)
  320. #axisline.major_ticklabels.set_rotate_along_line(True)
  321. #axisline.set_rotate_label_along_line(True)
  322. return axisline
  323. def _update_grid(self, x1, y1, x2, y2):
  324. self.grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)
  325. def get_gridlines(self, which="major", axis="both"):
  326. grid_lines = []
  327. if axis in ["both", "x"]:
  328. for gl in self.grid_info["lon"]["lines"]:
  329. grid_lines.extend(gl)
  330. if axis in ["both", "y"]:
  331. for gl in self.grid_info["lat"]["lines"]:
  332. grid_lines.extend(gl)
  333. return grid_lines
  334. def get_tick_iterator(self, nth_coord, axis_side, minor=False):
  335. #axisnr = dict(left=0, bottom=1, right=2, top=3)[axis_side]
  336. angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side]
  337. #angle = [0, 90, 180, 270][axisnr]
  338. lon_or_lat = ["lon", "lat"][nth_coord]
  339. if not minor: # major ticks
  340. def f():
  341. for (xy, a), l in zip(self.grid_info[lon_or_lat]["tick_locs"][axis_side],
  342. self.grid_info[lon_or_lat]["tick_labels"][axis_side]):
  343. angle_normal = a
  344. yield xy, angle_normal, angle_tangent, l
  345. else:
  346. def f():
  347. for (xy, a), l in zip(self.grid_info[lon_or_lat]["tick_locs"][axis_side],
  348. self.grid_info[lon_or_lat]["tick_labels"][axis_side]):
  349. angle_normal = a
  350. yield xy, angle_normal, angle_tangent, ""
  351. #for xy, a, l in self.grid_info[lon_or_lat]["ticks"][axis_side]:
  352. # yield xy, a, ""
  353. return f()