axislines.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. """
  2. Axislines includes modified implementation of the Axes class. The
  3. biggest difference is that the artists responsible for drawing the axis spine,
  4. ticks, ticklabels and axis labels are separated out from Matplotlib's Axis
  5. class. Originally, this change was motivated to support curvilinear
  6. grid. Here are a few reasons that I came up with a new axes class:
  7. * "top" and "bottom" x-axis (or "left" and "right" y-axis) can have
  8. different ticks (tick locations and labels). This is not possible
  9. with the current Matplotlib, although some twin axes trick can help.
  10. * Curvilinear grid.
  11. * angled ticks.
  12. In the new axes class, xaxis and yaxis is set to not visible by
  13. default, and new set of artist (AxisArtist) are defined to draw axis
  14. line, ticks, ticklabels and axis label. Axes.axis attribute serves as
  15. a dictionary of these artists, i.e., ax.axis["left"] is a AxisArtist
  16. instance responsible to draw left y-axis. The default Axes.axis contains
  17. "bottom", "left", "top" and "right".
  18. AxisArtist can be considered as a container artist and has the following
  19. children artists which will draw ticks, labels, etc.
  20. * line
  21. * major_ticks, major_ticklabels
  22. * minor_ticks, minor_ticklabels
  23. * offsetText
  24. * label
  25. Note that these are separate artists from `matplotlib.axis.Axis`, thus most
  26. tick-related functions in Matplotlib won't work. For example, color and
  27. markerwidth of the ``ax.axis["bottom"].major_ticks`` will follow those of
  28. Axes.xaxis unless explicitly specified.
  29. In addition to AxisArtist, the Axes will have *gridlines* attribute,
  30. which obviously draws grid lines. The gridlines needs to be separated
  31. from the axis as some gridlines can never pass any axis.
  32. """
  33. import numpy as np
  34. import matplotlib as mpl
  35. from matplotlib import _api
  36. import matplotlib.axes as maxes
  37. from matplotlib.path import Path
  38. from mpl_toolkits.axes_grid1 import mpl_axes
  39. from .axisline_style import AxislineStyle # noqa
  40. from .axis_artist import AxisArtist, GridlinesCollection
  41. class _AxisArtistHelperBase:
  42. """
  43. Base class for axis helper.
  44. Subclasses should define the methods listed below. The *axes*
  45. argument will be the ``.axes`` attribute of the caller artist. ::
  46. # Construct the spine.
  47. def get_line_transform(self, axes):
  48. return transform
  49. def get_line(self, axes):
  50. return path
  51. # Construct the label.
  52. def get_axislabel_transform(self, axes):
  53. return transform
  54. def get_axislabel_pos_angle(self, axes):
  55. return (x, y), angle
  56. # Construct the ticks.
  57. def get_tick_transform(self, axes):
  58. return transform
  59. def get_tick_iterators(self, axes):
  60. # A pair of iterables (one for major ticks, one for minor ticks)
  61. # that yield (tick_position, tick_angle, tick_label).
  62. return iter_major, iter_minor
  63. """
  64. def update_lim(self, axes):
  65. pass
  66. def _to_xy(self, values, const):
  67. """
  68. Create a (*values.shape, 2)-shape array representing (x, y) pairs.
  69. The other coordinate is filled with the constant *const*.
  70. Example::
  71. >>> self.nth_coord = 0
  72. >>> self._to_xy([1, 2, 3], const=0)
  73. array([[1, 0],
  74. [2, 0],
  75. [3, 0]])
  76. """
  77. if self.nth_coord == 0:
  78. return np.stack(np.broadcast_arrays(values, const), axis=-1)
  79. elif self.nth_coord == 1:
  80. return np.stack(np.broadcast_arrays(const, values), axis=-1)
  81. else:
  82. raise ValueError("Unexpected nth_coord")
  83. class _FixedAxisArtistHelperBase(_AxisArtistHelperBase):
  84. """Helper class for a fixed (in the axes coordinate) axis."""
  85. passthru_pt = _api.deprecated("3.7")(property(
  86. lambda self: {"left": (0, 0), "right": (1, 0),
  87. "bottom": (0, 0), "top": (0, 1)}[self._loc]))
  88. def __init__(self, loc, nth_coord=None):
  89. """``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis."""
  90. self.nth_coord = (
  91. nth_coord if nth_coord is not None else
  92. _api.check_getitem(
  93. {"bottom": 0, "top": 0, "left": 1, "right": 1}, loc=loc))
  94. if (nth_coord == 0 and loc not in ["left", "right"]
  95. or nth_coord == 1 and loc not in ["bottom", "top"]):
  96. _api.warn_deprecated(
  97. "3.7", message=f"{loc=!r} is incompatible with "
  98. "{nth_coord=}; support is deprecated since %(since)s")
  99. self._loc = loc
  100. self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc]
  101. super().__init__()
  102. # axis line in transAxes
  103. self._path = Path(self._to_xy((0, 1), const=self._pos))
  104. def get_nth_coord(self):
  105. return self.nth_coord
  106. # LINE
  107. def get_line(self, axes):
  108. return self._path
  109. def get_line_transform(self, axes):
  110. return axes.transAxes
  111. # LABEL
  112. def get_axislabel_transform(self, axes):
  113. return axes.transAxes
  114. def get_axislabel_pos_angle(self, axes):
  115. """
  116. Return the label reference position in transAxes.
  117. get_label_transform() returns a transform of (transAxes+offset)
  118. """
  119. return dict(left=((0., 0.5), 90), # (position, angle_tangent)
  120. right=((1., 0.5), 90),
  121. bottom=((0.5, 0.), 0),
  122. top=((0.5, 1.), 0))[self._loc]
  123. # TICK
  124. def get_tick_transform(self, axes):
  125. return [axes.get_xaxis_transform(),
  126. axes.get_yaxis_transform()][self.nth_coord]
  127. class _FloatingAxisArtistHelperBase(_AxisArtistHelperBase):
  128. def __init__(self, nth_coord, value):
  129. self.nth_coord = nth_coord
  130. self._value = value
  131. super().__init__()
  132. def get_nth_coord(self):
  133. return self.nth_coord
  134. def get_line(self, axes):
  135. raise RuntimeError(
  136. "get_line method should be defined by the derived class")
  137. class FixedAxisArtistHelperRectilinear(_FixedAxisArtistHelperBase):
  138. def __init__(self, axes, loc, nth_coord=None):
  139. """
  140. nth_coord = along which coordinate value varies
  141. in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  142. """
  143. super().__init__(loc, nth_coord)
  144. self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
  145. # TICK
  146. def get_tick_iterators(self, axes):
  147. """tick_loc, tick_angle, tick_label"""
  148. if self._loc in ["bottom", "top"]:
  149. angle_normal, angle_tangent = 90, 0
  150. else: # "left", "right"
  151. angle_normal, angle_tangent = 0, 90
  152. major = self.axis.major
  153. major_locs = major.locator()
  154. major_labels = major.formatter.format_ticks(major_locs)
  155. minor = self.axis.minor
  156. minor_locs = minor.locator()
  157. minor_labels = minor.formatter.format_ticks(minor_locs)
  158. tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
  159. def _f(locs, labels):
  160. for loc, label in zip(locs, labels):
  161. c = self._to_xy(loc, const=self._pos)
  162. # check if the tick point is inside axes
  163. c2 = tick_to_axes.transform(c)
  164. if mpl.transforms._interval_contains_close(
  165. (0, 1), c2[self.nth_coord]):
  166. yield c, angle_normal, angle_tangent, label
  167. return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
  168. class FloatingAxisArtistHelperRectilinear(_FloatingAxisArtistHelperBase):
  169. def __init__(self, axes, nth_coord,
  170. passingthrough_point, axis_direction="bottom"):
  171. super().__init__(nth_coord, passingthrough_point)
  172. self._axis_direction = axis_direction
  173. self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
  174. def get_line(self, axes):
  175. fixed_coord = 1 - self.nth_coord
  176. data_to_axes = axes.transData - axes.transAxes
  177. p = data_to_axes.transform([self._value, self._value])
  178. return Path(self._to_xy((0, 1), const=p[fixed_coord]))
  179. def get_line_transform(self, axes):
  180. return axes.transAxes
  181. def get_axislabel_transform(self, axes):
  182. return axes.transAxes
  183. def get_axislabel_pos_angle(self, axes):
  184. """
  185. Return the label reference position in transAxes.
  186. get_label_transform() returns a transform of (transAxes+offset)
  187. """
  188. angle = [0, 90][self.nth_coord]
  189. fixed_coord = 1 - self.nth_coord
  190. data_to_axes = axes.transData - axes.transAxes
  191. p = data_to_axes.transform([self._value, self._value])
  192. verts = self._to_xy(0.5, const=p[fixed_coord])
  193. if 0 <= verts[fixed_coord] <= 1:
  194. return verts, angle
  195. else:
  196. return None, None
  197. def get_tick_transform(self, axes):
  198. return axes.transData
  199. def get_tick_iterators(self, axes):
  200. """tick_loc, tick_angle, tick_label"""
  201. if self.nth_coord == 0:
  202. angle_normal, angle_tangent = 90, 0
  203. else:
  204. angle_normal, angle_tangent = 0, 90
  205. major = self.axis.major
  206. major_locs = major.locator()
  207. major_labels = major.formatter.format_ticks(major_locs)
  208. minor = self.axis.minor
  209. minor_locs = minor.locator()
  210. minor_labels = minor.formatter.format_ticks(minor_locs)
  211. data_to_axes = axes.transData - axes.transAxes
  212. def _f(locs, labels):
  213. for loc, label in zip(locs, labels):
  214. c = self._to_xy(loc, const=self._value)
  215. c1, c2 = data_to_axes.transform(c)
  216. if 0 <= c1 <= 1 and 0 <= c2 <= 1:
  217. yield c, angle_normal, angle_tangent, label
  218. return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
  219. class AxisArtistHelper: # Backcompat.
  220. Fixed = _FixedAxisArtistHelperBase
  221. Floating = _FloatingAxisArtistHelperBase
  222. class AxisArtistHelperRectlinear: # Backcompat.
  223. Fixed = FixedAxisArtistHelperRectilinear
  224. Floating = FloatingAxisArtistHelperRectilinear
  225. class GridHelperBase:
  226. def __init__(self):
  227. self._old_limits = None
  228. super().__init__()
  229. def update_lim(self, axes):
  230. x1, x2 = axes.get_xlim()
  231. y1, y2 = axes.get_ylim()
  232. if self._old_limits != (x1, x2, y1, y2):
  233. self._update_grid(x1, y1, x2, y2)
  234. self._old_limits = (x1, x2, y1, y2)
  235. def _update_grid(self, x1, y1, x2, y2):
  236. """Cache relevant computations when the axes limits have changed."""
  237. def get_gridlines(self, which, axis):
  238. """
  239. Return list of grid lines as a list of paths (list of points).
  240. Parameters
  241. ----------
  242. which : {"both", "major", "minor"}
  243. axis : {"both", "x", "y"}
  244. """
  245. return []
  246. class GridHelperRectlinear(GridHelperBase):
  247. def __init__(self, axes):
  248. super().__init__()
  249. self.axes = axes
  250. def new_fixed_axis(self, loc,
  251. nth_coord=None,
  252. axis_direction=None,
  253. offset=None,
  254. axes=None,
  255. ):
  256. if axes is None:
  257. _api.warn_external(
  258. "'new_fixed_axis' explicitly requires the axes keyword.")
  259. axes = self.axes
  260. if axis_direction is None:
  261. axis_direction = loc
  262. helper = FixedAxisArtistHelperRectilinear(axes, loc, nth_coord)
  263. axisline = AxisArtist(axes, helper, offset=offset,
  264. axis_direction=axis_direction)
  265. return axisline
  266. def new_floating_axis(self, nth_coord, value,
  267. axis_direction="bottom",
  268. axes=None,
  269. ):
  270. if axes is None:
  271. _api.warn_external(
  272. "'new_floating_axis' explicitly requires the axes keyword.")
  273. axes = self.axes
  274. helper = FloatingAxisArtistHelperRectilinear(
  275. axes, nth_coord, value, axis_direction)
  276. axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
  277. axisline.line.set_clip_on(True)
  278. axisline.line.set_clip_box(axisline.axes.bbox)
  279. return axisline
  280. def get_gridlines(self, which="major", axis="both"):
  281. """
  282. Return list of gridline coordinates in data coordinates.
  283. Parameters
  284. ----------
  285. which : {"both", "major", "minor"}
  286. axis : {"both", "x", "y"}
  287. """
  288. _api.check_in_list(["both", "major", "minor"], which=which)
  289. _api.check_in_list(["both", "x", "y"], axis=axis)
  290. gridlines = []
  291. if axis in ("both", "x"):
  292. locs = []
  293. y1, y2 = self.axes.get_ylim()
  294. if which in ("both", "major"):
  295. locs.extend(self.axes.xaxis.major.locator())
  296. if which in ("both", "minor"):
  297. locs.extend(self.axes.xaxis.minor.locator())
  298. for x in locs:
  299. gridlines.append([[x, x], [y1, y2]])
  300. if axis in ("both", "y"):
  301. x1, x2 = self.axes.get_xlim()
  302. locs = []
  303. if self.axes.yaxis._major_tick_kw["gridOn"]:
  304. locs.extend(self.axes.yaxis.major.locator())
  305. if self.axes.yaxis._minor_tick_kw["gridOn"]:
  306. locs.extend(self.axes.yaxis.minor.locator())
  307. for y in locs:
  308. gridlines.append([[x1, x2], [y, y]])
  309. return gridlines
  310. class Axes(maxes.Axes):
  311. @_api.deprecated("3.8", alternative="ax.axis")
  312. def __call__(self, *args, **kwargs):
  313. return maxes.Axes.axis(self.axes, *args, **kwargs)
  314. def __init__(self, *args, grid_helper=None, **kwargs):
  315. self._axisline_on = True
  316. self._grid_helper = (grid_helper if grid_helper
  317. else GridHelperRectlinear(self))
  318. super().__init__(*args, **kwargs)
  319. self.toggle_axisline(True)
  320. def toggle_axisline(self, b=None):
  321. if b is None:
  322. b = not self._axisline_on
  323. if b:
  324. self._axisline_on = True
  325. self.spines[:].set_visible(False)
  326. self.xaxis.set_visible(False)
  327. self.yaxis.set_visible(False)
  328. else:
  329. self._axisline_on = False
  330. self.spines[:].set_visible(True)
  331. self.xaxis.set_visible(True)
  332. self.yaxis.set_visible(True)
  333. @property
  334. def axis(self):
  335. return self._axislines
  336. def clear(self):
  337. # docstring inherited
  338. # Init gridlines before clear() as clear() calls grid().
  339. self.gridlines = gridlines = GridlinesCollection(
  340. [],
  341. colors=mpl.rcParams['grid.color'],
  342. linestyles=mpl.rcParams['grid.linestyle'],
  343. linewidths=mpl.rcParams['grid.linewidth'])
  344. self._set_artist_props(gridlines)
  345. gridlines.set_grid_helper(self.get_grid_helper())
  346. super().clear()
  347. # clip_path is set after Axes.clear(): that's when a patch is created.
  348. gridlines.set_clip_path(self.axes.patch)
  349. # Init axis artists.
  350. self._axislines = mpl_axes.Axes.AxisDict(self)
  351. new_fixed_axis = self.get_grid_helper().new_fixed_axis
  352. self._axislines.update({
  353. loc: new_fixed_axis(loc=loc, axes=self, axis_direction=loc)
  354. for loc in ["bottom", "top", "left", "right"]})
  355. for axisline in [self._axislines["top"], self._axislines["right"]]:
  356. axisline.label.set_visible(False)
  357. axisline.major_ticklabels.set_visible(False)
  358. axisline.minor_ticklabels.set_visible(False)
  359. def get_grid_helper(self):
  360. return self._grid_helper
  361. def grid(self, visible=None, which='major', axis="both", **kwargs):
  362. """
  363. Toggle the gridlines, and optionally set the properties of the lines.
  364. """
  365. # There are some discrepancies in the behavior of grid() between
  366. # axes_grid and Matplotlib, because axes_grid explicitly sets the
  367. # visibility of the gridlines.
  368. super().grid(visible, which=which, axis=axis, **kwargs)
  369. if not self._axisline_on:
  370. return
  371. if visible is None:
  372. visible = (self.axes.xaxis._minor_tick_kw["gridOn"]
  373. or self.axes.xaxis._major_tick_kw["gridOn"]
  374. or self.axes.yaxis._minor_tick_kw["gridOn"]
  375. or self.axes.yaxis._major_tick_kw["gridOn"])
  376. self.gridlines.set(which=which, axis=axis, visible=visible)
  377. self.gridlines.set(**kwargs)
  378. def get_children(self):
  379. if self._axisline_on:
  380. children = [*self._axislines.values(), self.gridlines]
  381. else:
  382. children = []
  383. children.extend(super().get_children())
  384. return children
  385. def new_fixed_axis(self, loc, offset=None):
  386. gh = self.get_grid_helper()
  387. axis = gh.new_fixed_axis(loc,
  388. nth_coord=None,
  389. axis_direction=None,
  390. offset=offset,
  391. axes=self,
  392. )
  393. return axis
  394. def new_floating_axis(self, nth_coord, value, axis_direction="bottom"):
  395. gh = self.get_grid_helper()
  396. axis = gh.new_floating_axis(nth_coord, value,
  397. axis_direction=axis_direction,
  398. axes=self)
  399. return axis
  400. class AxesZero(Axes):
  401. def clear(self):
  402. super().clear()
  403. new_floating_axis = self.get_grid_helper().new_floating_axis
  404. self._axislines.update(
  405. xzero=new_floating_axis(
  406. nth_coord=0, value=0., axis_direction="bottom", axes=self),
  407. yzero=new_floating_axis(
  408. nth_coord=1, value=0., axis_direction="left", axes=self),
  409. )
  410. for k in ["xzero", "yzero"]:
  411. self._axislines[k].line.set_clip_path(self.patch)
  412. self._axislines[k].set_visible(False)
  413. Subplot = Axes
  414. SubplotZero = AxesZero