axes3d.py 127 KB


  1. """
  2. axes3d.py, original mplot3d version by John Porter
  3. Created: 23 Sep 2005
  4. Parts fixed by Reinier Heeres <reinier@heeres.eu>
  5. Minor additions by Ben Axelrod <baxelrod@coroware.com>
  6. Significant updates and revisions by Ben Root <ben.v.root@gmail.com>
  7. Module containing Axes3D, an object which can plot 3D objects on a
  8. 2D matplotlib figure.
  9. """
  10. from collections import defaultdict
  11. import functools
  12. import itertools
  13. import math
  14. import textwrap
  15. import numpy as np
  16. import matplotlib as mpl
  17. from matplotlib import _api, cbook, _docstring, _preprocess_data
  18. import matplotlib.artist as martist
  19. import matplotlib.axes as maxes
  20. import matplotlib.collections as mcoll
  21. import matplotlib.colors as mcolors
  22. import matplotlib.image as mimage
  23. import matplotlib.lines as mlines
  24. import matplotlib.patches as mpatches
  25. import matplotlib.container as mcontainer
  26. import matplotlib.transforms as mtransforms
  27. from matplotlib.axes import Axes
  28. from matplotlib.axes._base import _axis_method_wrapper, _process_plot_format
  29. from matplotlib.transforms import Bbox
  30. from matplotlib.tri._triangulation import Triangulation
  31. from . import art3d
  32. from . import proj3d
  33. from . import axis3d
  34. @_docstring.interpd
  35. @_api.define_aliases({
  36. "xlim": ["xlim3d"], "ylim": ["ylim3d"], "zlim": ["zlim3d"]})
  37. class Axes3D(Axes):
  38. """
  39. 3D Axes object.
  40. .. note::
  41. As a user, you do not instantiate Axes directly, but use Axes creation
  42. methods instead; e.g. from `.pyplot` or `.Figure`:
  43. `~.pyplot.subplots`, `~.pyplot.subplot_mosaic` or `.Figure.add_axes`.
  44. """
  45. name = '3d'
  46. _axis_names = ("x", "y", "z")
  47. Axes._shared_axes["z"] = cbook.Grouper()
  48. Axes._shared_axes["view"] = cbook.Grouper()
  49. vvec = _api.deprecate_privatize_attribute("3.7")
  50. eye = _api.deprecate_privatize_attribute("3.7")
  51. sx = _api.deprecate_privatize_attribute("3.7")
  52. sy = _api.deprecate_privatize_attribute("3.7")
  53. def __init__(
  54. self, fig, rect=None, *args,
  55. elev=30, azim=-60, roll=0, sharez=None, proj_type='persp',
  56. box_aspect=None, computed_zorder=True, focal_length=None,
  57. shareview=None,
  58. **kwargs):
  59. """
  60. Parameters
  61. ----------
  62. fig : Figure
  63. The parent figure.
  64. rect : tuple (left, bottom, width, height), default: None.
  65. The ``(left, bottom, width, height)`` axes position.
  66. elev : float, default: 30
  67. The elevation angle in degrees rotates the camera above and below
  68. the x-y plane, with a positive angle corresponding to a location
  69. above the plane.
  70. azim : float, default: -60
  71. The azimuthal angle in degrees rotates the camera about the z axis,
  72. with a positive angle corresponding to a right-handed rotation. In
  73. other words, a positive azimuth rotates the camera about the origin
  74. from its location along the +x axis towards the +y axis.
  75. roll : float, default: 0
  76. The roll angle in degrees rotates the camera about the viewing
  77. axis. A positive angle spins the camera clockwise, causing the
  78. scene to rotate counter-clockwise.
  79. sharez : Axes3D, optional
  80. Other Axes to share z-limits with.
  81. proj_type : {'persp', 'ortho'}
  82. The projection type, default 'persp'.
  83. box_aspect : 3-tuple of floats, default: None
  84. Changes the physical dimensions of the Axes3D, such that the ratio
  85. of the axis lengths in display units is x:y:z.
  86. If None, defaults to 4:4:3
  87. computed_zorder : bool, default: True
  88. If True, the draw order is computed based on the average position
  89. of the `.Artist`\\s along the view direction.
  90. Set to False if you want to manually control the order in which
  91. Artists are drawn on top of each other using their *zorder*
  92. attribute. This can be used for fine-tuning if the automatic order
  93. does not produce the desired result. Note however, that a manual
  94. zorder will only be correct for a limited view angle. If the figure
  95. is rotated by the user, it will look wrong from certain angles.
  96. focal_length : float, default: None
  97. For a projection type of 'persp', the focal length of the virtual
  98. camera. Must be > 0. If None, defaults to 1.
  99. For a projection type of 'ortho', must be set to either None
  100. or infinity (numpy.inf). If None, defaults to infinity.
  101. The focal length can be computed from a desired Field Of View via
  102. the equation: focal_length = 1/tan(FOV/2)
  103. shareview : Axes3D, optional
  104. Other Axes to share view angles with.
  105. **kwargs
  106. Other optional keyword arguments:
  107. %(Axes3D:kwdoc)s
  108. """
  109. if rect is None:
  110. rect = [0.0, 0.0, 1.0, 1.0]
  111. self.initial_azim = azim
  112. self.initial_elev = elev
  113. self.initial_roll = roll
  114. self.set_proj_type(proj_type, focal_length)
  115. self.computed_zorder = computed_zorder
  116. self.xy_viewLim = Bbox.unit()
  117. self.zz_viewLim = Bbox.unit()
  118. self.xy_dataLim = Bbox.unit()
  119. # z-limits are encoded in the x-component of the Bbox, y is un-used
  120. self.zz_dataLim = Bbox.unit()
  121. # inhibit autoscale_view until the axes are defined
  122. # they can't be defined until Axes.__init__ has been called
  123. self.view_init(self.initial_elev, self.initial_azim, self.initial_roll)
  124. self._sharez = sharez
  125. if sharez is not None:
  126. self._shared_axes["z"].join(self, sharez)
  127. self._adjustable = 'datalim'
  128. self._shareview = shareview
  129. if shareview is not None:
  130. self._shared_axes["view"].join(self, shareview)
  131. if kwargs.pop('auto_add_to_figure', False):
  132. raise AttributeError(
  133. 'auto_add_to_figure is no longer supported for Axes3D. '
  134. 'Use fig.add_axes(ax) instead.'
  135. )
  136. super().__init__(
  137. fig, rect, frameon=True, box_aspect=box_aspect, *args, **kwargs
  138. )
  139. # Disable drawing of axes by base class
  140. super().set_axis_off()
  141. # Enable drawing of axes by Axes3D class
  142. self.set_axis_on()
  143. self.M = None
  144. self.invM = None
  145. # func used to format z -- fall back on major formatters
  146. self.fmt_zdata = None
  147. self.mouse_init()
  148. self.figure.canvas.callbacks._connect_picklable(
  149. 'motion_notify_event', self._on_move)
  150. self.figure.canvas.callbacks._connect_picklable(
  151. 'button_press_event', self._button_press)
  152. self.figure.canvas.callbacks._connect_picklable(
  153. 'button_release_event', self._button_release)
  154. self.set_top_view()
  155. self.patch.set_linewidth(0)
  156. # Calculate the pseudo-data width and height
  157. pseudo_bbox = self.transLimits.inverted().transform([(0, 0), (1, 1)])
  158. self._pseudo_w, self._pseudo_h = pseudo_bbox[1] - pseudo_bbox[0]
  159. # mplot3d currently manages its own spines and needs these turned off
  160. # for bounding box calculations
  161. self.spines[:].set_visible(False)
  162. def set_axis_off(self):
  163. self._axis3don = False
  164. self.stale = True
  165. def set_axis_on(self):
  166. self._axis3don = True
  167. self.stale = True
  168. def convert_zunits(self, z):
  169. """
  170. For artists in an Axes, if the zaxis has units support,
  171. convert *z* using zaxis unit type
  172. """
  173. return self.zaxis.convert_units(z)
  174. def set_top_view(self):
  175. # this happens to be the right view for the viewing coordinates
  176. # moved up and to the left slightly to fit labels and axes
  177. xdwl = 0.95 / self._dist
  178. xdw = 0.9 / self._dist
  179. ydwl = 0.95 / self._dist
  180. ydw = 0.9 / self._dist
  181. # Set the viewing pane.
  182. self.viewLim.intervalx = (-xdwl, xdw)
  183. self.viewLim.intervaly = (-ydwl, ydw)
  184. self.stale = True
  185. def _init_axis(self):
  186. """Init 3D axes; overrides creation of regular X/Y axes."""
  187. self.xaxis = axis3d.XAxis(self)
  188. self.yaxis = axis3d.YAxis(self)
  189. self.zaxis = axis3d.ZAxis(self)
  190. def get_zaxis(self):
  191. """Return the ``ZAxis`` (`~.axis3d.Axis`) instance."""
  192. return self.zaxis
  193. get_zgridlines = _axis_method_wrapper("zaxis", "get_gridlines")
  194. get_zticklines = _axis_method_wrapper("zaxis", "get_ticklines")
  195. @_api.deprecated("3.7")
  196. def unit_cube(self, vals=None):
  197. return self._unit_cube(vals)
  198. def _unit_cube(self, vals=None):
  199. minx, maxx, miny, maxy, minz, maxz = vals or self.get_w_lims()
  200. return [(minx, miny, minz),
  201. (maxx, miny, minz),
  202. (maxx, maxy, minz),
  203. (minx, maxy, minz),
  204. (minx, miny, maxz),
  205. (maxx, miny, maxz),
  206. (maxx, maxy, maxz),
  207. (minx, maxy, maxz)]
  208. @_api.deprecated("3.7")
  209. def tunit_cube(self, vals=None, M=None):
  210. return self._tunit_cube(vals, M)
  211. def _tunit_cube(self, vals=None, M=None):
  212. if M is None:
  213. M = self.M
  214. xyzs = self._unit_cube(vals)
  215. tcube = proj3d._proj_points(xyzs, M)
  216. return tcube
  217. @_api.deprecated("3.7")
  218. def tunit_edges(self, vals=None, M=None):
  219. return self._tunit_edges(vals, M)
  220. def _tunit_edges(self, vals=None, M=None):
  221. tc = self._tunit_cube(vals, M)
  222. edges = [(tc[0], tc[1]),
  223. (tc[1], tc[2]),
  224. (tc[2], tc[3]),
  225. (tc[3], tc[0]),
  226. (tc[0], tc[4]),
  227. (tc[1], tc[5]),
  228. (tc[2], tc[6]),
  229. (tc[3], tc[7]),
  230. (tc[4], tc[5]),
  231. (tc[5], tc[6]),
  232. (tc[6], tc[7]),
  233. (tc[7], tc[4])]
  234. return edges
  235. def set_aspect(self, aspect, adjustable=None, anchor=None, share=False):
  236. """
  237. Set the aspect ratios.
  238. Parameters
  239. ----------
  240. aspect : {'auto', 'equal', 'equalxy', 'equalxz', 'equalyz'}
  241. Possible values:
  242. ========= ==================================================
  243. value description
  244. ========= ==================================================
  245. 'auto' automatic; fill the position rectangle with data.
  246. 'equal' adapt all the axes to have equal aspect ratios.
  247. 'equalxy' adapt the x and y axes to have equal aspect ratios.
  248. 'equalxz' adapt the x and z axes to have equal aspect ratios.
  249. 'equalyz' adapt the y and z axes to have equal aspect ratios.
  250. ========= ==================================================
  251. adjustable : None or {'box', 'datalim'}, optional
  252. If not *None*, this defines which parameter will be adjusted to
  253. meet the required aspect. See `.set_adjustable` for further
  254. details.
  255. anchor : None or str or 2-tuple of float, optional
  256. If not *None*, this defines where the Axes will be drawn if there
  257. is extra space due to aspect constraints. The most common way to
  258. specify the anchor are abbreviations of cardinal directions:
  259. ===== =====================
  260. value description
  261. ===== =====================
  262. 'C' centered
  263. 'SW' lower left corner
  264. 'S' middle of bottom edge
  265. 'SE' lower right corner
  266. etc.
  267. ===== =====================
  268. See `~.Axes.set_anchor` for further details.
  269. share : bool, default: False
  270. If ``True``, apply the settings to all shared Axes.
  271. See Also
  272. --------
  273. mpl_toolkits.mplot3d.axes3d.Axes3D.set_box_aspect
  274. """
  275. _api.check_in_list(('auto', 'equal', 'equalxy', 'equalyz', 'equalxz'),
  276. aspect=aspect)
  277. super().set_aspect(
  278. aspect='auto', adjustable=adjustable, anchor=anchor, share=share)
  279. self._aspect = aspect
  280. if aspect in ('equal', 'equalxy', 'equalxz', 'equalyz'):
  281. ax_indices = self._equal_aspect_axis_indices(aspect)
  282. view_intervals = np.array([self.xaxis.get_view_interval(),
  283. self.yaxis.get_view_interval(),
  284. self.zaxis.get_view_interval()])
  285. ptp = np.ptp(view_intervals, axis=1)
  286. if self._adjustable == 'datalim':
  287. mean = np.mean(view_intervals, axis=1)
  288. scale = max(ptp[ax_indices] / self._box_aspect[ax_indices])
  289. deltas = scale * self._box_aspect
  290. for i, set_lim in enumerate((self.set_xlim3d,
  291. self.set_ylim3d,
  292. self.set_zlim3d)):
  293. if i in ax_indices:
  294. set_lim(mean[i] - deltas[i]/2., mean[i] + deltas[i]/2.)
  295. else: # 'box'
  296. # Change the box aspect such that the ratio of the length of
  297. # the unmodified axis to the length of the diagonal
  298. # perpendicular to it remains unchanged.
  299. box_aspect = np.array(self._box_aspect)
  300. box_aspect[ax_indices] = ptp[ax_indices]
  301. remaining_ax_indices = {0, 1, 2}.difference(ax_indices)
  302. if remaining_ax_indices:
  303. remaining = remaining_ax_indices.pop()
  304. old_diag = np.linalg.norm(self._box_aspect[ax_indices])
  305. new_diag = np.linalg.norm(box_aspect[ax_indices])
  306. box_aspect[remaining] *= new_diag / old_diag
  307. self.set_box_aspect(box_aspect)
  308. def _equal_aspect_axis_indices(self, aspect):
  309. """
  310. Get the indices for which of the x, y, z axes are constrained to have
  311. equal aspect ratios.
  312. Parameters
  313. ----------
  314. aspect : {'auto', 'equal', 'equalxy', 'equalxz', 'equalyz'}
  315. See descriptions in docstring for `.set_aspect()`.
  316. """
  317. ax_indices = [] # aspect == 'auto'
  318. if aspect == 'equal':
  319. ax_indices = [0, 1, 2]
  320. elif aspect == 'equalxy':
  321. ax_indices = [0, 1]
  322. elif aspect == 'equalxz':
  323. ax_indices = [0, 2]
  324. elif aspect == 'equalyz':
  325. ax_indices = [1, 2]
  326. return ax_indices
  327. def set_box_aspect(self, aspect, *, zoom=1):
  328. """
  329. Set the Axes box aspect.
  330. The box aspect is the ratio of height to width in display
  331. units for each face of the box when viewed perpendicular to
  332. that face. This is not to be confused with the data aspect (see
  333. `~.Axes3D.set_aspect`). The default ratios are 4:4:3 (x:y:z).
  334. To simulate having equal aspect in data space, set the box
  335. aspect to match your data range in each dimension.
  336. *zoom* controls the overall size of the Axes3D in the figure.
  337. Parameters
  338. ----------
  339. aspect : 3-tuple of floats or None
  340. Changes the physical dimensions of the Axes3D, such that the ratio
  341. of the axis lengths in display units is x:y:z.
  342. If None, defaults to (4, 4, 3).
  343. zoom : float, default: 1
  344. Control overall size of the Axes3D in the figure. Must be > 0.
  345. """
  346. if zoom <= 0:
  347. raise ValueError(f'Argument zoom = {zoom} must be > 0')
  348. if aspect is None:
  349. aspect = np.asarray((4, 4, 3), dtype=float)
  350. else:
  351. aspect = np.asarray(aspect, dtype=float)
  352. _api.check_shape((3,), aspect=aspect)
  353. # default scale tuned to match the mpl32 appearance.
  354. aspect *= 1.8294640721620434 * zoom / np.linalg.norm(aspect)
  355. self._box_aspect = aspect
  356. self.stale = True
  357. def apply_aspect(self, position=None):
  358. if position is None:
  359. position = self.get_position(original=True)
  360. # in the superclass, we would go through and actually deal with axis
  361. # scales and box/datalim. Those are all irrelevant - all we need to do
  362. # is make sure our coordinate system is square.
  363. trans = self.get_figure().transSubfigure
  364. bb = mtransforms.Bbox.unit().transformed(trans)
  365. # this is the physical aspect of the panel (or figure):
  366. fig_aspect = bb.height / bb.width
  367. box_aspect = 1
  368. pb = position.frozen()
  369. pb1 = pb.shrunk_to_aspect(box_aspect, pb, fig_aspect)
  370. self._set_position(pb1.anchored(self.get_anchor(), pb), 'active')
  371. @martist.allow_rasterization
  372. def draw(self, renderer):
  373. if not self.get_visible():
  374. return
  375. self._unstale_viewLim()
  376. # draw the background patch
  377. self.patch.draw(renderer)
  378. self._frameon = False
  379. # first, set the aspect
  380. # this is duplicated from `axes._base._AxesBase.draw`
  381. # but must be called before any of the artist are drawn as
  382. # it adjusts the view limits and the size of the bounding box
  383. # of the Axes
  384. locator = self.get_axes_locator()
  385. self.apply_aspect(locator(self, renderer) if locator else None)
  386. # add the projection matrix to the renderer
  387. self.M = self.get_proj()
  388. self.invM = np.linalg.inv(self.M)
  389. collections_and_patches = (
  390. artist for artist in self._children
  391. if isinstance(artist, (mcoll.Collection, mpatches.Patch))
  392. and artist.get_visible())
  393. if self.computed_zorder:
  394. # Calculate projection of collections and patches and zorder
  395. # them. Make sure they are drawn above the grids.
  396. zorder_offset = max(axis.get_zorder()
  397. for axis in self._axis_map.values()) + 1
  398. collection_zorder = patch_zorder = zorder_offset
  399. for artist in sorted(collections_and_patches,
  400. key=lambda artist: artist.do_3d_projection(),
  401. reverse=True):
  402. if isinstance(artist, mcoll.Collection):
  403. artist.zorder = collection_zorder
  404. collection_zorder += 1
  405. elif isinstance(artist, mpatches.Patch):
  406. artist.zorder = patch_zorder
  407. patch_zorder += 1
  408. else:
  409. for artist in collections_and_patches:
  410. artist.do_3d_projection()
  411. if self._axis3don:
  412. # Draw panes first
  413. for axis in self._axis_map.values():
  414. axis.draw_pane(renderer)
  415. # Then gridlines
  416. for axis in self._axis_map.values():
  417. axis.draw_grid(renderer)
  418. # Then axes, labels, text, and ticks
  419. for axis in self._axis_map.values():
  420. axis.draw(renderer)
  421. # Then rest
  422. super().draw(renderer)
  423. def get_axis_position(self):
  424. vals = self.get_w_lims()
  425. tc = self._tunit_cube(vals, self.M)
  426. xhigh = tc[1][2] > tc[2][2]
  427. yhigh = tc[3][2] > tc[2][2]
  428. zhigh = tc[0][2] > tc[2][2]
  429. return xhigh, yhigh, zhigh
  430. def update_datalim(self, xys, **kwargs):
  431. """
  432. Not implemented in `~mpl_toolkits.mplot3d.axes3d.Axes3D`.
  433. """
  434. pass
  435. get_autoscalez_on = _axis_method_wrapper("zaxis", "_get_autoscale_on")
  436. set_autoscalez_on = _axis_method_wrapper("zaxis", "_set_autoscale_on")
  437. def set_zmargin(self, m):
  438. """
  439. Set padding of Z data limits prior to autoscaling.
  440. *m* times the data interval will be added to each end of that interval
  441. before it is used in autoscaling. If *m* is negative, this will clip
  442. the data range instead of expanding it.
  443. For example, if your data is in the range [0, 2], a margin of 0.1 will
  444. result in a range [-0.2, 2.2]; a margin of -0.1 will result in a range
  445. of [0.2, 1.8].
  446. Parameters
  447. ----------
  448. m : float greater than -0.5
  449. """
  450. if m <= -0.5:
  451. raise ValueError("margin must be greater than -0.5")
  452. self._zmargin = m
  453. self._request_autoscale_view("z")
  454. self.stale = True
  455. def margins(self, *margins, x=None, y=None, z=None, tight=True):
  456. """
  457. Set or retrieve autoscaling margins.
  458. See `.Axes.margins` for full documentation. Because this function
  459. applies to 3D Axes, it also takes a *z* argument, and returns
  460. ``(xmargin, ymargin, zmargin)``.
  461. """
  462. if margins and (x is not None or y is not None or z is not None):
  463. raise TypeError('Cannot pass both positional and keyword '
  464. 'arguments for x, y, and/or z.')
  465. elif len(margins) == 1:
  466. x = y = z = margins[0]
  467. elif len(margins) == 3:
  468. x, y, z = margins
  469. elif margins:
  470. raise TypeError('Must pass a single positional argument for all '
  471. 'margins, or one for each margin (x, y, z).')
  472. if x is None and y is None and z is None:
  473. if tight is not True:
  474. _api.warn_external(f'ignoring tight={tight!r} in get mode')
  475. return self._xmargin, self._ymargin, self._zmargin
  476. if x is not None:
  477. self.set_xmargin(x)
  478. if y is not None:
  479. self.set_ymargin(y)
  480. if z is not None:
  481. self.set_zmargin(z)
  482. self.autoscale_view(
  483. tight=tight, scalex=(x is not None), scaley=(y is not None),
  484. scalez=(z is not None)
  485. )
  486. def autoscale(self, enable=True, axis='both', tight=None):
  487. """
  488. Convenience method for simple axis view autoscaling.
  489. See `.Axes.autoscale` for full documentation. Because this function
  490. applies to 3D Axes, *axis* can also be set to 'z', and setting *axis*
  491. to 'both' autoscales all three axes.
  492. """
  493. if enable is None:
  494. scalex = True
  495. scaley = True
  496. scalez = True
  497. else:
  498. if axis in ['x', 'both']:
  499. self.set_autoscalex_on(bool(enable))
  500. scalex = self.get_autoscalex_on()
  501. else:
  502. scalex = False
  503. if axis in ['y', 'both']:
  504. self.set_autoscaley_on(bool(enable))
  505. scaley = self.get_autoscaley_on()
  506. else:
  507. scaley = False
  508. if axis in ['z', 'both']:
  509. self.set_autoscalez_on(bool(enable))
  510. scalez = self.get_autoscalez_on()
  511. else:
  512. scalez = False
  513. if scalex:
  514. self._request_autoscale_view("x", tight=tight)
  515. if scaley:
  516. self._request_autoscale_view("y", tight=tight)
  517. if scalez:
  518. self._request_autoscale_view("z", tight=tight)
  519. def auto_scale_xyz(self, X, Y, Z=None, had_data=None):
  520. # This updates the bounding boxes as to keep a record as to what the
  521. # minimum sized rectangular volume holds the data.
  522. if np.shape(X) == np.shape(Y):
  523. self.xy_dataLim.update_from_data_xy(
  524. np.column_stack([np.ravel(X), np.ravel(Y)]), not had_data)
  525. else:
  526. self.xy_dataLim.update_from_data_x(X, not had_data)
  527. self.xy_dataLim.update_from_data_y(Y, not had_data)
  528. if Z is not None:
  529. self.zz_dataLim.update_from_data_x(Z, not had_data)
  530. # Let autoscale_view figure out how to use this data.
  531. self.autoscale_view()
  532. def autoscale_view(self, tight=None, scalex=True, scaley=True,
  533. scalez=True):
  534. """
  535. Autoscale the view limits using the data limits.
  536. See `.Axes.autoscale_view` for full documentation. Because this
  537. function applies to 3D Axes, it also takes a *scalez* argument.
  538. """
  539. # This method looks at the rectangular volume (see above)
  540. # of data and decides how to scale the view portal to fit it.
  541. if tight is None:
  542. _tight = self._tight
  543. if not _tight:
  544. # if image data only just use the datalim
  545. for artist in self._children:
  546. if isinstance(artist, mimage.AxesImage):
  547. _tight = True
  548. elif isinstance(artist, (mlines.Line2D, mpatches.Patch)):
  549. _tight = False
  550. break
  551. else:
  552. _tight = self._tight = bool(tight)
  553. if scalex and self.get_autoscalex_on():
  554. x0, x1 = self.xy_dataLim.intervalx
  555. xlocator = self.xaxis.get_major_locator()
  556. x0, x1 = xlocator.nonsingular(x0, x1)
  557. if self._xmargin > 0:
  558. delta = (x1 - x0) * self._xmargin
  559. x0 -= delta
  560. x1 += delta
  561. if not _tight:
  562. x0, x1 = xlocator.view_limits(x0, x1)
  563. self.set_xbound(x0, x1)
  564. if scaley and self.get_autoscaley_on():
  565. y0, y1 = self.xy_dataLim.intervaly
  566. ylocator = self.yaxis.get_major_locator()
  567. y0, y1 = ylocator.nonsingular(y0, y1)
  568. if self._ymargin > 0:
  569. delta = (y1 - y0) * self._ymargin
  570. y0 -= delta
  571. y1 += delta
  572. if not _tight:
  573. y0, y1 = ylocator.view_limits(y0, y1)
  574. self.set_ybound(y0, y1)
  575. if scalez and self.get_autoscalez_on():
  576. z0, z1 = self.zz_dataLim.intervalx
  577. zlocator = self.zaxis.get_major_locator()
  578. z0, z1 = zlocator.nonsingular(z0, z1)
  579. if self._zmargin > 0:
  580. delta = (z1 - z0) * self._zmargin
  581. z0 -= delta
  582. z1 += delta
  583. if not _tight:
  584. z0, z1 = zlocator.view_limits(z0, z1)
  585. self.set_zbound(z0, z1)
  586. def get_w_lims(self):
  587. """Get 3D world limits."""
  588. minx, maxx = self.get_xlim3d()
  589. miny, maxy = self.get_ylim3d()
  590. minz, maxz = self.get_zlim3d()
  591. return minx, maxx, miny, maxy, minz, maxz
  592. # set_xlim, set_ylim are directly inherited from base Axes.
  593. def set_zlim(self, bottom=None, top=None, *, emit=True, auto=False,
  594. zmin=None, zmax=None):
  595. """
  596. Set 3D z limits.
  597. See `.Axes.set_ylim` for full documentation
  598. """
  599. if top is None and np.iterable(bottom):
  600. bottom, top = bottom
  601. if zmin is not None:
  602. if bottom is not None:
  603. raise TypeError("Cannot pass both 'bottom' and 'zmin'")
  604. bottom = zmin
  605. if zmax is not None:
  606. if top is not None:
  607. raise TypeError("Cannot pass both 'top' and 'zmax'")
  608. top = zmax
  609. return self.zaxis._set_lim(bottom, top, emit=emit, auto=auto)
  610. set_xlim3d = maxes.Axes.set_xlim
  611. set_ylim3d = maxes.Axes.set_ylim
  612. set_zlim3d = set_zlim
  613. def get_xlim(self):
  614. # docstring inherited
  615. return tuple(self.xy_viewLim.intervalx)
  616. def get_ylim(self):
  617. # docstring inherited
  618. return tuple(self.xy_viewLim.intervaly)
  619. def get_zlim(self):
  620. """
  621. Return the 3D z-axis view limits.
  622. Returns
  623. -------
  624. left, right : (float, float)
  625. The current z-axis limits in data coordinates.
  626. See Also
  627. --------
  628. set_zlim
  629. set_zbound, get_zbound
  630. invert_zaxis, zaxis_inverted
  631. Notes
  632. -----
  633. The z-axis may be inverted, in which case the *left* value will
  634. be greater than the *right* value.
  635. """
  636. return tuple(self.zz_viewLim.intervalx)
  637. get_zscale = _axis_method_wrapper("zaxis", "get_scale")
  638. # Redefine all three methods to overwrite their docstrings.
  639. set_xscale = _axis_method_wrapper("xaxis", "_set_axes_scale")
  640. set_yscale = _axis_method_wrapper("yaxis", "_set_axes_scale")
  641. set_zscale = _axis_method_wrapper("zaxis", "_set_axes_scale")
  642. set_xscale.__doc__, set_yscale.__doc__, set_zscale.__doc__ = map(
  643. """
  644. Set the {}-axis scale.
  645. Parameters
  646. ----------
  647. value : {{"linear"}}
  648. The axis scale type to apply. 3D axes currently only support
  649. linear scales; other scales yield nonsensical results.
  650. **kwargs
  651. Keyword arguments are nominally forwarded to the scale class, but
  652. none of them is applicable for linear scales.
  653. """.format,
  654. ["x", "y", "z"])
  655. get_zticks = _axis_method_wrapper("zaxis", "get_ticklocs")
  656. set_zticks = _axis_method_wrapper("zaxis", "set_ticks")
  657. get_zmajorticklabels = _axis_method_wrapper("zaxis", "get_majorticklabels")
  658. get_zminorticklabels = _axis_method_wrapper("zaxis", "get_minorticklabels")
  659. get_zticklabels = _axis_method_wrapper("zaxis", "get_ticklabels")
  660. set_zticklabels = _axis_method_wrapper(
  661. "zaxis", "set_ticklabels",
  662. doc_sub={"Axis.set_ticks": "Axes3D.set_zticks"})
  663. zaxis_date = _axis_method_wrapper("zaxis", "axis_date")
  664. if zaxis_date.__doc__:
  665. zaxis_date.__doc__ += textwrap.dedent("""
  666. Notes
  667. -----
  668. This function is merely provided for completeness, but 3D axes do not
  669. support dates for ticks, and so this may not work as expected.
  670. """)
  671. def clabel(self, *args, **kwargs):
  672. """Currently not implemented for 3D axes, and returns *None*."""
  673. return None
  674. def view_init(self, elev=None, azim=None, roll=None, vertical_axis="z",
  675. share=False):
  676. """
  677. Set the elevation and azimuth of the axes in degrees (not radians).
  678. This can be used to rotate the axes programmatically.
  679. To look normal to the primary planes, the following elevation and
  680. azimuth angles can be used. A roll angle of 0, 90, 180, or 270 deg
  681. will rotate these views while keeping the axes at right angles.
  682. ========== ==== ====
  683. view plane elev azim
  684. ========== ==== ====
  685. XY 90 -90
  686. XZ 0 -90
  687. YZ 0 0
  688. -XY -90 90
  689. -XZ 0 90
  690. -YZ 0 180
  691. ========== ==== ====
  692. Parameters
  693. ----------
  694. elev : float, default: None
  695. The elevation angle in degrees rotates the camera above the plane
  696. pierced by the vertical axis, with a positive angle corresponding
  697. to a location above that plane. For example, with the default
  698. vertical axis of 'z', the elevation defines the angle of the camera
  699. location above the x-y plane.
  700. If None, then the initial value as specified in the `Axes3D`
  701. constructor is used.
  702. azim : float, default: None
  703. The azimuthal angle in degrees rotates the camera about the
  704. vertical axis, with a positive angle corresponding to a
  705. right-handed rotation. For example, with the default vertical axis
  706. of 'z', a positive azimuth rotates the camera about the origin from
  707. its location along the +x axis towards the +y axis.
  708. If None, then the initial value as specified in the `Axes3D`
  709. constructor is used.
  710. roll : float, default: None
  711. The roll angle in degrees rotates the camera about the viewing
  712. axis. A positive angle spins the camera clockwise, causing the
  713. scene to rotate counter-clockwise.
  714. If None, then the initial value as specified in the `Axes3D`
  715. constructor is used.
  716. vertical_axis : {"z", "x", "y"}, default: "z"
  717. The axis to align vertically. *azim* rotates about this axis.
  718. share : bool, default: False
  719. If ``True``, apply the settings to all Axes with shared views.
  720. """
  721. self._dist = 10 # The camera distance from origin. Behaves like zoom
  722. if elev is None:
  723. elev = self.initial_elev
  724. if azim is None:
  725. azim = self.initial_azim
  726. if roll is None:
  727. roll = self.initial_roll
  728. vertical_axis = _api.check_getitem(
  729. dict(x=0, y=1, z=2), vertical_axis=vertical_axis
  730. )
  731. if share:
  732. axes = {sibling for sibling
  733. in self._shared_axes['view'].get_siblings(self)}
  734. else:
  735. axes = [self]
  736. for ax in axes:
  737. ax.elev = elev
  738. ax.azim = azim
  739. ax.roll = roll
  740. ax._vertical_axis = vertical_axis
  741. def set_proj_type(self, proj_type, focal_length=None):
  742. """
  743. Set the projection type.
  744. Parameters
  745. ----------
  746. proj_type : {'persp', 'ortho'}
  747. The projection type.
  748. focal_length : float, default: None
  749. For a projection type of 'persp', the focal length of the virtual
  750. camera. Must be > 0. If None, defaults to 1.
  751. The focal length can be computed from a desired Field Of View via
  752. the equation: focal_length = 1/tan(FOV/2)
  753. """
  754. _api.check_in_list(['persp', 'ortho'], proj_type=proj_type)
  755. if proj_type == 'persp':
  756. if focal_length is None:
  757. focal_length = 1
  758. elif focal_length <= 0:
  759. raise ValueError(f"focal_length = {focal_length} must be "
  760. "greater than 0")
  761. self._focal_length = focal_length
  762. else: # 'ortho':
  763. if focal_length not in (None, np.inf):
  764. raise ValueError(f"focal_length = {focal_length} must be "
  765. f"None for proj_type = {proj_type}")
  766. self._focal_length = np.inf
  767. def _roll_to_vertical(self, arr):
  768. """Roll arrays to match the different vertical axis."""
  769. return np.roll(arr, self._vertical_axis - 2)
  770. def get_proj(self):
  771. """Create the projection matrix from the current viewing position."""
  772. # Transform to uniform world coordinates 0-1, 0-1, 0-1
  773. box_aspect = self._roll_to_vertical(self._box_aspect)
  774. worldM = proj3d.world_transformation(
  775. *self.get_xlim3d(),
  776. *self.get_ylim3d(),
  777. *self.get_zlim3d(),
  778. pb_aspect=box_aspect,
  779. )
  780. # Look into the middle of the world coordinates:
  781. R = 0.5 * box_aspect
  782. # elev: elevation angle in the z plane.
  783. # azim: azimuth angle in the xy plane.
  784. # Coordinates for a point that rotates around the box of data.
  785. # p0, p1 corresponds to rotating the box only around the vertical axis.
  786. # p2 corresponds to rotating the box only around the horizontal axis.
  787. elev_rad = np.deg2rad(self.elev)
  788. azim_rad = np.deg2rad(self.azim)
  789. p0 = np.cos(elev_rad) * np.cos(azim_rad)
  790. p1 = np.cos(elev_rad) * np.sin(azim_rad)
  791. p2 = np.sin(elev_rad)
  792. # When changing vertical axis the coordinates changes as well.
  793. # Roll the values to get the same behaviour as the default:
  794. ps = self._roll_to_vertical([p0, p1, p2])
  795. # The coordinates for the eye viewing point. The eye is looking
  796. # towards the middle of the box of data from a distance:
  797. eye = R + self._dist * ps
  798. # vvec, self._vvec and self._eye are unused, remove when deprecated
  799. vvec = R - eye
  800. self._eye = eye
  801. self._vvec = vvec / np.linalg.norm(vvec)
  802. # Calculate the viewing axes for the eye position
  803. u, v, w = self._calc_view_axes(eye)
  804. self._view_u = u # _view_u is towards the right of the screen
  805. self._view_v = v # _view_v is towards the top of the screen
  806. self._view_w = w # _view_w is out of the screen
  807. # Generate the view and projection transformation matrices
  808. if self._focal_length == np.inf:
  809. # Orthographic projection
  810. viewM = proj3d._view_transformation_uvw(u, v, w, eye)
  811. projM = proj3d._ortho_transformation(-self._dist, self._dist)
  812. else:
  813. # Perspective projection
  814. # Scale the eye dist to compensate for the focal length zoom effect
  815. eye_focal = R + self._dist * ps * self._focal_length
  816. viewM = proj3d._view_transformation_uvw(u, v, w, eye_focal)
  817. projM = proj3d._persp_transformation(-self._dist,
  818. self._dist,
  819. self._focal_length)
  820. # Combine all the transformation matrices to get the final projection
  821. M0 = np.dot(viewM, worldM)
  822. M = np.dot(projM, M0)
  823. return M
  824. def mouse_init(self, rotate_btn=1, pan_btn=2, zoom_btn=3):
  825. """
  826. Set the mouse buttons for 3D rotation and zooming.
  827. Parameters
  828. ----------
  829. rotate_btn : int or list of int, default: 1
  830. The mouse button or buttons to use for 3D rotation of the axes.
  831. pan_btn : int or list of int, default: 2
  832. The mouse button or buttons to use to pan the 3D axes.
  833. zoom_btn : int or list of int, default: 3
  834. The mouse button or buttons to use to zoom the 3D axes.
  835. """
  836. self.button_pressed = None
  837. # coerce scalars into array-like, then convert into
  838. # a regular list to avoid comparisons against None
  839. # which breaks in recent versions of numpy.
  840. self._rotate_btn = np.atleast_1d(rotate_btn).tolist()
  841. self._pan_btn = np.atleast_1d(pan_btn).tolist()
  842. self._zoom_btn = np.atleast_1d(zoom_btn).tolist()
  843. def disable_mouse_rotation(self):
  844. """Disable mouse buttons for 3D rotation, panning, and zooming."""
  845. self.mouse_init(rotate_btn=[], pan_btn=[], zoom_btn=[])
  846. def can_zoom(self):
  847. # doc-string inherited
  848. return True
  849. def can_pan(self):
  850. # doc-string inherited
  851. return True
  852. def sharez(self, other):
  853. """
  854. Share the z-axis with *other*.
  855. This is equivalent to passing ``sharez=other`` when constructing the
  856. Axes, and cannot be used if the z-axis is already being shared with
  857. another Axes.
  858. """
  859. _api.check_isinstance(Axes3D, other=other)
  860. if self._sharez is not None and other is not self._sharez:
  861. raise ValueError("z-axis is already shared")
  862. self._shared_axes["z"].join(self, other)
  863. self._sharez = other
  864. self.zaxis.major = other.zaxis.major # Ticker instances holding
  865. self.zaxis.minor = other.zaxis.minor # locator and formatter.
  866. z0, z1 = other.get_zlim()
  867. self.set_zlim(z0, z1, emit=False, auto=other.get_autoscalez_on())
  868. self.zaxis._scale = other.zaxis._scale
  869. def shareview(self, other):
  870. """
  871. Share the view angles with *other*.
  872. This is equivalent to passing ``shareview=other`` when
  873. constructing the Axes, and cannot be used if the view angles are
  874. already being shared with another Axes.
  875. """
  876. _api.check_isinstance(Axes3D, other=other)
  877. if self._shareview is not None and other is not self._shareview:
  878. raise ValueError("view angles are already shared")
  879. self._shared_axes["view"].join(self, other)
  880. self._shareview = other
  881. vertical_axis = {0: "x", 1: "y", 2: "z"}[other._vertical_axis]
  882. self.view_init(elev=other.elev, azim=other.azim, roll=other.roll,
  883. vertical_axis=vertical_axis, share=True)
  884. def clear(self):
  885. # docstring inherited.
  886. super().clear()
  887. if self._focal_length == np.inf:
  888. self._zmargin = mpl.rcParams['axes.zmargin']
  889. else:
  890. self._zmargin = 0.
  891. self.grid(mpl.rcParams['axes3d.grid'])
  892. def _button_press(self, event):
  893. if event.inaxes == self:
  894. self.button_pressed = event.button
  895. self._sx, self._sy = event.xdata, event.ydata
  896. toolbar = self.figure.canvas.toolbar
  897. if toolbar and toolbar._nav_stack() is None:
  898. toolbar.push_current()
  899. def _button_release(self, event):
  900. self.button_pressed = None
  901. toolbar = self.figure.canvas.toolbar
  902. # backend_bases.release_zoom and backend_bases.release_pan call
  903. # push_current, so check the navigation mode so we don't call it twice
  904. if toolbar and self.get_navigate_mode() is None:
  905. toolbar.push_current()
  906. def _get_view(self):
  907. # docstring inherited
  908. return {
  909. "xlim": self.get_xlim(), "autoscalex_on": self.get_autoscalex_on(),
  910. "ylim": self.get_ylim(), "autoscaley_on": self.get_autoscaley_on(),
  911. "zlim": self.get_zlim(), "autoscalez_on": self.get_autoscalez_on(),
  912. }, (self.elev, self.azim, self.roll)
  913. def _set_view(self, view):
  914. # docstring inherited
  915. props, (elev, azim, roll) = view
  916. self.set(**props)
  917. self.elev = elev
  918. self.azim = azim
  919. self.roll = roll
  920. def format_zdata(self, z):
  921. """
  922. Return *z* string formatted. This function will use the
  923. :attr:`fmt_zdata` attribute if it is callable, else will fall
  924. back on the zaxis major formatter
  925. """
  926. try:
  927. return self.fmt_zdata(z)
  928. except (AttributeError, TypeError):
  929. func = self.zaxis.get_major_formatter().format_data_short
  930. val = func(z)
  931. return val
  932. def format_coord(self, xv, yv, renderer=None):
  933. """
  934. Return a string giving the current view rotation angles, or the x, y, z
  935. coordinates of the point on the nearest axis pane underneath the mouse
  936. cursor, depending on the mouse button pressed.
  937. """
  938. coords = ''
  939. if self.button_pressed in self._rotate_btn:
  940. # ignore xv and yv and display angles instead
  941. coords = self._rotation_coords()
  942. elif self.M is not None:
  943. coords = self._location_coords(xv, yv, renderer)
  944. return coords
  945. def _rotation_coords(self):
  946. """
  947. Return the rotation angles as a string.
  948. """
  949. norm_elev = art3d._norm_angle(self.elev)
  950. norm_azim = art3d._norm_angle(self.azim)
  951. norm_roll = art3d._norm_angle(self.roll)
  952. coords = (f"elevation={norm_elev:.0f}\N{DEGREE SIGN}, "
  953. f"azimuth={norm_azim:.0f}\N{DEGREE SIGN}, "
  954. f"roll={norm_roll:.0f}\N{DEGREE SIGN}"
  955. ).replace("-", "\N{MINUS SIGN}")
  956. return coords
  957. def _location_coords(self, xv, yv, renderer):
  958. """
  959. Return the location on the axis pane underneath the cursor as a string.
  960. """
  961. p1, pane_idx = self._calc_coord(xv, yv, renderer)
  962. xs = self.format_xdata(p1[0])
  963. ys = self.format_ydata(p1[1])
  964. zs = self.format_zdata(p1[2])
  965. if pane_idx == 0:
  966. coords = f'x pane={xs}, y={ys}, z={zs}'
  967. elif pane_idx == 1:
  968. coords = f'x={xs}, y pane={ys}, z={zs}'
  969. elif pane_idx == 2:
  970. coords = f'x={xs}, y={ys}, z pane={zs}'
  971. return coords
  972. def _get_camera_loc(self):
  973. """
  974. Returns the current camera location in data coordinates.
  975. """
  976. cx, cy, cz, dx, dy, dz = self._get_w_centers_ranges()
  977. c = np.array([cx, cy, cz])
  978. r = np.array([dx, dy, dz])
  979. if self._focal_length == np.inf: # orthographic projection
  980. focal_length = 1e9 # large enough to be effectively infinite
  981. else: # perspective projection
  982. focal_length = self._focal_length
  983. eye = c + self._view_w * self._dist * r / self._box_aspect * focal_length
  984. return eye
  985. def _calc_coord(self, xv, yv, renderer=None):
  986. """
  987. Given the 2D view coordinates, find the point on the nearest axis pane
  988. that lies directly below those coordinates. Returns a 3D point in data
  989. coordinates.
  990. """
  991. if self._focal_length == np.inf: # orthographic projection
  992. zv = 1
  993. else: # perspective projection
  994. zv = -1 / self._focal_length
  995. # Convert point on view plane to data coordinates
  996. p1 = np.array(proj3d.inv_transform(xv, yv, zv, self.invM)).ravel()
  997. # Get the vector from the camera to the point on the view plane
  998. vec = self._get_camera_loc() - p1
  999. # Get the pane locations for each of the axes
  1000. pane_locs = []
  1001. for axis in self._axis_map.values():
  1002. xys, loc = axis.active_pane(renderer)
  1003. pane_locs.append(loc)
  1004. # Find the distance to the nearest pane by projecting the view vector
  1005. scales = np.zeros(3)
  1006. for i in range(3):
  1007. if vec[i] == 0:
  1008. scales[i] = np.inf
  1009. else:
  1010. scales[i] = (p1[i] - pane_locs[i]) / vec[i]
  1011. pane_idx = np.argmin(abs(scales))
  1012. scale = scales[pane_idx]
  1013. # Calculate the point on the closest pane
  1014. p2 = p1 - scale*vec
  1015. return p2, pane_idx
  1016. def _on_move(self, event):
  1017. """
  1018. Mouse moving.
  1019. By default, button-1 rotates, button-2 pans, and button-3 zooms;
  1020. these buttons can be modified via `mouse_init`.
  1021. """
  1022. if not self.button_pressed:
  1023. return
  1024. if self.get_navigate_mode() is not None:
  1025. # we don't want to rotate if we are zooming/panning
  1026. # from the toolbar
  1027. return
  1028. if self.M is None:
  1029. return
  1030. x, y = event.xdata, event.ydata
  1031. # In case the mouse is out of bounds.
  1032. if x is None or event.inaxes != self:
  1033. return
  1034. dx, dy = x - self._sx, y - self._sy
  1035. w = self._pseudo_w
  1036. h = self._pseudo_h
  1037. # Rotation
  1038. if self.button_pressed in self._rotate_btn:
  1039. # rotate viewing point
  1040. # get the x and y pixel coords
  1041. if dx == 0 and dy == 0:
  1042. return
  1043. roll = np.deg2rad(self.roll)
  1044. delev = -(dy/h)*180*np.cos(roll) + (dx/w)*180*np.sin(roll)
  1045. dazim = -(dy/h)*180*np.sin(roll) - (dx/w)*180*np.cos(roll)
  1046. elev = self.elev + delev
  1047. azim = self.azim + dazim
  1048. self.view_init(elev=elev, azim=azim, roll=roll, share=True)
  1049. self.stale = True
  1050. # Pan
  1051. elif self.button_pressed in self._pan_btn:
  1052. # Start the pan event with pixel coordinates
  1053. px, py = self.transData.transform([self._sx, self._sy])
  1054. self.start_pan(px, py, 2)
  1055. # pan view (takes pixel coordinate input)
  1056. self.drag_pan(2, None, event.x, event.y)
  1057. self.end_pan()
  1058. # Zoom
  1059. elif self.button_pressed in self._zoom_btn:
  1060. # zoom view (dragging down zooms in)
  1061. scale = h/(h - dy)
  1062. self._scale_axis_limits(scale, scale, scale)
  1063. # Store the event coordinates for the next time through.
  1064. self._sx, self._sy = x, y
  1065. # Always request a draw update at the end of interaction
  1066. self.figure.canvas.draw_idle()
  1067. def drag_pan(self, button, key, x, y):
  1068. # docstring inherited
  1069. # Get the coordinates from the move event
  1070. p = self._pan_start
  1071. (xdata, ydata), (xdata_start, ydata_start) = p.trans_inverse.transform(
  1072. [(x, y), (p.x, p.y)])
  1073. self._sx, self._sy = xdata, ydata
  1074. # Calling start_pan() to set the x/y of this event as the starting
  1075. # move location for the next event
  1076. self.start_pan(x, y, button)
  1077. du, dv = xdata - xdata_start, ydata - ydata_start
  1078. dw = 0
  1079. if key == 'x':
  1080. dv = 0
  1081. elif key == 'y':
  1082. du = 0
  1083. if du == 0 and dv == 0:
  1084. return
  1085. # Transform the pan from the view axes to the data axes
  1086. R = np.array([self._view_u, self._view_v, self._view_w])
  1087. R = -R / self._box_aspect * self._dist
  1088. duvw_projected = R.T @ np.array([du, dv, dw])
  1089. # Calculate pan distance
  1090. minx, maxx, miny, maxy, minz, maxz = self.get_w_lims()
  1091. dx = (maxx - minx) * duvw_projected[0]
  1092. dy = (maxy - miny) * duvw_projected[1]
  1093. dz = (maxz - minz) * duvw_projected[2]
  1094. # Set the new axis limits
  1095. self.set_xlim3d(minx + dx, maxx + dx)
  1096. self.set_ylim3d(miny + dy, maxy + dy)
  1097. self.set_zlim3d(minz + dz, maxz + dz)
  1098. def _calc_view_axes(self, eye):
  1099. """
  1100. Get the unit vectors for the viewing axes in data coordinates.
  1101. `u` is towards the right of the screen
  1102. `v` is towards the top of the screen
  1103. `w` is out of the screen
  1104. """
  1105. elev_rad = np.deg2rad(art3d._norm_angle(self.elev))
  1106. roll_rad = np.deg2rad(art3d._norm_angle(self.roll))
  1107. # Look into the middle of the world coordinates
  1108. R = 0.5 * self._roll_to_vertical(self._box_aspect)
  1109. # Define which axis should be vertical. A negative value
  1110. # indicates the plot is upside down and therefore the values
  1111. # have been reversed:
  1112. V = np.zeros(3)
  1113. V[self._vertical_axis] = -1 if abs(elev_rad) > np.pi/2 else 1
  1114. u, v, w = proj3d._view_axes(eye, R, V, roll_rad)
  1115. return u, v, w
  1116. def _set_view_from_bbox(self, bbox, direction='in',
  1117. mode=None, twinx=False, twiny=False):
  1118. """
  1119. Zoom in or out of the bounding box.
  1120. Will center the view in the center of the bounding box, and zoom by
  1121. the ratio of the size of the bounding box to the size of the Axes3D.
  1122. """
  1123. (start_x, start_y, stop_x, stop_y) = bbox
  1124. if mode == 'x':
  1125. start_y = self.bbox.min[1]
  1126. stop_y = self.bbox.max[1]
  1127. elif mode == 'y':
  1128. start_x = self.bbox.min[0]
  1129. stop_x = self.bbox.max[0]
  1130. # Clip to bounding box limits
  1131. start_x, stop_x = np.clip(sorted([start_x, stop_x]),
  1132. self.bbox.min[0], self.bbox.max[0])
  1133. start_y, stop_y = np.clip(sorted([start_y, stop_y]),
  1134. self.bbox.min[1], self.bbox.max[1])
  1135. # Move the center of the view to the center of the bbox
  1136. zoom_center_x = (start_x + stop_x)/2
  1137. zoom_center_y = (start_y + stop_y)/2
  1138. ax_center_x = (self.bbox.max[0] + self.bbox.min[0])/2
  1139. ax_center_y = (self.bbox.max[1] + self.bbox.min[1])/2
  1140. self.start_pan(zoom_center_x, zoom_center_y, 2)
  1141. self.drag_pan(2, None, ax_center_x, ax_center_y)
  1142. self.end_pan()
  1143. # Calculate zoom level
  1144. dx = abs(start_x - stop_x)
  1145. dy = abs(start_y - stop_y)
  1146. scale_u = dx / (self.bbox.max[0] - self.bbox.min[0])
  1147. scale_v = dy / (self.bbox.max[1] - self.bbox.min[1])
  1148. # Keep aspect ratios equal
  1149. scale = max(scale_u, scale_v)
  1150. # Zoom out
  1151. if direction == 'out':
  1152. scale = 1 / scale
  1153. self._zoom_data_limits(scale, scale, scale)
  1154. def _zoom_data_limits(self, scale_u, scale_v, scale_w):
  1155. """
  1156. Zoom in or out of a 3D plot.
  1157. Will scale the data limits by the scale factors. These will be
  1158. transformed to the x, y, z data axes based on the current view angles.
  1159. A scale factor > 1 zooms out and a scale factor < 1 zooms in.
  1160. For an axes that has had its aspect ratio set to 'equal', 'equalxy',
  1161. 'equalyz', or 'equalxz', the relevant axes are constrained to zoom
  1162. equally.
  1163. Parameters
  1164. ----------
  1165. scale_u : float
  1166. Scale factor for the u view axis (view screen horizontal).
  1167. scale_v : float
  1168. Scale factor for the v view axis (view screen vertical).
  1169. scale_w : float
  1170. Scale factor for the w view axis (view screen depth).
  1171. """
  1172. scale = np.array([scale_u, scale_v, scale_w])
  1173. # Only perform frame conversion if unequal scale factors
  1174. if not np.allclose(scale, scale_u):
  1175. # Convert the scale factors from the view frame to the data frame
  1176. R = np.array([self._view_u, self._view_v, self._view_w])
  1177. S = scale * np.eye(3)
  1178. scale = np.linalg.norm(R.T @ S, axis=1)
  1179. # Set the constrained scale factors to the factor closest to 1
  1180. if self._aspect in ('equal', 'equalxy', 'equalxz', 'equalyz'):
  1181. ax_idxs = self._equal_aspect_axis_indices(self._aspect)
  1182. min_ax_idxs = np.argmin(np.abs(scale[ax_idxs] - 1))
  1183. scale[ax_idxs] = scale[ax_idxs][min_ax_idxs]
  1184. self._scale_axis_limits(scale[0], scale[1], scale[2])
  1185. def _scale_axis_limits(self, scale_x, scale_y, scale_z):
  1186. """
  1187. Keeping the center of the x, y, and z data axes fixed, scale their
  1188. limits by scale factors. A scale factor > 1 zooms out and a scale
  1189. factor < 1 zooms in.
  1190. Parameters
  1191. ----------
  1192. scale_x : float
  1193. Scale factor for the x data axis.
  1194. scale_y : float
  1195. Scale factor for the y data axis.
  1196. scale_z : float
  1197. Scale factor for the z data axis.
  1198. """
  1199. # Get the axis centers and ranges
  1200. cx, cy, cz, dx, dy, dz = self._get_w_centers_ranges()
  1201. # Set the scaled axis limits
  1202. self.set_xlim3d(cx - dx*scale_x/2, cx + dx*scale_x/2)
  1203. self.set_ylim3d(cy - dy*scale_y/2, cy + dy*scale_y/2)
  1204. self.set_zlim3d(cz - dz*scale_z/2, cz + dz*scale_z/2)
  1205. def _get_w_centers_ranges(self):
  1206. """Get 3D world centers and axis ranges."""
  1207. # Calculate center of axis limits
  1208. minx, maxx, miny, maxy, minz, maxz = self.get_w_lims()
  1209. cx = (maxx + minx)/2
  1210. cy = (maxy + miny)/2
  1211. cz = (maxz + minz)/2
  1212. # Calculate range of axis limits
  1213. dx = (maxx - minx)
  1214. dy = (maxy - miny)
  1215. dz = (maxz - minz)
  1216. return cx, cy, cz, dx, dy, dz
  1217. def set_zlabel(self, zlabel, fontdict=None, labelpad=None, **kwargs):
  1218. """
  1219. Set zlabel. See doc for `.set_ylabel` for description.
  1220. """
  1221. if labelpad is not None:
  1222. self.zaxis.labelpad = labelpad
  1223. return self.zaxis.set_label_text(zlabel, fontdict, **kwargs)
  1224. def get_zlabel(self):
  1225. """
  1226. Get the z-label text string.
  1227. """
  1228. label = self.zaxis.get_label()
  1229. return label.get_text()
  1230. # Axes rectangle characteristics
  1231. # The frame_on methods are not available for 3D axes.
  1232. # Python will raise a TypeError if they are called.
  1233. get_frame_on = None
  1234. set_frame_on = None
  1235. def grid(self, visible=True, **kwargs):
  1236. """
  1237. Set / unset 3D grid.
  1238. .. note::
  1239. Currently, this function does not behave the same as
  1240. `.axes.Axes.grid`, but it is intended to eventually support that
  1241. behavior.
  1242. """
  1243. # TODO: Operate on each axes separately
  1244. if len(kwargs):
  1245. visible = True
  1246. self._draw_grid = visible
  1247. self.stale = True
  1248. def tick_params(self, axis='both', **kwargs):
  1249. """
  1250. Convenience method for changing the appearance of ticks and
  1251. tick labels.
  1252. See `.Axes.tick_params` for full documentation. Because this function
  1253. applies to 3D Axes, *axis* can also be set to 'z', and setting *axis*
  1254. to 'both' autoscales all three axes.
  1255. Also, because of how Axes3D objects are drawn very differently
  1256. from regular 2D axes, some of these settings may have
  1257. ambiguous meaning. For simplicity, the 'z' axis will
  1258. accept settings as if it was like the 'y' axis.
  1259. .. note::
  1260. Axes3D currently ignores some of these settings.
  1261. """
  1262. _api.check_in_list(['x', 'y', 'z', 'both'], axis=axis)
  1263. if axis in ['x', 'y', 'both']:
  1264. super().tick_params(axis, **kwargs)
  1265. if axis in ['z', 'both']:
  1266. zkw = dict(kwargs)
  1267. zkw.pop('top', None)
  1268. zkw.pop('bottom', None)
  1269. zkw.pop('labeltop', None)
  1270. zkw.pop('labelbottom', None)
  1271. self.zaxis.set_tick_params(**zkw)
  1272. # data limits, ticks, tick labels, and formatting
  1273. def invert_zaxis(self):
  1274. """
  1275. Invert the z-axis.
  1276. See Also
  1277. --------
  1278. zaxis_inverted
  1279. get_zlim, set_zlim
  1280. get_zbound, set_zbound
  1281. """
  1282. bottom, top = self.get_zlim()
  1283. self.set_zlim(top, bottom, auto=None)
  1284. zaxis_inverted = _axis_method_wrapper("zaxis", "get_inverted")
  1285. def get_zbound(self):
  1286. """
  1287. Return the lower and upper z-axis bounds, in increasing order.
  1288. See Also
  1289. --------
  1290. set_zbound
  1291. get_zlim, set_zlim
  1292. invert_zaxis, zaxis_inverted
  1293. """
  1294. bottom, top = self.get_zlim()
  1295. if bottom < top:
  1296. return bottom, top
  1297. else:
  1298. return top, bottom
  1299. def set_zbound(self, lower=None, upper=None):
  1300. """
  1301. Set the lower and upper numerical bounds of the z-axis.
  1302. This method will honor axes inversion regardless of parameter order.
  1303. It will not change the autoscaling setting (`.get_autoscalez_on()`).
  1304. Parameters
  1305. ----------
  1306. lower, upper : float or None
  1307. The lower and upper bounds. If *None*, the respective axis bound
  1308. is not modified.
  1309. See Also
  1310. --------
  1311. get_zbound
  1312. get_zlim, set_zlim
  1313. invert_zaxis, zaxis_inverted
  1314. """
  1315. if upper is None and np.iterable(lower):
  1316. lower, upper = lower
  1317. old_lower, old_upper = self.get_zbound()
  1318. if lower is None:
  1319. lower = old_lower
  1320. if upper is None:
  1321. upper = old_upper
  1322. self.set_zlim(sorted((lower, upper),
  1323. reverse=bool(self.zaxis_inverted())),
  1324. auto=None)
  1325. def text(self, x, y, z, s, zdir=None, **kwargs):
  1326. """
  1327. Add the text *s* to the 3D Axes at location *x*, *y*, *z* in data coordinates.
  1328. Parameters
  1329. ----------
  1330. x, y, z : float
  1331. The position to place the text.
  1332. s : str
  1333. The text.
  1334. zdir : {'x', 'y', 'z', 3-tuple}, optional
  1335. The direction to be used as the z-direction. Default: 'z'.
  1336. See `.get_dir_vector` for a description of the values.
  1337. **kwargs
  1338. Other arguments are forwarded to `matplotlib.axes.Axes.text`.
  1339. Returns
  1340. -------
  1341. `.Text3D`
  1342. The created `.Text3D` instance.
  1343. """
  1344. text = super().text(x, y, s, **kwargs)
  1345. art3d.text_2d_to_3d(text, z, zdir)
  1346. return text
  1347. text3D = text
  1348. text2D = Axes.text
  1349. def plot(self, xs, ys, *args, zdir='z', **kwargs):
  1350. """
  1351. Plot 2D or 3D data.
  1352. Parameters
  1353. ----------
  1354. xs : 1D array-like
  1355. x coordinates of vertices.
  1356. ys : 1D array-like
  1357. y coordinates of vertices.
  1358. zs : float or 1D array-like
  1359. z coordinates of vertices; either one for all points or one for
  1360. each point.
  1361. zdir : {'x', 'y', 'z'}, default: 'z'
  1362. When plotting 2D data, the direction to use as z.
  1363. **kwargs
  1364. Other arguments are forwarded to `matplotlib.axes.Axes.plot`.
  1365. """
  1366. had_data = self.has_data()
  1367. # `zs` can be passed positionally or as keyword; checking whether
  1368. # args[0] is a string matches the behavior of 2D `plot` (via
  1369. # `_process_plot_var_args`).
  1370. if args and not isinstance(args[0], str):
  1371. zs, *args = args
  1372. if 'zs' in kwargs:
  1373. raise TypeError("plot() for multiple values for argument 'z'")
  1374. else:
  1375. zs = kwargs.pop('zs', 0)
  1376. # Match length
  1377. zs = np.broadcast_to(zs, np.shape(xs))
  1378. lines = super().plot(xs, ys, *args, **kwargs)
  1379. for line in lines:
  1380. art3d.line_2d_to_3d(line, zs=zs, zdir=zdir)
  1381. xs, ys, zs = art3d.juggle_axes(xs, ys, zs, zdir)
  1382. self.auto_scale_xyz(xs, ys, zs, had_data)
  1383. return lines
  1384. plot3D = plot
  1385. def plot_surface(self, X, Y, Z, *, norm=None, vmin=None,
  1386. vmax=None, lightsource=None, **kwargs):
  1387. """
  1388. Create a surface plot.
  1389. By default, it will be colored in shades of a solid color, but it also
  1390. supports colormapping by supplying the *cmap* argument.
  1391. .. note::
  1392. The *rcount* and *ccount* kwargs, which both default to 50,
  1393. determine the maximum number of samples used in each direction. If
  1394. the input data is larger, it will be downsampled (by slicing) to
  1395. these numbers of points.
  1396. .. note::
  1397. To maximize rendering speed consider setting *rstride* and *cstride*
  1398. to divisors of the number of rows minus 1 and columns minus 1
  1399. respectively. For example, given 51 rows rstride can be any of the
  1400. divisors of 50.
  1401. Similarly, a setting of *rstride* and *cstride* equal to 1 (or
  1402. *rcount* and *ccount* equal the number of rows and columns) can use
  1403. the optimized path.
  1404. Parameters
  1405. ----------
  1406. X, Y, Z : 2D arrays
  1407. Data values.
  1408. rcount, ccount : int
  1409. Maximum number of samples used in each direction. If the input
  1410. data is larger, it will be downsampled (by slicing) to these
  1411. numbers of points. Defaults to 50.
  1412. rstride, cstride : int
  1413. Downsampling stride in each direction. These arguments are
  1414. mutually exclusive with *rcount* and *ccount*. If only one of
  1415. *rstride* or *cstride* is set, the other defaults to 10.
  1416. 'classic' mode uses a default of ``rstride = cstride = 10`` instead
  1417. of the new default of ``rcount = ccount = 50``.
  1418. color : color-like
  1419. Color of the surface patches.
  1420. cmap : Colormap
  1421. Colormap of the surface patches.
  1422. facecolors : array-like of colors.
  1423. Colors of each individual patch.
  1424. norm : Normalize
  1425. Normalization for the colormap.
  1426. vmin, vmax : float
  1427. Bounds for the normalization.
  1428. shade : bool, default: True
  1429. Whether to shade the facecolors. Shading is always disabled when
  1430. *cmap* is specified.
  1431. lightsource : `~matplotlib.colors.LightSource`
  1432. The lightsource to use when *shade* is True.
  1433. **kwargs
  1434. Other keyword arguments are forwarded to `.Poly3DCollection`.
  1435. """
  1436. had_data = self.has_data()
  1437. if Z.ndim != 2:
  1438. raise ValueError("Argument Z must be 2-dimensional.")
  1439. Z = cbook._to_unmasked_float_array(Z)
  1440. X, Y, Z = np.broadcast_arrays(X, Y, Z)
  1441. rows, cols = Z.shape
  1442. has_stride = 'rstride' in kwargs or 'cstride' in kwargs
  1443. has_count = 'rcount' in kwargs or 'ccount' in kwargs
  1444. if has_stride and has_count:
  1445. raise ValueError("Cannot specify both stride and count arguments")
  1446. rstride = kwargs.pop('rstride', 10)
  1447. cstride = kwargs.pop('cstride', 10)
  1448. rcount = kwargs.pop('rcount', 50)
  1449. ccount = kwargs.pop('ccount', 50)
  1450. if mpl.rcParams['_internal.classic_mode']:
  1451. # Strides have priority over counts in classic mode.
  1452. # So, only compute strides from counts
  1453. # if counts were explicitly given
  1454. compute_strides = has_count
  1455. else:
  1456. # If the strides are provided then it has priority.
  1457. # Otherwise, compute the strides from the counts.
  1458. compute_strides = not has_stride
  1459. if compute_strides:
  1460. rstride = int(max(np.ceil(rows / rcount), 1))
  1461. cstride = int(max(np.ceil(cols / ccount), 1))
  1462. fcolors = kwargs.pop('facecolors', None)
  1463. cmap = kwargs.get('cmap', None)
  1464. shade = kwargs.pop('shade', cmap is None)
  1465. if shade is None:
  1466. raise ValueError("shade cannot be None.")
  1467. colset = [] # the sampled facecolor
  1468. if (rows - 1) % rstride == 0 and \
  1469. (cols - 1) % cstride == 0 and \
  1470. fcolors is None:
  1471. polys = np.stack(
  1472. [cbook._array_patch_perimeters(a, rstride, cstride)
  1473. for a in (X, Y, Z)],
  1474. axis=-1)
  1475. else:
  1476. # evenly spaced, and including both endpoints
  1477. row_inds = list(range(0, rows-1, rstride)) + [rows-1]
  1478. col_inds = list(range(0, cols-1, cstride)) + [cols-1]
  1479. polys = []
  1480. for rs, rs_next in zip(row_inds[:-1], row_inds[1:]):
  1481. for cs, cs_next in zip(col_inds[:-1], col_inds[1:]):
  1482. ps = [
  1483. # +1 ensures we share edges between polygons
  1484. cbook._array_perimeter(a[rs:rs_next+1, cs:cs_next+1])
  1485. for a in (X, Y, Z)
  1486. ]
  1487. # ps = np.stack(ps, axis=-1)
  1488. ps = np.array(ps).T
  1489. polys.append(ps)
  1490. if fcolors is not None:
  1491. colset.append(fcolors[rs][cs])
  1492. # In cases where there are non-finite values in the data (possibly NaNs from
  1493. # masked arrays), artifacts can be introduced. Here check whether such values
  1494. # are present and remove them.
  1495. if not isinstance(polys, np.ndarray) or not np.isfinite(polys).all():
  1496. new_polys = []
  1497. new_colset = []
  1498. # Depending on fcolors, colset is either an empty list or has as
  1499. # many elements as polys. In the former case new_colset results in
  1500. # a list with None entries, that is discarded later.
  1501. for p, col in itertools.zip_longest(polys, colset):
  1502. new_poly = np.array(p)[np.isfinite(p).all(axis=1)]
  1503. if len(new_poly):
  1504. new_polys.append(new_poly)
  1505. new_colset.append(col)
  1506. # Replace previous polys and, if fcolors is not None, colset
  1507. polys = new_polys
  1508. if fcolors is not None:
  1509. colset = new_colset
  1510. # note that the striding causes some polygons to have more coordinates
  1511. # than others
  1512. if fcolors is not None:
  1513. polyc = art3d.Poly3DCollection(
  1514. polys, edgecolors=colset, facecolors=colset, shade=shade,
  1515. lightsource=lightsource, **kwargs)
  1516. elif cmap:
  1517. polyc = art3d.Poly3DCollection(polys, **kwargs)
  1518. # can't always vectorize, because polys might be jagged
  1519. if isinstance(polys, np.ndarray):
  1520. avg_z = polys[..., 2].mean(axis=-1)
  1521. else:
  1522. avg_z = np.array([ps[:, 2].mean() for ps in polys])
  1523. polyc.set_array(avg_z)
  1524. if vmin is not None or vmax is not None:
  1525. polyc.set_clim(vmin, vmax)
  1526. if norm is not None:
  1527. polyc.set_norm(norm)
  1528. else:
  1529. color = kwargs.pop('color', None)
  1530. if color is None:
  1531. color = self._get_lines.get_next_color()
  1532. color = np.array(mcolors.to_rgba(color))
  1533. polyc = art3d.Poly3DCollection(
  1534. polys, facecolors=color, shade=shade,
  1535. lightsource=lightsource, **kwargs)
  1536. self.add_collection(polyc)
  1537. self.auto_scale_xyz(X, Y, Z, had_data)
  1538. return polyc
  1539. def plot_wireframe(self, X, Y, Z, **kwargs):
  1540. """
  1541. Plot a 3D wireframe.
  1542. .. note::
  1543. The *rcount* and *ccount* kwargs, which both default to 50,
  1544. determine the maximum number of samples used in each direction. If
  1545. the input data is larger, it will be downsampled (by slicing) to
  1546. these numbers of points.
  1547. Parameters
  1548. ----------
  1549. X, Y, Z : 2D arrays
  1550. Data values.
  1551. rcount, ccount : int
  1552. Maximum number of samples used in each direction. If the input
  1553. data is larger, it will be downsampled (by slicing) to these
  1554. numbers of points. Setting a count to zero causes the data to be
  1555. not sampled in the corresponding direction, producing a 3D line
  1556. plot rather than a wireframe plot. Defaults to 50.
  1557. rstride, cstride : int
  1558. Downsampling stride in each direction. These arguments are
  1559. mutually exclusive with *rcount* and *ccount*. If only one of
  1560. *rstride* or *cstride* is set, the other defaults to 1. Setting a
  1561. stride to zero causes the data to be not sampled in the
  1562. corresponding direction, producing a 3D line plot rather than a
  1563. wireframe plot.
  1564. 'classic' mode uses a default of ``rstride = cstride = 1`` instead
  1565. of the new default of ``rcount = ccount = 50``.
  1566. **kwargs
  1567. Other keyword arguments are forwarded to `.Line3DCollection`.
  1568. """
  1569. had_data = self.has_data()
  1570. if Z.ndim != 2:
  1571. raise ValueError("Argument Z must be 2-dimensional.")
  1572. # FIXME: Support masked arrays
  1573. X, Y, Z = np.broadcast_arrays(X, Y, Z)
  1574. rows, cols = Z.shape
  1575. has_stride = 'rstride' in kwargs or 'cstride' in kwargs
  1576. has_count = 'rcount' in kwargs or 'ccount' in kwargs
  1577. if has_stride and has_count:
  1578. raise ValueError("Cannot specify both stride and count arguments")
  1579. rstride = kwargs.pop('rstride', 1)
  1580. cstride = kwargs.pop('cstride', 1)
  1581. rcount = kwargs.pop('rcount', 50)
  1582. ccount = kwargs.pop('ccount', 50)
  1583. if mpl.rcParams['_internal.classic_mode']:
  1584. # Strides have priority over counts in classic mode.
  1585. # So, only compute strides from counts
  1586. # if counts were explicitly given
  1587. if has_count:
  1588. rstride = int(max(np.ceil(rows / rcount), 1)) if rcount else 0
  1589. cstride = int(max(np.ceil(cols / ccount), 1)) if ccount else 0
  1590. else:
  1591. # If the strides are provided then it has priority.
  1592. # Otherwise, compute the strides from the counts.
  1593. if not has_stride:
  1594. rstride = int(max(np.ceil(rows / rcount), 1)) if rcount else 0
  1595. cstride = int(max(np.ceil(cols / ccount), 1)) if ccount else 0
  1596. # We want two sets of lines, one running along the "rows" of
  1597. # Z and another set of lines running along the "columns" of Z.
  1598. # This transpose will make it easy to obtain the columns.
  1599. tX, tY, tZ = np.transpose(X), np.transpose(Y), np.transpose(Z)
  1600. if rstride:
  1601. rii = list(range(0, rows, rstride))
  1602. # Add the last index only if needed
  1603. if rows > 0 and rii[-1] != (rows - 1):
  1604. rii += [rows-1]
  1605. else:
  1606. rii = []
  1607. if cstride:
  1608. cii = list(range(0, cols, cstride))
  1609. # Add the last index only if needed
  1610. if cols > 0 and cii[-1] != (cols - 1):
  1611. cii += [cols-1]
  1612. else:
  1613. cii = []
  1614. if rstride == 0 and cstride == 0:
  1615. raise ValueError("Either rstride or cstride must be non zero")
  1616. # If the inputs were empty, then just
  1617. # reset everything.
  1618. if Z.size == 0:
  1619. rii = []
  1620. cii = []
  1621. xlines = [X[i] for i in rii]
  1622. ylines = [Y[i] for i in rii]
  1623. zlines = [Z[i] for i in rii]
  1624. txlines = [tX[i] for i in cii]
  1625. tylines = [tY[i] for i in cii]
  1626. tzlines = [tZ[i] for i in cii]
  1627. lines = ([list(zip(xl, yl, zl))
  1628. for xl, yl, zl in zip(xlines, ylines, zlines)]
  1629. + [list(zip(xl, yl, zl))
  1630. for xl, yl, zl in zip(txlines, tylines, tzlines)])
  1631. linec = art3d.Line3DCollection(lines, **kwargs)
  1632. self.add_collection(linec)
  1633. self.auto_scale_xyz(X, Y, Z, had_data)
  1634. return linec
  1635. def plot_trisurf(self, *args, color=None, norm=None, vmin=None, vmax=None,
  1636. lightsource=None, **kwargs):
  1637. """
  1638. Plot a triangulated surface.
  1639. The (optional) triangulation can be specified in one of two ways;
  1640. either::
  1641. plot_trisurf(triangulation, ...)
  1642. where triangulation is a `~matplotlib.tri.Triangulation` object, or::
  1643. plot_trisurf(X, Y, ...)
  1644. plot_trisurf(X, Y, triangles, ...)
  1645. plot_trisurf(X, Y, triangles=triangles, ...)
  1646. in which case a Triangulation object will be created. See
  1647. `.Triangulation` for an explanation of these possibilities.
  1648. The remaining arguments are::
  1649. plot_trisurf(..., Z)
  1650. where *Z* is the array of values to contour, one per point
  1651. in the triangulation.
  1652. Parameters
  1653. ----------
  1654. X, Y, Z : array-like
  1655. Data values as 1D arrays.
  1656. color
  1657. Color of the surface patches.
  1658. cmap
  1659. A colormap for the surface patches.
  1660. norm : Normalize
  1661. An instance of Normalize to map values to colors.
  1662. vmin, vmax : float, default: None
  1663. Minimum and maximum value to map.
  1664. shade : bool, default: True
  1665. Whether to shade the facecolors. Shading is always disabled when
  1666. *cmap* is specified.
  1667. lightsource : `~matplotlib.colors.LightSource`
  1668. The lightsource to use when *shade* is True.
  1669. **kwargs
  1670. All other keyword arguments are passed on to
  1671. :class:`~mpl_toolkits.mplot3d.art3d.Poly3DCollection`
  1672. Examples
  1673. --------
  1674. .. plot:: gallery/mplot3d/trisurf3d.py
  1675. .. plot:: gallery/mplot3d/trisurf3d_2.py
  1676. """
  1677. had_data = self.has_data()
  1678. # TODO: Support custom face colours
  1679. if color is None:
  1680. color = self._get_lines.get_next_color()
  1681. color = np.array(mcolors.to_rgba(color))
  1682. cmap = kwargs.get('cmap', None)
  1683. shade = kwargs.pop('shade', cmap is None)
  1684. tri, args, kwargs = \
  1685. Triangulation.get_from_args_and_kwargs(*args, **kwargs)
  1686. try:
  1687. z = kwargs.pop('Z')
  1688. except KeyError:
  1689. # We do this so Z doesn't get passed as an arg to PolyCollection
  1690. z, *args = args
  1691. z = np.asarray(z)
  1692. triangles = tri.get_masked_triangles()
  1693. xt = tri.x[triangles]
  1694. yt = tri.y[triangles]
  1695. zt = z[triangles]
  1696. verts = np.stack((xt, yt, zt), axis=-1)
  1697. if cmap:
  1698. polyc = art3d.Poly3DCollection(verts, *args, **kwargs)
  1699. # average over the three points of each triangle
  1700. avg_z = verts[:, :, 2].mean(axis=1)
  1701. polyc.set_array(avg_z)
  1702. if vmin is not None or vmax is not None:
  1703. polyc.set_clim(vmin, vmax)
  1704. if norm is not None:
  1705. polyc.set_norm(norm)
  1706. else:
  1707. polyc = art3d.Poly3DCollection(
  1708. verts, *args, shade=shade, lightsource=lightsource,
  1709. facecolors=color, **kwargs)
  1710. self.add_collection(polyc)
  1711. self.auto_scale_xyz(tri.x, tri.y, z, had_data)
  1712. return polyc
  1713. def _3d_extend_contour(self, cset, stride=5):
  1714. """
  1715. Extend a contour in 3D by creating
  1716. """
  1717. dz = (cset.levels[1] - cset.levels[0]) / 2
  1718. polyverts = []
  1719. colors = []
  1720. for idx, level in enumerate(cset.levels):
  1721. path = cset.get_paths()[idx]
  1722. subpaths = [*path._iter_connected_components()]
  1723. color = cset.get_edgecolor()[idx]
  1724. top = art3d._paths_to_3d_segments(subpaths, level - dz)
  1725. bot = art3d._paths_to_3d_segments(subpaths, level + dz)
  1726. if not len(top[0]):
  1727. continue
  1728. nsteps = max(round(len(top[0]) / stride), 2)
  1729. stepsize = (len(top[0]) - 1) / (nsteps - 1)
  1730. polyverts.extend([
  1731. (top[0][round(i * stepsize)], top[0][round((i + 1) * stepsize)],
  1732. bot[0][round((i + 1) * stepsize)], bot[0][round(i * stepsize)])
  1733. for i in range(round(nsteps) - 1)])
  1734. colors.extend([color] * (round(nsteps) - 1))
  1735. self.add_collection3d(art3d.Poly3DCollection(
  1736. np.array(polyverts), # All polygons have 4 vertices, so vectorize.
  1737. facecolors=colors, edgecolors=colors, shade=True))
  1738. cset.remove()
  1739. def add_contour_set(
  1740. self, cset, extend3d=False, stride=5, zdir='z', offset=None):
  1741. zdir = '-' + zdir
  1742. if extend3d:
  1743. self._3d_extend_contour(cset, stride)
  1744. else:
  1745. art3d.collection_2d_to_3d(
  1746. cset, zs=offset if offset is not None else cset.levels, zdir=zdir)
  1747. def add_contourf_set(self, cset, zdir='z', offset=None):
  1748. self._add_contourf_set(cset, zdir=zdir, offset=offset)
  1749. def _add_contourf_set(self, cset, zdir='z', offset=None):
  1750. """
  1751. Returns
  1752. -------
  1753. levels : `numpy.ndarray`
  1754. Levels at which the filled contours are added.
  1755. """
  1756. zdir = '-' + zdir
  1757. midpoints = cset.levels[:-1] + np.diff(cset.levels) / 2
  1758. # Linearly interpolate to get levels for any extensions
  1759. if cset._extend_min:
  1760. min_level = cset.levels[0] - np.diff(cset.levels[:2]) / 2
  1761. midpoints = np.insert(midpoints, 0, min_level)
  1762. if cset._extend_max:
  1763. max_level = cset.levels[-1] + np.diff(cset.levels[-2:]) / 2
  1764. midpoints = np.append(midpoints, max_level)
  1765. art3d.collection_2d_to_3d(
  1766. cset, zs=offset if offset is not None else midpoints, zdir=zdir)
  1767. return midpoints
  1768. @_preprocess_data()
  1769. def contour(self, X, Y, Z, *args,
  1770. extend3d=False, stride=5, zdir='z', offset=None, **kwargs):
  1771. """
  1772. Create a 3D contour plot.
  1773. Parameters
  1774. ----------
  1775. X, Y, Z : array-like,
  1776. Input data. See `.Axes.contour` for supported data shapes.
  1777. extend3d : bool, default: False
  1778. Whether to extend contour in 3D.
  1779. stride : int
  1780. Step size for extending contour.
  1781. zdir : {'x', 'y', 'z'}, default: 'z'
  1782. The direction to use.
  1783. offset : float, optional
  1784. If specified, plot a projection of the contour lines at this
  1785. position in a plane normal to *zdir*.
  1786. data : indexable object, optional
  1787. DATA_PARAMETER_PLACEHOLDER
  1788. *args, **kwargs
  1789. Other arguments are forwarded to `matplotlib.axes.Axes.contour`.
  1790. Returns
  1791. -------
  1792. matplotlib.contour.QuadContourSet
  1793. """
  1794. had_data = self.has_data()
  1795. jX, jY, jZ = art3d.rotate_axes(X, Y, Z, zdir)
  1796. cset = super().contour(jX, jY, jZ, *args, **kwargs)
  1797. self.add_contour_set(cset, extend3d, stride, zdir, offset)
  1798. self.auto_scale_xyz(X, Y, Z, had_data)
  1799. return cset
  1800. contour3D = contour
  1801. @_preprocess_data()
  1802. def tricontour(self, *args,
  1803. extend3d=False, stride=5, zdir='z', offset=None, **kwargs):
  1804. """
  1805. Create a 3D contour plot.
  1806. .. note::
  1807. This method currently produces incorrect output due to a
  1808. longstanding bug in 3D PolyCollection rendering.
  1809. Parameters
  1810. ----------
  1811. X, Y, Z : array-like
  1812. Input data. See `.Axes.tricontour` for supported data shapes.
  1813. extend3d : bool, default: False
  1814. Whether to extend contour in 3D.
  1815. stride : int
  1816. Step size for extending contour.
  1817. zdir : {'x', 'y', 'z'}, default: 'z'
  1818. The direction to use.
  1819. offset : float, optional
  1820. If specified, plot a projection of the contour lines at this
  1821. position in a plane normal to *zdir*.
  1822. data : indexable object, optional
  1823. DATA_PARAMETER_PLACEHOLDER
  1824. *args, **kwargs
  1825. Other arguments are forwarded to `matplotlib.axes.Axes.tricontour`.
  1826. Returns
  1827. -------
  1828. matplotlib.tri._tricontour.TriContourSet
  1829. """
  1830. had_data = self.has_data()
  1831. tri, args, kwargs = Triangulation.get_from_args_and_kwargs(
  1832. *args, **kwargs)
  1833. X = tri.x
  1834. Y = tri.y
  1835. if 'Z' in kwargs:
  1836. Z = kwargs.pop('Z')
  1837. else:
  1838. # We do this so Z doesn't get passed as an arg to Axes.tricontour
  1839. Z, *args = args
  1840. jX, jY, jZ = art3d.rotate_axes(X, Y, Z, zdir)
  1841. tri = Triangulation(jX, jY, tri.triangles, tri.mask)
  1842. cset = super().tricontour(tri, jZ, *args, **kwargs)
  1843. self.add_contour_set(cset, extend3d, stride, zdir, offset)
  1844. self.auto_scale_xyz(X, Y, Z, had_data)
  1845. return cset
  1846. def _auto_scale_contourf(self, X, Y, Z, zdir, levels, had_data):
  1847. # Autoscale in the zdir based on the levels added, which are
  1848. # different from data range if any contour extensions are present
  1849. dim_vals = {'x': X, 'y': Y, 'z': Z, zdir: levels}
  1850. # Input data and levels have different sizes, but auto_scale_xyz
  1851. # expected same-size input, so manually take min/max limits
  1852. limits = [(np.nanmin(dim_vals[dim]), np.nanmax(dim_vals[dim]))
  1853. for dim in ['x', 'y', 'z']]
  1854. self.auto_scale_xyz(*limits, had_data)
  1855. @_preprocess_data()
  1856. def contourf(self, X, Y, Z, *args, zdir='z', offset=None, **kwargs):
  1857. """
  1858. Create a 3D filled contour plot.
  1859. Parameters
  1860. ----------
  1861. X, Y, Z : array-like
  1862. Input data. See `.Axes.contourf` for supported data shapes.
  1863. zdir : {'x', 'y', 'z'}, default: 'z'
  1864. The direction to use.
  1865. offset : float, optional
  1866. If specified, plot a projection of the contour lines at this
  1867. position in a plane normal to *zdir*.
  1868. data : indexable object, optional
  1869. DATA_PARAMETER_PLACEHOLDER
  1870. *args, **kwargs
  1871. Other arguments are forwarded to `matplotlib.axes.Axes.contourf`.
  1872. Returns
  1873. -------
  1874. matplotlib.contour.QuadContourSet
  1875. """
  1876. had_data = self.has_data()
  1877. jX, jY, jZ = art3d.rotate_axes(X, Y, Z, zdir)
  1878. cset = super().contourf(jX, jY, jZ, *args, **kwargs)
  1879. levels = self._add_contourf_set(cset, zdir, offset)
  1880. self._auto_scale_contourf(X, Y, Z, zdir, levels, had_data)
  1881. return cset
  1882. contourf3D = contourf
  1883. @_preprocess_data()
  1884. def tricontourf(self, *args, zdir='z', offset=None, **kwargs):
  1885. """
  1886. Create a 3D filled contour plot.
  1887. .. note::
  1888. This method currently produces incorrect output due to a
  1889. longstanding bug in 3D PolyCollection rendering.
  1890. Parameters
  1891. ----------
  1892. X, Y, Z : array-like
  1893. Input data. See `.Axes.tricontourf` for supported data shapes.
  1894. zdir : {'x', 'y', 'z'}, default: 'z'
  1895. The direction to use.
  1896. offset : float, optional
  1897. If specified, plot a projection of the contour lines at this
  1898. position in a plane normal to zdir.
  1899. data : indexable object, optional
  1900. DATA_PARAMETER_PLACEHOLDER
  1901. *args, **kwargs
  1902. Other arguments are forwarded to
  1903. `matplotlib.axes.Axes.tricontourf`.
  1904. Returns
  1905. -------
  1906. matplotlib.tri._tricontour.TriContourSet
  1907. """
  1908. had_data = self.has_data()
  1909. tri, args, kwargs = Triangulation.get_from_args_and_kwargs(
  1910. *args, **kwargs)
  1911. X = tri.x
  1912. Y = tri.y
  1913. if 'Z' in kwargs:
  1914. Z = kwargs.pop('Z')
  1915. else:
  1916. # We do this so Z doesn't get passed as an arg to Axes.tricontourf
  1917. Z, *args = args
  1918. jX, jY, jZ = art3d.rotate_axes(X, Y, Z, zdir)
  1919. tri = Triangulation(jX, jY, tri.triangles, tri.mask)
  1920. cset = super().tricontourf(tri, jZ, *args, **kwargs)
  1921. levels = self._add_contourf_set(cset, zdir, offset)
  1922. self._auto_scale_contourf(X, Y, Z, zdir, levels, had_data)
  1923. return cset
  1924. def add_collection3d(self, col, zs=0, zdir='z'):
  1925. """
  1926. Add a 3D collection object to the plot.
  1927. 2D collection types are converted to a 3D version by
  1928. modifying the object and adding z coordinate information.
  1929. Supported are:
  1930. - PolyCollection
  1931. - LineCollection
  1932. - PatchCollection
  1933. """
  1934. zvals = np.atleast_1d(zs)
  1935. zsortval = (np.min(zvals) if zvals.size
  1936. else 0) # FIXME: arbitrary default
  1937. # FIXME: use issubclass() (although, then a 3D collection
  1938. # object would also pass.) Maybe have a collection3d
  1939. # abstract class to test for and exclude?
  1940. if type(col) is mcoll.PolyCollection:
  1941. art3d.poly_collection_2d_to_3d(col, zs=zs, zdir=zdir)
  1942. col.set_sort_zpos(zsortval)
  1943. elif type(col) is mcoll.LineCollection:
  1944. art3d.line_collection_2d_to_3d(col, zs=zs, zdir=zdir)
  1945. col.set_sort_zpos(zsortval)
  1946. elif type(col) is mcoll.PatchCollection:
  1947. art3d.patch_collection_2d_to_3d(col, zs=zs, zdir=zdir)
  1948. col.set_sort_zpos(zsortval)
  1949. collection = super().add_collection(col)
  1950. return collection
  1951. @_preprocess_data(replace_names=["xs", "ys", "zs", "s",
  1952. "edgecolors", "c", "facecolor",
  1953. "facecolors", "color"])
  1954. def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True,
  1955. *args, **kwargs):
  1956. """
  1957. Create a scatter plot.
  1958. Parameters
  1959. ----------
  1960. xs, ys : array-like
  1961. The data positions.
  1962. zs : float or array-like, default: 0
  1963. The z-positions. Either an array of the same length as *xs* and
  1964. *ys* or a single value to place all points in the same plane.
  1965. zdir : {'x', 'y', 'z', '-x', '-y', '-z'}, default: 'z'
  1966. The axis direction for the *zs*. This is useful when plotting 2D
  1967. data on a 3D Axes. The data must be passed as *xs*, *ys*. Setting
  1968. *zdir* to 'y' then plots the data to the x-z-plane.
  1969. See also :doc:`/gallery/mplot3d/2dcollections3d`.
  1970. s : float or array-like, default: 20
  1971. The marker size in points**2. Either an array of the same length
  1972. as *xs* and *ys* or a single value to make all markers the same
  1973. size.
  1974. c : color, sequence, or sequence of colors, optional
  1975. The marker color. Possible values:
  1976. - A single color format string.
  1977. - A sequence of colors of length n.
  1978. - A sequence of n numbers to be mapped to colors using *cmap* and
  1979. *norm*.
  1980. - A 2D array in which the rows are RGB or RGBA.
  1981. For more details see the *c* argument of `~.axes.Axes.scatter`.
  1982. depthshade : bool, default: True
  1983. Whether to shade the scatter markers to give the appearance of
  1984. depth. Each call to ``scatter()`` will perform its depthshading
  1985. independently.
  1986. data : indexable object, optional
  1987. DATA_PARAMETER_PLACEHOLDER
  1988. **kwargs
  1989. All other keyword arguments are passed on to `~.axes.Axes.scatter`.
  1990. Returns
  1991. -------
  1992. paths : `~matplotlib.collections.PathCollection`
  1993. """
  1994. had_data = self.has_data()
  1995. zs_orig = zs
  1996. xs, ys, zs = np.broadcast_arrays(
  1997. *[np.ravel(np.ma.filled(t, np.nan)) for t in [xs, ys, zs]])
  1998. s = np.ma.ravel(s) # This doesn't have to match x, y in size.
  1999. xs, ys, zs, s, c, color = cbook.delete_masked_points(
  2000. xs, ys, zs, s, c, kwargs.get('color', None)
  2001. )
  2002. if kwargs.get("color") is not None:
  2003. kwargs['color'] = color
  2004. # For xs and ys, 2D scatter() will do the copying.
  2005. if np.may_share_memory(zs_orig, zs): # Avoid unnecessary copies.
  2006. zs = zs.copy()
  2007. patches = super().scatter(xs, ys, s=s, c=c, *args, **kwargs)
  2008. art3d.patch_collection_2d_to_3d(patches, zs=zs, zdir=zdir,
  2009. depthshade=depthshade)
  2010. if self._zmargin < 0.05 and xs.size > 0:
  2011. self.set_zmargin(0.05)
  2012. self.auto_scale_xyz(xs, ys, zs, had_data)
  2013. return patches
  2014. scatter3D = scatter
  2015. @_preprocess_data()
  2016. def bar(self, left, height, zs=0, zdir='z', *args, **kwargs):
  2017. """
  2018. Add 2D bar(s).
  2019. Parameters
  2020. ----------
  2021. left : 1D array-like
  2022. The x coordinates of the left sides of the bars.
  2023. height : 1D array-like
  2024. The height of the bars.
  2025. zs : float or 1D array-like
  2026. Z coordinate of bars; if a single value is specified, it will be
  2027. used for all bars.
  2028. zdir : {'x', 'y', 'z'}, default: 'z'
  2029. When plotting 2D data, the direction to use as z ('x', 'y' or 'z').
  2030. data : indexable object, optional
  2031. DATA_PARAMETER_PLACEHOLDER
  2032. **kwargs
  2033. Other keyword arguments are forwarded to
  2034. `matplotlib.axes.Axes.bar`.
  2035. Returns
  2036. -------
  2037. mpl_toolkits.mplot3d.art3d.Patch3DCollection
  2038. """
  2039. had_data = self.has_data()
  2040. patches = super().bar(left, height, *args, **kwargs)
  2041. zs = np.broadcast_to(zs, len(left))
  2042. verts = []
  2043. verts_zs = []
  2044. for p, z in zip(patches, zs):
  2045. vs = art3d._get_patch_verts(p)
  2046. verts += vs.tolist()
  2047. verts_zs += [z] * len(vs)
  2048. art3d.patch_2d_to_3d(p, z, zdir)
  2049. if 'alpha' in kwargs:
  2050. p.set_alpha(kwargs['alpha'])
  2051. if len(verts) > 0:
  2052. # the following has to be skipped if verts is empty
  2053. # NOTE: Bugs could still occur if len(verts) > 0,
  2054. # but the "2nd dimension" is empty.
  2055. xs, ys = zip(*verts)
  2056. else:
  2057. xs, ys = [], []
  2058. xs, ys, verts_zs = art3d.juggle_axes(xs, ys, verts_zs, zdir)
  2059. self.auto_scale_xyz(xs, ys, verts_zs, had_data)
  2060. return patches
  2061. @_preprocess_data()
  2062. def bar3d(self, x, y, z, dx, dy, dz, color=None,
  2063. zsort='average', shade=True, lightsource=None, *args, **kwargs):
  2064. """
  2065. Generate a 3D barplot.
  2066. This method creates three-dimensional barplot where the width,
  2067. depth, height, and color of the bars can all be uniquely set.
  2068. Parameters
  2069. ----------
  2070. x, y, z : array-like
  2071. The coordinates of the anchor point of the bars.
  2072. dx, dy, dz : float or array-like
  2073. The width, depth, and height of the bars, respectively.
  2074. color : sequence of colors, optional
  2075. The color of the bars can be specified globally or
  2076. individually. This parameter can be:
  2077. - A single color, to color all bars the same color.
  2078. - An array of colors of length N bars, to color each bar
  2079. independently.
  2080. - An array of colors of length 6, to color the faces of the
  2081. bars similarly.
  2082. - An array of colors of length 6 * N bars, to color each face
  2083. independently.
  2084. When coloring the faces of the boxes specifically, this is
  2085. the order of the coloring:
  2086. 1. -Z (bottom of box)
  2087. 2. +Z (top of box)
  2088. 3. -Y
  2089. 4. +Y
  2090. 5. -X
  2091. 6. +X
  2092. zsort : str, optional
  2093. The z-axis sorting scheme passed onto `~.art3d.Poly3DCollection`
  2094. shade : bool, default: True
  2095. When true, this shades the dark sides of the bars (relative
  2096. to the plot's source of light).
  2097. lightsource : `~matplotlib.colors.LightSource`
  2098. The lightsource to use when *shade* is True.
  2099. data : indexable object, optional
  2100. DATA_PARAMETER_PLACEHOLDER
  2101. **kwargs
  2102. Any additional keyword arguments are passed onto
  2103. `~.art3d.Poly3DCollection`.
  2104. Returns
  2105. -------
  2106. collection : `~.art3d.Poly3DCollection`
  2107. A collection of three-dimensional polygons representing the bars.
  2108. """
  2109. had_data = self.has_data()
  2110. x, y, z, dx, dy, dz = np.broadcast_arrays(
  2111. np.atleast_1d(x), y, z, dx, dy, dz)
  2112. minx = np.min(x)
  2113. maxx = np.max(x + dx)
  2114. miny = np.min(y)
  2115. maxy = np.max(y + dy)
  2116. minz = np.min(z)
  2117. maxz = np.max(z + dz)
  2118. # shape (6, 4, 3)
  2119. # All faces are oriented facing outwards - when viewed from the
  2120. # outside, their vertices are in a counterclockwise ordering.
  2121. cuboid = np.array([
  2122. # -z
  2123. (
  2124. (0, 0, 0),
  2125. (0, 1, 0),
  2126. (1, 1, 0),
  2127. (1, 0, 0),
  2128. ),
  2129. # +z
  2130. (
  2131. (0, 0, 1),
  2132. (1, 0, 1),
  2133. (1, 1, 1),
  2134. (0, 1, 1),
  2135. ),
  2136. # -y
  2137. (
  2138. (0, 0, 0),
  2139. (1, 0, 0),
  2140. (1, 0, 1),
  2141. (0, 0, 1),
  2142. ),
  2143. # +y
  2144. (
  2145. (0, 1, 0),
  2146. (0, 1, 1),
  2147. (1, 1, 1),
  2148. (1, 1, 0),
  2149. ),
  2150. # -x
  2151. (
  2152. (0, 0, 0),
  2153. (0, 0, 1),
  2154. (0, 1, 1),
  2155. (0, 1, 0),
  2156. ),
  2157. # +x
  2158. (
  2159. (1, 0, 0),
  2160. (1, 1, 0),
  2161. (1, 1, 1),
  2162. (1, 0, 1),
  2163. ),
  2164. ])
  2165. # indexed by [bar, face, vertex, coord]
  2166. polys = np.empty(x.shape + cuboid.shape)
  2167. # handle each coordinate separately
  2168. for i, p, dp in [(0, x, dx), (1, y, dy), (2, z, dz)]:
  2169. p = p[..., np.newaxis, np.newaxis]
  2170. dp = dp[..., np.newaxis, np.newaxis]
  2171. polys[..., i] = p + dp * cuboid[..., i]
  2172. # collapse the first two axes
  2173. polys = polys.reshape((-1,) + polys.shape[2:])
  2174. facecolors = []
  2175. if color is None:
  2176. color = [self._get_patches_for_fill.get_next_color()]
  2177. color = list(mcolors.to_rgba_array(color))
  2178. if len(color) == len(x):
  2179. # bar colors specified, need to expand to number of faces
  2180. for c in color:
  2181. facecolors.extend([c] * 6)
  2182. else:
  2183. # a single color specified, or face colors specified explicitly
  2184. facecolors = color
  2185. if len(facecolors) < len(x):
  2186. facecolors *= (6 * len(x))
  2187. col = art3d.Poly3DCollection(polys,
  2188. zsort=zsort,
  2189. facecolors=facecolors,
  2190. shade=shade,
  2191. lightsource=lightsource,
  2192. *args, **kwargs)
  2193. self.add_collection(col)
  2194. self.auto_scale_xyz((minx, maxx), (miny, maxy), (minz, maxz), had_data)
  2195. return col
  2196. def set_title(self, label, fontdict=None, loc='center', **kwargs):
  2197. # docstring inherited
  2198. ret = super().set_title(label, fontdict=fontdict, loc=loc, **kwargs)
  2199. (x, y) = self.title.get_position()
  2200. self.title.set_y(0.92 * y)
  2201. return ret
  2202. @_preprocess_data()
  2203. def quiver(self, X, Y, Z, U, V, W, *,
  2204. length=1, arrow_length_ratio=.3, pivot='tail', normalize=False,
  2205. **kwargs):
  2206. """
  2207. Plot a 3D field of arrows.
  2208. The arguments can be array-like or scalars, so long as they can be
  2209. broadcast together. The arguments can also be masked arrays. If an
  2210. element in any of argument is masked, then that corresponding quiver
  2211. element will not be plotted.
  2212. Parameters
  2213. ----------
  2214. X, Y, Z : array-like
  2215. The x, y and z coordinates of the arrow locations (default is
  2216. tail of arrow; see *pivot* kwarg).
  2217. U, V, W : array-like
  2218. The x, y and z components of the arrow vectors.
  2219. length : float, default: 1
  2220. The length of each quiver.
  2221. arrow_length_ratio : float, default: 0.3
  2222. The ratio of the arrow head with respect to the quiver.
  2223. pivot : {'tail', 'middle', 'tip'}, default: 'tail'
  2224. The part of the arrow that is at the grid point; the arrow
  2225. rotates about this point, hence the name *pivot*.
  2226. normalize : bool, default: False
  2227. Whether all arrows are normalized to have the same length, or keep
  2228. the lengths defined by *u*, *v*, and *w*.
  2229. data : indexable object, optional
  2230. DATA_PARAMETER_PLACEHOLDER
  2231. **kwargs
  2232. Any additional keyword arguments are delegated to
  2233. :class:`.Line3DCollection`
  2234. """
  2235. def calc_arrows(UVW):
  2236. # get unit direction vector perpendicular to (u, v, w)
  2237. x = UVW[:, 0]
  2238. y = UVW[:, 1]
  2239. norm = np.linalg.norm(UVW[:, :2], axis=1)
  2240. x_p = np.divide(y, norm, where=norm != 0, out=np.zeros_like(x))
  2241. y_p = np.divide(-x, norm, where=norm != 0, out=np.ones_like(x))
  2242. # compute the two arrowhead direction unit vectors
  2243. rangle = math.radians(15)
  2244. c = math.cos(rangle)
  2245. s = math.sin(rangle)
  2246. # construct the rotation matrices of shape (3, 3, n)
  2247. r13 = y_p * s
  2248. r32 = x_p * s
  2249. r12 = x_p * y_p * (1 - c)
  2250. Rpos = np.array(
  2251. [[c + (x_p ** 2) * (1 - c), r12, r13],
  2252. [r12, c + (y_p ** 2) * (1 - c), -r32],
  2253. [-r13, r32, np.full_like(x_p, c)]])
  2254. # opposite rotation negates all the sin terms
  2255. Rneg = Rpos.copy()
  2256. Rneg[[0, 1, 2, 2], [2, 2, 0, 1]] *= -1
  2257. # Batch n (3, 3) x (3) matrix multiplications ((3, 3, n) x (n, 3)).
  2258. Rpos_vecs = np.einsum("ij...,...j->...i", Rpos, UVW)
  2259. Rneg_vecs = np.einsum("ij...,...j->...i", Rneg, UVW)
  2260. # Stack into (n, 2, 3) result.
  2261. return np.stack([Rpos_vecs, Rneg_vecs], axis=1)
  2262. had_data = self.has_data()
  2263. input_args = [X, Y, Z, U, V, W]
  2264. # extract the masks, if any
  2265. masks = [k.mask for k in input_args
  2266. if isinstance(k, np.ma.MaskedArray)]
  2267. # broadcast to match the shape
  2268. bcast = np.broadcast_arrays(*input_args, *masks)
  2269. input_args = bcast[:6]
  2270. masks = bcast[6:]
  2271. if masks:
  2272. # combine the masks into one
  2273. mask = functools.reduce(np.logical_or, masks)
  2274. # put mask on and compress
  2275. input_args = [np.ma.array(k, mask=mask).compressed()
  2276. for k in input_args]
  2277. else:
  2278. input_args = [np.ravel(k) for k in input_args]
  2279. if any(len(v) == 0 for v in input_args):
  2280. # No quivers, so just make an empty collection and return early
  2281. linec = art3d.Line3DCollection([], **kwargs)
  2282. self.add_collection(linec)
  2283. return linec
  2284. shaft_dt = np.array([0., length], dtype=float)
  2285. arrow_dt = shaft_dt * arrow_length_ratio
  2286. _api.check_in_list(['tail', 'middle', 'tip'], pivot=pivot)
  2287. if pivot == 'tail':
  2288. shaft_dt -= length
  2289. elif pivot == 'middle':
  2290. shaft_dt -= length / 2
  2291. XYZ = np.column_stack(input_args[:3])
  2292. UVW = np.column_stack(input_args[3:]).astype(float)
  2293. # Normalize rows of UVW
  2294. norm = np.linalg.norm(UVW, axis=1)
  2295. # If any row of UVW is all zeros, don't make a quiver for it
  2296. mask = norm > 0
  2297. XYZ = XYZ[mask]
  2298. if normalize:
  2299. UVW = UVW[mask] / norm[mask].reshape((-1, 1))
  2300. else:
  2301. UVW = UVW[mask]
  2302. if len(XYZ) > 0:
  2303. # compute the shaft lines all at once with an outer product
  2304. shafts = (XYZ - np.multiply.outer(shaft_dt, UVW)).swapaxes(0, 1)
  2305. # compute head direction vectors, n heads x 2 sides x 3 dimensions
  2306. head_dirs = calc_arrows(UVW)
  2307. # compute all head lines at once, starting from the shaft ends
  2308. heads = shafts[:, :1] - np.multiply.outer(arrow_dt, head_dirs)
  2309. # stack left and right head lines together
  2310. heads = heads.reshape((len(arrow_dt), -1, 3))
  2311. # transpose to get a list of lines
  2312. heads = heads.swapaxes(0, 1)
  2313. lines = [*shafts, *heads]
  2314. else:
  2315. lines = []
  2316. linec = art3d.Line3DCollection(lines, **kwargs)
  2317. self.add_collection(linec)
  2318. self.auto_scale_xyz(XYZ[:, 0], XYZ[:, 1], XYZ[:, 2], had_data)
  2319. return linec
  2320. quiver3D = quiver
  2321. def voxels(self, *args, facecolors=None, edgecolors=None, shade=True,
  2322. lightsource=None, **kwargs):
  2323. """
  2324. ax.voxels([x, y, z,] /, filled, facecolors=None, edgecolors=None, \
  2325. **kwargs)
  2326. Plot a set of filled voxels
  2327. All voxels are plotted as 1x1x1 cubes on the axis, with
  2328. ``filled[0, 0, 0]`` placed with its lower corner at the origin.
  2329. Occluded faces are not plotted.
  2330. Parameters
  2331. ----------
  2332. filled : 3D np.array of bool
  2333. A 3D array of values, with truthy values indicating which voxels
  2334. to fill
  2335. x, y, z : 3D np.array, optional
  2336. The coordinates of the corners of the voxels. This should broadcast
  2337. to a shape one larger in every dimension than the shape of
  2338. *filled*. These can be used to plot non-cubic voxels.
  2339. If not specified, defaults to increasing integers along each axis,
  2340. like those returned by :func:`~numpy.indices`.
  2341. As indicated by the ``/`` in the function signature, these
  2342. arguments can only be passed positionally.
  2343. facecolors, edgecolors : array-like, optional
  2344. The color to draw the faces and edges of the voxels. Can only be
  2345. passed as keyword arguments.
  2346. These parameters can be:
  2347. - A single color value, to color all voxels the same color. This
  2348. can be either a string, or a 1D RGB/RGBA array
  2349. - ``None``, the default, to use a single color for the faces, and
  2350. the style default for the edges.
  2351. - A 3D `~numpy.ndarray` of color names, with each item the color
  2352. for the corresponding voxel. The size must match the voxels.
  2353. - A 4D `~numpy.ndarray` of RGB/RGBA data, with the components
  2354. along the last axis.
  2355. shade : bool, default: True
  2356. Whether to shade the facecolors.
  2357. lightsource : `~matplotlib.colors.LightSource`
  2358. The lightsource to use when *shade* is True.
  2359. **kwargs
  2360. Additional keyword arguments to pass onto
  2361. `~mpl_toolkits.mplot3d.art3d.Poly3DCollection`.
  2362. Returns
  2363. -------
  2364. faces : dict
  2365. A dictionary indexed by coordinate, where ``faces[i, j, k]`` is a
  2366. `.Poly3DCollection` of the faces drawn for the voxel
  2367. ``filled[i, j, k]``. If no faces were drawn for a given voxel,
  2368. either because it was not asked to be drawn, or it is fully
  2369. occluded, then ``(i, j, k) not in faces``.
  2370. Examples
  2371. --------
  2372. .. plot:: gallery/mplot3d/voxels.py
  2373. .. plot:: gallery/mplot3d/voxels_rgb.py
  2374. .. plot:: gallery/mplot3d/voxels_torus.py
  2375. .. plot:: gallery/mplot3d/voxels_numpy_logo.py
  2376. """
  2377. # work out which signature we should be using, and use it to parse
  2378. # the arguments. Name must be voxels for the correct error message
  2379. if len(args) >= 3:
  2380. # underscores indicate position only
  2381. def voxels(__x, __y, __z, filled, **kwargs):
  2382. return (__x, __y, __z), filled, kwargs
  2383. else:
  2384. def voxels(filled, **kwargs):
  2385. return None, filled, kwargs
  2386. xyz, filled, kwargs = voxels(*args, **kwargs)
  2387. # check dimensions
  2388. if filled.ndim != 3:
  2389. raise ValueError("Argument filled must be 3-dimensional")
  2390. size = np.array(filled.shape, dtype=np.intp)
  2391. # check xyz coordinates, which are one larger than the filled shape
  2392. coord_shape = tuple(size + 1)
  2393. if xyz is None:
  2394. x, y, z = np.indices(coord_shape)
  2395. else:
  2396. x, y, z = (np.broadcast_to(c, coord_shape) for c in xyz)
  2397. def _broadcast_color_arg(color, name):
  2398. if np.ndim(color) in (0, 1):
  2399. # single color, like "red" or [1, 0, 0]
  2400. return np.broadcast_to(color, filled.shape + np.shape(color))
  2401. elif np.ndim(color) in (3, 4):
  2402. # 3D array of strings, or 4D array with last axis rgb
  2403. if np.shape(color)[:3] != filled.shape:
  2404. raise ValueError(
  2405. f"When multidimensional, {name} must match the shape "
  2406. "of filled")
  2407. return color
  2408. else:
  2409. raise ValueError(f"Invalid {name} argument")
  2410. # broadcast and default on facecolors
  2411. if facecolors is None:
  2412. facecolors = self._get_patches_for_fill.get_next_color()
  2413. facecolors = _broadcast_color_arg(facecolors, 'facecolors')
  2414. # broadcast but no default on edgecolors
  2415. edgecolors = _broadcast_color_arg(edgecolors, 'edgecolors')
  2416. # scale to the full array, even if the data is only in the center
  2417. self.auto_scale_xyz(x, y, z)
  2418. # points lying on corners of a square
  2419. square = np.array([
  2420. [0, 0, 0],
  2421. [1, 0, 0],
  2422. [1, 1, 0],
  2423. [0, 1, 0],
  2424. ], dtype=np.intp)
  2425. voxel_faces = defaultdict(list)
  2426. def permutation_matrices(n):
  2427. """Generate cyclic permutation matrices."""
  2428. mat = np.eye(n, dtype=np.intp)
  2429. for i in range(n):
  2430. yield mat
  2431. mat = np.roll(mat, 1, axis=0)
  2432. # iterate over each of the YZ, ZX, and XY orientations, finding faces
  2433. # to render
  2434. for permute in permutation_matrices(3):
  2435. # find the set of ranges to iterate over
  2436. pc, qc, rc = permute.T.dot(size)
  2437. pinds = np.arange(pc)
  2438. qinds = np.arange(qc)
  2439. rinds = np.arange(rc)
  2440. square_rot_pos = square.dot(permute.T)
  2441. square_rot_neg = square_rot_pos[::-1]
  2442. # iterate within the current plane
  2443. for p in pinds:
  2444. for q in qinds:
  2445. # iterate perpendicularly to the current plane, handling
  2446. # boundaries. We only draw faces between a voxel and an
  2447. # empty space, to avoid drawing internal faces.
  2448. # draw lower faces
  2449. p0 = permute.dot([p, q, 0])
  2450. i0 = tuple(p0)
  2451. if filled[i0]:
  2452. voxel_faces[i0].append(p0 + square_rot_neg)
  2453. # draw middle faces
  2454. for r1, r2 in zip(rinds[:-1], rinds[1:]):
  2455. p1 = permute.dot([p, q, r1])
  2456. p2 = permute.dot([p, q, r2])
  2457. i1 = tuple(p1)
  2458. i2 = tuple(p2)
  2459. if filled[i1] and not filled[i2]:
  2460. voxel_faces[i1].append(p2 + square_rot_pos)
  2461. elif not filled[i1] and filled[i2]:
  2462. voxel_faces[i2].append(p2 + square_rot_neg)
  2463. # draw upper faces
  2464. pk = permute.dot([p, q, rc-1])
  2465. pk2 = permute.dot([p, q, rc])
  2466. ik = tuple(pk)
  2467. if filled[ik]:
  2468. voxel_faces[ik].append(pk2 + square_rot_pos)
  2469. # iterate over the faces, and generate a Poly3DCollection for each
  2470. # voxel
  2471. polygons = {}
  2472. for coord, faces_inds in voxel_faces.items():
  2473. # convert indices into 3D positions
  2474. if xyz is None:
  2475. faces = faces_inds
  2476. else:
  2477. faces = []
  2478. for face_inds in faces_inds:
  2479. ind = face_inds[:, 0], face_inds[:, 1], face_inds[:, 2]
  2480. face = np.empty(face_inds.shape)
  2481. face[:, 0] = x[ind]
  2482. face[:, 1] = y[ind]
  2483. face[:, 2] = z[ind]
  2484. faces.append(face)
  2485. # shade the faces
  2486. facecolor = facecolors[coord]
  2487. edgecolor = edgecolors[coord]
  2488. poly = art3d.Poly3DCollection(
  2489. faces, facecolors=facecolor, edgecolors=edgecolor,
  2490. shade=shade, lightsource=lightsource, **kwargs)
  2491. self.add_collection3d(poly)
  2492. polygons[coord] = poly
  2493. return polygons
  2494. @_preprocess_data(replace_names=["x", "y", "z", "xerr", "yerr", "zerr"])
  2495. def errorbar(self, x, y, z, zerr=None, yerr=None, xerr=None, fmt='',
  2496. barsabove=False, errorevery=1, ecolor=None, elinewidth=None,
  2497. capsize=None, capthick=None, xlolims=False, xuplims=False,
  2498. ylolims=False, yuplims=False, zlolims=False, zuplims=False,
  2499. **kwargs):
  2500. """
  2501. Plot lines and/or markers with errorbars around them.
  2502. *x*/*y*/*z* define the data locations, and *xerr*/*yerr*/*zerr* define
  2503. the errorbar sizes. By default, this draws the data markers/lines as
  2504. well the errorbars. Use fmt='none' to draw errorbars only.
  2505. Parameters
  2506. ----------
  2507. x, y, z : float or array-like
  2508. The data positions.
  2509. xerr, yerr, zerr : float or array-like, shape (N,) or (2, N), optional
  2510. The errorbar sizes:
  2511. - scalar: Symmetric +/- values for all data points.
  2512. - shape(N,): Symmetric +/-values for each data point.
  2513. - shape(2, N): Separate - and + values for each bar. First row
  2514. contains the lower errors, the second row contains the upper
  2515. errors.
  2516. - *None*: No errorbar.
  2517. Note that all error arrays should have *positive* values.
  2518. fmt : str, default: ''
  2519. The format for the data points / data lines. See `.plot` for
  2520. details.
  2521. Use 'none' (case-insensitive) to plot errorbars without any data
  2522. markers.
  2523. ecolor : color, default: None
  2524. The color of the errorbar lines. If None, use the color of the
  2525. line connecting the markers.
  2526. elinewidth : float, default: None
  2527. The linewidth of the errorbar lines. If None, the linewidth of
  2528. the current style is used.
  2529. capsize : float, default: :rc:`errorbar.capsize`
  2530. The length of the error bar caps in points.
  2531. capthick : float, default: None
  2532. An alias to the keyword argument *markeredgewidth* (a.k.a. *mew*).
  2533. This setting is a more sensible name for the property that
  2534. controls the thickness of the error bar cap in points. For
  2535. backwards compatibility, if *mew* or *markeredgewidth* are given,
  2536. then they will over-ride *capthick*. This may change in future
  2537. releases.
  2538. barsabove : bool, default: False
  2539. If True, will plot the errorbars above the plot
  2540. symbols. Default is below.
  2541. xlolims, ylolims, zlolims : bool, default: False
  2542. These arguments can be used to indicate that a value gives only
  2543. lower limits. In that case a caret symbol is used to indicate
  2544. this. *lims*-arguments may be scalars, or array-likes of the same
  2545. length as the errors. To use limits with inverted axes,
  2546. `~.Axes.set_xlim` or `~.Axes.set_ylim` must be called before
  2547. `errorbar`. Note the tricky parameter names: setting e.g.
  2548. *ylolims* to True means that the y-value is a *lower* limit of the
  2549. True value, so, only an *upward*-pointing arrow will be drawn!
  2550. xuplims, yuplims, zuplims : bool, default: False
  2551. Same as above, but for controlling the upper limits.
  2552. errorevery : int or (int, int), default: 1
  2553. draws error bars on a subset of the data. *errorevery* =N draws
  2554. error bars on the points (x[::N], y[::N], z[::N]).
  2555. *errorevery* =(start, N) draws error bars on the points
  2556. (x[start::N], y[start::N], z[start::N]). e.g. *errorevery* =(6, 3)
  2557. adds error bars to the data at (x[6], x[9], x[12], x[15], ...).
  2558. Used to avoid overlapping error bars when two series share x-axis
  2559. values.
  2560. Returns
  2561. -------
  2562. errlines : list
  2563. List of `~mpl_toolkits.mplot3d.art3d.Line3DCollection` instances
  2564. each containing an errorbar line.
  2565. caplines : list
  2566. List of `~mpl_toolkits.mplot3d.art3d.Line3D` instances each
  2567. containing a capline object.
  2568. limmarks : list
  2569. List of `~mpl_toolkits.mplot3d.art3d.Line3D` instances each
  2570. containing a marker with an upper or lower limit.
  2571. Other Parameters
  2572. ----------------
  2573. data : indexable object, optional
  2574. DATA_PARAMETER_PLACEHOLDER
  2575. **kwargs
  2576. All other keyword arguments for styling errorbar lines are passed
  2577. `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
  2578. Examples
  2579. --------
  2580. .. plot:: gallery/mplot3d/errorbar3d.py
  2581. """
  2582. had_data = self.has_data()
  2583. kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D)
  2584. # Drop anything that comes in as None to use the default instead.
  2585. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  2586. kwargs.setdefault('zorder', 2)
  2587. self._process_unit_info([("x", x), ("y", y), ("z", z)], kwargs,
  2588. convert=False)
  2589. # make sure all the args are iterable; use lists not arrays to
  2590. # preserve units
  2591. x = x if np.iterable(x) else [x]
  2592. y = y if np.iterable(y) else [y]
  2593. z = z if np.iterable(z) else [z]
  2594. if not len(x) == len(y) == len(z):
  2595. raise ValueError("'x', 'y', and 'z' must have the same size")
  2596. everymask = self._errorevery_to_mask(x, errorevery)
  2597. label = kwargs.pop("label", None)
  2598. kwargs['label'] = '_nolegend_'
  2599. # Create the main line and determine overall kwargs for child artists.
  2600. # We avoid calling self.plot() directly, or self._get_lines(), because
  2601. # that would call self._process_unit_info again, and do other indirect
  2602. # data processing.
  2603. (data_line, base_style), = self._get_lines._plot_args(
  2604. self, (x, y) if fmt == '' else (x, y, fmt), kwargs, return_kwargs=True)
  2605. art3d.line_2d_to_3d(data_line, zs=z)
  2606. # Do this after creating `data_line` to avoid modifying `base_style`.
  2607. if barsabove:
  2608. data_line.set_zorder(kwargs['zorder'] - .1)
  2609. else:
  2610. data_line.set_zorder(kwargs['zorder'] + .1)
  2611. # Add line to plot, or throw it away and use it to determine kwargs.
  2612. if fmt.lower() != 'none':
  2613. self.add_line(data_line)
  2614. else:
  2615. data_line = None
  2616. # Remove alpha=0 color that _process_plot_format returns.
  2617. base_style.pop('color')
  2618. if 'color' not in base_style:
  2619. base_style['color'] = 'C0'
  2620. if ecolor is None:
  2621. ecolor = base_style['color']
  2622. # Eject any line-specific information from format string, as it's not
  2623. # needed for bars or caps.
  2624. for key in ['marker', 'markersize', 'markerfacecolor',
  2625. 'markeredgewidth', 'markeredgecolor', 'markevery',
  2626. 'linestyle', 'fillstyle', 'drawstyle', 'dash_capstyle',
  2627. 'dash_joinstyle', 'solid_capstyle', 'solid_joinstyle']:
  2628. base_style.pop(key, None)
  2629. # Make the style dict for the line collections (the bars).
  2630. eb_lines_style = {**base_style, 'color': ecolor}
  2631. if elinewidth:
  2632. eb_lines_style['linewidth'] = elinewidth
  2633. elif 'linewidth' in kwargs:
  2634. eb_lines_style['linewidth'] = kwargs['linewidth']
  2635. for key in ('transform', 'alpha', 'zorder', 'rasterized'):
  2636. if key in kwargs:
  2637. eb_lines_style[key] = kwargs[key]
  2638. # Make the style dict for caps (the "hats").
  2639. eb_cap_style = {**base_style, 'linestyle': 'None'}
  2640. if capsize is None:
  2641. capsize = mpl.rcParams["errorbar.capsize"]
  2642. if capsize > 0:
  2643. eb_cap_style['markersize'] = 2. * capsize
  2644. if capthick is not None:
  2645. eb_cap_style['markeredgewidth'] = capthick
  2646. eb_cap_style['color'] = ecolor
  2647. def _apply_mask(arrays, mask):
  2648. # Return, for each array in *arrays*, the elements for which *mask*
  2649. # is True, without using fancy indexing.
  2650. return [[*itertools.compress(array, mask)] for array in arrays]
  2651. def _extract_errs(err, data, lomask, himask):
  2652. # For separate +/- error values we need to unpack err
  2653. if len(err.shape) == 2:
  2654. low_err, high_err = err
  2655. else:
  2656. low_err, high_err = err, err
  2657. lows = np.where(lomask | ~everymask, data, data - low_err)
  2658. highs = np.where(himask | ~everymask, data, data + high_err)
  2659. return lows, highs
  2660. # collect drawn items while looping over the three coordinates
  2661. errlines, caplines, limmarks = [], [], []
  2662. # list of endpoint coordinates, used for auto-scaling
  2663. coorderrs = []
  2664. # define the markers used for errorbar caps and limits below
  2665. # the dictionary key is mapped by the `i_xyz` helper dictionary
  2666. capmarker = {0: '|', 1: '|', 2: '_'}
  2667. i_xyz = {'x': 0, 'y': 1, 'z': 2}
  2668. # Calculate marker size from points to quiver length. Because these are
  2669. # not markers, and 3D Axes do not use the normal transform stack, this
  2670. # is a bit involved. Since the quiver arrows will change size as the
  2671. # scene is rotated, they are given a standard size based on viewing
  2672. # them directly in planar form.
  2673. quiversize = eb_cap_style.get('markersize',
  2674. mpl.rcParams['lines.markersize']) ** 2
  2675. quiversize *= self.figure.dpi / 72
  2676. quiversize = self.transAxes.inverted().transform([
  2677. (0, 0), (quiversize, quiversize)])
  2678. quiversize = np.mean(np.diff(quiversize, axis=0))
  2679. # quiversize is now in Axes coordinates, and to convert back to data
  2680. # coordinates, we need to run it through the inverse 3D transform. For
  2681. # consistency, this uses a fixed elevation, azimuth, and roll.
  2682. with cbook._setattr_cm(self, elev=0, azim=0, roll=0):
  2683. invM = np.linalg.inv(self.get_proj())
  2684. # elev=azim=roll=0 produces the Y-Z plane, so quiversize in 2D 'x' is
  2685. # 'y' in 3D, hence the 1 index.
  2686. quiversize = np.dot(invM, [quiversize, 0, 0, 0])[1]
  2687. # Quivers use a fixed 15-degree arrow head, so scale up the length so
  2688. # that the size corresponds to the base. In other words, this constant
  2689. # corresponds to the equation tan(15) = (base / 2) / (arrow length).
  2690. quiversize *= 1.8660254037844388
  2691. eb_quiver_style = {**eb_cap_style,
  2692. 'length': quiversize, 'arrow_length_ratio': 1}
  2693. eb_quiver_style.pop('markersize', None)
  2694. # loop over x-, y-, and z-direction and draw relevant elements
  2695. for zdir, data, err, lolims, uplims in zip(
  2696. ['x', 'y', 'z'], [x, y, z], [xerr, yerr, zerr],
  2697. [xlolims, ylolims, zlolims], [xuplims, yuplims, zuplims]):
  2698. dir_vector = art3d.get_dir_vector(zdir)
  2699. i_zdir = i_xyz[zdir]
  2700. if err is None:
  2701. continue
  2702. if not np.iterable(err):
  2703. err = [err] * len(data)
  2704. err = np.atleast_1d(err)
  2705. # arrays fine here, they are booleans and hence not units
  2706. lolims = np.broadcast_to(lolims, len(data)).astype(bool)
  2707. uplims = np.broadcast_to(uplims, len(data)).astype(bool)
  2708. # a nested list structure that expands to (xl,xh),(yl,yh),(zl,zh),
  2709. # where x/y/z and l/h correspond to dimensions and low/high
  2710. # positions of errorbars in a dimension we're looping over
  2711. coorderr = [
  2712. _extract_errs(err * dir_vector[i], coord, lolims, uplims)
  2713. for i, coord in enumerate([x, y, z])]
  2714. (xl, xh), (yl, yh), (zl, zh) = coorderr
  2715. # draws capmarkers - flat caps orthogonal to the error bars
  2716. nolims = ~(lolims | uplims)
  2717. if nolims.any() and capsize > 0:
  2718. lo_caps_xyz = _apply_mask([xl, yl, zl], nolims & everymask)
  2719. hi_caps_xyz = _apply_mask([xh, yh, zh], nolims & everymask)
  2720. # setting '_' for z-caps and '|' for x- and y-caps;
  2721. # these markers will rotate as the viewing angle changes
  2722. cap_lo = art3d.Line3D(*lo_caps_xyz, ls='',
  2723. marker=capmarker[i_zdir],
  2724. **eb_cap_style)
  2725. cap_hi = art3d.Line3D(*hi_caps_xyz, ls='',
  2726. marker=capmarker[i_zdir],
  2727. **eb_cap_style)
  2728. self.add_line(cap_lo)
  2729. self.add_line(cap_hi)
  2730. caplines.append(cap_lo)
  2731. caplines.append(cap_hi)
  2732. if lolims.any():
  2733. xh0, yh0, zh0 = _apply_mask([xh, yh, zh], lolims & everymask)
  2734. self.quiver(xh0, yh0, zh0, *dir_vector, **eb_quiver_style)
  2735. if uplims.any():
  2736. xl0, yl0, zl0 = _apply_mask([xl, yl, zl], uplims & everymask)
  2737. self.quiver(xl0, yl0, zl0, *-dir_vector, **eb_quiver_style)
  2738. errline = art3d.Line3DCollection(np.array(coorderr).T,
  2739. **eb_lines_style)
  2740. self.add_collection(errline)
  2741. errlines.append(errline)
  2742. coorderrs.append(coorderr)
  2743. coorderrs = np.array(coorderrs)
  2744. def _digout_minmax(err_arr, coord_label):
  2745. return (np.nanmin(err_arr[:, i_xyz[coord_label], :, :]),
  2746. np.nanmax(err_arr[:, i_xyz[coord_label], :, :]))
  2747. minx, maxx = _digout_minmax(coorderrs, 'x')
  2748. miny, maxy = _digout_minmax(coorderrs, 'y')
  2749. minz, maxz = _digout_minmax(coorderrs, 'z')
  2750. self.auto_scale_xyz((minx, maxx), (miny, maxy), (minz, maxz), had_data)
  2751. # Adapting errorbar containers for 3d case, assuming z-axis points "up"
  2752. errorbar_container = mcontainer.ErrorbarContainer(
  2753. (data_line, tuple(caplines), tuple(errlines)),
  2754. has_xerr=(xerr is not None or yerr is not None),
  2755. has_yerr=(zerr is not None),
  2756. label=label)
  2757. self.containers.append(errorbar_container)
  2758. return errlines, caplines, limmarks
  2759. @_api.make_keyword_only("3.8", "call_axes_locator")
  2760. def get_tightbbox(self, renderer=None, call_axes_locator=True,
  2761. bbox_extra_artists=None, *, for_layout_only=False):
  2762. ret = super().get_tightbbox(renderer,
  2763. call_axes_locator=call_axes_locator,
  2764. bbox_extra_artists=bbox_extra_artists,
  2765. for_layout_only=for_layout_only)
  2766. batch = [ret]
  2767. if self._axis3don:
  2768. for axis in self._axis_map.values():
  2769. if axis.get_visible():
  2770. axis_bb = martist._get_tightbbox_for_layout_only(
  2771. axis, renderer)
  2772. if axis_bb:
  2773. batch.append(axis_bb)
  2774. return mtransforms.Bbox.union(batch)
  2775. @_preprocess_data()
  2776. def stem(self, x, y, z, *, linefmt='C0-', markerfmt='C0o', basefmt='C3-',
  2777. bottom=0, label=None, orientation='z'):
  2778. """
  2779. Create a 3D stem plot.
  2780. A stem plot draws lines perpendicular to a baseline, and places markers
  2781. at the heads. By default, the baseline is defined by *x* and *y*, and
  2782. stems are drawn vertically from *bottom* to *z*.
  2783. Parameters
  2784. ----------
  2785. x, y, z : array-like
  2786. The positions of the heads of the stems. The stems are drawn along
  2787. the *orientation*-direction from the baseline at *bottom* (in the
  2788. *orientation*-coordinate) to the heads. By default, the *x* and *y*
  2789. positions are used for the baseline and *z* for the head position,
  2790. but this can be changed by *orientation*.
  2791. linefmt : str, default: 'C0-'
  2792. A string defining the properties of the vertical lines. Usually,
  2793. this will be a color or a color and a linestyle:
  2794. ========= =============
  2795. Character Line Style
  2796. ========= =============
  2797. ``'-'`` solid line
  2798. ``'--'`` dashed line
  2799. ``'-.'`` dash-dot line
  2800. ``':'`` dotted line
  2801. ========= =============
  2802. Note: While it is technically possible to specify valid formats
  2803. other than color or color and linestyle (e.g. 'rx' or '-.'), this
  2804. is beyond the intention of the method and will most likely not
  2805. result in a reasonable plot.
  2806. markerfmt : str, default: 'C0o'
  2807. A string defining the properties of the markers at the stem heads.
  2808. basefmt : str, default: 'C3-'
  2809. A format string defining the properties of the baseline.
  2810. bottom : float, default: 0
  2811. The position of the baseline, in *orientation*-coordinates.
  2812. label : str, default: None
  2813. The label to use for the stems in legends.
  2814. orientation : {'x', 'y', 'z'}, default: 'z'
  2815. The direction along which stems are drawn.
  2816. data : indexable object, optional
  2817. DATA_PARAMETER_PLACEHOLDER
  2818. Returns
  2819. -------
  2820. `.StemContainer`
  2821. The container may be treated like a tuple
  2822. (*markerline*, *stemlines*, *baseline*)
  2823. Examples
  2824. --------
  2825. .. plot:: gallery/mplot3d/stem3d_demo.py
  2826. """
  2827. from matplotlib.container import StemContainer
  2828. had_data = self.has_data()
  2829. _api.check_in_list(['x', 'y', 'z'], orientation=orientation)
  2830. xlim = (np.min(x), np.max(x))
  2831. ylim = (np.min(y), np.max(y))
  2832. zlim = (np.min(z), np.max(z))
  2833. # Determine the appropriate plane for the baseline and the direction of
  2834. # stemlines based on the value of orientation.
  2835. if orientation == 'x':
  2836. basex, basexlim = y, ylim
  2837. basey, baseylim = z, zlim
  2838. lines = [[(bottom, thisy, thisz), (thisx, thisy, thisz)]
  2839. for thisx, thisy, thisz in zip(x, y, z)]
  2840. elif orientation == 'y':
  2841. basex, basexlim = x, xlim
  2842. basey, baseylim = z, zlim
  2843. lines = [[(thisx, bottom, thisz), (thisx, thisy, thisz)]
  2844. for thisx, thisy, thisz in zip(x, y, z)]
  2845. else:
  2846. basex, basexlim = x, xlim
  2847. basey, baseylim = y, ylim
  2848. lines = [[(thisx, thisy, bottom), (thisx, thisy, thisz)]
  2849. for thisx, thisy, thisz in zip(x, y, z)]
  2850. # Determine style for stem lines.
  2851. linestyle, linemarker, linecolor = _process_plot_format(linefmt)
  2852. if linestyle is None:
  2853. linestyle = mpl.rcParams['lines.linestyle']
  2854. # Plot everything in required order.
  2855. baseline, = self.plot(basex, basey, basefmt, zs=bottom,
  2856. zdir=orientation, label='_nolegend_')
  2857. stemlines = art3d.Line3DCollection(
  2858. lines, linestyles=linestyle, colors=linecolor, label='_nolegend_')
  2859. self.add_collection(stemlines)
  2860. markerline, = self.plot(x, y, z, markerfmt, label='_nolegend_')
  2861. stem_container = StemContainer((markerline, stemlines, baseline),
  2862. label=label)
  2863. self.add_container(stem_container)
  2864. jx, jy, jz = art3d.juggle_axes(basexlim, baseylim, [bottom, bottom],
  2865. orientation)
  2866. self.auto_scale_xyz([*jx, *xlim], [*jy, *ylim], [*jz, *zlim], had_data)
  2867. return stem_container
  2868. stem3D = stem
  2869. def get_test_data(delta=0.05):
  2870. """Return a tuple X, Y, Z with a test data set."""
  2871. x = y = np.arange(-3.0, 3.0, delta)
  2872. X, Y = np.meshgrid(x, y)
  2873. Z1 = np.exp(-(X**2 + Y**2) / 2) / (2 * np.pi)
  2874. Z2 = (np.exp(-(((X - 1) / 1.5)**2 + ((Y - 1) / 0.5)**2) / 2) /
  2875. (2 * np.pi * 0.5 * 1.5))
  2876. Z = Z2 - Z1
  2877. X = X * 10
  2878. Y = Y * 10
  2879. Z = Z * 500
  2880. return X, Y, Z