_compressed.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317
  1. """Base class for sparse matrix formats using compressed storage."""
  2. __all__ = []
  3. from warnings import warn
  4. import operator
  5. import numpy as np
  6. from scipy._lib._util import _prune_array
  7. from ._base import _spbase, issparse, SparseEfficiencyWarning
  8. from ._data import _data_matrix, _minmax_mixin
  9. from . import _sparsetools
  10. from ._sparsetools import (get_csr_submatrix, csr_sample_offsets, csr_todense,
  11. csr_sample_values, csr_row_index, csr_row_slice,
  12. csr_column_index1, csr_column_index2)
  13. from ._index import IndexMixin
  14. from ._sputils import (upcast, upcast_char, to_native, isdense, isshape,
  15. getdtype, isscalarlike, isintlike, downcast_intp_index, get_sum_dtype, check_shape,
  16. is_pydata_spmatrix)
  17. class _cs_matrix(_data_matrix, _minmax_mixin, IndexMixin):
  18. """base matrix class for compressed row- and column-oriented matrices"""
  19. def __init__(self, arg1, shape=None, dtype=None, copy=False):
  20. _data_matrix.__init__(self)
  21. if issparse(arg1):
  22. if arg1.format == self.format and copy:
  23. arg1 = arg1.copy()
  24. else:
  25. arg1 = arg1.asformat(self.format)
  26. self._set_self(arg1)
  27. elif isinstance(arg1, tuple):
  28. if isshape(arg1):
  29. # It's a tuple of matrix dimensions (M, N)
  30. # create empty matrix
  31. self._shape = check_shape(arg1)
  32. M, N = self.shape
  33. # Select index dtype large enough to pass array and
  34. # scalar parameters to sparsetools
  35. idx_dtype = self._get_index_dtype(maxval=max(M, N))
  36. self.data = np.zeros(0, getdtype(dtype, default=float))
  37. self.indices = np.zeros(0, idx_dtype)
  38. self.indptr = np.zeros(self._swap((M, N))[0] + 1,
  39. dtype=idx_dtype)
  40. else:
  41. if len(arg1) == 2:
  42. # (data, ij) format
  43. other = self.__class__(
  44. self._coo_container(arg1, shape=shape, dtype=dtype)
  45. )
  46. self._set_self(other)
  47. elif len(arg1) == 3:
  48. # (data, indices, indptr) format
  49. (data, indices, indptr) = arg1
  50. # Select index dtype large enough to pass array and
  51. # scalar parameters to sparsetools
  52. maxval = None
  53. if shape is not None:
  54. maxval = max(shape)
  55. idx_dtype = self._get_index_dtype((indices, indptr),
  56. maxval=maxval,
  57. check_contents=True)
  58. self.indices = np.array(indices, copy=copy,
  59. dtype=idx_dtype)
  60. self.indptr = np.array(indptr, copy=copy, dtype=idx_dtype)
  61. self.data = np.array(data, copy=copy, dtype=dtype)
  62. else:
  63. raise ValueError("unrecognized {}_matrix "
  64. "constructor usage".format(self.format))
  65. else:
  66. # must be dense
  67. try:
  68. arg1 = np.asarray(arg1)
  69. except Exception as e:
  70. raise ValueError("unrecognized {}_matrix constructor usage"
  71. "".format(self.format)) from e
  72. self._set_self(self.__class__(
  73. self._coo_container(arg1, dtype=dtype)
  74. ))
  75. # Read matrix dimensions given, if any
  76. if shape is not None:
  77. self._shape = check_shape(shape)
  78. else:
  79. if self.shape is None:
  80. # shape not already set, try to infer dimensions
  81. try:
  82. major_dim = len(self.indptr) - 1
  83. minor_dim = self.indices.max() + 1
  84. except Exception as e:
  85. raise ValueError('unable to infer matrix dimensions') from e
  86. else:
  87. self._shape = check_shape(self._swap((major_dim,
  88. minor_dim)))
  89. if dtype is not None:
  90. self.data = self.data.astype(dtype, copy=False)
  91. self.check_format(full_check=False)
  92. def _getnnz(self, axis=None):
  93. if axis is None:
  94. return int(self.indptr[-1])
  95. else:
  96. if axis < 0:
  97. axis += 2
  98. axis, _ = self._swap((axis, 1 - axis))
  99. _, N = self._swap(self.shape)
  100. if axis == 0:
  101. return np.bincount(downcast_intp_index(self.indices),
  102. minlength=N)
  103. elif axis == 1:
  104. return np.diff(self.indptr)
  105. raise ValueError('axis out of bounds')
  106. _getnnz.__doc__ = _spbase._getnnz.__doc__
  107. def _set_self(self, other, copy=False):
  108. """take the member variables of other and assign them to self"""
  109. if copy:
  110. other = other.copy()
  111. self.data = other.data
  112. self.indices = other.indices
  113. self.indptr = other.indptr
  114. self._shape = check_shape(other.shape)
  115. def check_format(self, full_check=True):
  116. """check whether the matrix format is valid
  117. Parameters
  118. ----------
  119. full_check : bool, optional
  120. If `True`, rigorous check, O(N) operations. Otherwise
  121. basic check, O(1) operations (default True).
  122. """
  123. # use _swap to determine proper bounds
  124. major_name, minor_name = self._swap(('row', 'column'))
  125. major_dim, minor_dim = self._swap(self.shape)
  126. # index arrays should have integer data types
  127. if self.indptr.dtype.kind != 'i':
  128. warn("indptr array has non-integer dtype ({})"
  129. "".format(self.indptr.dtype.name), stacklevel=3)
  130. if self.indices.dtype.kind != 'i':
  131. warn("indices array has non-integer dtype ({})"
  132. "".format(self.indices.dtype.name), stacklevel=3)
  133. idx_dtype = self._get_index_dtype((self.indptr, self.indices))
  134. self.indptr = np.asarray(self.indptr, dtype=idx_dtype)
  135. self.indices = np.asarray(self.indices, dtype=idx_dtype)
  136. self.data = to_native(self.data)
  137. # check array shapes
  138. for x in [self.data.ndim, self.indices.ndim, self.indptr.ndim]:
  139. if x != 1:
  140. raise ValueError('data, indices, and indptr should be 1-D')
  141. # check index pointer
  142. if (len(self.indptr) != major_dim + 1):
  143. raise ValueError("index pointer size ({}) should be ({})"
  144. "".format(len(self.indptr), major_dim + 1))
  145. if (self.indptr[0] != 0):
  146. raise ValueError("index pointer should start with 0")
  147. # check index and data arrays
  148. if (len(self.indices) != len(self.data)):
  149. raise ValueError("indices and data should have the same size")
  150. if (self.indptr[-1] > len(self.indices)):
  151. raise ValueError("Last value of index pointer should be less than "
  152. "the size of index and data arrays")
  153. self.prune()
  154. if full_check:
  155. # check format validity (more expensive)
  156. if self.nnz > 0:
  157. if self.indices.max() >= minor_dim:
  158. raise ValueError("{} index values must be < {}"
  159. "".format(minor_name, minor_dim))
  160. if self.indices.min() < 0:
  161. raise ValueError("{} index values must be >= 0"
  162. "".format(minor_name))
  163. if np.diff(self.indptr).min() < 0:
  164. raise ValueError("index pointer values must form a "
  165. "non-decreasing sequence")
  166. # if not self.has_sorted_indices():
  167. # warn('Indices were not in sorted order. Sorting indices.')
  168. # self.sort_indices()
  169. # assert(self.has_sorted_indices())
  170. # TODO check for duplicates?
  171. #######################
  172. # Boolean comparisons #
  173. #######################
  174. def _scalar_binopt(self, other, op):
  175. """Scalar version of self._binopt, for cases in which no new nonzeros
  176. are added. Produces a new sparse array in canonical form.
  177. """
  178. self.sum_duplicates()
  179. res = self._with_data(op(self.data, other), copy=True)
  180. res.eliminate_zeros()
  181. return res
  182. def __eq__(self, other):
  183. # Scalar other.
  184. if isscalarlike(other):
  185. if np.isnan(other):
  186. return self.__class__(self.shape, dtype=np.bool_)
  187. if other == 0:
  188. warn("Comparing a sparse matrix with 0 using == is inefficient"
  189. ", try using != instead.", SparseEfficiencyWarning,
  190. stacklevel=3)
  191. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  192. inv = self._scalar_binopt(other, operator.ne)
  193. return all_true - inv
  194. else:
  195. return self._scalar_binopt(other, operator.eq)
  196. # Dense other.
  197. elif isdense(other):
  198. return self.todense() == other
  199. # Pydata sparse other.
  200. elif is_pydata_spmatrix(other):
  201. return NotImplemented
  202. # Sparse other.
  203. elif issparse(other):
  204. warn("Comparing sparse matrices using == is inefficient, try using"
  205. " != instead.", SparseEfficiencyWarning, stacklevel=3)
  206. # TODO sparse broadcasting
  207. if self.shape != other.shape:
  208. return False
  209. elif self.format != other.format:
  210. other = other.asformat(self.format)
  211. res = self._binopt(other, '_ne_')
  212. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  213. return all_true - res
  214. else:
  215. return False
  216. def __ne__(self, other):
  217. # Scalar other.
  218. if isscalarlike(other):
  219. if np.isnan(other):
  220. warn("Comparing a sparse matrix with nan using != is"
  221. " inefficient", SparseEfficiencyWarning, stacklevel=3)
  222. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  223. return all_true
  224. elif other != 0:
  225. warn("Comparing a sparse matrix with a nonzero scalar using !="
  226. " is inefficient, try using == instead.",
  227. SparseEfficiencyWarning, stacklevel=3)
  228. all_true = self.__class__(np.ones(self.shape), dtype=np.bool_)
  229. inv = self._scalar_binopt(other, operator.eq)
  230. return all_true - inv
  231. else:
  232. return self._scalar_binopt(other, operator.ne)
  233. # Dense other.
  234. elif isdense(other):
  235. return self.todense() != other
  236. # Pydata sparse other.
  237. elif is_pydata_spmatrix(other):
  238. return NotImplemented
  239. # Sparse other.
  240. elif issparse(other):
  241. # TODO sparse broadcasting
  242. if self.shape != other.shape:
  243. return True
  244. elif self.format != other.format:
  245. other = other.asformat(self.format)
  246. return self._binopt(other, '_ne_')
  247. else:
  248. return True
  249. def _inequality(self, other, op, op_name, bad_scalar_msg):
  250. # Scalar other.
  251. if isscalarlike(other):
  252. if 0 == other and op_name in ('_le_', '_ge_'):
  253. raise NotImplementedError(" >= and <= don't work with 0.")
  254. elif op(0, other):
  255. warn(bad_scalar_msg, SparseEfficiencyWarning)
  256. other_arr = np.empty(self.shape, dtype=np.result_type(other))
  257. other_arr.fill(other)
  258. other_arr = self.__class__(other_arr)
  259. return self._binopt(other_arr, op_name)
  260. else:
  261. return self._scalar_binopt(other, op)
  262. # Dense other.
  263. elif isdense(other):
  264. return op(self.todense(), other)
  265. # Sparse other.
  266. elif issparse(other):
  267. # TODO sparse broadcasting
  268. if self.shape != other.shape:
  269. raise ValueError("inconsistent shapes")
  270. elif self.format != other.format:
  271. other = other.asformat(self.format)
  272. if op_name not in ('_ge_', '_le_'):
  273. return self._binopt(other, op_name)
  274. warn("Comparing sparse matrices using >= and <= is inefficient, "
  275. "using <, >, or !=, instead.", SparseEfficiencyWarning)
  276. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  277. res = self._binopt(other, '_gt_' if op_name == '_le_' else '_lt_')
  278. return all_true - res
  279. else:
  280. raise ValueError("Operands could not be compared.")
  281. def __lt__(self, other):
  282. return self._inequality(other, operator.lt, '_lt_',
  283. "Comparing a sparse matrix with a scalar "
  284. "greater than zero using < is inefficient, "
  285. "try using >= instead.")
  286. def __gt__(self, other):
  287. return self._inequality(other, operator.gt, '_gt_',
  288. "Comparing a sparse matrix with a scalar "
  289. "less than zero using > is inefficient, "
  290. "try using <= instead.")
  291. def __le__(self, other):
  292. return self._inequality(other, operator.le, '_le_',
  293. "Comparing a sparse matrix with a scalar "
  294. "greater than zero using <= is inefficient, "
  295. "try using > instead.")
  296. def __ge__(self, other):
  297. return self._inequality(other, operator.ge, '_ge_',
  298. "Comparing a sparse matrix with a scalar "
  299. "less than zero using >= is inefficient, "
  300. "try using < instead.")
  301. #################################
  302. # Arithmetic operator overrides #
  303. #################################
  304. def _add_dense(self, other):
  305. if other.shape != self.shape:
  306. raise ValueError('Incompatible shapes ({} and {})'
  307. .format(self.shape, other.shape))
  308. dtype = upcast_char(self.dtype.char, other.dtype.char)
  309. order = self._swap('CF')[0]
  310. result = np.array(other, dtype=dtype, order=order, copy=True)
  311. M, N = self._swap(self.shape)
  312. y = result if result.flags.c_contiguous else result.T
  313. csr_todense(M, N, self.indptr, self.indices, self.data, y)
  314. return self._container(result, copy=False)
  315. def _add_sparse(self, other):
  316. return self._binopt(other, '_plus_')
  317. def _sub_sparse(self, other):
  318. return self._binopt(other, '_minus_')
  319. def multiply(self, other):
  320. """Point-wise multiplication by another matrix, vector, or
  321. scalar.
  322. """
  323. # Scalar multiplication.
  324. if isscalarlike(other):
  325. return self._mul_scalar(other)
  326. # Sparse matrix or vector.
  327. if issparse(other):
  328. if self.shape == other.shape:
  329. other = self.__class__(other)
  330. return self._binopt(other, '_elmul_')
  331. # Single element.
  332. elif other.shape == (1, 1):
  333. return self._mul_scalar(other.toarray()[0, 0])
  334. elif self.shape == (1, 1):
  335. return other._mul_scalar(self.toarray()[0, 0])
  336. # A row times a column.
  337. elif self.shape[1] == 1 and other.shape[0] == 1:
  338. return self._mul_sparse_matrix(other.tocsc())
  339. elif self.shape[0] == 1 and other.shape[1] == 1:
  340. return other._mul_sparse_matrix(self.tocsc())
  341. # Row vector times matrix. other is a row.
  342. elif other.shape[0] == 1 and self.shape[1] == other.shape[1]:
  343. other = self._dia_container(
  344. (other.toarray().ravel(), [0]),
  345. shape=(other.shape[1], other.shape[1])
  346. )
  347. return self._mul_sparse_matrix(other)
  348. # self is a row.
  349. elif self.shape[0] == 1 and self.shape[1] == other.shape[1]:
  350. copy = self._dia_container(
  351. (self.toarray().ravel(), [0]),
  352. shape=(self.shape[1], self.shape[1])
  353. )
  354. return other._mul_sparse_matrix(copy)
  355. # Column vector times matrix. other is a column.
  356. elif other.shape[1] == 1 and self.shape[0] == other.shape[0]:
  357. other = self._dia_container(
  358. (other.toarray().ravel(), [0]),
  359. shape=(other.shape[0], other.shape[0])
  360. )
  361. return other._mul_sparse_matrix(self)
  362. # self is a column.
  363. elif self.shape[1] == 1 and self.shape[0] == other.shape[0]:
  364. copy = self._dia_container(
  365. (self.toarray().ravel(), [0]),
  366. shape=(self.shape[0], self.shape[0])
  367. )
  368. return copy._mul_sparse_matrix(other)
  369. else:
  370. raise ValueError("inconsistent shapes")
  371. # Assume other is a dense matrix/array, which produces a single-item
  372. # object array if other isn't convertible to ndarray.
  373. other = np.atleast_2d(other)
  374. if other.ndim != 2:
  375. return np.multiply(self.toarray(), other)
  376. # Single element / wrapped object.
  377. if other.size == 1:
  378. return self._mul_scalar(other.flat[0])
  379. # Fast case for trivial sparse matrix.
  380. elif self.shape == (1, 1):
  381. return np.multiply(self.toarray()[0, 0], other)
  382. ret = self.tocoo()
  383. # Matching shapes.
  384. if self.shape == other.shape:
  385. data = np.multiply(ret.data, other[ret.row, ret.col])
  386. # Sparse row vector times...
  387. elif self.shape[0] == 1:
  388. if other.shape[1] == 1: # Dense column vector.
  389. data = np.multiply(ret.data, other)
  390. elif other.shape[1] == self.shape[1]: # Dense matrix.
  391. data = np.multiply(ret.data, other[:, ret.col])
  392. else:
  393. raise ValueError("inconsistent shapes")
  394. row = np.repeat(np.arange(other.shape[0]), len(ret.row))
  395. col = np.tile(ret.col, other.shape[0])
  396. return self._coo_container(
  397. (data.view(np.ndarray).ravel(), (row, col)),
  398. shape=(other.shape[0], self.shape[1]),
  399. copy=False
  400. )
  401. # Sparse column vector times...
  402. elif self.shape[1] == 1:
  403. if other.shape[0] == 1: # Dense row vector.
  404. data = np.multiply(ret.data[:, None], other)
  405. elif other.shape[0] == self.shape[0]: # Dense matrix.
  406. data = np.multiply(ret.data[:, None], other[ret.row])
  407. else:
  408. raise ValueError("inconsistent shapes")
  409. row = np.repeat(ret.row, other.shape[1])
  410. col = np.tile(np.arange(other.shape[1]), len(ret.col))
  411. return self._coo_container(
  412. (data.view(np.ndarray).ravel(), (row, col)),
  413. shape=(self.shape[0], other.shape[1]),
  414. copy=False
  415. )
  416. # Sparse matrix times dense row vector.
  417. elif other.shape[0] == 1 and self.shape[1] == other.shape[1]:
  418. data = np.multiply(ret.data, other[:, ret.col].ravel())
  419. # Sparse matrix times dense column vector.
  420. elif other.shape[1] == 1 and self.shape[0] == other.shape[0]:
  421. data = np.multiply(ret.data, other[ret.row].ravel())
  422. else:
  423. raise ValueError("inconsistent shapes")
  424. ret.data = data.view(np.ndarray).ravel()
  425. return ret
  426. ###########################
  427. # Multiplication handlers #
  428. ###########################
  429. def _mul_vector(self, other):
  430. M, N = self.shape
  431. # output array
  432. result = np.zeros(M, dtype=upcast_char(self.dtype.char,
  433. other.dtype.char))
  434. # csr_matvec or csc_matvec
  435. fn = getattr(_sparsetools, self.format + '_matvec')
  436. fn(M, N, self.indptr, self.indices, self.data, other, result)
  437. return result
  438. def _mul_multivector(self, other):
  439. M, N = self.shape
  440. n_vecs = other.shape[1] # number of column vectors
  441. result = np.zeros((M, n_vecs),
  442. dtype=upcast_char(self.dtype.char, other.dtype.char))
  443. # csr_matvecs or csc_matvecs
  444. fn = getattr(_sparsetools, self.format + '_matvecs')
  445. fn(M, N, n_vecs, self.indptr, self.indices, self.data,
  446. other.ravel(), result.ravel())
  447. return result
  448. def _mul_sparse_matrix(self, other):
  449. M, K1 = self.shape
  450. K2, N = other.shape
  451. major_axis = self._swap((M, N))[0]
  452. other = self.__class__(other) # convert to this format
  453. idx_dtype = self._get_index_dtype((self.indptr, self.indices,
  454. other.indptr, other.indices))
  455. fn = getattr(_sparsetools, self.format + '_matmat_maxnnz')
  456. nnz = fn(M, N,
  457. np.asarray(self.indptr, dtype=idx_dtype),
  458. np.asarray(self.indices, dtype=idx_dtype),
  459. np.asarray(other.indptr, dtype=idx_dtype),
  460. np.asarray(other.indices, dtype=idx_dtype))
  461. idx_dtype = self._get_index_dtype((self.indptr, self.indices,
  462. other.indptr, other.indices),
  463. maxval=nnz)
  464. indptr = np.empty(major_axis + 1, dtype=idx_dtype)
  465. indices = np.empty(nnz, dtype=idx_dtype)
  466. data = np.empty(nnz, dtype=upcast(self.dtype, other.dtype))
  467. fn = getattr(_sparsetools, self.format + '_matmat')
  468. fn(M, N, np.asarray(self.indptr, dtype=idx_dtype),
  469. np.asarray(self.indices, dtype=idx_dtype),
  470. self.data,
  471. np.asarray(other.indptr, dtype=idx_dtype),
  472. np.asarray(other.indices, dtype=idx_dtype),
  473. other.data,
  474. indptr, indices, data)
  475. return self.__class__((data, indices, indptr), shape=(M, N))
  476. def diagonal(self, k=0):
  477. rows, cols = self.shape
  478. if k <= -rows or k >= cols:
  479. return np.empty(0, dtype=self.data.dtype)
  480. fn = getattr(_sparsetools, self.format + "_diagonal")
  481. y = np.empty(min(rows + min(k, 0), cols - max(k, 0)),
  482. dtype=upcast(self.dtype))
  483. fn(k, self.shape[0], self.shape[1], self.indptr, self.indices,
  484. self.data, y)
  485. return y
  486. diagonal.__doc__ = _spbase.diagonal.__doc__
  487. #####################
  488. # Other binary ops #
  489. #####################
  490. def _maximum_minimum(self, other, npop, op_name, dense_check):
  491. if isscalarlike(other):
  492. if dense_check(other):
  493. warn("Taking maximum (minimum) with > 0 (< 0) number results"
  494. " to a dense matrix.", SparseEfficiencyWarning,
  495. stacklevel=3)
  496. other_arr = np.empty(self.shape, dtype=np.asarray(other).dtype)
  497. other_arr.fill(other)
  498. other_arr = self.__class__(other_arr)
  499. return self._binopt(other_arr, op_name)
  500. else:
  501. self.sum_duplicates()
  502. new_data = npop(self.data, np.asarray(other))
  503. mat = self.__class__((new_data, self.indices, self.indptr),
  504. dtype=new_data.dtype, shape=self.shape)
  505. return mat
  506. elif isdense(other):
  507. return npop(self.todense(), other)
  508. elif issparse(other):
  509. return self._binopt(other, op_name)
  510. else:
  511. raise ValueError("Operands not compatible.")
  512. def maximum(self, other):
  513. return self._maximum_minimum(other, np.maximum,
  514. '_maximum_', lambda x: np.asarray(x) > 0)
  515. maximum.__doc__ = _spbase.maximum.__doc__
  516. def minimum(self, other):
  517. return self._maximum_minimum(other, np.minimum,
  518. '_minimum_', lambda x: np.asarray(x) < 0)
  519. minimum.__doc__ = _spbase.minimum.__doc__
  520. #####################
  521. # Reduce operations #
  522. #####################
  523. def sum(self, axis=None, dtype=None, out=None):
  524. """Sum the matrix over the given axis. If the axis is None, sum
  525. over both rows and columns, returning a scalar.
  526. """
  527. # The _spbase base class already does axis=0 and axis=1 efficiently
  528. # so we only do the case axis=None here
  529. if (not hasattr(self, 'blocksize') and
  530. axis in self._swap(((1, -1), (0, 2)))[0]):
  531. # faster than multiplication for large minor axis in CSC/CSR
  532. res_dtype = get_sum_dtype(self.dtype)
  533. ret = np.zeros(len(self.indptr) - 1, dtype=res_dtype)
  534. major_index, value = self._minor_reduce(np.add)
  535. ret[major_index] = value
  536. ret = self._ascontainer(ret)
  537. if axis % 2 == 1:
  538. ret = ret.T
  539. if out is not None and out.shape != ret.shape:
  540. raise ValueError('dimensions do not match')
  541. return ret.sum(axis=(), dtype=dtype, out=out)
  542. # _spbase will handle the remaining situations when axis
  543. # is in {None, -1, 0, 1}
  544. else:
  545. return _spbase.sum(self, axis=axis, dtype=dtype, out=out)
  546. sum.__doc__ = _spbase.sum.__doc__
  547. def _minor_reduce(self, ufunc, data=None):
  548. """Reduce nonzeros with a ufunc over the minor axis when non-empty
  549. Can be applied to a function of self.data by supplying data parameter.
  550. Warning: this does not call sum_duplicates()
  551. Returns
  552. -------
  553. major_index : array of ints
  554. Major indices where nonzero
  555. value : array of self.dtype
  556. Reduce result for nonzeros in each major_index
  557. """
  558. if data is None:
  559. data = self.data
  560. major_index = np.flatnonzero(np.diff(self.indptr))
  561. value = ufunc.reduceat(data,
  562. downcast_intp_index(self.indptr[major_index]))
  563. return major_index, value
  564. #######################
  565. # Getting and Setting #
  566. #######################
  567. def _get_intXint(self, row, col):
  568. M, N = self._swap(self.shape)
  569. major, minor = self._swap((row, col))
  570. indptr, indices, data = get_csr_submatrix(
  571. M, N, self.indptr, self.indices, self.data,
  572. major, major + 1, minor, minor + 1)
  573. return data.sum(dtype=self.dtype)
  574. def _get_sliceXslice(self, row, col):
  575. major, minor = self._swap((row, col))
  576. if major.step in (1, None) and minor.step in (1, None):
  577. return self._get_submatrix(major, minor, copy=True)
  578. return self._major_slice(major)._minor_slice(minor)
  579. def _get_arrayXarray(self, row, col):
  580. # inner indexing
  581. idx_dtype = self.indices.dtype
  582. M, N = self._swap(self.shape)
  583. major, minor = self._swap((row, col))
  584. major = np.asarray(major, dtype=idx_dtype)
  585. minor = np.asarray(minor, dtype=idx_dtype)
  586. val = np.empty(major.size, dtype=self.dtype)
  587. csr_sample_values(M, N, self.indptr, self.indices, self.data,
  588. major.size, major.ravel(), minor.ravel(), val)
  589. if major.ndim == 1:
  590. return self._ascontainer(val)
  591. return self.__class__(val.reshape(major.shape))
  592. def _get_columnXarray(self, row, col):
  593. # outer indexing
  594. major, minor = self._swap((row, col))
  595. return self._major_index_fancy(major)._minor_index_fancy(minor)
  596. def _major_index_fancy(self, idx):
  597. """Index along the major axis where idx is an array of ints.
  598. """
  599. idx_dtype = self.indices.dtype
  600. indices = np.asarray(idx, dtype=idx_dtype).ravel()
  601. _, N = self._swap(self.shape)
  602. M = len(indices)
  603. new_shape = self._swap((M, N))
  604. if M == 0:
  605. return self.__class__(new_shape, dtype=self.dtype)
  606. row_nnz = self.indptr[indices + 1] - self.indptr[indices]
  607. idx_dtype = self.indices.dtype
  608. res_indptr = np.zeros(M+1, dtype=idx_dtype)
  609. np.cumsum(row_nnz, out=res_indptr[1:])
  610. nnz = res_indptr[-1]
  611. res_indices = np.empty(nnz, dtype=idx_dtype)
  612. res_data = np.empty(nnz, dtype=self.dtype)
  613. csr_row_index(M, indices, self.indptr, self.indices, self.data,
  614. res_indices, res_data)
  615. return self.__class__((res_data, res_indices, res_indptr),
  616. shape=new_shape, copy=False)
  617. def _major_slice(self, idx, copy=False):
  618. """Index along the major axis where idx is a slice object.
  619. """
  620. if idx == slice(None):
  621. return self.copy() if copy else self
  622. M, N = self._swap(self.shape)
  623. start, stop, step = idx.indices(M)
  624. M = len(range(start, stop, step))
  625. new_shape = self._swap((M, N))
  626. if M == 0:
  627. return self.__class__(new_shape, dtype=self.dtype)
  628. # Work out what slices are needed for `row_nnz`
  629. # start,stop can be -1, only if step is negative
  630. start0, stop0 = start, stop
  631. if stop == -1 and start >= 0:
  632. stop0 = None
  633. start1, stop1 = start + 1, stop + 1
  634. row_nnz = self.indptr[start1:stop1:step] - \
  635. self.indptr[start0:stop0:step]
  636. idx_dtype = self.indices.dtype
  637. res_indptr = np.zeros(M+1, dtype=idx_dtype)
  638. np.cumsum(row_nnz, out=res_indptr[1:])
  639. if step == 1:
  640. all_idx = slice(self.indptr[start], self.indptr[stop])
  641. res_indices = np.array(self.indices[all_idx], copy=copy)
  642. res_data = np.array(self.data[all_idx], copy=copy)
  643. else:
  644. nnz = res_indptr[-1]
  645. res_indices = np.empty(nnz, dtype=idx_dtype)
  646. res_data = np.empty(nnz, dtype=self.dtype)
  647. csr_row_slice(start, stop, step, self.indptr, self.indices,
  648. self.data, res_indices, res_data)
  649. return self.__class__((res_data, res_indices, res_indptr),
  650. shape=new_shape, copy=False)
  651. def _minor_index_fancy(self, idx):
  652. """Index along the minor axis where idx is an array of ints.
  653. """
  654. idx_dtype = self.indices.dtype
  655. idx = np.asarray(idx, dtype=idx_dtype).ravel()
  656. M, N = self._swap(self.shape)
  657. k = len(idx)
  658. new_shape = self._swap((M, k))
  659. if k == 0:
  660. return self.__class__(new_shape, dtype=self.dtype)
  661. # pass 1: count idx entries and compute new indptr
  662. col_offsets = np.zeros(N, dtype=idx_dtype)
  663. res_indptr = np.empty_like(self.indptr)
  664. csr_column_index1(k, idx, M, N, self.indptr, self.indices,
  665. col_offsets, res_indptr)
  666. # pass 2: copy indices/data for selected idxs
  667. col_order = np.argsort(idx).astype(idx_dtype, copy=False)
  668. nnz = res_indptr[-1]
  669. res_indices = np.empty(nnz, dtype=idx_dtype)
  670. res_data = np.empty(nnz, dtype=self.dtype)
  671. csr_column_index2(col_order, col_offsets, len(self.indices),
  672. self.indices, self.data, res_indices, res_data)
  673. return self.__class__((res_data, res_indices, res_indptr),
  674. shape=new_shape, copy=False)
  675. def _minor_slice(self, idx, copy=False):
  676. """Index along the minor axis where idx is a slice object.
  677. """
  678. if idx == slice(None):
  679. return self.copy() if copy else self
  680. M, N = self._swap(self.shape)
  681. start, stop, step = idx.indices(N)
  682. N = len(range(start, stop, step))
  683. if N == 0:
  684. return self.__class__(self._swap((M, N)), dtype=self.dtype)
  685. if step == 1:
  686. return self._get_submatrix(minor=idx, copy=copy)
  687. # TODO: don't fall back to fancy indexing here
  688. return self._minor_index_fancy(np.arange(start, stop, step))
  689. def _get_submatrix(self, major=None, minor=None, copy=False):
  690. """Return a submatrix of this matrix.
  691. major, minor: None, int, or slice with step 1
  692. """
  693. M, N = self._swap(self.shape)
  694. i0, i1 = _process_slice(major, M)
  695. j0, j1 = _process_slice(minor, N)
  696. if i0 == 0 and j0 == 0 and i1 == M and j1 == N:
  697. return self.copy() if copy else self
  698. indptr, indices, data = get_csr_submatrix(
  699. M, N, self.indptr, self.indices, self.data, i0, i1, j0, j1)
  700. shape = self._swap((i1 - i0, j1 - j0))
  701. return self.__class__((data, indices, indptr), shape=shape,
  702. dtype=self.dtype, copy=False)
  703. def _set_intXint(self, row, col, x):
  704. i, j = self._swap((row, col))
  705. self._set_many(i, j, x)
  706. def _set_arrayXarray(self, row, col, x):
  707. i, j = self._swap((row, col))
  708. self._set_many(i, j, x)
  709. def _set_arrayXarray_sparse(self, row, col, x):
  710. # clear entries that will be overwritten
  711. self._zero_many(*self._swap((row, col)))
  712. M, N = row.shape # matches col.shape
  713. broadcast_row = M != 1 and x.shape[0] == 1
  714. broadcast_col = N != 1 and x.shape[1] == 1
  715. r, c = x.row, x.col
  716. x = np.asarray(x.data, dtype=self.dtype)
  717. if x.size == 0:
  718. return
  719. if broadcast_row:
  720. r = np.repeat(np.arange(M), len(r))
  721. c = np.tile(c, M)
  722. x = np.tile(x, M)
  723. if broadcast_col:
  724. r = np.repeat(r, N)
  725. c = np.tile(np.arange(N), len(c))
  726. x = np.repeat(x, N)
  727. # only assign entries in the new sparsity structure
  728. i, j = self._swap((row[r, c], col[r, c]))
  729. self._set_many(i, j, x)
  730. def _setdiag(self, values, k):
  731. if 0 in self.shape:
  732. return
  733. M, N = self.shape
  734. broadcast = (values.ndim == 0)
  735. if k < 0:
  736. if broadcast:
  737. max_index = min(M + k, N)
  738. else:
  739. max_index = min(M + k, N, len(values))
  740. i = np.arange(max_index, dtype=self.indices.dtype)
  741. j = np.arange(max_index, dtype=self.indices.dtype)
  742. i -= k
  743. else:
  744. if broadcast:
  745. max_index = min(M, N - k)
  746. else:
  747. max_index = min(M, N - k, len(values))
  748. i = np.arange(max_index, dtype=self.indices.dtype)
  749. j = np.arange(max_index, dtype=self.indices.dtype)
  750. j += k
  751. if not broadcast:
  752. values = values[:len(i)]
  753. self[i, j] = values
  754. def _prepare_indices(self, i, j):
  755. M, N = self._swap(self.shape)
  756. def check_bounds(indices, bound):
  757. idx = indices.max()
  758. if idx >= bound:
  759. raise IndexError('index (%d) out of range (>= %d)' %
  760. (idx, bound))
  761. idx = indices.min()
  762. if idx < -bound:
  763. raise IndexError('index (%d) out of range (< -%d)' %
  764. (idx, bound))
  765. i = np.array(i, dtype=self.indices.dtype, copy=False, ndmin=1).ravel()
  766. j = np.array(j, dtype=self.indices.dtype, copy=False, ndmin=1).ravel()
  767. check_bounds(i, M)
  768. check_bounds(j, N)
  769. return i, j, M, N
  770. def _set_many(self, i, j, x):
  771. """Sets value at each (i, j) to x
  772. Here (i,j) index major and minor respectively, and must not contain
  773. duplicate entries.
  774. """
  775. i, j, M, N = self._prepare_indices(i, j)
  776. x = np.array(x, dtype=self.dtype, copy=False, ndmin=1).ravel()
  777. n_samples = x.size
  778. offsets = np.empty(n_samples, dtype=self.indices.dtype)
  779. ret = csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
  780. i, j, offsets)
  781. if ret == 1:
  782. # rinse and repeat
  783. self.sum_duplicates()
  784. csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
  785. i, j, offsets)
  786. if -1 not in offsets:
  787. # only affects existing non-zero cells
  788. self.data[offsets] = x
  789. return
  790. else:
  791. warn("Changing the sparsity structure of a {}_matrix is expensive."
  792. " lil_matrix is more efficient.".format(self.format),
  793. SparseEfficiencyWarning, stacklevel=3)
  794. # replace where possible
  795. mask = offsets > -1
  796. self.data[offsets[mask]] = x[mask]
  797. # only insertions remain
  798. mask = ~mask
  799. i = i[mask]
  800. i[i < 0] += M
  801. j = j[mask]
  802. j[j < 0] += N
  803. self._insert_many(i, j, x[mask])
  804. def _zero_many(self, i, j):
  805. """Sets value at each (i, j) to zero, preserving sparsity structure.
  806. Here (i,j) index major and minor respectively.
  807. """
  808. i, j, M, N = self._prepare_indices(i, j)
  809. n_samples = len(i)
  810. offsets = np.empty(n_samples, dtype=self.indices.dtype)
  811. ret = csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
  812. i, j, offsets)
  813. if ret == 1:
  814. # rinse and repeat
  815. self.sum_duplicates()
  816. csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
  817. i, j, offsets)
  818. # only assign zeros to the existing sparsity structure
  819. self.data[offsets[offsets > -1]] = 0
  820. def _insert_many(self, i, j, x):
  821. """Inserts new nonzero at each (i, j) with value x
  822. Here (i,j) index major and minor respectively.
  823. i, j and x must be non-empty, 1d arrays.
  824. Inserts each major group (e.g. all entries per row) at a time.
  825. Maintains has_sorted_indices property.
  826. Modifies i, j, x in place.
  827. """
  828. order = np.argsort(i, kind='mergesort') # stable for duplicates
  829. i = i.take(order, mode='clip')
  830. j = j.take(order, mode='clip')
  831. x = x.take(order, mode='clip')
  832. do_sort = self.has_sorted_indices
  833. # Update index data type
  834. idx_dtype = self._get_index_dtype((self.indices, self.indptr),
  835. maxval=(self.indptr[-1] + x.size))
  836. self.indptr = np.asarray(self.indptr, dtype=idx_dtype)
  837. self.indices = np.asarray(self.indices, dtype=idx_dtype)
  838. i = np.asarray(i, dtype=idx_dtype)
  839. j = np.asarray(j, dtype=idx_dtype)
  840. # Collate old and new in chunks by major index
  841. indices_parts = []
  842. data_parts = []
  843. ui, ui_indptr = np.unique(i, return_index=True)
  844. ui_indptr = np.append(ui_indptr, len(j))
  845. new_nnzs = np.diff(ui_indptr)
  846. prev = 0
  847. for c, (ii, js, je) in enumerate(zip(ui, ui_indptr, ui_indptr[1:])):
  848. # old entries
  849. start = self.indptr[prev]
  850. stop = self.indptr[ii]
  851. indices_parts.append(self.indices[start:stop])
  852. data_parts.append(self.data[start:stop])
  853. # handle duplicate j: keep last setting
  854. uj, uj_indptr = np.unique(j[js:je][::-1], return_index=True)
  855. if len(uj) == je - js:
  856. indices_parts.append(j[js:je])
  857. data_parts.append(x[js:je])
  858. else:
  859. indices_parts.append(j[js:je][::-1][uj_indptr])
  860. data_parts.append(x[js:je][::-1][uj_indptr])
  861. new_nnzs[c] = len(uj)
  862. prev = ii
  863. # remaining old entries
  864. start = self.indptr[ii]
  865. indices_parts.append(self.indices[start:])
  866. data_parts.append(self.data[start:])
  867. # update attributes
  868. self.indices = np.concatenate(indices_parts)
  869. self.data = np.concatenate(data_parts)
  870. nnzs = np.empty(self.indptr.shape, dtype=idx_dtype)
  871. nnzs[0] = idx_dtype(0)
  872. indptr_diff = np.diff(self.indptr)
  873. indptr_diff[ui] += new_nnzs
  874. nnzs[1:] = indptr_diff
  875. self.indptr = np.cumsum(nnzs, out=nnzs)
  876. if do_sort:
  877. # TODO: only sort where necessary
  878. self.has_sorted_indices = False
  879. self.sort_indices()
  880. self.check_format(full_check=False)
  881. ######################
  882. # Conversion methods #
  883. ######################
  884. def tocoo(self, copy=True):
  885. major_dim, minor_dim = self._swap(self.shape)
  886. minor_indices = self.indices
  887. major_indices = np.empty(len(minor_indices), dtype=self.indices.dtype)
  888. _sparsetools.expandptr(major_dim, self.indptr, major_indices)
  889. row, col = self._swap((major_indices, minor_indices))
  890. return self._coo_container(
  891. (self.data, (row, col)), self.shape, copy=copy,
  892. dtype=self.dtype
  893. )
  894. tocoo.__doc__ = _spbase.tocoo.__doc__
  895. def toarray(self, order=None, out=None):
  896. if out is None and order is None:
  897. order = self._swap('cf')[0]
  898. out = self._process_toarray_args(order, out)
  899. if not (out.flags.c_contiguous or out.flags.f_contiguous):
  900. raise ValueError('Output array must be C or F contiguous')
  901. # align ideal order with output array order
  902. if out.flags.c_contiguous:
  903. x = self.tocsr()
  904. y = out
  905. else:
  906. x = self.tocsc()
  907. y = out.T
  908. M, N = x._swap(x.shape)
  909. csr_todense(M, N, x.indptr, x.indices, x.data, y)
  910. return out
  911. toarray.__doc__ = _spbase.toarray.__doc__
  912. ##############################################################
  913. # methods that examine or modify the internal data structure #
  914. ##############################################################
  915. def eliminate_zeros(self):
  916. """Remove zero entries from the matrix
  917. This is an *in place* operation.
  918. """
  919. M, N = self._swap(self.shape)
  920. _sparsetools.csr_eliminate_zeros(M, N, self.indptr, self.indices,
  921. self.data)
  922. self.prune() # nnz may have changed
  923. def __get_has_canonical_format(self):
  924. """Determine whether the matrix has sorted indices and no duplicates
  925. Returns
  926. - True: if the above applies
  927. - False: otherwise
  928. has_canonical_format implies has_sorted_indices, so if the latter flag
  929. is False, so will the former be; if the former is found True, the
  930. latter flag is also set.
  931. """
  932. # first check to see if result was cached
  933. if not getattr(self, '_has_sorted_indices', True):
  934. # not sorted => not canonical
  935. self._has_canonical_format = False
  936. elif not hasattr(self, '_has_canonical_format'):
  937. self.has_canonical_format = bool(
  938. _sparsetools.csr_has_canonical_format(
  939. len(self.indptr) - 1, self.indptr, self.indices))
  940. return self._has_canonical_format
  941. def __set_has_canonical_format(self, val):
  942. self._has_canonical_format = bool(val)
  943. if val:
  944. self.has_sorted_indices = True
  945. has_canonical_format = property(fget=__get_has_canonical_format,
  946. fset=__set_has_canonical_format)
  947. def sum_duplicates(self):
  948. """Eliminate duplicate matrix entries by adding them together
  949. This is an *in place* operation.
  950. """
  951. if self.has_canonical_format:
  952. return
  953. self.sort_indices()
  954. M, N = self._swap(self.shape)
  955. _sparsetools.csr_sum_duplicates(M, N, self.indptr, self.indices,
  956. self.data)
  957. self.prune() # nnz may have changed
  958. self.has_canonical_format = True
  959. def __get_sorted(self):
  960. """Determine whether the matrix has sorted indices
  961. Returns
  962. - True: if the indices of the matrix are in sorted order
  963. - False: otherwise
  964. """
  965. # first check to see if result was cached
  966. if not hasattr(self, '_has_sorted_indices'):
  967. self._has_sorted_indices = bool(
  968. _sparsetools.csr_has_sorted_indices(
  969. len(self.indptr) - 1, self.indptr, self.indices))
  970. return self._has_sorted_indices
  971. def __set_sorted(self, val):
  972. self._has_sorted_indices = bool(val)
  973. has_sorted_indices = property(fget=__get_sorted, fset=__set_sorted)
  974. def sorted_indices(self):
  975. """Return a copy of this matrix with sorted indices
  976. """
  977. A = self.copy()
  978. A.sort_indices()
  979. return A
  980. # an alternative that has linear complexity is the following
  981. # although the previous option is typically faster
  982. # return self.toother().toother()
  983. def sort_indices(self):
  984. """Sort the indices of this matrix *in place*
  985. """
  986. if not self.has_sorted_indices:
  987. _sparsetools.csr_sort_indices(len(self.indptr) - 1, self.indptr,
  988. self.indices, self.data)
  989. self.has_sorted_indices = True
  990. def prune(self):
  991. """Remove empty space after all non-zero elements.
  992. """
  993. major_dim = self._swap(self.shape)[0]
  994. if len(self.indptr) != major_dim + 1:
  995. raise ValueError('index pointer has invalid length')
  996. if len(self.indices) < self.nnz:
  997. raise ValueError('indices array has fewer than nnz elements')
  998. if len(self.data) < self.nnz:
  999. raise ValueError('data array has fewer than nnz elements')
  1000. self.indices = _prune_array(self.indices[:self.nnz])
  1001. self.data = _prune_array(self.data[:self.nnz])
  1002. def resize(self, *shape):
  1003. shape = check_shape(shape)
  1004. if hasattr(self, 'blocksize'):
  1005. bm, bn = self.blocksize
  1006. new_M, rm = divmod(shape[0], bm)
  1007. new_N, rn = divmod(shape[1], bn)
  1008. if rm or rn:
  1009. raise ValueError("shape must be divisible into {} blocks. "
  1010. "Got {}".format(self.blocksize, shape))
  1011. M, N = self.shape[0] // bm, self.shape[1] // bn
  1012. else:
  1013. new_M, new_N = self._swap(shape)
  1014. M, N = self._swap(self.shape)
  1015. if new_M < M:
  1016. self.indices = self.indices[:self.indptr[new_M]]
  1017. self.data = self.data[:self.indptr[new_M]]
  1018. self.indptr = self.indptr[:new_M + 1]
  1019. elif new_M > M:
  1020. self.indptr = np.resize(self.indptr, new_M + 1)
  1021. self.indptr[M + 1:].fill(self.indptr[M])
  1022. if new_N < N:
  1023. mask = self.indices < new_N
  1024. if not np.all(mask):
  1025. self.indices = self.indices[mask]
  1026. self.data = self.data[mask]
  1027. major_index, val = self._minor_reduce(np.add, mask)
  1028. self.indptr.fill(0)
  1029. self.indptr[1:][major_index] = val
  1030. np.cumsum(self.indptr, out=self.indptr)
  1031. self._shape = shape
  1032. resize.__doc__ = _spbase.resize.__doc__
  1033. ###################
  1034. # utility methods #
  1035. ###################
  1036. # needed by _data_matrix
  1037. def _with_data(self, data, copy=True):
  1038. """Returns a matrix with the same sparsity structure as self,
  1039. but with different data. By default the structure arrays
  1040. (i.e. .indptr and .indices) are copied.
  1041. """
  1042. if copy:
  1043. return self.__class__((data, self.indices.copy(),
  1044. self.indptr.copy()),
  1045. shape=self.shape,
  1046. dtype=data.dtype)
  1047. else:
  1048. return self.__class__((data, self.indices, self.indptr),
  1049. shape=self.shape, dtype=data.dtype)
  1050. def _binopt(self, other, op):
  1051. """apply the binary operation fn to two sparse matrices."""
  1052. other = self.__class__(other)
  1053. # e.g. csr_plus_csr, csr_minus_csr, etc.
  1054. fn = getattr(_sparsetools, self.format + op + self.format)
  1055. maxnnz = self.nnz + other.nnz
  1056. idx_dtype = self._get_index_dtype((self.indptr, self.indices,
  1057. other.indptr, other.indices),
  1058. maxval=maxnnz)
  1059. indptr = np.empty(self.indptr.shape, dtype=idx_dtype)
  1060. indices = np.empty(maxnnz, dtype=idx_dtype)
  1061. bool_ops = ['_ne_', '_lt_', '_gt_', '_le_', '_ge_']
  1062. if op in bool_ops:
  1063. data = np.empty(maxnnz, dtype=np.bool_)
  1064. else:
  1065. data = np.empty(maxnnz, dtype=upcast(self.dtype, other.dtype))
  1066. fn(self.shape[0], self.shape[1],
  1067. np.asarray(self.indptr, dtype=idx_dtype),
  1068. np.asarray(self.indices, dtype=idx_dtype),
  1069. self.data,
  1070. np.asarray(other.indptr, dtype=idx_dtype),
  1071. np.asarray(other.indices, dtype=idx_dtype),
  1072. other.data,
  1073. indptr, indices, data)
  1074. A = self.__class__((data, indices, indptr), shape=self.shape)
  1075. A.prune()
  1076. return A
  1077. def _divide_sparse(self, other):
  1078. """
  1079. Divide this matrix by a second sparse matrix.
  1080. """
  1081. if other.shape != self.shape:
  1082. raise ValueError('inconsistent shapes')
  1083. r = self._binopt(other, '_eldiv_')
  1084. if np.issubdtype(r.dtype, np.inexact):
  1085. # Eldiv leaves entries outside the combined sparsity
  1086. # pattern empty, so they must be filled manually.
  1087. # Everything outside of other's sparsity is NaN, and everything
  1088. # inside it is either zero or defined by eldiv.
  1089. out = np.empty(self.shape, dtype=self.dtype)
  1090. out.fill(np.nan)
  1091. row, col = other.nonzero()
  1092. out[row, col] = 0
  1093. r = r.tocoo()
  1094. out[r.row, r.col] = r.data
  1095. out = self._container(out)
  1096. else:
  1097. # integers types go with nan <-> 0
  1098. out = r
  1099. return out
  1100. def _process_slice(sl, num):
  1101. if sl is None:
  1102. i0, i1 = 0, num
  1103. elif isinstance(sl, slice):
  1104. i0, i1, stride = sl.indices(num)
  1105. if stride != 1:
  1106. raise ValueError('slicing with step != 1 not supported')
  1107. i0 = min(i0, i1) # give an empty slice when i0 > i1
  1108. elif isintlike(sl):
  1109. if sl < 0:
  1110. sl += num
  1111. i0, i1 = sl, sl + 1
  1112. if i0 < 0 or i1 > num:
  1113. raise IndexError('index out of bounds: 0 <= %d < %d <= %d' %
  1114. (i0, i1, num))
  1115. else:
  1116. raise TypeError('expected slice or scalar')
  1117. return i0, i1