axes_rgb.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from types import MethodType
  2. import numpy as np
  3. from .axes_divider import make_axes_locatable, Size
  4. from .mpl_axes import Axes, SimpleAxisArtist
  5. def make_rgb_axes(ax, pad=0.01, axes_class=None, **kwargs):
  6. """
  7. Parameters
  8. ----------
  9. ax : `~matplotlib.axes.Axes`
  10. Axes instance to create the RGB Axes in.
  11. pad : float, optional
  12. Fraction of the Axes height to pad.
  13. axes_class : `matplotlib.axes.Axes` or None, optional
  14. Axes class to use for the R, G, and B Axes. If None, use
  15. the same class as *ax*.
  16. **kwargs
  17. Forwarded to *axes_class* init for the R, G, and B Axes.
  18. """
  19. divider = make_axes_locatable(ax)
  20. pad_size = pad * Size.AxesY(ax)
  21. xsize = ((1-2*pad)/3) * Size.AxesX(ax)
  22. ysize = ((1-2*pad)/3) * Size.AxesY(ax)
  23. divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
  24. divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
  25. ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
  26. ax_rgb = []
  27. if axes_class is None:
  28. axes_class = type(ax)
  29. for ny in [4, 2, 0]:
  30. ax1 = axes_class(ax.get_figure(), ax.get_position(original=True),
  31. sharex=ax, sharey=ax, **kwargs)
  32. locator = divider.new_locator(nx=2, ny=ny)
  33. ax1.set_axes_locator(locator)
  34. for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
  35. t.set_visible(False)
  36. try:
  37. for axis in ax1.axis.values():
  38. axis.major_ticklabels.set_visible(False)
  39. except AttributeError:
  40. pass
  41. ax_rgb.append(ax1)
  42. fig = ax.get_figure()
  43. for ax1 in ax_rgb:
  44. fig.add_axes(ax1)
  45. return ax_rgb
  46. class RGBAxes:
  47. """
  48. 4-panel `~.Axes.imshow` (RGB, R, G, B).
  49. Layout::
  50. ┌───────────────┬─────┐
  51. │ │ R │
  52. │ ├─────┤
  53. │ RGB │ G │
  54. │ ├─────┤
  55. │ │ B │
  56. └───────────────┴─────┘
  57. Subclasses can override the ``_defaultAxesClass`` attribute.
  58. By default RGBAxes uses `.mpl_axes.Axes`.
  59. Attributes
  60. ----------
  61. RGB : ``_defaultAxesClass``
  62. The Axes object for the three-channel `~.Axes.imshow`.
  63. R : ``_defaultAxesClass``
  64. The Axes object for the red channel `~.Axes.imshow`.
  65. G : ``_defaultAxesClass``
  66. The Axes object for the green channel `~.Axes.imshow`.
  67. B : ``_defaultAxesClass``
  68. The Axes object for the blue channel `~.Axes.imshow`.
  69. """
  70. _defaultAxesClass = Axes
  71. def __init__(self, *args, pad=0, **kwargs):
  72. """
  73. Parameters
  74. ----------
  75. pad : float, default: 0
  76. Fraction of the Axes height to put as padding.
  77. axes_class : `~matplotlib.axes.Axes`
  78. Axes class to use. If not provided, ``_defaultAxesClass`` is used.
  79. *args
  80. Forwarded to *axes_class* init for the RGB Axes
  81. **kwargs
  82. Forwarded to *axes_class* init for the RGB, R, G, and B Axes
  83. """
  84. axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
  85. self.RGB = ax = axes_class(*args, **kwargs)
  86. ax.get_figure().add_axes(ax)
  87. self.R, self.G, self.B = make_rgb_axes(
  88. ax, pad=pad, axes_class=axes_class, **kwargs)
  89. # Set the line color and ticks for the axes.
  90. for ax1 in [self.RGB, self.R, self.G, self.B]:
  91. if isinstance(ax1.axis, MethodType):
  92. ad = Axes.AxisDict(self)
  93. ad.update(
  94. bottom=SimpleAxisArtist(ax1.xaxis, 1, ax1.spines["bottom"]),
  95. top=SimpleAxisArtist(ax1.xaxis, 2, ax1.spines["top"]),
  96. left=SimpleAxisArtist(ax1.yaxis, 1, ax1.spines["left"]),
  97. right=SimpleAxisArtist(ax1.yaxis, 2, ax1.spines["right"]))
  98. else:
  99. ad = ax1.axis
  100. ad[:].line.set_color("w")
  101. ad[:].major_ticks.set_markeredgecolor("w")
  102. def imshow_rgb(self, r, g, b, **kwargs):
  103. """
  104. Create the four images {rgb, r, g, b}.
  105. Parameters
  106. ----------
  107. r, g, b : array-like
  108. The red, green, and blue arrays.
  109. **kwargs
  110. Forwarded to `~.Axes.imshow` calls for the four images.
  111. Returns
  112. -------
  113. rgb : `~matplotlib.image.AxesImage`
  114. r : `~matplotlib.image.AxesImage`
  115. g : `~matplotlib.image.AxesImage`
  116. b : `~matplotlib.image.AxesImage`
  117. """
  118. if not (r.shape == g.shape == b.shape):
  119. raise ValueError(
  120. f'Input shapes ({r.shape}, {g.shape}, {b.shape}) do not match')
  121. RGB = np.dstack([r, g, b])
  122. R = np.zeros_like(RGB)
  123. R[:, :, 0] = r
  124. G = np.zeros_like(RGB)
  125. G[:, :, 1] = g
  126. B = np.zeros_like(RGB)
  127. B[:, :, 2] = b
  128. im_rgb = self.RGB.imshow(RGB, **kwargs)
  129. im_r = self.R.imshow(R, **kwargs)
  130. im_g = self.G.imshow(G, **kwargs)
  131. im_b = self.B.imshow(B, **kwargs)
  132. return im_rgb, im_r, im_g, im_b