axes_grid.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. from numbers import Number
  2. import functools
  3. from types import MethodType
  4. import numpy as np
  5. from matplotlib import _api, cbook
  6. from matplotlib.gridspec import SubplotSpec
  7. from .axes_divider import Size, SubplotDivider, Divider
  8. from .mpl_axes import Axes, SimpleAxisArtist
  9. class CbarAxesBase:
  10. def __init__(self, *args, orientation, **kwargs):
  11. self.orientation = orientation
  12. super().__init__(*args, **kwargs)
  13. def colorbar(self, mappable, **kwargs):
  14. return self.figure.colorbar(
  15. mappable, cax=self, location=self.orientation, **kwargs)
  16. @_api.deprecated("3.8", alternative="ax.tick_params and colorbar.set_label")
  17. def toggle_label(self, b):
  18. axis = self.axis[self.orientation]
  19. axis.toggle(ticklabels=b, label=b)
  20. _cbaraxes_class_factory = cbook._make_class_factory(CbarAxesBase, "Cbar{}")
  21. class Grid:
  22. """
  23. A grid of Axes.
  24. In Matplotlib, the Axes location (and size) is specified in normalized
  25. figure coordinates. This may not be ideal for images that needs to be
  26. displayed with a given aspect ratio; for example, it is difficult to
  27. display multiple images of a same size with some fixed padding between
  28. them. AxesGrid can be used in such case.
  29. """
  30. _defaultAxesClass = Axes
  31. def __init__(self, fig,
  32. rect,
  33. nrows_ncols,
  34. ngrids=None,
  35. direction="row",
  36. axes_pad=0.02,
  37. *,
  38. share_all=False,
  39. share_x=True,
  40. share_y=True,
  41. label_mode="L",
  42. axes_class=None,
  43. aspect=False,
  44. ):
  45. """
  46. Parameters
  47. ----------
  48. fig : `.Figure`
  49. The parent figure.
  50. rect : (float, float, float, float), (int, int, int), int, or \
  51. `~.SubplotSpec`
  52. The axes position, as a ``(left, bottom, width, height)`` tuple,
  53. as a three-digit subplot position code (e.g., ``(1, 2, 1)`` or
  54. ``121``), or as a `~.SubplotSpec`.
  55. nrows_ncols : (int, int)
  56. Number of rows and columns in the grid.
  57. ngrids : int or None, default: None
  58. If not None, only the first *ngrids* axes in the grid are created.
  59. direction : {"row", "column"}, default: "row"
  60. Whether axes are created in row-major ("row by row") or
  61. column-major order ("column by column"). This also affects the
  62. order in which axes are accessed using indexing (``grid[index]``).
  63. axes_pad : float or (float, float), default: 0.02
  64. Padding or (horizontal padding, vertical padding) between axes, in
  65. inches.
  66. share_all : bool, default: False
  67. Whether all axes share their x- and y-axis. Overrides *share_x*
  68. and *share_y*.
  69. share_x : bool, default: True
  70. Whether all axes of a column share their x-axis.
  71. share_y : bool, default: True
  72. Whether all axes of a row share their y-axis.
  73. label_mode : {"L", "1", "all", "keep"}, default: "L"
  74. Determines which axes will get tick labels:
  75. - "L": All axes on the left column get vertical tick labels;
  76. all axes on the bottom row get horizontal tick labels.
  77. - "1": Only the bottom left axes is labelled.
  78. - "all": All axes are labelled.
  79. - "keep": Do not do anything.
  80. axes_class : subclass of `matplotlib.axes.Axes`, default: None
  81. aspect : bool, default: False
  82. Whether the axes aspect ratio follows the aspect ratio of the data
  83. limits.
  84. """
  85. self._nrows, self._ncols = nrows_ncols
  86. if ngrids is None:
  87. ngrids = self._nrows * self._ncols
  88. else:
  89. if not 0 < ngrids <= self._nrows * self._ncols:
  90. raise ValueError(
  91. "ngrids must be positive and not larger than nrows*ncols")
  92. self.ngrids = ngrids
  93. self._horiz_pad_size, self._vert_pad_size = map(
  94. Size.Fixed, np.broadcast_to(axes_pad, 2))
  95. _api.check_in_list(["column", "row"], direction=direction)
  96. self._direction = direction
  97. if axes_class is None:
  98. axes_class = self._defaultAxesClass
  99. elif isinstance(axes_class, (list, tuple)):
  100. cls, kwargs = axes_class
  101. axes_class = functools.partial(cls, **kwargs)
  102. kw = dict(horizontal=[], vertical=[], aspect=aspect)
  103. if isinstance(rect, (Number, SubplotSpec)):
  104. self._divider = SubplotDivider(fig, rect, **kw)
  105. elif len(rect) == 3:
  106. self._divider = SubplotDivider(fig, *rect, **kw)
  107. elif len(rect) == 4:
  108. self._divider = Divider(fig, rect, **kw)
  109. else:
  110. raise TypeError("Incorrect rect format")
  111. rect = self._divider.get_position()
  112. axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
  113. for i in range(self.ngrids):
  114. col, row = self._get_col_row(i)
  115. if share_all:
  116. sharex = sharey = axes_array[0, 0]
  117. else:
  118. sharex = axes_array[0, col] if share_x else None
  119. sharey = axes_array[row, 0] if share_y else None
  120. axes_array[row, col] = axes_class(
  121. fig, rect, sharex=sharex, sharey=sharey)
  122. self.axes_all = axes_array.ravel(
  123. order="C" if self._direction == "row" else "F").tolist()
  124. self.axes_column = axes_array.T.tolist()
  125. self.axes_row = axes_array.tolist()
  126. self.axes_llc = self.axes_column[0][-1]
  127. self._init_locators()
  128. for ax in self.axes_all:
  129. fig.add_axes(ax)
  130. self.set_label_mode(label_mode)
  131. def _init_locators(self):
  132. self._divider.set_horizontal(
  133. [Size.Scaled(1), self._horiz_pad_size] * (self._ncols-1) + [Size.Scaled(1)])
  134. self._divider.set_vertical(
  135. [Size.Scaled(1), self._vert_pad_size] * (self._nrows-1) + [Size.Scaled(1)])
  136. for i in range(self.ngrids):
  137. col, row = self._get_col_row(i)
  138. self.axes_all[i].set_axes_locator(
  139. self._divider.new_locator(nx=2 * col, ny=2 * (self._nrows - 1 - row)))
  140. def _get_col_row(self, n):
  141. if self._direction == "column":
  142. col, row = divmod(n, self._nrows)
  143. else:
  144. row, col = divmod(n, self._ncols)
  145. return col, row
  146. # Good to propagate __len__ if we have __getitem__
  147. def __len__(self):
  148. return len(self.axes_all)
  149. def __getitem__(self, i):
  150. return self.axes_all[i]
  151. def get_geometry(self):
  152. """
  153. Return the number of rows and columns of the grid as (nrows, ncols).
  154. """
  155. return self._nrows, self._ncols
  156. def set_axes_pad(self, axes_pad):
  157. """
  158. Set the padding between the axes.
  159. Parameters
  160. ----------
  161. axes_pad : (float, float)
  162. The padding (horizontal pad, vertical pad) in inches.
  163. """
  164. self._horiz_pad_size.fixed_size = axes_pad[0]
  165. self._vert_pad_size.fixed_size = axes_pad[1]
  166. def get_axes_pad(self):
  167. """
  168. Return the axes padding.
  169. Returns
  170. -------
  171. hpad, vpad
  172. Padding (horizontal pad, vertical pad) in inches.
  173. """
  174. return (self._horiz_pad_size.fixed_size,
  175. self._vert_pad_size.fixed_size)
  176. def set_aspect(self, aspect):
  177. """Set the aspect of the SubplotDivider."""
  178. self._divider.set_aspect(aspect)
  179. def get_aspect(self):
  180. """Return the aspect of the SubplotDivider."""
  181. return self._divider.get_aspect()
  182. def set_label_mode(self, mode):
  183. """
  184. Define which axes have tick labels.
  185. Parameters
  186. ----------
  187. mode : {"L", "1", "all", "keep"}
  188. The label mode:
  189. - "L": All axes on the left column get vertical tick labels;
  190. all axes on the bottom row get horizontal tick labels.
  191. - "1": Only the bottom left axes is labelled.
  192. - "all": All axes are labelled.
  193. - "keep": Do not do anything.
  194. """
  195. is_last_row, is_first_col = (
  196. np.mgrid[:self._nrows, :self._ncols] == [[[self._nrows - 1]], [[0]]])
  197. if mode == "all":
  198. bottom = left = np.full((self._nrows, self._ncols), True)
  199. elif mode == "L":
  200. bottom = is_last_row
  201. left = is_first_col
  202. elif mode == "1":
  203. bottom = left = is_last_row & is_first_col
  204. else:
  205. # Use _api.check_in_list at the top of the method when deprecation
  206. # period expires
  207. if mode != 'keep':
  208. _api.warn_deprecated(
  209. '3.7', name="Grid label_mode",
  210. message='Passing an undefined label_mode is deprecated '
  211. 'since %(since)s and will become an error '
  212. '%(removal)s. To silence this warning, pass '
  213. '"keep", which gives the same behaviour.')
  214. return
  215. for i in range(self._nrows):
  216. for j in range(self._ncols):
  217. ax = self.axes_row[i][j]
  218. if isinstance(ax.axis, MethodType):
  219. bottom_axis = SimpleAxisArtist(ax.xaxis, 1, ax.spines["bottom"])
  220. left_axis = SimpleAxisArtist(ax.yaxis, 1, ax.spines["left"])
  221. else:
  222. bottom_axis = ax.axis["bottom"]
  223. left_axis = ax.axis["left"]
  224. bottom_axis.toggle(ticklabels=bottom[i, j], label=bottom[i, j])
  225. left_axis.toggle(ticklabels=left[i, j], label=left[i, j])
  226. def get_divider(self):
  227. return self._divider
  228. def set_axes_locator(self, locator):
  229. self._divider.set_locator(locator)
  230. def get_axes_locator(self):
  231. return self._divider.get_locator()
  232. class ImageGrid(Grid):
  233. """
  234. A grid of Axes for Image display.
  235. This class is a specialization of `~.axes_grid1.axes_grid.Grid` for displaying a
  236. grid of images. In particular, it forces all axes in a column to share their x-axis
  237. and all axes in a row to share their y-axis. It further provides helpers to add
  238. colorbars to some or all axes.
  239. """
  240. def __init__(self, fig,
  241. rect,
  242. nrows_ncols,
  243. ngrids=None,
  244. direction="row",
  245. axes_pad=0.02,
  246. *,
  247. share_all=False,
  248. aspect=True,
  249. label_mode="L",
  250. cbar_mode=None,
  251. cbar_location="right",
  252. cbar_pad=None,
  253. cbar_size="5%",
  254. cbar_set_cax=True,
  255. axes_class=None,
  256. ):
  257. """
  258. Parameters
  259. ----------
  260. fig : `.Figure`
  261. The parent figure.
  262. rect : (float, float, float, float) or int
  263. The axes position, as a ``(left, bottom, width, height)`` tuple or
  264. as a three-digit subplot position code (e.g., "121").
  265. nrows_ncols : (int, int)
  266. Number of rows and columns in the grid.
  267. ngrids : int or None, default: None
  268. If not None, only the first *ngrids* axes in the grid are created.
  269. direction : {"row", "column"}, default: "row"
  270. Whether axes are created in row-major ("row by row") or
  271. column-major order ("column by column"). This also affects the
  272. order in which axes are accessed using indexing (``grid[index]``).
  273. axes_pad : float or (float, float), default: 0.02in
  274. Padding or (horizontal padding, vertical padding) between axes, in
  275. inches.
  276. share_all : bool, default: False
  277. Whether all axes share their x- and y-axis. Note that in any case,
  278. all axes in a column share their x-axis and all axes in a row share
  279. their y-axis.
  280. aspect : bool, default: True
  281. Whether the axes aspect ratio follows the aspect ratio of the data
  282. limits.
  283. label_mode : {"L", "1", "all"}, default: "L"
  284. Determines which axes will get tick labels:
  285. - "L": All axes on the left column get vertical tick labels;
  286. all axes on the bottom row get horizontal tick labels.
  287. - "1": Only the bottom left axes is labelled.
  288. - "all": all axes are labelled.
  289. cbar_mode : {"each", "single", "edge", None}, default: None
  290. Whether to create a colorbar for "each" axes, a "single" colorbar
  291. for the entire grid, colorbars only for axes on the "edge"
  292. determined by *cbar_location*, or no colorbars. The colorbars are
  293. stored in the :attr:`cbar_axes` attribute.
  294. cbar_location : {"left", "right", "bottom", "top"}, default: "right"
  295. cbar_pad : float, default: None
  296. Padding between the image axes and the colorbar axes.
  297. cbar_size : size specification (see `.Size.from_any`), default: "5%"
  298. Colorbar size.
  299. cbar_set_cax : bool, default: True
  300. If True, each axes in the grid has a *cax* attribute that is bound
  301. to associated *cbar_axes*.
  302. axes_class : subclass of `matplotlib.axes.Axes`, default: None
  303. """
  304. _api.check_in_list(["each", "single", "edge", None],
  305. cbar_mode=cbar_mode)
  306. _api.check_in_list(["left", "right", "bottom", "top"],
  307. cbar_location=cbar_location)
  308. self._colorbar_mode = cbar_mode
  309. self._colorbar_location = cbar_location
  310. self._colorbar_pad = cbar_pad
  311. self._colorbar_size = cbar_size
  312. # The colorbar axes are created in _init_locators().
  313. super().__init__(
  314. fig, rect, nrows_ncols, ngrids,
  315. direction=direction, axes_pad=axes_pad,
  316. share_all=share_all, share_x=True, share_y=True, aspect=aspect,
  317. label_mode=label_mode, axes_class=axes_class)
  318. for ax in self.cbar_axes:
  319. fig.add_axes(ax)
  320. if cbar_set_cax:
  321. if self._colorbar_mode == "single":
  322. for ax in self.axes_all:
  323. ax.cax = self.cbar_axes[0]
  324. elif self._colorbar_mode == "edge":
  325. for index, ax in enumerate(self.axes_all):
  326. col, row = self._get_col_row(index)
  327. if self._colorbar_location in ("left", "right"):
  328. ax.cax = self.cbar_axes[row]
  329. else:
  330. ax.cax = self.cbar_axes[col]
  331. else:
  332. for ax, cax in zip(self.axes_all, self.cbar_axes):
  333. ax.cax = cax
  334. def _init_locators(self):
  335. # Slightly abusing this method to inject colorbar creation into init.
  336. if self._colorbar_pad is None:
  337. # horizontal or vertical arrangement?
  338. if self._colorbar_location in ("left", "right"):
  339. self._colorbar_pad = self._horiz_pad_size.fixed_size
  340. else:
  341. self._colorbar_pad = self._vert_pad_size.fixed_size
  342. self.cbar_axes = [
  343. _cbaraxes_class_factory(self._defaultAxesClass)(
  344. self.axes_all[0].figure, self._divider.get_position(),
  345. orientation=self._colorbar_location)
  346. for _ in range(self.ngrids)]
  347. cb_mode = self._colorbar_mode
  348. cb_location = self._colorbar_location
  349. h = []
  350. v = []
  351. h_ax_pos = []
  352. h_cb_pos = []
  353. if cb_mode == "single" and cb_location in ("left", "bottom"):
  354. if cb_location == "left":
  355. sz = self._nrows * Size.AxesX(self.axes_llc)
  356. h.append(Size.from_any(self._colorbar_size, sz))
  357. h.append(Size.from_any(self._colorbar_pad, sz))
  358. locator = self._divider.new_locator(nx=0, ny=0, ny1=-1)
  359. elif cb_location == "bottom":
  360. sz = self._ncols * Size.AxesY(self.axes_llc)
  361. v.append(Size.from_any(self._colorbar_size, sz))
  362. v.append(Size.from_any(self._colorbar_pad, sz))
  363. locator = self._divider.new_locator(nx=0, nx1=-1, ny=0)
  364. for i in range(self.ngrids):
  365. self.cbar_axes[i].set_visible(False)
  366. self.cbar_axes[0].set_axes_locator(locator)
  367. self.cbar_axes[0].set_visible(True)
  368. for col, ax in enumerate(self.axes_row[0]):
  369. if h:
  370. h.append(self._horiz_pad_size)
  371. if ax:
  372. sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0])
  373. else:
  374. sz = Size.AxesX(self.axes_all[0],
  375. aspect="axes", ref_ax=self.axes_all[0])
  376. if (cb_location == "left"
  377. and (cb_mode == "each"
  378. or (cb_mode == "edge" and col == 0))):
  379. h_cb_pos.append(len(h))
  380. h.append(Size.from_any(self._colorbar_size, sz))
  381. h.append(Size.from_any(self._colorbar_pad, sz))
  382. h_ax_pos.append(len(h))
  383. h.append(sz)
  384. if (cb_location == "right"
  385. and (cb_mode == "each"
  386. or (cb_mode == "edge" and col == self._ncols - 1))):
  387. h.append(Size.from_any(self._colorbar_pad, sz))
  388. h_cb_pos.append(len(h))
  389. h.append(Size.from_any(self._colorbar_size, sz))
  390. v_ax_pos = []
  391. v_cb_pos = []
  392. for row, ax in enumerate(self.axes_column[0][::-1]):
  393. if v:
  394. v.append(self._vert_pad_size)
  395. if ax:
  396. sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0])
  397. else:
  398. sz = Size.AxesY(self.axes_all[0],
  399. aspect="axes", ref_ax=self.axes_all[0])
  400. if (cb_location == "bottom"
  401. and (cb_mode == "each"
  402. or (cb_mode == "edge" and row == 0))):
  403. v_cb_pos.append(len(v))
  404. v.append(Size.from_any(self._colorbar_size, sz))
  405. v.append(Size.from_any(self._colorbar_pad, sz))
  406. v_ax_pos.append(len(v))
  407. v.append(sz)
  408. if (cb_location == "top"
  409. and (cb_mode == "each"
  410. or (cb_mode == "edge" and row == self._nrows - 1))):
  411. v.append(Size.from_any(self._colorbar_pad, sz))
  412. v_cb_pos.append(len(v))
  413. v.append(Size.from_any(self._colorbar_size, sz))
  414. for i in range(self.ngrids):
  415. col, row = self._get_col_row(i)
  416. locator = self._divider.new_locator(nx=h_ax_pos[col],
  417. ny=v_ax_pos[self._nrows-1-row])
  418. self.axes_all[i].set_axes_locator(locator)
  419. if cb_mode == "each":
  420. if cb_location in ("right", "left"):
  421. locator = self._divider.new_locator(
  422. nx=h_cb_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
  423. elif cb_location in ("top", "bottom"):
  424. locator = self._divider.new_locator(
  425. nx=h_ax_pos[col], ny=v_cb_pos[self._nrows - 1 - row])
  426. self.cbar_axes[i].set_axes_locator(locator)
  427. elif cb_mode == "edge":
  428. if (cb_location == "left" and col == 0
  429. or cb_location == "right" and col == self._ncols - 1):
  430. locator = self._divider.new_locator(
  431. nx=h_cb_pos[0], ny=v_ax_pos[self._nrows - 1 - row])
  432. self.cbar_axes[row].set_axes_locator(locator)
  433. elif (cb_location == "bottom" and row == self._nrows - 1
  434. or cb_location == "top" and row == 0):
  435. locator = self._divider.new_locator(nx=h_ax_pos[col],
  436. ny=v_cb_pos[0])
  437. self.cbar_axes[col].set_axes_locator(locator)
  438. if cb_mode == "single":
  439. if cb_location == "right":
  440. sz = self._nrows * Size.AxesX(self.axes_llc)
  441. h.append(Size.from_any(self._colorbar_pad, sz))
  442. h.append(Size.from_any(self._colorbar_size, sz))
  443. locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1)
  444. elif cb_location == "top":
  445. sz = self._ncols * Size.AxesY(self.axes_llc)
  446. v.append(Size.from_any(self._colorbar_pad, sz))
  447. v.append(Size.from_any(self._colorbar_size, sz))
  448. locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2)
  449. if cb_location in ("right", "top"):
  450. for i in range(self.ngrids):
  451. self.cbar_axes[i].set_visible(False)
  452. self.cbar_axes[0].set_axes_locator(locator)
  453. self.cbar_axes[0].set_visible(True)
  454. elif cb_mode == "each":
  455. for i in range(self.ngrids):
  456. self.cbar_axes[i].set_visible(True)
  457. elif cb_mode == "edge":
  458. if cb_location in ("right", "left"):
  459. count = self._nrows
  460. else:
  461. count = self._ncols
  462. for i in range(count):
  463. self.cbar_axes[i].set_visible(True)
  464. for j in range(i + 1, self.ngrids):
  465. self.cbar_axes[j].set_visible(False)
  466. else:
  467. for i in range(self.ngrids):
  468. self.cbar_axes[i].set_visible(False)
  469. self.cbar_axes[i].set_position([1., 1., 0.001, 0.001],
  470. which="active")
  471. self._divider.set_horizontal(h)
  472. self._divider.set_vertical(v)
  473. AxesGrid = ImageGrid