axes_rgb.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from __future__ import (absolute_import, division, print_function,
  2. unicode_literals)
  3. import six
  4. import numpy as np
  5. from .axes_divider import make_axes_locatable, Size, locatable_axes_factory
  6. import sys
  7. from .mpl_axes import Axes
  8. def make_rgb_axes(ax, pad=0.01, axes_class=None, add_all=True):
  9. """
  10. pad : fraction of the axes height.
  11. """
  12. divider = make_axes_locatable(ax)
  13. pad_size = Size.Fraction(pad, Size.AxesY(ax))
  14. xsize = Size.Fraction((1.-2.*pad)/3., Size.AxesX(ax))
  15. ysize = Size.Fraction((1.-2.*pad)/3., Size.AxesY(ax))
  16. divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
  17. divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
  18. ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
  19. ax_rgb = []
  20. if axes_class is None:
  21. try:
  22. axes_class = locatable_axes_factory(ax._axes_class)
  23. except AttributeError:
  24. axes_class = locatable_axes_factory(type(ax))
  25. for ny in [4, 2, 0]:
  26. ax1 = axes_class(ax.get_figure(),
  27. ax.get_position(original=True),
  28. sharex=ax, sharey=ax)
  29. locator = divider.new_locator(nx=2, ny=ny)
  30. ax1.set_axes_locator(locator)
  31. for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
  32. t.set_visible(False)
  33. try:
  34. for axis in ax1.axis.values():
  35. axis.major_ticklabels.set_visible(False)
  36. except AttributeError:
  37. pass
  38. ax_rgb.append(ax1)
  39. if add_all:
  40. fig = ax.get_figure()
  41. for ax1 in ax_rgb:
  42. fig.add_axes(ax1)
  43. return ax_rgb
  44. def imshow_rgb(ax, r, g, b, **kwargs):
  45. ny, nx = r.shape
  46. R = np.zeros([ny, nx, 3], dtype="d")
  47. R[:,:,0] = r
  48. G = np.zeros_like(R)
  49. G[:,:,1] = g
  50. B = np.zeros_like(R)
  51. B[:,:,2] = b
  52. RGB = R + G + B
  53. im_rgb = ax.imshow(RGB, **kwargs)
  54. return im_rgb
  55. class RGBAxesBase(object):
  56. """base class for a 4-panel imshow (RGB, R, G, B)
  57. Layout:
  58. +---------------+-----+
  59. | | R |
  60. + +-----+
  61. | RGB | G |
  62. + +-----+
  63. | | B |
  64. +---------------+-----+
  65. Attributes
  66. ----------
  67. _defaultAxesClass : matplotlib.axes.Axes
  68. defaults to 'Axes' in RGBAxes child class.
  69. No default in abstract base class
  70. RGB : _defaultAxesClass
  71. The axes object for the three-channel imshow
  72. R : _defaultAxesClass
  73. The axes object for the red channel imshow
  74. G : _defaultAxesClass
  75. The axes object for the green channel imshow
  76. B : _defaultAxesClass
  77. The axes object for the blue channel imshow
  78. """
  79. def __init__(self, *kl, **kwargs):
  80. """
  81. Parameters
  82. ----------
  83. pad : float
  84. fraction of the axes height to put as padding.
  85. defaults to 0.0
  86. add_all : bool
  87. True: Add the {rgb, r, g, b} axes to the figure
  88. defaults to True.
  89. axes_class : matplotlib.axes.Axes
  90. kl :
  91. Unpacked into axes_class() init for RGB
  92. kwargs :
  93. Unpacked into axes_class() init for RGB, R, G, B axes
  94. """
  95. pad = kwargs.pop("pad", 0.0)
  96. add_all = kwargs.pop("add_all", True)
  97. try:
  98. axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
  99. except AttributeError:
  100. new_msg = ("A subclass of RGBAxesBase must have a "
  101. "_defaultAxesClass attribute. If you are not sure which "
  102. "axes class to use, consider using "
  103. "mpl_toolkits.axes_grid1.mpl_axes.Axes.")
  104. six.reraise(AttributeError, AttributeError(new_msg),
  105. sys.exc_info()[2])
  106. ax = axes_class(*kl, **kwargs)
  107. divider = make_axes_locatable(ax)
  108. pad_size = Size.Fraction(pad, Size.AxesY(ax))
  109. xsize = Size.Fraction((1.-2.*pad)/3., Size.AxesX(ax))
  110. ysize = Size.Fraction((1.-2.*pad)/3., Size.AxesY(ax))
  111. divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
  112. divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
  113. ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
  114. ax_rgb = []
  115. for ny in [4, 2, 0]:
  116. ax1 = axes_class(ax.get_figure(),
  117. ax.get_position(original=True),
  118. sharex=ax, sharey=ax, **kwargs)
  119. locator = divider.new_locator(nx=2, ny=ny)
  120. ax1.set_axes_locator(locator)
  121. ax1.axis[:].toggle(ticklabels=False)
  122. ax_rgb.append(ax1)
  123. self.RGB = ax
  124. self.R, self.G, self.B = ax_rgb
  125. if add_all:
  126. fig = ax.get_figure()
  127. fig.add_axes(ax)
  128. self.add_RGB_to_figure()
  129. self._config_axes()
  130. def _config_axes(self, line_color='w', marker_edge_color='w'):
  131. """Set the line color and ticks for the axes
  132. Parameters
  133. ----------
  134. line_color : any matplotlib color
  135. marker_edge_color : any matplotlib color
  136. """
  137. for ax1 in [self.RGB, self.R, self.G, self.B]:
  138. ax1.axis[:].line.set_color(line_color)
  139. ax1.axis[:].major_ticks.set_markeredgecolor(marker_edge_color)
  140. def add_RGB_to_figure(self):
  141. """Add the red, green and blue axes to the RGB composite's axes figure
  142. """
  143. self.RGB.get_figure().add_axes(self.R)
  144. self.RGB.get_figure().add_axes(self.G)
  145. self.RGB.get_figure().add_axes(self.B)
  146. def imshow_rgb(self, r, g, b, **kwargs):
  147. """Create the four images {rgb, r, g, b}
  148. Parameters
  149. ----------
  150. r : array-like
  151. The red array
  152. g : array-like
  153. The green array
  154. b : array-like
  155. The blue array
  156. kwargs : imshow kwargs
  157. kwargs get unpacked into the imshow calls for the four images
  158. Returns
  159. -------
  160. rgb : matplotlib.image.AxesImage
  161. r : matplotlib.image.AxesImage
  162. g : matplotlib.image.AxesImage
  163. b : matplotlib.image.AxesImage
  164. """
  165. if not (r.shape == g.shape == b.shape):
  166. raise ValueError('Input shapes do not match.'
  167. '\nr.shape = {}'
  168. '\ng.shape = {}'
  169. '\nb.shape = {}'
  170. .format(r.shape, g.shape, b.shape))
  171. RGB = np.dstack([r, g, b])
  172. R = np.zeros_like(RGB)
  173. R[:,:,0] = r
  174. G = np.zeros_like(RGB)
  175. G[:,:,1] = g
  176. B = np.zeros_like(RGB)
  177. B[:,:,2] = b
  178. im_rgb = self.RGB.imshow(RGB, **kwargs)
  179. im_r = self.R.imshow(R, **kwargs)
  180. im_g = self.G.imshow(G, **kwargs)
  181. im_b = self.B.imshow(B, **kwargs)
  182. return im_rgb, im_r, im_g, im_b
  183. class RGBAxes(RGBAxesBase):
  184. _defaultAxesClass = Axes