stackplot.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """
  2. Stacked area plot for 1D arrays inspired by Douglas Y'barbo's stackoverflow
  3. answer:
  4. http://stackoverflow.com/questions/2225995/how-can-i-create-stacked-line-graph-with-matplotlib
  5. (http://stackoverflow.com/users/66549/doug)
  6. """
  7. from __future__ import (absolute_import, division, print_function,
  8. unicode_literals)
  9. import six
  10. from six.moves import xrange
  11. import numpy as np
  12. __all__ = ['stackplot']
  13. def stackplot(axes, x, *args, **kwargs):
  14. """
  15. Draws a stacked area plot.
  16. Parameters
  17. ----------
  18. x : 1d array of dimension N
  19. y : 2d array (dimension MxN), or sequence of 1d arrays (each dimension 1xN)
  20. The data is assumed to be unstacked. Each of the following
  21. calls is legal::
  22. stackplot(x, y) # where y is MxN
  23. stackplot(x, y1, y2, y3, y4) # where y1, y2, y3, y4, are all 1xNm
  24. baseline : ['zero' | 'sym' | 'wiggle' | 'weighted_wiggle']
  25. Method used to calculate the baseline:
  26. - ``'zero'``: Constant zero baseline, i.e. a simple stacked plot.
  27. - ``'sym'``: Symmetric around zero and is sometimes called
  28. 'ThemeRiver'.
  29. - ``'wiggle'``: Minimizes the sum of the squared slopes.
  30. - ``'weighted_wiggle'``: Does the same but weights to account for
  31. size of each layer. It is also called 'Streamgraph'-layout. More
  32. details can be found at http://leebyron.com/streamgraph/.
  33. labels : Length N sequence of strings
  34. Labels to assign to each data series.
  35. colors : Length N sequence of colors
  36. A list or tuple of colors. These will be cycled through and used to
  37. colour the stacked areas.
  38. **kwargs :
  39. All other keyword arguments are passed to `Axes.fill_between()`.
  40. Returns
  41. -------
  42. list : list of `.PolyCollection`
  43. A list of `.PolyCollection` instances, one for each element in the
  44. stacked area plot.
  45. """
  46. y = np.row_stack(args)
  47. labels = iter(kwargs.pop('labels', []))
  48. colors = kwargs.pop('colors', None)
  49. if colors is not None:
  50. axes.set_prop_cycle(color=colors)
  51. baseline = kwargs.pop('baseline', 'zero')
  52. # Assume data passed has not been 'stacked', so stack it here.
  53. # We'll need a float buffer for the upcoming calculations.
  54. stack = np.cumsum(y, axis=0, dtype=np.promote_types(y.dtype, np.float32))
  55. if baseline == 'zero':
  56. first_line = 0.
  57. elif baseline == 'sym':
  58. first_line = -np.sum(y, 0) * 0.5
  59. stack += first_line[None, :]
  60. elif baseline == 'wiggle':
  61. m = y.shape[0]
  62. first_line = (y * (m - 0.5 - np.arange(m)[:, None])).sum(0)
  63. first_line /= -m
  64. stack += first_line
  65. elif baseline == 'weighted_wiggle':
  66. m, n = y.shape
  67. total = np.sum(y, 0)
  68. # multiply by 1/total (or zero) to avoid infinities in the division:
  69. inv_total = np.zeros_like(total)
  70. mask = total > 0
  71. inv_total[mask] = 1.0 / total[mask]
  72. increase = np.hstack((y[:, 0:1], np.diff(y)))
  73. below_size = total - stack
  74. below_size += 0.5 * y
  75. move_up = below_size * inv_total
  76. move_up[:, 0] = 0.5
  77. center = (move_up - 0.5) * increase
  78. center = np.cumsum(center.sum(0))
  79. first_line = center - 0.5 * total
  80. stack += first_line
  81. else:
  82. errstr = "Baseline method %s not recognised. " % baseline
  83. errstr += "Expected 'zero', 'sym', 'wiggle' or 'weighted_wiggle'"
  84. raise ValueError(errstr)
  85. # Color between x = 0 and the first array.
  86. color = axes._get_lines.get_next_color()
  87. coll = axes.fill_between(x, first_line, stack[0, :],
  88. facecolor=color, label=next(labels, None),
  89. **kwargs)
  90. coll.sticky_edges.y[:] = [0]
  91. r = [coll]
  92. # Color between array i-1 and array i
  93. for i in xrange(len(y) - 1):
  94. color = axes._get_lines.get_next_color()
  95. r.append(axes.fill_between(x, stack[i, :], stack[i + 1, :],
  96. facecolor=color, label=next(labels, None),
  97. **kwargs))
  98. return r