floating_axes.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. """
  2. An experimental support for curvilinear grid.
  3. """
  4. # TODO :
  5. # see if tick_iterator method can be simplified by reusing the parent method.
  6. import functools
  7. import numpy as np
  8. import matplotlib as mpl
  9. from matplotlib import _api, cbook
  10. import matplotlib.patches as mpatches
  11. from matplotlib.path import Path
  12. from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
  13. from . import axislines, grid_helper_curvelinear
  14. from .axis_artist import AxisArtist
  15. from .grid_finder import ExtremeFinderSimple
  16. class FloatingAxisArtistHelper(
  17. grid_helper_curvelinear.FloatingAxisArtistHelper):
  18. pass
  19. class FixedAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper):
  20. def __init__(self, grid_helper, side, nth_coord_ticks=None):
  21. """
  22. nth_coord = along which coordinate value varies.
  23. nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  24. """
  25. lon1, lon2, lat1, lat2 = grid_helper.grid_finder.extreme_finder(*[None] * 5)
  26. value, nth_coord = _api.check_getitem(
  27. dict(left=(lon1, 0), right=(lon2, 0), bottom=(lat1, 1), top=(lat2, 1)),
  28. side=side)
  29. super().__init__(grid_helper, nth_coord, value, axis_direction=side)
  30. if nth_coord_ticks is None:
  31. nth_coord_ticks = nth_coord
  32. self.nth_coord_ticks = nth_coord_ticks
  33. self.value = value
  34. self.grid_helper = grid_helper
  35. self._side = side
  36. def update_lim(self, axes):
  37. self.grid_helper.update_lim(axes)
  38. self._grid_info = self.grid_helper._grid_info
  39. def get_tick_iterators(self, axes):
  40. """tick_loc, tick_angle, tick_label, (optionally) tick_label"""
  41. grid_finder = self.grid_helper.grid_finder
  42. lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
  43. yy0 = lat_levs / lat_factor
  44. lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
  45. xx0 = lon_levs / lon_factor
  46. extremes = self.grid_helper.grid_finder.extreme_finder(*[None] * 5)
  47. xmin, xmax = sorted(extremes[:2])
  48. ymin, ymax = sorted(extremes[2:])
  49. def trf_xy(x, y):
  50. trf = grid_finder.get_transform() + axes.transData
  51. return trf.transform(np.column_stack(np.broadcast_arrays(x, y))).T
  52. if self.nth_coord == 0:
  53. mask = (ymin <= yy0) & (yy0 <= ymax)
  54. (xx1, yy1), (dxx1, dyy1), (dxx2, dyy2) = \
  55. grid_helper_curvelinear._value_and_jacobian(
  56. trf_xy, self.value, yy0[mask], (xmin, xmax), (ymin, ymax))
  57. labels = self._grid_info["lat_labels"]
  58. elif self.nth_coord == 1:
  59. mask = (xmin <= xx0) & (xx0 <= xmax)
  60. (xx1, yy1), (dxx2, dyy2), (dxx1, dyy1) = \
  61. grid_helper_curvelinear._value_and_jacobian(
  62. trf_xy, xx0[mask], self.value, (xmin, xmax), (ymin, ymax))
  63. labels = self._grid_info["lon_labels"]
  64. labels = [l for l, m in zip(labels, mask) if m]
  65. angle_normal = np.arctan2(dyy1, dxx1)
  66. angle_tangent = np.arctan2(dyy2, dxx2)
  67. mm = (dyy1 == 0) & (dxx1 == 0) # points with degenerate normal
  68. angle_normal[mm] = angle_tangent[mm] + np.pi / 2
  69. tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
  70. in_01 = functools.partial(
  71. mpl.transforms._interval_contains_close, (0, 1))
  72. def f1():
  73. for x, y, normal, tangent, lab \
  74. in zip(xx1, yy1, angle_normal, angle_tangent, labels):
  75. c2 = tick_to_axes.transform((x, y))
  76. if in_01(c2[0]) and in_01(c2[1]):
  77. yield [x, y], *np.rad2deg([normal, tangent]), lab
  78. return f1(), iter([])
  79. def get_line(self, axes):
  80. self.update_lim(axes)
  81. k, v = dict(left=("lon_lines0", 0),
  82. right=("lon_lines0", 1),
  83. bottom=("lat_lines0", 0),
  84. top=("lat_lines0", 1))[self._side]
  85. xx, yy = self._grid_info[k][v]
  86. return Path(np.column_stack([xx, yy]))
  87. class ExtremeFinderFixed(ExtremeFinderSimple):
  88. # docstring inherited
  89. def __init__(self, extremes):
  90. """
  91. This subclass always returns the same bounding box.
  92. Parameters
  93. ----------
  94. extremes : (float, float, float, float)
  95. The bounding box that this helper always returns.
  96. """
  97. self._extremes = extremes
  98. def __call__(self, transform_xy, x1, y1, x2, y2):
  99. # docstring inherited
  100. return self._extremes
  101. class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
  102. def __init__(self, aux_trans, extremes,
  103. grid_locator1=None,
  104. grid_locator2=None,
  105. tick_formatter1=None,
  106. tick_formatter2=None):
  107. # docstring inherited
  108. super().__init__(aux_trans,
  109. extreme_finder=ExtremeFinderFixed(extremes),
  110. grid_locator1=grid_locator1,
  111. grid_locator2=grid_locator2,
  112. tick_formatter1=tick_formatter1,
  113. tick_formatter2=tick_formatter2)
  114. @_api.deprecated("3.8")
  115. def get_data_boundary(self, side):
  116. """
  117. Return v=0, nth=1.
  118. """
  119. lon1, lon2, lat1, lat2 = self.grid_finder.extreme_finder(*[None] * 5)
  120. return dict(left=(lon1, 0),
  121. right=(lon2, 0),
  122. bottom=(lat1, 1),
  123. top=(lat2, 1))[side]
  124. def new_fixed_axis(self, loc,
  125. nth_coord=None,
  126. axis_direction=None,
  127. offset=None,
  128. axes=None):
  129. if axes is None:
  130. axes = self.axes
  131. if axis_direction is None:
  132. axis_direction = loc
  133. # This is not the same as the FixedAxisArtistHelper class used by
  134. # grid_helper_curvelinear.GridHelperCurveLinear.new_fixed_axis!
  135. helper = FixedAxisArtistHelper(
  136. self, loc, nth_coord_ticks=nth_coord)
  137. axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
  138. # Perhaps should be moved to the base class?
  139. axisline.line.set_clip_on(True)
  140. axisline.line.set_clip_box(axisline.axes.bbox)
  141. return axisline
  142. # new_floating_axis will inherit the grid_helper's extremes.
  143. # def new_floating_axis(self, nth_coord,
  144. # value,
  145. # axes=None,
  146. # axis_direction="bottom"
  147. # ):
  148. # axis = super(GridHelperCurveLinear,
  149. # self).new_floating_axis(nth_coord,
  150. # value, axes=axes,
  151. # axis_direction=axis_direction)
  152. # # set extreme values of the axis helper
  153. # if nth_coord == 1:
  154. # axis.get_helper().set_extremes(*self._extremes[:2])
  155. # elif nth_coord == 0:
  156. # axis.get_helper().set_extremes(*self._extremes[2:])
  157. # return axis
  158. def _update_grid(self, x1, y1, x2, y2):
  159. if self._grid_info is None:
  160. self._grid_info = dict()
  161. grid_info = self._grid_info
  162. grid_finder = self.grid_finder
  163. extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
  164. x1, y1, x2, y2)
  165. lon_min, lon_max = sorted(extremes[:2])
  166. lat_min, lat_max = sorted(extremes[2:])
  167. grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max # extremes
  168. lon_levs, lon_n, lon_factor = \
  169. grid_finder.grid_locator1(lon_min, lon_max)
  170. lon_levs = np.asarray(lon_levs)
  171. lat_levs, lat_n, lat_factor = \
  172. grid_finder.grid_locator2(lat_min, lat_max)
  173. lat_levs = np.asarray(lat_levs)
  174. grid_info["lon_info"] = lon_levs, lon_n, lon_factor
  175. grid_info["lat_info"] = lat_levs, lat_n, lat_factor
  176. grid_info["lon_labels"] = grid_finder.tick_formatter1(
  177. "bottom", lon_factor, lon_levs)
  178. grid_info["lat_labels"] = grid_finder.tick_formatter2(
  179. "bottom", lat_factor, lat_levs)
  180. lon_values = lon_levs[:lon_n] / lon_factor
  181. lat_values = lat_levs[:lat_n] / lat_factor
  182. lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
  183. lon_values[(lon_min < lon_values) & (lon_values < lon_max)],
  184. lat_values[(lat_min < lat_values) & (lat_values < lat_max)],
  185. lon_min, lon_max, lat_min, lat_max)
  186. grid_info["lon_lines"] = lon_lines
  187. grid_info["lat_lines"] = lat_lines
  188. lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
  189. # lon_min, lon_max, lat_min, lat_max)
  190. extremes[:2], extremes[2:], *extremes)
  191. grid_info["lon_lines0"] = lon_lines
  192. grid_info["lat_lines0"] = lat_lines
  193. def get_gridlines(self, which="major", axis="both"):
  194. grid_lines = []
  195. if axis in ["both", "x"]:
  196. grid_lines.extend(self._grid_info["lon_lines"])
  197. if axis in ["both", "y"]:
  198. grid_lines.extend(self._grid_info["lat_lines"])
  199. return grid_lines
  200. class FloatingAxesBase:
  201. def __init__(self, *args, grid_helper, **kwargs):
  202. _api.check_isinstance(GridHelperCurveLinear, grid_helper=grid_helper)
  203. super().__init__(*args, grid_helper=grid_helper, **kwargs)
  204. self.set_aspect(1.)
  205. def _gen_axes_patch(self):
  206. # docstring inherited
  207. x0, x1, y0, y1 = self.get_grid_helper().grid_finder.extreme_finder(*[None] * 5)
  208. patch = mpatches.Polygon([(x0, y0), (x1, y0), (x1, y1), (x0, y1)])
  209. patch.get_path()._interpolation_steps = 100
  210. return patch
  211. def clear(self):
  212. super().clear()
  213. self.patch.set_transform(
  214. self.get_grid_helper().grid_finder.get_transform()
  215. + self.transData)
  216. # The original patch is not in the draw tree; it is only used for
  217. # clipping purposes.
  218. orig_patch = super()._gen_axes_patch()
  219. orig_patch.set_figure(self.figure)
  220. orig_patch.set_transform(self.transAxes)
  221. self.patch.set_clip_path(orig_patch)
  222. self.gridlines.set_clip_path(orig_patch)
  223. self.adjust_axes_lim()
  224. def adjust_axes_lim(self):
  225. bbox = self.patch.get_path().get_extents(
  226. # First transform to pixel coords, then to parent data coords.
  227. self.patch.get_transform() - self.transData)
  228. bbox = bbox.expanded(1.02, 1.02)
  229. self.set_xlim(bbox.xmin, bbox.xmax)
  230. self.set_ylim(bbox.ymin, bbox.ymax)
  231. floatingaxes_class_factory = cbook._make_class_factory(
  232. FloatingAxesBase, "Floating{}")
  233. FloatingAxes = floatingaxes_class_factory(
  234. host_axes_class_factory(axislines.Axes))
  235. FloatingSubplot = FloatingAxes