123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- from __future__ import (absolute_import, division, print_function,
- unicode_literals)
- import six
- import numpy as np
- from .axes_divider import make_axes_locatable, Size, locatable_axes_factory
- import sys
- from .mpl_axes import Axes
- def make_rgb_axes(ax, pad=0.01, axes_class=None, add_all=True):
- """
- pad : fraction of the axes height.
- """
- divider = make_axes_locatable(ax)
- pad_size = Size.Fraction(pad, Size.AxesY(ax))
- xsize = Size.Fraction((1.-2.*pad)/3., Size.AxesX(ax))
- ysize = Size.Fraction((1.-2.*pad)/3., Size.AxesY(ax))
- divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
- divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
- ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
- ax_rgb = []
- if axes_class is None:
- try:
- axes_class = locatable_axes_factory(ax._axes_class)
- except AttributeError:
- axes_class = locatable_axes_factory(type(ax))
- for ny in [4, 2, 0]:
- ax1 = axes_class(ax.get_figure(),
- ax.get_position(original=True),
- sharex=ax, sharey=ax)
- locator = divider.new_locator(nx=2, ny=ny)
- ax1.set_axes_locator(locator)
- for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
- t.set_visible(False)
- try:
- for axis in ax1.axis.values():
- axis.major_ticklabels.set_visible(False)
- except AttributeError:
- pass
- ax_rgb.append(ax1)
- if add_all:
- fig = ax.get_figure()
- for ax1 in ax_rgb:
- fig.add_axes(ax1)
- return ax_rgb
- def imshow_rgb(ax, r, g, b, **kwargs):
- ny, nx = r.shape
- R = np.zeros([ny, nx, 3], dtype="d")
- R[:,:,0] = r
- G = np.zeros_like(R)
- G[:,:,1] = g
- B = np.zeros_like(R)
- B[:,:,2] = b
- RGB = R + G + B
- im_rgb = ax.imshow(RGB, **kwargs)
- return im_rgb
- class RGBAxesBase(object):
- """base class for a 4-panel imshow (RGB, R, G, B)
- Layout:
- +---------------+-----+
- | | R |
- + +-----+
- | RGB | G |
- + +-----+
- | | B |
- +---------------+-----+
- Attributes
- ----------
- _defaultAxesClass : matplotlib.axes.Axes
- defaults to 'Axes' in RGBAxes child class.
- No default in abstract base class
- RGB : _defaultAxesClass
- The axes object for the three-channel imshow
- R : _defaultAxesClass
- The axes object for the red channel imshow
- G : _defaultAxesClass
- The axes object for the green channel imshow
- B : _defaultAxesClass
- The axes object for the blue channel imshow
- """
- def __init__(self, *kl, **kwargs):
- """
- Parameters
- ----------
- pad : float
- fraction of the axes height to put as padding.
- defaults to 0.0
- add_all : bool
- True: Add the {rgb, r, g, b} axes to the figure
- defaults to True.
- axes_class : matplotlib.axes.Axes
- kl :
- Unpacked into axes_class() init for RGB
- kwargs :
- Unpacked into axes_class() init for RGB, R, G, B axes
- """
- pad = kwargs.pop("pad", 0.0)
- add_all = kwargs.pop("add_all", True)
- try:
- axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
- except AttributeError:
- new_msg = ("A subclass of RGBAxesBase must have a "
- "_defaultAxesClass attribute. If you are not sure which "
- "axes class to use, consider using "
- "mpl_toolkits.axes_grid1.mpl_axes.Axes.")
- six.reraise(AttributeError, AttributeError(new_msg),
- sys.exc_info()[2])
- ax = axes_class(*kl, **kwargs)
- divider = make_axes_locatable(ax)
- pad_size = Size.Fraction(pad, Size.AxesY(ax))
- xsize = Size.Fraction((1.-2.*pad)/3., Size.AxesX(ax))
- ysize = Size.Fraction((1.-2.*pad)/3., Size.AxesY(ax))
- divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
- divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
- ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
- ax_rgb = []
- for ny in [4, 2, 0]:
- ax1 = axes_class(ax.get_figure(),
- ax.get_position(original=True),
- sharex=ax, sharey=ax, **kwargs)
- locator = divider.new_locator(nx=2, ny=ny)
- ax1.set_axes_locator(locator)
- ax1.axis[:].toggle(ticklabels=False)
- ax_rgb.append(ax1)
- self.RGB = ax
- self.R, self.G, self.B = ax_rgb
- if add_all:
- fig = ax.get_figure()
- fig.add_axes(ax)
- self.add_RGB_to_figure()
- self._config_axes()
- def _config_axes(self, line_color='w', marker_edge_color='w'):
- """Set the line color and ticks for the axes
- Parameters
- ----------
- line_color : any matplotlib color
- marker_edge_color : any matplotlib color
- """
- for ax1 in [self.RGB, self.R, self.G, self.B]:
- ax1.axis[:].line.set_color(line_color)
- ax1.axis[:].major_ticks.set_markeredgecolor(marker_edge_color)
- def add_RGB_to_figure(self):
- """Add the red, green and blue axes to the RGB composite's axes figure
- """
- self.RGB.get_figure().add_axes(self.R)
- self.RGB.get_figure().add_axes(self.G)
- self.RGB.get_figure().add_axes(self.B)
- def imshow_rgb(self, r, g, b, **kwargs):
- """Create the four images {rgb, r, g, b}
- Parameters
- ----------
- r : array-like
- The red array
- g : array-like
- The green array
- b : array-like
- The blue array
- kwargs : imshow kwargs
- kwargs get unpacked into the imshow calls for the four images
- Returns
- -------
- rgb : matplotlib.image.AxesImage
- r : matplotlib.image.AxesImage
- g : matplotlib.image.AxesImage
- b : matplotlib.image.AxesImage
- """
- if not (r.shape == g.shape == b.shape):
- raise ValueError('Input shapes do not match.'
- '\nr.shape = {}'
- '\ng.shape = {}'
- '\nb.shape = {}'
- .format(r.shape, g.shape, b.shape))
- RGB = np.dstack([r, g, b])
- R = np.zeros_like(RGB)
- R[:,:,0] = r
- G = np.zeros_like(RGB)
- G[:,:,1] = g
- B = np.zeros_like(RGB)
- B[:,:,2] = b
- im_rgb = self.RGB.imshow(RGB, **kwargs)
- im_r = self.R.imshow(R, **kwargs)
- im_g = self.G.imshow(G, **kwargs)
- im_b = self.B.imshow(B, **kwargs)
- return im_rgb, im_r, im_g, im_b
- class RGBAxes(RGBAxesBase):
- _defaultAxesClass = Axes
|