_matrix.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from ._sputils import isintlike, isscalarlike
  2. class spmatrix:
  3. """This class provides a base class for all sparse matrix classes.
  4. It cannot be instantiated. Most of the work is provided by subclasses.
  5. """
  6. _is_array = False
  7. @property
  8. def _bsr_container(self):
  9. from ._bsr import bsr_matrix
  10. return bsr_matrix
  11. @property
  12. def _coo_container(self):
  13. from ._coo import coo_matrix
  14. return coo_matrix
  15. @property
  16. def _csc_container(self):
  17. from ._csc import csc_matrix
  18. return csc_matrix
  19. @property
  20. def _csr_container(self):
  21. from ._csr import csr_matrix
  22. return csr_matrix
  23. @property
  24. def _dia_container(self):
  25. from ._dia import dia_matrix
  26. return dia_matrix
  27. @property
  28. def _dok_container(self):
  29. from ._dok import dok_matrix
  30. return dok_matrix
  31. @property
  32. def _lil_container(self):
  33. from ._lil import lil_matrix
  34. return lil_matrix
  35. # Restore matrix multiplication
  36. def __mul__(self, other):
  37. return self._mul_dispatch(other)
  38. def __rmul__(self, other):
  39. return self._rmul_dispatch(other)
  40. # Restore matrix power
  41. def __pow__(self, other):
  42. M, N = self.shape
  43. if M != N:
  44. raise TypeError('sparse matrix is not square')
  45. if isintlike(other):
  46. other = int(other)
  47. if other < 0:
  48. raise ValueError('exponent must be >= 0')
  49. if other == 0:
  50. from ._construct import eye
  51. return eye(M, dtype=self.dtype)
  52. if other == 1:
  53. return self.copy()
  54. tmp = self.__pow__(other // 2)
  55. if other % 2:
  56. return self @ tmp @ tmp
  57. else:
  58. return tmp @ tmp
  59. if isscalarlike(other):
  60. raise ValueError('exponent must be an integer')
  61. return NotImplemented
  62. ## Backward compatibility
  63. def set_shape(self, shape):
  64. """Set the shape of the matrix in-place"""
  65. # Make sure copy is False since this is in place
  66. # Make sure format is unchanged because we are doing a __dict__ swap
  67. new_self = self.reshape(shape, copy=False).asformat(self.format)
  68. self.__dict__ = new_self.__dict__
  69. def get_shape(self):
  70. """Get the shape of the matrix"""
  71. return self._shape
  72. shape = property(fget=get_shape, fset=set_shape,
  73. doc="Shape of the matrix")
  74. def asfptype(self):
  75. """Upcast array to a floating point format (if necessary)"""
  76. return self._asfptype()
  77. def getmaxprint(self):
  78. """Maximum number of elements to display when printed."""
  79. return self._getmaxprint()
  80. def getformat(self):
  81. """Matrix storage format"""
  82. return self.format
  83. def getnnz(self, axis=None):
  84. """Number of stored values, including explicit zeros.
  85. Parameters
  86. ----------
  87. axis : None, 0, or 1
  88. Select between the number of values across the whole array, in
  89. each column, or in each row.
  90. """
  91. return self._getnnz(axis=axis)
  92. def getH(self):
  93. """Return the Hermitian transpose of this array.
  94. See Also
  95. --------
  96. numpy.matrix.getH : NumPy's implementation of `getH` for matrices
  97. """
  98. return self.conjugate().transpose()
  99. def getcol(self, j):
  100. """Returns a copy of column j of the array, as an (m x 1) sparse
  101. array (column vector).
  102. """
  103. return self._getcol(j)
  104. def getrow(self, i):
  105. """Returns a copy of row i of the array, as a (1 x n) sparse
  106. array (row vector).
  107. """
  108. return self._getrow(i)
  109. def _array_doc_to_matrix(docstr):
  110. # For opimized builds with stripped docstrings
  111. if docstr is None:
  112. return None
  113. return (
  114. docstr.replace('sparse arrays', 'sparse matrices')
  115. .replace('sparse array', 'sparse matrix')
  116. )