decompressor.c 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765
  1. /**
  2. * Copyright (c) 2016-present, Gregory Szorc
  3. * All rights reserved.
  4. *
  5. * This software may be modified and distributed under the terms
  6. * of the BSD license. See the LICENSE file for details.
  7. */
  8. #include "python-zstandard.h"
  9. extern PyObject *ZstdError;
  10. /**
  11. * Ensure the ZSTD_DCtx on a decompressor is initiated and ready for a new
  12. * operation.
  13. */
  14. int ensure_dctx(ZstdDecompressor *decompressor, int loadDict) {
  15. size_t zresult;
  16. ZSTD_DCtx_reset(decompressor->dctx, ZSTD_reset_session_only);
  17. if (decompressor->maxWindowSize) {
  18. zresult = ZSTD_DCtx_setMaxWindowSize(decompressor->dctx,
  19. decompressor->maxWindowSize);
  20. if (ZSTD_isError(zresult)) {
  21. PyErr_Format(ZstdError, "unable to set max window size: %s",
  22. ZSTD_getErrorName(zresult));
  23. return 1;
  24. }
  25. }
  26. zresult = ZSTD_DCtx_setParameter(decompressor->dctx, ZSTD_d_format,
  27. decompressor->format);
  28. if (ZSTD_isError(zresult)) {
  29. PyErr_Format(ZstdError, "unable to set decoding format: %s",
  30. ZSTD_getErrorName(zresult));
  31. return 1;
  32. }
  33. if (loadDict && decompressor->dict) {
  34. if (ensure_ddict(decompressor->dict)) {
  35. return 1;
  36. }
  37. zresult =
  38. ZSTD_DCtx_refDDict(decompressor->dctx, decompressor->dict->ddict);
  39. if (ZSTD_isError(zresult)) {
  40. PyErr_Format(ZstdError,
  41. "unable to reference prepared dictionary: %s",
  42. ZSTD_getErrorName(zresult));
  43. return 1;
  44. }
  45. }
  46. return 0;
  47. }
  48. static int Decompressor_init(ZstdDecompressor *self, PyObject *args,
  49. PyObject *kwargs) {
  50. static char *kwlist[] = {"dict_data", "max_window_size", "format", NULL};
  51. PyObject *dict = NULL;
  52. Py_ssize_t maxWindowSize = 0;
  53. ZSTD_format_e format = ZSTD_f_zstd1;
  54. self->dctx = NULL;
  55. self->dict = NULL;
  56. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OnI:ZstdDecompressor",
  57. kwlist, &dict, &maxWindowSize, &format)) {
  58. return -1;
  59. }
  60. if (dict) {
  61. if (dict == Py_None) {
  62. dict = NULL;
  63. }
  64. else if (!PyObject_IsInstance(dict,
  65. (PyObject *)ZstdCompressionDictType)) {
  66. PyErr_Format(PyExc_TypeError,
  67. "dict_data must be zstd.ZstdCompressionDict");
  68. return -1;
  69. }
  70. }
  71. self->dctx = ZSTD_createDCtx();
  72. if (!self->dctx) {
  73. PyErr_NoMemory();
  74. goto except;
  75. }
  76. self->maxWindowSize = maxWindowSize;
  77. self->format = format;
  78. if (dict) {
  79. self->dict = (ZstdCompressionDict *)dict;
  80. Py_INCREF(dict);
  81. }
  82. if (ensure_dctx(self, 1)) {
  83. goto except;
  84. }
  85. return 0;
  86. except:
  87. Py_CLEAR(self->dict);
  88. if (self->dctx) {
  89. ZSTD_freeDCtx(self->dctx);
  90. self->dctx = NULL;
  91. }
  92. return -1;
  93. }
  94. static void Decompressor_dealloc(ZstdDecompressor *self) {
  95. Py_CLEAR(self->dict);
  96. if (self->dctx) {
  97. ZSTD_freeDCtx(self->dctx);
  98. self->dctx = NULL;
  99. }
  100. PyObject_Del(self);
  101. }
  102. static PyObject *Decompressor_memory_size(ZstdDecompressor *self) {
  103. if (self->dctx) {
  104. return PyLong_FromSize_t(ZSTD_sizeof_DCtx(self->dctx));
  105. }
  106. else {
  107. PyErr_SetString(
  108. ZstdError,
  109. "no decompressor context found; this should never happen");
  110. return NULL;
  111. }
  112. }
  113. static PyObject *Decompressor_copy_stream(ZstdDecompressor *self,
  114. PyObject *args, PyObject *kwargs) {
  115. static char *kwlist[] = {"ifh", "ofh", "read_size", "write_size", NULL};
  116. PyObject *source;
  117. PyObject *dest;
  118. size_t inSize = ZSTD_DStreamInSize();
  119. size_t outSize = ZSTD_DStreamOutSize();
  120. ZSTD_inBuffer input;
  121. ZSTD_outBuffer output;
  122. Py_ssize_t totalRead = 0;
  123. Py_ssize_t totalWrite = 0;
  124. char *readBuffer;
  125. Py_ssize_t readSize;
  126. PyObject *readResult = NULL;
  127. PyObject *res = NULL;
  128. size_t zresult = 0;
  129. PyObject *writeResult;
  130. PyObject *totalReadPy;
  131. PyObject *totalWritePy;
  132. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|kk:copy_stream", kwlist,
  133. &source, &dest, &inSize, &outSize)) {
  134. return NULL;
  135. }
  136. if (!PyObject_HasAttrString(source, "read")) {
  137. PyErr_SetString(PyExc_ValueError,
  138. "first argument must have a read() method");
  139. return NULL;
  140. }
  141. if (!PyObject_HasAttrString(dest, "write")) {
  142. PyErr_SetString(PyExc_ValueError,
  143. "second argument must have a write() method");
  144. return NULL;
  145. }
  146. /* Prevent free on uninitialized memory in finally. */
  147. output.dst = NULL;
  148. if (ensure_dctx(self, 1)) {
  149. res = NULL;
  150. goto finally;
  151. }
  152. output.dst = PyMem_Malloc(outSize);
  153. if (!output.dst) {
  154. PyErr_NoMemory();
  155. res = NULL;
  156. goto finally;
  157. }
  158. output.size = outSize;
  159. output.pos = 0;
  160. /* Read source stream until EOF */
  161. while (1) {
  162. readResult = PyObject_CallMethod(source, "read", "n", inSize);
  163. if (!readResult) {
  164. goto finally;
  165. }
  166. PyBytes_AsStringAndSize(readResult, &readBuffer, &readSize);
  167. /* If no data was read, we're at EOF. */
  168. if (0 == readSize) {
  169. break;
  170. }
  171. totalRead += readSize;
  172. /* Send data to decompressor */
  173. input.src = readBuffer;
  174. input.size = readSize;
  175. input.pos = 0;
  176. while (input.pos < input.size) {
  177. Py_BEGIN_ALLOW_THREADS zresult =
  178. ZSTD_decompressStream(self->dctx, &output, &input);
  179. Py_END_ALLOW_THREADS
  180. if (ZSTD_isError(zresult)) {
  181. PyErr_Format(ZstdError, "zstd decompressor error: %s",
  182. ZSTD_getErrorName(zresult));
  183. res = NULL;
  184. goto finally;
  185. }
  186. if (output.pos) {
  187. writeResult = PyObject_CallMethod(dest, "write", "y#",
  188. output.dst, output.pos);
  189. if (NULL == writeResult) {
  190. res = NULL;
  191. goto finally;
  192. }
  193. Py_XDECREF(writeResult);
  194. totalWrite += output.pos;
  195. output.pos = 0;
  196. }
  197. }
  198. Py_CLEAR(readResult);
  199. }
  200. /* Source stream is exhausted. Finish up. */
  201. totalReadPy = PyLong_FromSsize_t(totalRead);
  202. totalWritePy = PyLong_FromSsize_t(totalWrite);
  203. res = PyTuple_Pack(2, totalReadPy, totalWritePy);
  204. Py_DECREF(totalReadPy);
  205. Py_DECREF(totalWritePy);
  206. finally:
  207. if (output.dst) {
  208. PyMem_Free(output.dst);
  209. }
  210. Py_XDECREF(readResult);
  211. return res;
  212. }
  213. PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args,
  214. PyObject *kwargs) {
  215. static char *kwlist[] = {
  216. "data",
  217. "max_output_size",
  218. "read_across_frames",
  219. "allow_extra_data",
  220. NULL
  221. };
  222. Py_buffer source;
  223. Py_ssize_t maxOutputSize = 0;
  224. unsigned long long decompressedSize;
  225. PyObject *readAcrossFrames = NULL;
  226. PyObject *allowExtraData = NULL;
  227. size_t destCapacity;
  228. PyObject *result = NULL;
  229. size_t zresult;
  230. ZSTD_outBuffer outBuffer;
  231. ZSTD_inBuffer inBuffer;
  232. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|nOO:decompress", kwlist,
  233. &source, &maxOutputSize, &readAcrossFrames,
  234. &allowExtraData)) {
  235. return NULL;
  236. }
  237. if (readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0) {
  238. PyErr_SetString(ZstdError,
  239. "ZstdDecompressor.read_across_frames=True is not yet implemented"
  240. );
  241. goto finally;
  242. }
  243. if (ensure_dctx(self, 1)) {
  244. goto finally;
  245. }
  246. decompressedSize = ZSTD_getFrameContentSize(source.buf, source.len);
  247. if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) {
  248. PyErr_SetString(ZstdError,
  249. "error determining content size from frame header");
  250. goto finally;
  251. }
  252. /* Special case of empty frame. */
  253. else if (0 == decompressedSize) {
  254. result = PyBytes_FromStringAndSize("", 0);
  255. goto finally;
  256. }
  257. /* Missing content size in frame header. */
  258. if (ZSTD_CONTENTSIZE_UNKNOWN == decompressedSize) {
  259. if (0 == maxOutputSize) {
  260. PyErr_SetString(ZstdError,
  261. "could not determine content size in frame header");
  262. goto finally;
  263. }
  264. result = PyBytes_FromStringAndSize(NULL, maxOutputSize);
  265. destCapacity = maxOutputSize;
  266. decompressedSize = 0;
  267. }
  268. /* Size is recorded in frame header. */
  269. else {
  270. assert(SIZE_MAX >= PY_SSIZE_T_MAX);
  271. if (decompressedSize > PY_SSIZE_T_MAX) {
  272. PyErr_SetString(
  273. ZstdError, "frame is too large to decompress on this platform");
  274. goto finally;
  275. }
  276. result = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)decompressedSize);
  277. destCapacity = (size_t)decompressedSize;
  278. }
  279. if (!result) {
  280. goto finally;
  281. }
  282. outBuffer.dst = PyBytes_AsString(result);
  283. outBuffer.size = destCapacity;
  284. outBuffer.pos = 0;
  285. inBuffer.src = source.buf;
  286. inBuffer.size = source.len;
  287. inBuffer.pos = 0;
  288. Py_BEGIN_ALLOW_THREADS zresult =
  289. ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
  290. Py_END_ALLOW_THREADS
  291. if (ZSTD_isError(zresult)) {
  292. PyErr_Format(ZstdError, "decompression error: %s",
  293. ZSTD_getErrorName(zresult));
  294. Py_CLEAR(result);
  295. goto finally;
  296. }
  297. else if (zresult) {
  298. PyErr_Format(ZstdError,
  299. "decompression error: did not decompress full frame");
  300. Py_CLEAR(result);
  301. goto finally;
  302. }
  303. else if (decompressedSize && outBuffer.pos != decompressedSize) {
  304. PyErr_Format(
  305. ZstdError,
  306. "decompression error: decompressed %zu bytes; expected %llu",
  307. zresult, decompressedSize);
  308. Py_CLEAR(result);
  309. goto finally;
  310. }
  311. else if (outBuffer.pos < destCapacity) {
  312. if (safe_pybytes_resize(&result, outBuffer.pos)) {
  313. Py_CLEAR(result);
  314. goto finally;
  315. }
  316. }
  317. else if ((allowExtraData ? PyObject_IsTrue(allowExtraData) : 1) == 0
  318. && inBuffer.pos < inBuffer.size) {
  319. PyErr_Format(
  320. ZstdError,
  321. "compressed input contains %zu bytes of unused data, which is disallowed",
  322. inBuffer.size - inBuffer.pos
  323. );
  324. Py_CLEAR(result);
  325. goto finally;
  326. }
  327. finally:
  328. PyBuffer_Release(&source);
  329. return result;
  330. }
  331. static ZstdDecompressionObj *Decompressor_decompressobj(ZstdDecompressor *self,
  332. PyObject *args,
  333. PyObject *kwargs) {
  334. static char *kwlist[] = {"write_size", "read_across_frames", NULL};
  335. ZstdDecompressionObj *result = NULL;
  336. size_t outSize = ZSTD_DStreamOutSize();
  337. PyObject *readAcrossFrames = NULL;
  338. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|kO:decompressobj", kwlist,
  339. &outSize, &readAcrossFrames)) {
  340. return NULL;
  341. }
  342. if (!outSize) {
  343. PyErr_SetString(PyExc_ValueError, "write_size must be positive");
  344. return NULL;
  345. }
  346. result = (ZstdDecompressionObj *)PyObject_CallObject(
  347. (PyObject *)ZstdDecompressionObjType, NULL);
  348. if (!result) {
  349. return NULL;
  350. }
  351. if (ensure_dctx(self, 1)) {
  352. Py_DECREF(result);
  353. return NULL;
  354. }
  355. result->decompressor = self;
  356. Py_INCREF(result->decompressor);
  357. result->outSize = outSize;
  358. result->readAcrossFrames =
  359. readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0;
  360. return result;
  361. }
  362. static ZstdDecompressorIterator *
  363. Decompressor_read_to_iter(ZstdDecompressor *self, PyObject *args,
  364. PyObject *kwargs) {
  365. static char *kwlist[] = {"reader", "read_size", "write_size", "skip_bytes",
  366. NULL};
  367. PyObject *reader;
  368. size_t inSize = ZSTD_DStreamInSize();
  369. size_t outSize = ZSTD_DStreamOutSize();
  370. ZstdDecompressorIterator *result;
  371. size_t skipBytes = 0;
  372. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kkk:read_to_iter", kwlist,
  373. &reader, &inSize, &outSize, &skipBytes)) {
  374. return NULL;
  375. }
  376. if (skipBytes >= inSize) {
  377. PyErr_SetString(PyExc_ValueError,
  378. "skip_bytes must be smaller than read_size");
  379. return NULL;
  380. }
  381. result = (ZstdDecompressorIterator *)PyObject_CallObject(
  382. (PyObject *)ZstdDecompressorIteratorType, NULL);
  383. if (!result) {
  384. return NULL;
  385. }
  386. if (PyObject_HasAttrString(reader, "read")) {
  387. result->reader = reader;
  388. Py_INCREF(result->reader);
  389. }
  390. else if (1 == PyObject_CheckBuffer(reader)) {
  391. /* Object claims it is a buffer. Try to get a handle to it. */
  392. if (0 != PyObject_GetBuffer(reader, &result->buffer, PyBUF_CONTIG_RO)) {
  393. goto except;
  394. }
  395. }
  396. else {
  397. PyErr_SetString(PyExc_ValueError,
  398. "must pass an object with a read() method or conforms "
  399. "to buffer protocol");
  400. goto except;
  401. }
  402. result->decompressor = self;
  403. Py_INCREF(result->decompressor);
  404. result->inSize = inSize;
  405. result->outSize = outSize;
  406. result->skipBytes = skipBytes;
  407. if (ensure_dctx(self, 1)) {
  408. goto except;
  409. }
  410. result->input.src = PyMem_Malloc(inSize);
  411. if (!result->input.src) {
  412. PyErr_NoMemory();
  413. goto except;
  414. }
  415. goto finally;
  416. except:
  417. Py_CLEAR(result);
  418. finally:
  419. return result;
  420. }
  421. static ZstdDecompressionReader *
  422. Decompressor_stream_reader(ZstdDecompressor *self, PyObject *args,
  423. PyObject *kwargs) {
  424. static char *kwlist[] = {"source", "read_size", "read_across_frames",
  425. "closefd", NULL};
  426. PyObject *source;
  427. size_t readSize = ZSTD_DStreamInSize();
  428. PyObject *readAcrossFrames = NULL;
  429. PyObject *closefd = NULL;
  430. ZstdDecompressionReader *result;
  431. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kOO:stream_reader",
  432. kwlist, &source, &readSize,
  433. &readAcrossFrames, &closefd)) {
  434. return NULL;
  435. }
  436. if (ensure_dctx(self, 1)) {
  437. return NULL;
  438. }
  439. result = (ZstdDecompressionReader *)PyObject_CallObject(
  440. (PyObject *)ZstdDecompressionReaderType, NULL);
  441. if (NULL == result) {
  442. return NULL;
  443. }
  444. result->entered = 0;
  445. result->closed = 0;
  446. if (PyObject_HasAttrString(source, "read")) {
  447. result->reader = source;
  448. Py_INCREF(source);
  449. result->readSize = readSize;
  450. }
  451. else if (1 == PyObject_CheckBuffer(source)) {
  452. if (0 != PyObject_GetBuffer(source, &result->buffer, PyBUF_CONTIG_RO)) {
  453. Py_CLEAR(result);
  454. return NULL;
  455. }
  456. }
  457. else {
  458. PyErr_SetString(PyExc_TypeError,
  459. "must pass an object with a read() method or that "
  460. "conforms to the buffer protocol");
  461. Py_CLEAR(result);
  462. return NULL;
  463. }
  464. result->decompressor = self;
  465. Py_INCREF(self);
  466. result->readAcrossFrames =
  467. readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0;
  468. result->closefd = closefd ? PyObject_IsTrue(closefd) : 1;
  469. return result;
  470. }
  471. static ZstdDecompressionWriter *
  472. Decompressor_stream_writer(ZstdDecompressor *self, PyObject *args,
  473. PyObject *kwargs) {
  474. static char *kwlist[] = {"writer", "write_size", "write_return_read",
  475. "closefd", NULL};
  476. PyObject *writer;
  477. size_t outSize = ZSTD_DStreamOutSize();
  478. PyObject *writeReturnRead = NULL;
  479. PyObject *closefd = NULL;
  480. ZstdDecompressionWriter *result;
  481. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kOO:stream_writer",
  482. kwlist, &writer, &outSize,
  483. &writeReturnRead, &closefd)) {
  484. return NULL;
  485. }
  486. if (!PyObject_HasAttrString(writer, "write")) {
  487. PyErr_SetString(PyExc_ValueError,
  488. "must pass an object with a write() method");
  489. return NULL;
  490. }
  491. if (ensure_dctx(self, 1)) {
  492. return NULL;
  493. }
  494. result = (ZstdDecompressionWriter *)PyObject_CallObject(
  495. (PyObject *)ZstdDecompressionWriterType, NULL);
  496. if (!result) {
  497. return NULL;
  498. }
  499. result->entered = 0;
  500. result->closing = 0;
  501. result->closed = 0;
  502. result->decompressor = self;
  503. Py_INCREF(result->decompressor);
  504. result->writer = writer;
  505. Py_INCREF(result->writer);
  506. result->outSize = outSize;
  507. result->writeReturnRead =
  508. writeReturnRead ? PyObject_IsTrue(writeReturnRead) : 1;
  509. result->closefd = closefd ? PyObject_IsTrue(closefd) : 1;
  510. return result;
  511. }
  512. static PyObject *
  513. Decompressor_decompress_content_dict_chain(ZstdDecompressor *self,
  514. PyObject *args, PyObject *kwargs) {
  515. static char *kwlist[] = {"frames", NULL};
  516. PyObject *chunks;
  517. Py_ssize_t chunksLen;
  518. Py_ssize_t chunkIndex;
  519. char parity = 0;
  520. PyObject *chunk;
  521. char *chunkData;
  522. Py_ssize_t chunkSize;
  523. size_t zresult;
  524. ZSTD_frameHeader frameHeader;
  525. void *buffer1 = NULL;
  526. size_t buffer1Size = 0;
  527. size_t buffer1ContentSize = 0;
  528. void *buffer2 = NULL;
  529. size_t buffer2Size = 0;
  530. size_t buffer2ContentSize = 0;
  531. void *destBuffer = NULL;
  532. PyObject *result = NULL;
  533. ZSTD_outBuffer outBuffer;
  534. ZSTD_inBuffer inBuffer;
  535. if (!PyArg_ParseTupleAndKeywords(args, kwargs,
  536. "O!:decompress_content_dict_chain", kwlist,
  537. &PyList_Type, &chunks)) {
  538. return NULL;
  539. }
  540. chunksLen = PyList_Size(chunks);
  541. if (!chunksLen) {
  542. PyErr_SetString(PyExc_ValueError, "empty input chain");
  543. return NULL;
  544. }
  545. /* The first chunk should not be using a dictionary. We handle it specially.
  546. */
  547. chunk = PyList_GetItem(chunks, 0);
  548. if (!PyBytes_Check(chunk)) {
  549. PyErr_SetString(PyExc_ValueError, "chunk 0 must be bytes");
  550. return NULL;
  551. }
  552. /* We require that all chunks be zstd frames and that they have content size
  553. * set. */
  554. PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize);
  555. zresult = ZSTD_getFrameHeader(&frameHeader, (void *)chunkData, chunkSize);
  556. if (ZSTD_isError(zresult)) {
  557. PyErr_SetString(PyExc_ValueError, "chunk 0 is not a valid zstd frame");
  558. return NULL;
  559. }
  560. else if (zresult) {
  561. PyErr_SetString(PyExc_ValueError,
  562. "chunk 0 is too small to contain a zstd frame");
  563. return NULL;
  564. }
  565. if (ZSTD_CONTENTSIZE_UNKNOWN == frameHeader.frameContentSize) {
  566. PyErr_SetString(PyExc_ValueError,
  567. "chunk 0 missing content size in frame");
  568. return NULL;
  569. }
  570. assert(ZSTD_CONTENTSIZE_ERROR != frameHeader.frameContentSize);
  571. /* We check against PY_SSIZE_T_MAX here because we ultimately cast the
  572. * result to a Python object and it's length can be no greater than
  573. * Py_ssize_t. In theory, we could have an intermediate frame that is
  574. * larger. But a) why would this API be used for frames that large b)
  575. * it isn't worth the complexity to support. */
  576. assert(SIZE_MAX >= PY_SSIZE_T_MAX);
  577. if (frameHeader.frameContentSize > PY_SSIZE_T_MAX) {
  578. PyErr_SetString(PyExc_ValueError,
  579. "chunk 0 is too large to decompress on this platform");
  580. return NULL;
  581. }
  582. if (ensure_dctx(self, 0)) {
  583. goto finally;
  584. }
  585. buffer1Size = (size_t)frameHeader.frameContentSize;
  586. buffer1 = PyMem_Malloc(buffer1Size);
  587. if (!buffer1) {
  588. goto finally;
  589. }
  590. outBuffer.dst = buffer1;
  591. outBuffer.size = buffer1Size;
  592. outBuffer.pos = 0;
  593. inBuffer.src = chunkData;
  594. inBuffer.size = chunkSize;
  595. inBuffer.pos = 0;
  596. Py_BEGIN_ALLOW_THREADS zresult =
  597. ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
  598. Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
  599. PyErr_Format(ZstdError, "could not decompress chunk 0: %s",
  600. ZSTD_getErrorName(zresult));
  601. goto finally;
  602. }
  603. else if (zresult) {
  604. PyErr_Format(ZstdError, "chunk 0 did not decompress full frame");
  605. goto finally;
  606. }
  607. buffer1ContentSize = outBuffer.pos;
  608. /* Special case of a simple chain. */
  609. if (1 == chunksLen) {
  610. result = PyBytes_FromStringAndSize(buffer1, buffer1Size);
  611. goto finally;
  612. }
  613. /* This should ideally look at next chunk. But this is slightly simpler. */
  614. buffer2Size = (size_t)frameHeader.frameContentSize;
  615. buffer2 = PyMem_Malloc(buffer2Size);
  616. if (!buffer2) {
  617. goto finally;
  618. }
  619. /* For each subsequent chunk, use the previous fulltext as a content
  620. dictionary. Our strategy is to have 2 buffers. One holds the previous
  621. fulltext (to be used as a content dictionary) and the other holds the new
  622. fulltext. The buffers grow when needed but never decrease in size. This
  623. limits the memory allocator overhead.
  624. */
  625. for (chunkIndex = 1; chunkIndex < chunksLen; chunkIndex++) {
  626. chunk = PyList_GetItem(chunks, chunkIndex);
  627. if (!PyBytes_Check(chunk)) {
  628. PyErr_Format(PyExc_ValueError, "chunk %zd must be bytes",
  629. chunkIndex);
  630. goto finally;
  631. }
  632. PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize);
  633. zresult =
  634. ZSTD_getFrameHeader(&frameHeader, (void *)chunkData, chunkSize);
  635. if (ZSTD_isError(zresult)) {
  636. PyErr_Format(PyExc_ValueError,
  637. "chunk %zd is not a valid zstd frame", chunkIndex);
  638. goto finally;
  639. }
  640. else if (zresult) {
  641. PyErr_Format(PyExc_ValueError,
  642. "chunk %zd is too small to contain a zstd frame",
  643. chunkIndex);
  644. goto finally;
  645. }
  646. if (ZSTD_CONTENTSIZE_UNKNOWN == frameHeader.frameContentSize) {
  647. PyErr_Format(PyExc_ValueError,
  648. "chunk %zd missing content size in frame", chunkIndex);
  649. goto finally;
  650. }
  651. assert(ZSTD_CONTENTSIZE_ERROR != frameHeader.frameContentSize);
  652. if (frameHeader.frameContentSize > PY_SSIZE_T_MAX) {
  653. PyErr_Format(
  654. PyExc_ValueError,
  655. "chunk %zd is too large to decompress on this platform",
  656. chunkIndex);
  657. goto finally;
  658. }
  659. inBuffer.src = chunkData;
  660. inBuffer.size = chunkSize;
  661. inBuffer.pos = 0;
  662. parity = chunkIndex % 2;
  663. /* This could definitely be abstracted to reduce code duplication. */
  664. if (parity) {
  665. /* Resize destination buffer to hold larger content. */
  666. if (buffer2Size < frameHeader.frameContentSize) {
  667. buffer2Size = (size_t)frameHeader.frameContentSize;
  668. destBuffer = PyMem_Realloc(buffer2, buffer2Size);
  669. if (!destBuffer) {
  670. goto finally;
  671. }
  672. buffer2 = destBuffer;
  673. }
  674. Py_BEGIN_ALLOW_THREADS zresult = ZSTD_DCtx_refPrefix_advanced(
  675. self->dctx, buffer1, buffer1ContentSize, ZSTD_dct_rawContent);
  676. Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
  677. PyErr_Format(ZstdError,
  678. "failed to load prefix dictionary at chunk %zd",
  679. chunkIndex);
  680. goto finally;
  681. }
  682. outBuffer.dst = buffer2;
  683. outBuffer.size = buffer2Size;
  684. outBuffer.pos = 0;
  685. Py_BEGIN_ALLOW_THREADS zresult =
  686. ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
  687. Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
  688. PyErr_Format(ZstdError, "could not decompress chunk %zd: %s",
  689. chunkIndex, ZSTD_getErrorName(zresult));
  690. goto finally;
  691. }
  692. else if (zresult) {
  693. PyErr_Format(ZstdError,
  694. "chunk %zd did not decompress full frame",
  695. chunkIndex);
  696. goto finally;
  697. }
  698. buffer2ContentSize = outBuffer.pos;
  699. }
  700. else {
  701. if (buffer1Size < frameHeader.frameContentSize) {
  702. buffer1Size = (size_t)frameHeader.frameContentSize;
  703. destBuffer = PyMem_Realloc(buffer1, buffer1Size);
  704. if (!destBuffer) {
  705. goto finally;
  706. }
  707. buffer1 = destBuffer;
  708. }
  709. Py_BEGIN_ALLOW_THREADS zresult = ZSTD_DCtx_refPrefix_advanced(
  710. self->dctx, buffer2, buffer2ContentSize, ZSTD_dct_rawContent);
  711. Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
  712. PyErr_Format(ZstdError,
  713. "failed to load prefix dictionary at chunk %zd",
  714. chunkIndex);
  715. goto finally;
  716. }
  717. outBuffer.dst = buffer1;
  718. outBuffer.size = buffer1Size;
  719. outBuffer.pos = 0;
  720. Py_BEGIN_ALLOW_THREADS zresult =
  721. ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
  722. Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
  723. PyErr_Format(ZstdError, "could not decompress chunk %zd: %s",
  724. chunkIndex, ZSTD_getErrorName(zresult));
  725. goto finally;
  726. }
  727. else if (zresult) {
  728. PyErr_Format(ZstdError,
  729. "chunk %zd did not decompress full frame",
  730. chunkIndex);
  731. goto finally;
  732. }
  733. buffer1ContentSize = outBuffer.pos;
  734. }
  735. }
  736. result = PyBytes_FromStringAndSize(parity ? buffer2 : buffer1,
  737. parity ? buffer2ContentSize
  738. : buffer1ContentSize);
  739. finally:
  740. if (buffer2) {
  741. PyMem_Free(buffer2);
  742. }
  743. if (buffer1) {
  744. PyMem_Free(buffer1);
  745. }
  746. return result;
  747. }
  748. typedef struct {
  749. void *sourceData;
  750. size_t sourceSize;
  751. size_t destSize;
  752. } FramePointer;
  753. typedef struct {
  754. FramePointer *frames;
  755. Py_ssize_t framesSize;
  756. unsigned long long compressedSize;
  757. } FrameSources;
  758. typedef struct {
  759. void *dest;
  760. Py_ssize_t destSize;
  761. BufferSegment *segments;
  762. Py_ssize_t segmentsSize;
  763. } DecompressorDestBuffer;
  764. typedef enum {
  765. DecompressorWorkerError_none = 0,
  766. DecompressorWorkerError_zstd = 1,
  767. DecompressorWorkerError_memory = 2,
  768. DecompressorWorkerError_sizeMismatch = 3,
  769. DecompressorWorkerError_unknownSize = 4,
  770. } DecompressorWorkerError;
  771. typedef struct {
  772. /* Source records and length */
  773. FramePointer *framePointers;
  774. /* Which records to process. */
  775. Py_ssize_t startOffset;
  776. Py_ssize_t endOffset;
  777. unsigned long long totalSourceSize;
  778. /* Compression state and settings. */
  779. ZSTD_DCtx *dctx;
  780. int requireOutputSizes;
  781. /* Output storage. */
  782. DecompressorDestBuffer *destBuffers;
  783. Py_ssize_t destCount;
  784. /* Item that error occurred on. */
  785. Py_ssize_t errorOffset;
  786. /* If an error occurred. */
  787. DecompressorWorkerError error;
  788. /* result from zstd decompression operation */
  789. size_t zresult;
  790. } DecompressorWorkerState;
  791. #ifdef HAVE_ZSTD_POOL_APIS
  792. static void decompress_worker(DecompressorWorkerState *state) {
  793. size_t allocationSize;
  794. DecompressorDestBuffer *destBuffer;
  795. Py_ssize_t frameIndex;
  796. Py_ssize_t localOffset = 0;
  797. Py_ssize_t currentBufferStartIndex = state->startOffset;
  798. Py_ssize_t remainingItems = state->endOffset - state->startOffset + 1;
  799. void *tmpBuf;
  800. Py_ssize_t destOffset = 0;
  801. FramePointer *framePointers = state->framePointers;
  802. size_t zresult;
  803. assert(NULL == state->destBuffers);
  804. assert(0 == state->destCount);
  805. assert(state->endOffset - state->startOffset >= 0);
  806. /* We could get here due to the way work is allocated. Ideally we wouldn't
  807. get here. But that would require a bit of a refactor in the caller. */
  808. if (state->totalSourceSize > SIZE_MAX) {
  809. state->error = DecompressorWorkerError_memory;
  810. state->errorOffset = 0;
  811. return;
  812. }
  813. /*
  814. * We need to allocate a buffer to hold decompressed data. How we do this
  815. * depends on what we know about the output. The following scenarios are
  816. * possible:
  817. *
  818. * 1. All structs defining frames declare the output size.
  819. * 2. The decompressed size is embedded within the zstd frame.
  820. * 3. The decompressed size is not stored anywhere.
  821. *
  822. * For now, we only support #1 and #2.
  823. */
  824. /* Resolve ouput segments. */
  825. for (frameIndex = state->startOffset; frameIndex <= state->endOffset;
  826. frameIndex++) {
  827. FramePointer *fp = &framePointers[frameIndex];
  828. unsigned long long decompressedSize;
  829. if (0 == fp->destSize) {
  830. decompressedSize =
  831. ZSTD_getFrameContentSize(fp->sourceData, fp->sourceSize);
  832. if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) {
  833. state->error = DecompressorWorkerError_unknownSize;
  834. state->errorOffset = frameIndex;
  835. return;
  836. }
  837. else if (ZSTD_CONTENTSIZE_UNKNOWN == decompressedSize) {
  838. if (state->requireOutputSizes) {
  839. state->error = DecompressorWorkerError_unknownSize;
  840. state->errorOffset = frameIndex;
  841. return;
  842. }
  843. /* This will fail the assert for .destSize > 0 below. */
  844. decompressedSize = 0;
  845. }
  846. if (decompressedSize > SIZE_MAX) {
  847. state->error = DecompressorWorkerError_memory;
  848. state->errorOffset = frameIndex;
  849. return;
  850. }
  851. fp->destSize = (size_t)decompressedSize;
  852. }
  853. }
  854. state->destBuffers = calloc(1, sizeof(DecompressorDestBuffer));
  855. if (NULL == state->destBuffers) {
  856. state->error = DecompressorWorkerError_memory;
  857. return;
  858. }
  859. state->destCount = 1;
  860. destBuffer = &state->destBuffers[state->destCount - 1];
  861. assert(framePointers[state->startOffset].destSize > 0); /* For now. */
  862. allocationSize = roundpow2((size_t)state->totalSourceSize);
  863. if (framePointers[state->startOffset].destSize > allocationSize) {
  864. allocationSize = roundpow2(framePointers[state->startOffset].destSize);
  865. }
  866. destBuffer->dest = malloc(allocationSize);
  867. if (NULL == destBuffer->dest) {
  868. state->error = DecompressorWorkerError_memory;
  869. return;
  870. }
  871. destBuffer->destSize = allocationSize;
  872. destBuffer->segments = calloc(remainingItems, sizeof(BufferSegment));
  873. if (NULL == destBuffer->segments) {
  874. /* Caller will free state->dest as part of cleanup. */
  875. state->error = DecompressorWorkerError_memory;
  876. return;
  877. }
  878. destBuffer->segmentsSize = remainingItems;
  879. for (frameIndex = state->startOffset; frameIndex <= state->endOffset;
  880. frameIndex++) {
  881. ZSTD_outBuffer outBuffer;
  882. ZSTD_inBuffer inBuffer;
  883. const void *source = framePointers[frameIndex].sourceData;
  884. const size_t sourceSize = framePointers[frameIndex].sourceSize;
  885. void *dest;
  886. const size_t decompressedSize = framePointers[frameIndex].destSize;
  887. size_t destAvailable = destBuffer->destSize - destOffset;
  888. assert(decompressedSize > 0); /* For now. */
  889. /*
  890. * Not enough space in current buffer. Finish current before and
  891. * allocate and switch to a new one.
  892. */
  893. if (decompressedSize > destAvailable) {
  894. /*
  895. * Shrinking the destination buffer is optional. But it should be
  896. * cheap, so we just do it.
  897. */
  898. if (destAvailable) {
  899. tmpBuf = realloc(destBuffer->dest, destOffset);
  900. if (NULL == tmpBuf) {
  901. state->error = DecompressorWorkerError_memory;
  902. return;
  903. }
  904. destBuffer->dest = tmpBuf;
  905. destBuffer->destSize = destOffset;
  906. }
  907. /* Truncate segments buffer. */
  908. tmpBuf = realloc(destBuffer->segments,
  909. (frameIndex - currentBufferStartIndex) *
  910. sizeof(BufferSegment));
  911. if (NULL == tmpBuf) {
  912. state->error = DecompressorWorkerError_memory;
  913. return;
  914. }
  915. destBuffer->segments = tmpBuf;
  916. destBuffer->segmentsSize = frameIndex - currentBufferStartIndex;
  917. /* Grow space for new DestBuffer. */
  918. tmpBuf =
  919. realloc(state->destBuffers, (state->destCount + 1) *
  920. sizeof(DecompressorDestBuffer));
  921. if (NULL == tmpBuf) {
  922. state->error = DecompressorWorkerError_memory;
  923. return;
  924. }
  925. state->destBuffers = tmpBuf;
  926. state->destCount++;
  927. destBuffer = &state->destBuffers[state->destCount - 1];
  928. /* Don't take any chances will non-NULL pointers. */
  929. memset(destBuffer, 0, sizeof(DecompressorDestBuffer));
  930. allocationSize = roundpow2((size_t)state->totalSourceSize);
  931. if (decompressedSize > allocationSize) {
  932. allocationSize = roundpow2(decompressedSize);
  933. }
  934. destBuffer->dest = malloc(allocationSize);
  935. if (NULL == destBuffer->dest) {
  936. state->error = DecompressorWorkerError_memory;
  937. return;
  938. }
  939. destBuffer->destSize = allocationSize;
  940. destAvailable = allocationSize;
  941. destOffset = 0;
  942. localOffset = 0;
  943. destBuffer->segments =
  944. calloc(remainingItems, sizeof(BufferSegment));
  945. if (NULL == destBuffer->segments) {
  946. state->error = DecompressorWorkerError_memory;
  947. return;
  948. }
  949. destBuffer->segmentsSize = remainingItems;
  950. currentBufferStartIndex = frameIndex;
  951. }
  952. dest = (char *)destBuffer->dest + destOffset;
  953. outBuffer.dst = dest;
  954. outBuffer.size = decompressedSize;
  955. outBuffer.pos = 0;
  956. inBuffer.src = source;
  957. inBuffer.size = sourceSize;
  958. inBuffer.pos = 0;
  959. zresult = ZSTD_decompressStream(state->dctx, &outBuffer, &inBuffer);
  960. if (ZSTD_isError(zresult)) {
  961. state->error = DecompressorWorkerError_zstd;
  962. state->zresult = zresult;
  963. state->errorOffset = frameIndex;
  964. return;
  965. }
  966. else if (zresult || outBuffer.pos != decompressedSize) {
  967. state->error = DecompressorWorkerError_sizeMismatch;
  968. state->zresult = outBuffer.pos;
  969. state->errorOffset = frameIndex;
  970. return;
  971. }
  972. destBuffer->segments[localOffset].offset = destOffset;
  973. destBuffer->segments[localOffset].length = outBuffer.pos;
  974. destOffset += outBuffer.pos;
  975. localOffset++;
  976. remainingItems--;
  977. }
  978. if (destBuffer->destSize > destOffset) {
  979. tmpBuf = realloc(destBuffer->dest, destOffset);
  980. if (NULL == tmpBuf) {
  981. state->error = DecompressorWorkerError_memory;
  982. return;
  983. }
  984. destBuffer->dest = tmpBuf;
  985. destBuffer->destSize = destOffset;
  986. }
  987. }
  988. #endif
  989. #ifdef HAVE_ZSTD_POOL_APIS
  990. ZstdBufferWithSegmentsCollection *
  991. decompress_from_framesources(ZstdDecompressor *decompressor,
  992. FrameSources *frames, Py_ssize_t threadCount) {
  993. Py_ssize_t i = 0;
  994. int errored = 0;
  995. Py_ssize_t segmentsCount;
  996. ZstdBufferWithSegments *bws = NULL;
  997. PyObject *resultArg = NULL;
  998. Py_ssize_t resultIndex;
  999. ZstdBufferWithSegmentsCollection *result = NULL;
  1000. FramePointer *framePointers = frames->frames;
  1001. unsigned long long workerBytes = 0;
  1002. Py_ssize_t currentThread = 0;
  1003. Py_ssize_t workerStartOffset = 0;
  1004. POOL_ctx *pool = NULL;
  1005. DecompressorWorkerState *workerStates = NULL;
  1006. unsigned long long bytesPerWorker;
  1007. /* Caller should normalize 0 and negative values to 1 or larger. */
  1008. assert(threadCount >= 1);
  1009. /* More threads than inputs makes no sense under any conditions. */
  1010. threadCount =
  1011. frames->framesSize < threadCount ? frames->framesSize : threadCount;
  1012. /* TODO lower thread count if input size is too small and threads would just
  1013. add overhead. */
  1014. if (decompressor->dict) {
  1015. if (ensure_ddict(decompressor->dict)) {
  1016. return NULL;
  1017. }
  1018. }
  1019. /* If threadCount==1, we don't start a thread pool. But we do leverage the
  1020. same API for dispatching work. */
  1021. workerStates = PyMem_Malloc(threadCount * sizeof(DecompressorWorkerState));
  1022. if (NULL == workerStates) {
  1023. PyErr_NoMemory();
  1024. goto finally;
  1025. }
  1026. memset(workerStates, 0, threadCount * sizeof(DecompressorWorkerState));
  1027. if (threadCount > 1) {
  1028. pool = POOL_create(threadCount, 1);
  1029. if (NULL == pool) {
  1030. PyErr_SetString(ZstdError, "could not initialize zstd thread pool");
  1031. goto finally;
  1032. }
  1033. }
  1034. bytesPerWorker = frames->compressedSize / threadCount;
  1035. if (bytesPerWorker > SIZE_MAX) {
  1036. PyErr_SetString(ZstdError,
  1037. "too much data per worker for this platform");
  1038. goto finally;
  1039. }
  1040. for (i = 0; i < threadCount; i++) {
  1041. size_t zresult;
  1042. workerStates[i].dctx = ZSTD_createDCtx();
  1043. if (NULL == workerStates[i].dctx) {
  1044. PyErr_NoMemory();
  1045. goto finally;
  1046. }
  1047. if (decompressor->dict) {
  1048. zresult = ZSTD_DCtx_refDDict(workerStates[i].dctx,
  1049. decompressor->dict->ddict);
  1050. if (zresult) {
  1051. PyErr_Format(ZstdError,
  1052. "unable to reference prepared dictionary: %s",
  1053. ZSTD_getErrorName(zresult));
  1054. goto finally;
  1055. }
  1056. }
  1057. workerStates[i].framePointers = framePointers;
  1058. workerStates[i].requireOutputSizes = 1;
  1059. }
  1060. Py_BEGIN_ALLOW_THREADS
  1061. /* There are many ways to split work among workers.
  1062. For now, we take a simple approach of splitting work so each worker
  1063. gets roughly the same number of input bytes. This will result in more
  1064. starvation than running N>threadCount jobs. But it avoids
  1065. complications around state tracking, which could involve extra
  1066. locking.
  1067. */
  1068. for (i = 0; i < frames->framesSize; i++) {
  1069. workerBytes += frames->frames[i].sourceSize;
  1070. /*
  1071. * The last worker/thread needs to handle all remaining work. Don't
  1072. * trigger it prematurely. Defer to the block outside of the loop.
  1073. * (But still process this loop so workerBytes is correct.
  1074. */
  1075. if (currentThread == threadCount - 1) {
  1076. continue;
  1077. }
  1078. if (workerBytes >= bytesPerWorker) {
  1079. workerStates[currentThread].startOffset = workerStartOffset;
  1080. workerStates[currentThread].endOffset = i;
  1081. workerStates[currentThread].totalSourceSize = workerBytes;
  1082. if (threadCount > 1) {
  1083. POOL_add(pool, (POOL_function)decompress_worker,
  1084. &workerStates[currentThread]);
  1085. }
  1086. else {
  1087. decompress_worker(&workerStates[currentThread]);
  1088. }
  1089. currentThread++;
  1090. workerStartOffset = i + 1;
  1091. workerBytes = 0;
  1092. }
  1093. }
  1094. if (workerBytes) {
  1095. workerStates[currentThread].startOffset = workerStartOffset;
  1096. workerStates[currentThread].endOffset = frames->framesSize - 1;
  1097. workerStates[currentThread].totalSourceSize = workerBytes;
  1098. if (threadCount > 1) {
  1099. POOL_add(pool, (POOL_function)decompress_worker,
  1100. &workerStates[currentThread]);
  1101. }
  1102. else {
  1103. decompress_worker(&workerStates[currentThread]);
  1104. }
  1105. }
  1106. if (threadCount > 1) {
  1107. POOL_free(pool);
  1108. pool = NULL;
  1109. }
  1110. Py_END_ALLOW_THREADS
  1111. for (i = 0; i < threadCount; i++) {
  1112. switch (workerStates[i].error) {
  1113. case DecompressorWorkerError_none:
  1114. break;
  1115. case DecompressorWorkerError_zstd:
  1116. PyErr_Format(ZstdError, "error decompressing item %zd: %s",
  1117. workerStates[i].errorOffset,
  1118. ZSTD_getErrorName(workerStates[i].zresult));
  1119. errored = 1;
  1120. break;
  1121. case DecompressorWorkerError_memory:
  1122. PyErr_NoMemory();
  1123. errored = 1;
  1124. break;
  1125. case DecompressorWorkerError_sizeMismatch:
  1126. PyErr_Format(ZstdError,
  1127. "error decompressing item %zd: decompressed %zu "
  1128. "bytes; expected %zu",
  1129. workerStates[i].errorOffset, workerStates[i].zresult,
  1130. framePointers[workerStates[i].errorOffset].destSize);
  1131. errored = 1;
  1132. break;
  1133. case DecompressorWorkerError_unknownSize:
  1134. PyErr_Format(PyExc_ValueError,
  1135. "could not determine decompressed size of item %zd",
  1136. workerStates[i].errorOffset);
  1137. errored = 1;
  1138. break;
  1139. default:
  1140. PyErr_Format(ZstdError, "unhandled error type: %d; this is a bug",
  1141. workerStates[i].error);
  1142. errored = 1;
  1143. break;
  1144. }
  1145. if (errored) {
  1146. break;
  1147. }
  1148. }
  1149. if (errored) {
  1150. goto finally;
  1151. }
  1152. segmentsCount = 0;
  1153. for (i = 0; i < threadCount; i++) {
  1154. segmentsCount += workerStates[i].destCount;
  1155. }
  1156. resultArg = PyTuple_New(segmentsCount);
  1157. if (NULL == resultArg) {
  1158. goto finally;
  1159. }
  1160. resultIndex = 0;
  1161. for (i = 0; i < threadCount; i++) {
  1162. Py_ssize_t bufferIndex;
  1163. DecompressorWorkerState *state = &workerStates[i];
  1164. for (bufferIndex = 0; bufferIndex < state->destCount; bufferIndex++) {
  1165. DecompressorDestBuffer *destBuffer =
  1166. &state->destBuffers[bufferIndex];
  1167. bws = BufferWithSegments_FromMemory(
  1168. destBuffer->dest, destBuffer->destSize, destBuffer->segments,
  1169. destBuffer->segmentsSize);
  1170. if (NULL == bws) {
  1171. goto finally;
  1172. }
  1173. /*
  1174. * Memory for buffer and segments was allocated using malloc() in
  1175. * worker and the memory is transferred to the BufferWithSegments
  1176. * instance. So tell instance to use free() and NULL the reference
  1177. * in the state struct so it isn't freed below.
  1178. */
  1179. bws->useFree = 1;
  1180. destBuffer->dest = NULL;
  1181. destBuffer->segments = NULL;
  1182. PyTuple_SET_ITEM(resultArg, resultIndex++, (PyObject *)bws);
  1183. }
  1184. }
  1185. result = (ZstdBufferWithSegmentsCollection *)PyObject_CallObject(
  1186. (PyObject *)ZstdBufferWithSegmentsCollectionType, resultArg);
  1187. finally:
  1188. Py_CLEAR(resultArg);
  1189. if (workerStates) {
  1190. for (i = 0; i < threadCount; i++) {
  1191. Py_ssize_t bufferIndex;
  1192. DecompressorWorkerState *state = &workerStates[i];
  1193. if (state->dctx) {
  1194. ZSTD_freeDCtx(state->dctx);
  1195. }
  1196. for (bufferIndex = 0; bufferIndex < state->destCount;
  1197. bufferIndex++) {
  1198. if (state->destBuffers) {
  1199. /*
  1200. * Will be NULL if memory transfered to a
  1201. * BufferWithSegments. Otherwise it is left over after an
  1202. * error occurred.
  1203. */
  1204. free(state->destBuffers[bufferIndex].dest);
  1205. free(state->destBuffers[bufferIndex].segments);
  1206. }
  1207. }
  1208. free(state->destBuffers);
  1209. }
  1210. PyMem_Free(workerStates);
  1211. }
  1212. POOL_free(pool);
  1213. return result;
  1214. }
  1215. #endif
  1216. #ifdef HAVE_ZSTD_POOL_APIS
  1217. static ZstdBufferWithSegmentsCollection *
  1218. Decompressor_multi_decompress_to_buffer(ZstdDecompressor *self, PyObject *args,
  1219. PyObject *kwargs) {
  1220. static char *kwlist[] = {"frames", "decompressed_sizes", "threads", NULL};
  1221. PyObject *frames;
  1222. Py_buffer frameSizes;
  1223. int threads = 0;
  1224. Py_ssize_t frameCount;
  1225. Py_buffer *frameBuffers = NULL;
  1226. FramePointer *framePointers = NULL;
  1227. unsigned long long *frameSizesP = NULL;
  1228. unsigned long long totalInputSize = 0;
  1229. FrameSources frameSources;
  1230. ZstdBufferWithSegmentsCollection *result = NULL;
  1231. Py_ssize_t i;
  1232. memset(&frameSizes, 0, sizeof(frameSizes));
  1233. if (!PyArg_ParseTupleAndKeywords(args, kwargs,
  1234. "O|y*i:multi_decompress_to_buffer", kwlist,
  1235. &frames, &frameSizes, &threads)) {
  1236. return NULL;
  1237. }
  1238. if (frameSizes.buf) {
  1239. frameSizesP = (unsigned long long *)frameSizes.buf;
  1240. }
  1241. if (threads < 0) {
  1242. threads = cpu_count();
  1243. }
  1244. if (threads < 2) {
  1245. threads = 1;
  1246. }
  1247. if (PyObject_TypeCheck(frames, ZstdBufferWithSegmentsType)) {
  1248. ZstdBufferWithSegments *buffer = (ZstdBufferWithSegments *)frames;
  1249. frameCount = buffer->segmentCount;
  1250. if (frameSizes.buf &&
  1251. frameSizes.len !=
  1252. frameCount * (Py_ssize_t)sizeof(unsigned long long)) {
  1253. PyErr_Format(
  1254. PyExc_ValueError,
  1255. "decompressed_sizes size mismatch; expected %zd, got %zd",
  1256. frameCount * sizeof(unsigned long long), frameSizes.len);
  1257. goto finally;
  1258. }
  1259. framePointers = PyMem_Malloc(frameCount * sizeof(FramePointer));
  1260. if (!framePointers) {
  1261. PyErr_NoMemory();
  1262. goto finally;
  1263. }
  1264. for (i = 0; i < frameCount; i++) {
  1265. void *sourceData;
  1266. unsigned long long sourceSize;
  1267. unsigned long long decompressedSize = 0;
  1268. if (buffer->segments[i].offset + buffer->segments[i].length >
  1269. buffer->dataSize) {
  1270. PyErr_Format(PyExc_ValueError,
  1271. "item %zd has offset outside memory area", i);
  1272. goto finally;
  1273. }
  1274. sourceData = (char *)buffer->data + buffer->segments[i].offset;
  1275. sourceSize = buffer->segments[i].length;
  1276. totalInputSize += sourceSize;
  1277. if (frameSizesP) {
  1278. decompressedSize = frameSizesP[i];
  1279. }
  1280. if (sourceSize > SIZE_MAX) {
  1281. PyErr_Format(PyExc_ValueError,
  1282. "item %zd is too large for this platform", i);
  1283. goto finally;
  1284. }
  1285. if (decompressedSize > SIZE_MAX) {
  1286. PyErr_Format(PyExc_ValueError,
  1287. "decompressed size of item %zd is too large for "
  1288. "this platform",
  1289. i);
  1290. goto finally;
  1291. }
  1292. framePointers[i].sourceData = sourceData;
  1293. framePointers[i].sourceSize = (size_t)sourceSize;
  1294. framePointers[i].destSize = (size_t)decompressedSize;
  1295. }
  1296. }
  1297. else if (PyObject_TypeCheck(frames, ZstdBufferWithSegmentsCollectionType)) {
  1298. Py_ssize_t offset = 0;
  1299. ZstdBufferWithSegments *buffer;
  1300. ZstdBufferWithSegmentsCollection *collection =
  1301. (ZstdBufferWithSegmentsCollection *)frames;
  1302. frameCount = BufferWithSegmentsCollection_length(collection);
  1303. if (frameSizes.buf && frameSizes.len != frameCount) {
  1304. PyErr_Format(
  1305. PyExc_ValueError,
  1306. "decompressed_sizes size mismatch; expected %zd; got %zd",
  1307. frameCount * sizeof(unsigned long long), frameSizes.len);
  1308. goto finally;
  1309. }
  1310. framePointers = PyMem_Malloc(frameCount * sizeof(FramePointer));
  1311. if (NULL == framePointers) {
  1312. PyErr_NoMemory();
  1313. goto finally;
  1314. }
  1315. /* Iterate the data structure directly because it is faster. */
  1316. for (i = 0; i < collection->bufferCount; i++) {
  1317. Py_ssize_t segmentIndex;
  1318. buffer = collection->buffers[i];
  1319. for (segmentIndex = 0; segmentIndex < buffer->segmentCount;
  1320. segmentIndex++) {
  1321. unsigned long long decompressedSize =
  1322. frameSizesP ? frameSizesP[offset] : 0;
  1323. if (buffer->segments[segmentIndex].offset +
  1324. buffer->segments[segmentIndex].length >
  1325. buffer->dataSize) {
  1326. PyErr_Format(PyExc_ValueError,
  1327. "item %zd has offset outside memory area",
  1328. offset);
  1329. goto finally;
  1330. }
  1331. if (buffer->segments[segmentIndex].length > SIZE_MAX) {
  1332. PyErr_Format(
  1333. PyExc_ValueError,
  1334. "item %zd in buffer %zd is too large for this platform",
  1335. segmentIndex, i);
  1336. goto finally;
  1337. }
  1338. if (decompressedSize > SIZE_MAX) {
  1339. PyErr_Format(PyExc_ValueError,
  1340. "decompressed size of item %zd in buffer %zd "
  1341. "is too large for this platform",
  1342. segmentIndex, i);
  1343. goto finally;
  1344. }
  1345. totalInputSize += buffer->segments[segmentIndex].length;
  1346. framePointers[offset].sourceData =
  1347. (char *)buffer->data +
  1348. buffer->segments[segmentIndex].offset;
  1349. framePointers[offset].sourceSize =
  1350. (size_t)buffer->segments[segmentIndex].length;
  1351. framePointers[offset].destSize = (size_t)decompressedSize;
  1352. offset++;
  1353. }
  1354. }
  1355. }
  1356. else if (PyList_Check(frames)) {
  1357. frameCount = PyList_GET_SIZE(frames);
  1358. if (frameSizes.buf &&
  1359. frameSizes.len !=
  1360. frameCount * (Py_ssize_t)sizeof(unsigned long long)) {
  1361. PyErr_Format(
  1362. PyExc_ValueError,
  1363. "decompressed_sizes size mismatch; expected %zd, got %zd",
  1364. frameCount * sizeof(unsigned long long), frameSizes.len);
  1365. goto finally;
  1366. }
  1367. framePointers = PyMem_Malloc(frameCount * sizeof(FramePointer));
  1368. if (!framePointers) {
  1369. PyErr_NoMemory();
  1370. goto finally;
  1371. }
  1372. frameBuffers = PyMem_Malloc(frameCount * sizeof(Py_buffer));
  1373. if (NULL == frameBuffers) {
  1374. PyErr_NoMemory();
  1375. goto finally;
  1376. }
  1377. memset(frameBuffers, 0, frameCount * sizeof(Py_buffer));
  1378. /* Do a pass to assemble info about our input buffers and output sizes.
  1379. */
  1380. for (i = 0; i < frameCount; i++) {
  1381. unsigned long long decompressedSize =
  1382. frameSizesP ? frameSizesP[i] : 0;
  1383. if (0 != PyObject_GetBuffer(PyList_GET_ITEM(frames, i),
  1384. &frameBuffers[i], PyBUF_CONTIG_RO)) {
  1385. PyErr_Clear();
  1386. PyErr_Format(PyExc_TypeError,
  1387. "item %zd not a bytes like object", i);
  1388. goto finally;
  1389. }
  1390. if (decompressedSize > SIZE_MAX) {
  1391. PyErr_Format(PyExc_ValueError,
  1392. "decompressed size of item %zd is too large for "
  1393. "this platform",
  1394. i);
  1395. goto finally;
  1396. }
  1397. totalInputSize += frameBuffers[i].len;
  1398. framePointers[i].sourceData = frameBuffers[i].buf;
  1399. framePointers[i].sourceSize = frameBuffers[i].len;
  1400. framePointers[i].destSize = (size_t)decompressedSize;
  1401. }
  1402. }
  1403. else {
  1404. PyErr_SetString(PyExc_TypeError,
  1405. "argument must be list or BufferWithSegments");
  1406. goto finally;
  1407. }
  1408. /* We now have an array with info about our inputs and outputs. Feed it into
  1409. our generic decompression function. */
  1410. frameSources.frames = framePointers;
  1411. frameSources.framesSize = frameCount;
  1412. frameSources.compressedSize = totalInputSize;
  1413. result = decompress_from_framesources(self, &frameSources, threads);
  1414. finally:
  1415. if (frameSizes.buf) {
  1416. PyBuffer_Release(&frameSizes);
  1417. }
  1418. PyMem_Free(framePointers);
  1419. if (frameBuffers) {
  1420. for (i = 0; i < frameCount; i++) {
  1421. PyBuffer_Release(&frameBuffers[i]);
  1422. }
  1423. PyMem_Free(frameBuffers);
  1424. }
  1425. return result;
  1426. }
  1427. #endif
  1428. static PyMethodDef Decompressor_methods[] = {
  1429. {"copy_stream", (PyCFunction)Decompressor_copy_stream,
  1430. METH_VARARGS | METH_KEYWORDS, NULL},
  1431. {"decompress", (PyCFunction)Decompressor_decompress,
  1432. METH_VARARGS | METH_KEYWORDS, NULL},
  1433. {"decompressobj", (PyCFunction)Decompressor_decompressobj,
  1434. METH_VARARGS | METH_KEYWORDS, NULL},
  1435. {"read_to_iter", (PyCFunction)Decompressor_read_to_iter,
  1436. METH_VARARGS | METH_KEYWORDS, NULL},
  1437. {"stream_reader", (PyCFunction)Decompressor_stream_reader,
  1438. METH_VARARGS | METH_KEYWORDS, NULL},
  1439. {"stream_writer", (PyCFunction)Decompressor_stream_writer,
  1440. METH_VARARGS | METH_KEYWORDS, NULL},
  1441. {"decompress_content_dict_chain",
  1442. (PyCFunction)Decompressor_decompress_content_dict_chain,
  1443. METH_VARARGS | METH_KEYWORDS, NULL},
  1444. #ifdef HAVE_ZSTD_POOL_APIS
  1445. {"multi_decompress_to_buffer",
  1446. (PyCFunction)Decompressor_multi_decompress_to_buffer,
  1447. METH_VARARGS | METH_KEYWORDS, NULL},
  1448. #endif
  1449. {"memory_size", (PyCFunction)Decompressor_memory_size, METH_NOARGS, NULL},
  1450. {NULL, NULL}};
  1451. PyType_Slot ZstdDecompressorSlots[] = {
  1452. {Py_tp_dealloc, Decompressor_dealloc},
  1453. {Py_tp_methods, Decompressor_methods},
  1454. {Py_tp_init, Decompressor_init},
  1455. {Py_tp_new, PyType_GenericNew},
  1456. {0, NULL},
  1457. };
  1458. PyType_Spec ZstdDecompressorSpec = {
  1459. "zstd.ZstdDecompressor",
  1460. sizeof(ZstdDecompressor),
  1461. 0,
  1462. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
  1463. ZstdDecompressorSlots,
  1464. };
  1465. PyTypeObject *ZstdDecompressorType;
  1466. void decompressor_module_init(PyObject *mod) {
  1467. ZstdDecompressorType =
  1468. (PyTypeObject *)PyType_FromSpec(&ZstdDecompressorSpec);
  1469. if (PyType_Ready(ZstdDecompressorType) < 0) {
  1470. return;
  1471. }
  1472. Py_INCREF((PyObject *)ZstdDecompressorType);
  1473. PyModule_AddObject(mod, "ZstdDecompressor",
  1474. (PyObject *)ZstdDecompressorType);
  1475. }