123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- from types import MethodType
- import numpy as np
- from .axes_divider import make_axes_locatable, Size
- from .mpl_axes import Axes, SimpleAxisArtist
- def make_rgb_axes(ax, pad=0.01, axes_class=None, **kwargs):
- """
- Parameters
- ----------
- ax : `~matplotlib.axes.Axes`
- Axes instance to create the RGB Axes in.
- pad : float, optional
- Fraction of the Axes height to pad.
- axes_class : `matplotlib.axes.Axes` or None, optional
- Axes class to use for the R, G, and B Axes. If None, use
- the same class as *ax*.
- **kwargs
- Forwarded to *axes_class* init for the R, G, and B Axes.
- """
- divider = make_axes_locatable(ax)
- pad_size = pad * Size.AxesY(ax)
- xsize = ((1-2*pad)/3) * Size.AxesX(ax)
- ysize = ((1-2*pad)/3) * Size.AxesY(ax)
- divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
- divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
- ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
- ax_rgb = []
- if axes_class is None:
- axes_class = type(ax)
- for ny in [4, 2, 0]:
- ax1 = axes_class(ax.get_figure(), ax.get_position(original=True),
- sharex=ax, sharey=ax, **kwargs)
- locator = divider.new_locator(nx=2, ny=ny)
- ax1.set_axes_locator(locator)
- for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
- t.set_visible(False)
- try:
- for axis in ax1.axis.values():
- axis.major_ticklabels.set_visible(False)
- except AttributeError:
- pass
- ax_rgb.append(ax1)
- fig = ax.get_figure()
- for ax1 in ax_rgb:
- fig.add_axes(ax1)
- return ax_rgb
- class RGBAxes:
- """
- 4-panel `~.Axes.imshow` (RGB, R, G, B).
- Layout::
- ┌───────────────┬─────┐
- │ │ R │
- │ ├─────┤
- │ RGB │ G │
- │ ├─────┤
- │ │ B │
- └───────────────┴─────┘
- Subclasses can override the ``_defaultAxesClass`` attribute.
- By default RGBAxes uses `.mpl_axes.Axes`.
- Attributes
- ----------
- RGB : ``_defaultAxesClass``
- The Axes object for the three-channel `~.Axes.imshow`.
- R : ``_defaultAxesClass``
- The Axes object for the red channel `~.Axes.imshow`.
- G : ``_defaultAxesClass``
- The Axes object for the green channel `~.Axes.imshow`.
- B : ``_defaultAxesClass``
- The Axes object for the blue channel `~.Axes.imshow`.
- """
- _defaultAxesClass = Axes
- def __init__(self, *args, pad=0, **kwargs):
- """
- Parameters
- ----------
- pad : float, default: 0
- Fraction of the Axes height to put as padding.
- axes_class : `~matplotlib.axes.Axes`
- Axes class to use. If not provided, ``_defaultAxesClass`` is used.
- *args
- Forwarded to *axes_class* init for the RGB Axes
- **kwargs
- Forwarded to *axes_class* init for the RGB, R, G, and B Axes
- """
- axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
- self.RGB = ax = axes_class(*args, **kwargs)
- ax.get_figure().add_axes(ax)
- self.R, self.G, self.B = make_rgb_axes(
- ax, pad=pad, axes_class=axes_class, **kwargs)
- # Set the line color and ticks for the axes.
- for ax1 in [self.RGB, self.R, self.G, self.B]:
- if isinstance(ax1.axis, MethodType):
- ad = Axes.AxisDict(self)
- ad.update(
- bottom=SimpleAxisArtist(ax1.xaxis, 1, ax1.spines["bottom"]),
- top=SimpleAxisArtist(ax1.xaxis, 2, ax1.spines["top"]),
- left=SimpleAxisArtist(ax1.yaxis, 1, ax1.spines["left"]),
- right=SimpleAxisArtist(ax1.yaxis, 2, ax1.spines["right"]))
- else:
- ad = ax1.axis
- ad[:].line.set_color("w")
- ad[:].major_ticks.set_markeredgecolor("w")
- def imshow_rgb(self, r, g, b, **kwargs):
- """
- Create the four images {rgb, r, g, b}.
- Parameters
- ----------
- r, g, b : array-like
- The red, green, and blue arrays.
- **kwargs
- Forwarded to `~.Axes.imshow` calls for the four images.
- Returns
- -------
- rgb : `~matplotlib.image.AxesImage`
- r : `~matplotlib.image.AxesImage`
- g : `~matplotlib.image.AxesImage`
- b : `~matplotlib.image.AxesImage`
- """
- if not (r.shape == g.shape == b.shape):
- raise ValueError(
- f'Input shapes ({r.shape}, {g.shape}, {b.shape}) do not match')
- RGB = np.dstack([r, g, b])
- R = np.zeros_like(RGB)
- R[:, :, 0] = r
- G = np.zeros_like(RGB)
- G[:, :, 1] = g
- B = np.zeros_like(RGB)
- B[:, :, 2] = b
- im_rgb = self.RGB.imshow(RGB, **kwargs)
- im_r = self.R.imshow(R, **kwargs)
- im_g = self.G.imshow(G, **kwargs)
- im_b = self.B.imshow(B, **kwargs)
- return im_rgb, im_r, im_g, im_b
|