axis3d.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # axis3d.py, original mplot3d version by John Porter
  2. # Created: 23 Sep 2005
  3. # Parts rewritten by Reinier Heeres <reinier@heeres.eu>
  4. from __future__ import (absolute_import, division, print_function,
  5. unicode_literals)
  6. import six
  7. import math
  8. import copy
  9. from matplotlib import lines as mlines, axis as maxis, patches as mpatches
  10. from matplotlib import rcParams
  11. from . import art3d
  12. from . import proj3d
  13. import numpy as np
  14. def get_flip_min_max(coord, index, mins, maxs):
  15. if coord[index] == mins[index]:
  16. return maxs[index]
  17. else:
  18. return mins[index]
  19. def move_from_center(coord, centers, deltas, axmask=(True, True, True)):
  20. '''Return a coordinate that is moved by "deltas" away from the center.'''
  21. coord = copy.copy(coord)
  22. for i in range(3):
  23. if not axmask[i]:
  24. continue
  25. if coord[i] < centers[i]:
  26. coord[i] -= deltas[i]
  27. else:
  28. coord[i] += deltas[i]
  29. return coord
  30. def tick_update_position(tick, tickxs, tickys, labelpos):
  31. '''Update tick line and label position and style.'''
  32. for (label, on) in [(tick.label1, tick.label1On),
  33. (tick.label2, tick.label2On)]:
  34. if on:
  35. label.set_position(labelpos)
  36. tick.tick1On, tick.tick2On = True, False
  37. tick.tick1line.set_linestyle('-')
  38. tick.tick1line.set_marker('')
  39. tick.tick1line.set_data(tickxs, tickys)
  40. tick.gridline.set_data(0, 0)
  41. class Axis(maxis.XAxis):
  42. # These points from the unit cube make up the x, y and z-planes
  43. _PLANES = (
  44. (0, 3, 7, 4), (1, 2, 6, 5), # yz planes
  45. (0, 1, 5, 4), (3, 2, 6, 7), # xz planes
  46. (0, 1, 2, 3), (4, 5, 6, 7), # xy planes
  47. )
  48. # Some properties for the axes
  49. _AXINFO = {
  50. 'x': {'i': 0, 'tickdir': 1, 'juggled': (1, 0, 2),
  51. 'color': (0.95, 0.95, 0.95, 0.5)},
  52. 'y': {'i': 1, 'tickdir': 0, 'juggled': (0, 1, 2),
  53. 'color': (0.90, 0.90, 0.90, 0.5)},
  54. 'z': {'i': 2, 'tickdir': 0, 'juggled': (0, 2, 1),
  55. 'color': (0.925, 0.925, 0.925, 0.5)},
  56. }
  57. def __init__(self, adir, v_intervalx, d_intervalx, axes, *args, **kwargs):
  58. # adir identifies which axes this is
  59. self.adir = adir
  60. # data and viewing intervals for this direction
  61. self.d_interval = d_intervalx
  62. self.v_interval = v_intervalx
  63. # This is a temporary member variable.
  64. # Do not depend on this existing in future releases!
  65. self._axinfo = self._AXINFO[adir].copy()
  66. if rcParams['_internal.classic_mode']:
  67. self._axinfo.update(
  68. {'label': {'va': 'center',
  69. 'ha': 'center'},
  70. 'tick': {'inward_factor': 0.2,
  71. 'outward_factor': 0.1,
  72. 'linewidth': rcParams['lines.linewidth'],
  73. 'color': 'k'},
  74. 'axisline': {'linewidth': 0.75,
  75. 'color': (0, 0, 0, 1)},
  76. 'grid': {'color': (0.9, 0.9, 0.9, 1),
  77. 'linewidth': 1.0,
  78. 'linestyle': '-'},
  79. })
  80. else:
  81. self._axinfo.update(
  82. {'label': {'va': 'center',
  83. 'ha': 'center'},
  84. 'tick': {'inward_factor': 0.2,
  85. 'outward_factor': 0.1,
  86. 'linewidth': rcParams.get(
  87. adir + 'tick.major.width',
  88. rcParams['xtick.major.width']),
  89. 'color': rcParams.get(
  90. adir + 'tick.color',
  91. rcParams['xtick.color'])},
  92. 'axisline': {'linewidth': rcParams['axes.linewidth'],
  93. 'color': rcParams['axes.edgecolor']},
  94. 'grid': {'color': rcParams['grid.color'],
  95. 'linewidth': rcParams['grid.linewidth'],
  96. 'linestyle': rcParams['grid.linestyle']},
  97. })
  98. maxis.XAxis.__init__(self, axes, *args, **kwargs)
  99. self.set_rotate_label(kwargs.get('rotate_label', None))
  100. def init3d(self):
  101. self.line = mlines.Line2D(
  102. xdata=(0, 0), ydata=(0, 0),
  103. linewidth=self._axinfo['axisline']['linewidth'],
  104. color=self._axinfo['axisline']['color'],
  105. antialiased=True)
  106. # Store dummy data in Polygon object
  107. self.pane = mpatches.Polygon(
  108. np.array([[0, 0], [0, 1], [1, 0], [0, 0]]),
  109. closed=False, alpha=0.8, facecolor='k', edgecolor='k')
  110. self.set_pane_color(self._axinfo['color'])
  111. self.axes._set_artist_props(self.line)
  112. self.axes._set_artist_props(self.pane)
  113. self.gridlines = art3d.Line3DCollection([])
  114. self.axes._set_artist_props(self.gridlines)
  115. self.axes._set_artist_props(self.label)
  116. self.axes._set_artist_props(self.offsetText)
  117. # Need to be able to place the label at the correct location
  118. self.label._transform = self.axes.transData
  119. self.offsetText._transform = self.axes.transData
  120. def get_tick_positions(self):
  121. majorLocs = self.major.locator()
  122. self.major.formatter.set_locs(majorLocs)
  123. majorLabels = [self.major.formatter(val, i)
  124. for i, val in enumerate(majorLocs)]
  125. return majorLabels, majorLocs
  126. def get_major_ticks(self, numticks=None):
  127. ticks = maxis.XAxis.get_major_ticks(self, numticks)
  128. for t in ticks:
  129. t.tick1line.set_transform(self.axes.transData)
  130. t.tick2line.set_transform(self.axes.transData)
  131. t.gridline.set_transform(self.axes.transData)
  132. t.label1.set_transform(self.axes.transData)
  133. t.label2.set_transform(self.axes.transData)
  134. return ticks
  135. def set_pane_pos(self, xys):
  136. xys = np.asarray(xys)
  137. xys = xys[:,:2]
  138. self.pane.xy = xys
  139. self.stale = True
  140. def set_pane_color(self, color):
  141. '''Set pane color to a RGBA tuple.'''
  142. self._axinfo['color'] = color
  143. self.pane.set_edgecolor(color)
  144. self.pane.set_facecolor(color)
  145. self.pane.set_alpha(color[-1])
  146. self.stale = True
  147. def set_rotate_label(self, val):
  148. '''
  149. Whether to rotate the axis label: True, False or None.
  150. If set to None the label will be rotated if longer than 4 chars.
  151. '''
  152. self._rotate_label = val
  153. self.stale = True
  154. def get_rotate_label(self, text):
  155. if self._rotate_label is not None:
  156. return self._rotate_label
  157. else:
  158. return len(text) > 4
  159. def _get_coord_info(self, renderer):
  160. minx, maxx, miny, maxy, minz, maxz = self.axes.get_w_lims()
  161. if minx > maxx:
  162. minx, maxx = maxx, minx
  163. if miny > maxy:
  164. miny, maxy = maxy, miny
  165. if minz > maxz:
  166. minz, maxz = maxz, minz
  167. mins = np.array((minx, miny, minz))
  168. maxs = np.array((maxx, maxy, maxz))
  169. centers = (maxs + mins) / 2.
  170. deltas = (maxs - mins) / 12.
  171. mins = mins - deltas / 4.
  172. maxs = maxs + deltas / 4.
  173. vals = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]
  174. tc = self.axes.tunit_cube(vals, renderer.M)
  175. avgz = [tc[p1][2] + tc[p2][2] + tc[p3][2] + tc[p4][2]
  176. for p1, p2, p3, p4 in self._PLANES]
  177. highs = np.array([avgz[2*i] < avgz[2*i+1] for i in range(3)])
  178. return mins, maxs, centers, deltas, tc, highs
  179. def draw_pane(self, renderer):
  180. renderer.open_group('pane3d')
  181. mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
  182. info = self._axinfo
  183. index = info['i']
  184. if not highs[index]:
  185. plane = self._PLANES[2 * index]
  186. else:
  187. plane = self._PLANES[2 * index + 1]
  188. xys = [tc[p] for p in plane]
  189. self.set_pane_pos(xys)
  190. self.pane.draw(renderer)
  191. renderer.close_group('pane3d')
  192. def draw(self, renderer):
  193. self.label._transform = self.axes.transData
  194. renderer.open_group('axis3d')
  195. # code from XAxis
  196. majorTicks = self.get_major_ticks()
  197. majorLocs = self.major.locator()
  198. info = self._axinfo
  199. index = info['i']
  200. # filter locations here so that no extra grid lines are drawn
  201. locmin, locmax = self.get_view_interval()
  202. if locmin > locmax:
  203. locmin, locmax = locmax, locmin
  204. # Rudimentary clipping
  205. majorLocs = [loc for loc in majorLocs if
  206. locmin <= loc <= locmax]
  207. self.major.formatter.set_locs(majorLocs)
  208. majorLabels = [self.major.formatter(val, i)
  209. for i, val in enumerate(majorLocs)]
  210. mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
  211. # Determine grid lines
  212. minmax = np.where(highs, maxs, mins)
  213. # Draw main axis line
  214. juggled = info['juggled']
  215. edgep1 = minmax.copy()
  216. edgep1[juggled[0]] = get_flip_min_max(edgep1, juggled[0], mins, maxs)
  217. edgep2 = edgep1.copy()
  218. edgep2[juggled[1]] = get_flip_min_max(edgep2, juggled[1], mins, maxs)
  219. pep = proj3d.proj_trans_points([edgep1, edgep2], renderer.M)
  220. centpt = proj3d.proj_transform(
  221. centers[0], centers[1], centers[2], renderer.M)
  222. self.line.set_data((pep[0][0], pep[0][1]), (pep[1][0], pep[1][1]))
  223. self.line.draw(renderer)
  224. # Grid points where the planes meet
  225. xyz0 = []
  226. for val in majorLocs:
  227. coord = minmax.copy()
  228. coord[index] = val
  229. xyz0.append(coord)
  230. # Draw labels
  231. peparray = np.asanyarray(pep)
  232. # The transAxes transform is used because the Text object
  233. # rotates the text relative to the display coordinate system.
  234. # Therefore, if we want the labels to remain parallel to the
  235. # axis regardless of the aspect ratio, we need to convert the
  236. # edge points of the plane to display coordinates and calculate
  237. # an angle from that.
  238. # TODO: Maybe Text objects should handle this themselves?
  239. dx, dy = (self.axes.transAxes.transform([peparray[0:2, 1]]) -
  240. self.axes.transAxes.transform([peparray[0:2, 0]]))[0]
  241. lxyz = 0.5*(edgep1 + edgep2)
  242. # A rough estimate; points are ambiguous since 3D plots rotate
  243. ax_scale = self.axes.bbox.size / self.figure.bbox.size
  244. ax_inches = np.multiply(ax_scale, self.figure.get_size_inches())
  245. ax_points_estimate = sum(72. * ax_inches)
  246. deltas_per_point = 48. / ax_points_estimate
  247. default_offset = 21.
  248. labeldeltas = (
  249. (self.labelpad + default_offset) * deltas_per_point * deltas)
  250. axmask = [True, True, True]
  251. axmask[index] = False
  252. lxyz = move_from_center(lxyz, centers, labeldeltas, axmask)
  253. tlx, tly, tlz = proj3d.proj_transform(lxyz[0], lxyz[1], lxyz[2],
  254. renderer.M)
  255. self.label.set_position((tlx, tly))
  256. if self.get_rotate_label(self.label.get_text()):
  257. angle = art3d.norm_text_angle(math.degrees(math.atan2(dy, dx)))
  258. self.label.set_rotation(angle)
  259. self.label.set_va(info['label']['va'])
  260. self.label.set_ha(info['label']['ha'])
  261. self.label.draw(renderer)
  262. # Draw Offset text
  263. # Which of the two edge points do we want to
  264. # use for locating the offset text?
  265. if juggled[2] == 2 :
  266. outeredgep = edgep1
  267. outerindex = 0
  268. else :
  269. outeredgep = edgep2
  270. outerindex = 1
  271. pos = copy.copy(outeredgep)
  272. pos = move_from_center(pos, centers, labeldeltas, axmask)
  273. olx, oly, olz = proj3d.proj_transform(
  274. pos[0], pos[1], pos[2], renderer.M)
  275. self.offsetText.set_text( self.major.formatter.get_offset() )
  276. self.offsetText.set_position( (olx, oly) )
  277. angle = art3d.norm_text_angle(math.degrees(math.atan2(dy, dx)))
  278. self.offsetText.set_rotation(angle)
  279. # Must set rotation mode to "anchor" so that
  280. # the alignment point is used as the "fulcrum" for rotation.
  281. self.offsetText.set_rotation_mode('anchor')
  282. #----------------------------------------------------------------------
  283. # Note: the following statement for determining the proper alignment of
  284. # the offset text. This was determined entirely by trial-and-error
  285. # and should not be in any way considered as "the way". There are
  286. # still some edge cases where alignment is not quite right, but this
  287. # seems to be more of a geometry issue (in other words, I might be
  288. # using the wrong reference points).
  289. #
  290. # (TT, FF, TF, FT) are the shorthand for the tuple of
  291. # (centpt[info['tickdir']] <= peparray[info['tickdir'], outerindex],
  292. # centpt[index] <= peparray[index, outerindex])
  293. #
  294. # Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
  295. # from the variable 'highs'.
  296. # ---------------------------------------------------------------------
  297. if centpt[info['tickdir']] > peparray[info['tickdir'], outerindex] :
  298. # if FT and if highs has an even number of Trues
  299. if (centpt[index] <= peparray[index, outerindex]
  300. and ((len(highs.nonzero()[0]) % 2) == 0)) :
  301. # Usually, this means align right, except for the FTT case,
  302. # in which offset for axis 1 and 2 are aligned left.
  303. if highs.tolist() == [False, True, True] and index in (1, 2) :
  304. align = 'left'
  305. else :
  306. align = 'right'
  307. else :
  308. # The FF case
  309. align = 'left'
  310. else :
  311. # if TF and if highs has an even number of Trues
  312. if (centpt[index] > peparray[index, outerindex]
  313. and ((len(highs.nonzero()[0]) % 2) == 0)) :
  314. # Usually mean align left, except if it is axis 2
  315. if index == 2 :
  316. align = 'right'
  317. else :
  318. align = 'left'
  319. else :
  320. # The TT case
  321. align = 'right'
  322. self.offsetText.set_va('center')
  323. self.offsetText.set_ha(align)
  324. self.offsetText.draw(renderer)
  325. # Draw grid lines
  326. if len(xyz0) > 0:
  327. # Grid points at end of one plane
  328. xyz1 = copy.deepcopy(xyz0)
  329. newindex = (index + 1) % 3
  330. newval = get_flip_min_max(xyz1[0], newindex, mins, maxs)
  331. for i in range(len(majorLocs)):
  332. xyz1[i][newindex] = newval
  333. # Grid points at end of the other plane
  334. xyz2 = copy.deepcopy(xyz0)
  335. newindex = (index + 2) % 3
  336. newval = get_flip_min_max(xyz2[0], newindex, mins, maxs)
  337. for i in range(len(majorLocs)):
  338. xyz2[i][newindex] = newval
  339. lines = list(zip(xyz1, xyz0, xyz2))
  340. if self.axes._draw_grid:
  341. self.gridlines.set_segments(lines)
  342. self.gridlines.set_color([info['grid']['color']] * len(lines))
  343. self.gridlines.set_linewidth(
  344. [info['grid']['linewidth']] * len(lines))
  345. self.gridlines.set_linestyle(
  346. [info['grid']['linestyle']] * len(lines))
  347. self.gridlines.draw(renderer, project=True)
  348. # Draw ticks
  349. tickdir = info['tickdir']
  350. tickdelta = deltas[tickdir]
  351. if highs[tickdir]:
  352. ticksign = 1
  353. else:
  354. ticksign = -1
  355. for tick, loc, label in zip(majorTicks, majorLocs, majorLabels):
  356. if tick is None:
  357. continue
  358. # Get tick line positions
  359. pos = copy.copy(edgep1)
  360. pos[index] = loc
  361. pos[tickdir] = (
  362. edgep1[tickdir]
  363. + info['tick']['outward_factor'] * ticksign * tickdelta)
  364. x1, y1, z1 = proj3d.proj_transform(pos[0], pos[1], pos[2],
  365. renderer.M)
  366. pos[tickdir] = (
  367. edgep1[tickdir]
  368. - info['tick']['inward_factor'] * ticksign * tickdelta)
  369. x2, y2, z2 = proj3d.proj_transform(pos[0], pos[1], pos[2],
  370. renderer.M)
  371. # Get position of label
  372. default_offset = 8. # A rough estimate
  373. labeldeltas = (
  374. (tick.get_pad() + default_offset) * deltas_per_point * deltas)
  375. axmask = [True, True, True]
  376. axmask[index] = False
  377. pos[tickdir] = edgep1[tickdir]
  378. pos = move_from_center(pos, centers, labeldeltas, axmask)
  379. lx, ly, lz = proj3d.proj_transform(pos[0], pos[1], pos[2],
  380. renderer.M)
  381. tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
  382. tick.tick1line.set_linewidth(info['tick']['linewidth'])
  383. tick.tick1line.set_color(info['tick']['color'])
  384. tick.set_label1(label)
  385. tick.set_label2(label)
  386. tick.draw(renderer)
  387. renderer.close_group('axis3d')
  388. self.stale = False
  389. def get_view_interval(self):
  390. """return the Interval instance for this 3d axis view limits"""
  391. return self.v_interval
  392. def set_view_interval(self, vmin, vmax, ignore=False):
  393. if ignore:
  394. self.v_interval = vmin, vmax
  395. else:
  396. Vmin, Vmax = self.get_view_interval()
  397. self.v_interval = min(vmin, Vmin), max(vmax, Vmax)
  398. # TODO: Get this to work properly when mplot3d supports
  399. # the transforms framework.
  400. def get_tightbbox(self, renderer) :
  401. # Currently returns None so that Axis.get_tightbbox
  402. # doesn't return junk info.
  403. return None
  404. # Use classes to look at different data limits
  405. class XAxis(Axis):
  406. def get_data_interval(self):
  407. 'return the Interval instance for this axis data limits'
  408. return self.axes.xy_dataLim.intervalx
  409. class YAxis(Axis):
  410. def get_data_interval(self):
  411. 'return the Interval instance for this axis data limits'
  412. return self.axes.xy_dataLim.intervaly
  413. class ZAxis(Axis):
  414. def get_data_interval(self):
  415. 'return the Interval instance for this axis data limits'
  416. return self.axes.zz_dataLim.intervalx