compressiondict.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. /**
  2. * Copyright (c) 2016-present, Gregory Szorc
  3. * All rights reserved.
  4. *
  5. * This software may be modified and distributed under the terms
  6. * of the BSD license. See the LICENSE file for details.
  7. */
  8. #include "python-zstandard.h"
  9. extern PyObject* ZstdError;
  10. ZstdCompressionDict* train_dictionary(PyObject* self, PyObject* args, PyObject* kwargs) {
  11. static char* kwlist[] = {
  12. "dict_size",
  13. "samples",
  14. "k",
  15. "d",
  16. "notifications",
  17. "dict_id",
  18. "level",
  19. "steps",
  20. "threads",
  21. NULL
  22. };
  23. size_t capacity;
  24. PyObject* samples;
  25. unsigned k = 0;
  26. unsigned d = 0;
  27. unsigned notifications = 0;
  28. unsigned dictID = 0;
  29. int level = 0;
  30. unsigned steps = 0;
  31. int threads = 0;
  32. ZDICT_cover_params_t params;
  33. Py_ssize_t samplesLen;
  34. Py_ssize_t i;
  35. size_t samplesSize = 0;
  36. void* sampleBuffer = NULL;
  37. size_t* sampleSizes = NULL;
  38. void* sampleOffset;
  39. Py_ssize_t sampleSize;
  40. void* dict = NULL;
  41. size_t zresult;
  42. ZstdCompressionDict* result = NULL;
  43. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "nO!|IIIIiIi:train_dictionary",
  44. kwlist, &capacity, &PyList_Type, &samples,
  45. &k, &d, &notifications, &dictID, &level, &steps, &threads)) {
  46. return NULL;
  47. }
  48. if (threads < 0) {
  49. threads = cpu_count();
  50. }
  51. memset(&params, 0, sizeof(params));
  52. params.k = k;
  53. params.d = d;
  54. params.steps = steps;
  55. params.nbThreads = threads;
  56. params.zParams.notificationLevel = notifications;
  57. params.zParams.dictID = dictID;
  58. params.zParams.compressionLevel = level;
  59. /* Figure out total size of input samples. */
  60. samplesLen = PyList_Size(samples);
  61. for (i = 0; i < samplesLen; i++) {
  62. PyObject* sampleItem = PyList_GET_ITEM(samples, i);
  63. if (!PyBytes_Check(sampleItem)) {
  64. PyErr_SetString(PyExc_ValueError, "samples must be bytes");
  65. return NULL;
  66. }
  67. samplesSize += PyBytes_GET_SIZE(sampleItem);
  68. }
  69. sampleBuffer = PyMem_Malloc(samplesSize);
  70. if (!sampleBuffer) {
  71. PyErr_NoMemory();
  72. goto finally;
  73. }
  74. sampleSizes = PyMem_Malloc(samplesLen * sizeof(size_t));
  75. if (!sampleSizes) {
  76. PyErr_NoMemory();
  77. goto finally;
  78. }
  79. sampleOffset = sampleBuffer;
  80. for (i = 0; i < samplesLen; i++) {
  81. PyObject* sampleItem = PyList_GET_ITEM(samples, i);
  82. sampleSize = PyBytes_GET_SIZE(sampleItem);
  83. sampleSizes[i] = sampleSize;
  84. memcpy(sampleOffset, PyBytes_AS_STRING(sampleItem), sampleSize);
  85. sampleOffset = (char*)sampleOffset + sampleSize;
  86. }
  87. dict = PyMem_Malloc(capacity);
  88. if (!dict) {
  89. PyErr_NoMemory();
  90. goto finally;
  91. }
  92. Py_BEGIN_ALLOW_THREADS
  93. /* No parameters uses the default function, which will use default params
  94. and call ZDICT_optimizeTrainFromBuffer_cover under the hood. */
  95. if (!params.k && !params.d && !params.zParams.compressionLevel
  96. && !params.zParams.notificationLevel && !params.zParams.dictID) {
  97. zresult = ZDICT_trainFromBuffer(dict, capacity, sampleBuffer,
  98. sampleSizes, (unsigned)samplesLen);
  99. }
  100. /* Use optimize mode if user controlled steps or threads explicitly. */
  101. else if (params.steps || params.nbThreads) {
  102. zresult = ZDICT_optimizeTrainFromBuffer_cover(dict, capacity,
  103. sampleBuffer, sampleSizes, (unsigned)samplesLen, &params);
  104. }
  105. /* Non-optimize mode with explicit control. */
  106. else {
  107. zresult = ZDICT_trainFromBuffer_cover(dict, capacity,
  108. sampleBuffer, sampleSizes, (unsigned)samplesLen, params);
  109. }
  110. Py_END_ALLOW_THREADS
  111. if (ZDICT_isError(zresult)) {
  112. PyMem_Free(dict);
  113. PyErr_Format(ZstdError, "cannot train dict: %s", ZDICT_getErrorName(zresult));
  114. goto finally;
  115. }
  116. result = PyObject_New(ZstdCompressionDict, &ZstdCompressionDictType);
  117. if (!result) {
  118. PyMem_Free(dict);
  119. goto finally;
  120. }
  121. result->dictData = dict;
  122. result->dictSize = zresult;
  123. result->dictType = ZSTD_dct_fullDict;
  124. result->d = params.d;
  125. result->k = params.k;
  126. result->cdict = NULL;
  127. result->ddict = NULL;
  128. finally:
  129. PyMem_Free(sampleBuffer);
  130. PyMem_Free(sampleSizes);
  131. return result;
  132. }
  133. int ensure_ddict(ZstdCompressionDict* dict) {
  134. if (dict->ddict) {
  135. return 0;
  136. }
  137. Py_BEGIN_ALLOW_THREADS
  138. dict->ddict = ZSTD_createDDict_advanced(dict->dictData, dict->dictSize,
  139. ZSTD_dlm_byRef, dict->dictType, ZSTD_defaultCMem);
  140. Py_END_ALLOW_THREADS
  141. if (!dict->ddict) {
  142. PyErr_SetString(ZstdError, "could not create decompression dict");
  143. return 1;
  144. }
  145. return 0;
  146. }
  147. PyDoc_STRVAR(ZstdCompressionDict__doc__,
  148. "ZstdCompressionDict(data) - Represents a computed compression dictionary\n"
  149. "\n"
  150. "This type holds the results of a computed Zstandard compression dictionary.\n"
  151. "Instances are obtained by calling ``train_dictionary()`` or by passing\n"
  152. "bytes obtained from another source into the constructor.\n"
  153. );
  154. static int ZstdCompressionDict_init(ZstdCompressionDict* self, PyObject* args, PyObject* kwargs) {
  155. static char* kwlist[] = {
  156. "data",
  157. "dict_type",
  158. NULL
  159. };
  160. int result = -1;
  161. Py_buffer source;
  162. unsigned dictType = ZSTD_dct_auto;
  163. self->dictData = NULL;
  164. self->dictSize = 0;
  165. self->cdict = NULL;
  166. self->ddict = NULL;
  167. #if PY_MAJOR_VERSION >= 3
  168. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|I:ZstdCompressionDict",
  169. #else
  170. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s*|I:ZstdCompressionDict",
  171. #endif
  172. kwlist, &source, &dictType)) {
  173. return -1;
  174. }
  175. if (!PyBuffer_IsContiguous(&source, 'C') || source.ndim > 1) {
  176. PyErr_SetString(PyExc_ValueError,
  177. "data buffer should be contiguous and have at most one dimension");
  178. goto finally;
  179. }
  180. if (dictType != ZSTD_dct_auto && dictType != ZSTD_dct_rawContent
  181. && dictType != ZSTD_dct_fullDict) {
  182. PyErr_Format(PyExc_ValueError,
  183. "invalid dictionary load mode: %d; must use DICT_TYPE_* constants",
  184. dictType);
  185. goto finally;
  186. }
  187. self->dictType = dictType;
  188. self->dictData = PyMem_Malloc(source.len);
  189. if (!self->dictData) {
  190. PyErr_NoMemory();
  191. goto finally;
  192. }
  193. memcpy(self->dictData, source.buf, source.len);
  194. self->dictSize = source.len;
  195. result = 0;
  196. finally:
  197. PyBuffer_Release(&source);
  198. return result;
  199. }
  200. static void ZstdCompressionDict_dealloc(ZstdCompressionDict* self) {
  201. if (self->cdict) {
  202. ZSTD_freeCDict(self->cdict);
  203. self->cdict = NULL;
  204. }
  205. if (self->ddict) {
  206. ZSTD_freeDDict(self->ddict);
  207. self->ddict = NULL;
  208. }
  209. if (self->dictData) {
  210. PyMem_Free(self->dictData);
  211. self->dictData = NULL;
  212. }
  213. PyObject_Del(self);
  214. }
  215. PyDoc_STRVAR(ZstdCompressionDict_precompute_compress__doc__,
  216. "Precompute a dictionary so it can be used by multiple compressors.\n"
  217. );
  218. static PyObject* ZstdCompressionDict_precompute_compress(ZstdCompressionDict* self, PyObject* args, PyObject* kwargs) {
  219. static char* kwlist[] = {
  220. "level",
  221. "compression_params",
  222. NULL
  223. };
  224. int level = 0;
  225. ZstdCompressionParametersObject* compressionParams = NULL;
  226. ZSTD_compressionParameters cParams;
  227. size_t zresult;
  228. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|iO!:precompute_compress", kwlist,
  229. &level, &ZstdCompressionParametersType, &compressionParams)) {
  230. return NULL;
  231. }
  232. if (level && compressionParams) {
  233. PyErr_SetString(PyExc_ValueError,
  234. "must only specify one of level or compression_params");
  235. return NULL;
  236. }
  237. if (!level && !compressionParams) {
  238. PyErr_SetString(PyExc_ValueError,
  239. "must specify one of level or compression_params");
  240. return NULL;
  241. }
  242. if (self->cdict) {
  243. zresult = ZSTD_freeCDict(self->cdict);
  244. self->cdict = NULL;
  245. if (ZSTD_isError(zresult)) {
  246. PyErr_Format(ZstdError, "unable to free CDict: %s",
  247. ZSTD_getErrorName(zresult));
  248. return NULL;
  249. }
  250. }
  251. if (level) {
  252. cParams = ZSTD_getCParams(level, 0, self->dictSize);
  253. }
  254. else {
  255. if (to_cparams(compressionParams, &cParams)) {
  256. return NULL;
  257. }
  258. }
  259. assert(!self->cdict);
  260. self->cdict = ZSTD_createCDict_advanced(self->dictData, self->dictSize,
  261. ZSTD_dlm_byRef, self->dictType, cParams, ZSTD_defaultCMem);
  262. if (!self->cdict) {
  263. PyErr_SetString(ZstdError, "unable to precompute dictionary");
  264. return NULL;
  265. }
  266. Py_RETURN_NONE;
  267. }
  268. static PyObject* ZstdCompressionDict_dict_id(ZstdCompressionDict* self) {
  269. unsigned dictID = ZDICT_getDictID(self->dictData, self->dictSize);
  270. return PyLong_FromLong(dictID);
  271. }
  272. static PyObject* ZstdCompressionDict_as_bytes(ZstdCompressionDict* self) {
  273. return PyBytes_FromStringAndSize(self->dictData, self->dictSize);
  274. }
  275. static PyMethodDef ZstdCompressionDict_methods[] = {
  276. { "dict_id", (PyCFunction)ZstdCompressionDict_dict_id, METH_NOARGS,
  277. PyDoc_STR("dict_id() -- obtain the numeric dictionary ID") },
  278. { "as_bytes", (PyCFunction)ZstdCompressionDict_as_bytes, METH_NOARGS,
  279. PyDoc_STR("as_bytes() -- obtain the raw bytes constituting the dictionary data") },
  280. { "precompute_compress", (PyCFunction)ZstdCompressionDict_precompute_compress,
  281. METH_VARARGS | METH_KEYWORDS, ZstdCompressionDict_precompute_compress__doc__ },
  282. { NULL, NULL }
  283. };
  284. static PyMemberDef ZstdCompressionDict_members[] = {
  285. { "k", T_UINT, offsetof(ZstdCompressionDict, k), READONLY,
  286. "segment size" },
  287. { "d", T_UINT, offsetof(ZstdCompressionDict, d), READONLY,
  288. "dmer size" },
  289. { NULL }
  290. };
  291. static Py_ssize_t ZstdCompressionDict_length(ZstdCompressionDict* self) {
  292. return self->dictSize;
  293. }
  294. static PySequenceMethods ZstdCompressionDict_sq = {
  295. (lenfunc)ZstdCompressionDict_length, /* sq_length */
  296. 0, /* sq_concat */
  297. 0, /* sq_repeat */
  298. 0, /* sq_item */
  299. 0, /* sq_ass_item */
  300. 0, /* sq_contains */
  301. 0, /* sq_inplace_concat */
  302. 0 /* sq_inplace_repeat */
  303. };
  304. PyTypeObject ZstdCompressionDictType = {
  305. PyVarObject_HEAD_INIT(NULL, 0)
  306. "zstd.ZstdCompressionDict", /* tp_name */
  307. sizeof(ZstdCompressionDict), /* tp_basicsize */
  308. 0, /* tp_itemsize */
  309. (destructor)ZstdCompressionDict_dealloc, /* tp_dealloc */
  310. 0, /* tp_print */
  311. 0, /* tp_getattr */
  312. 0, /* tp_setattr */
  313. 0, /* tp_compare */
  314. 0, /* tp_repr */
  315. 0, /* tp_as_number */
  316. &ZstdCompressionDict_sq, /* tp_as_sequence */
  317. 0, /* tp_as_mapping */
  318. 0, /* tp_hash */
  319. 0, /* tp_call */
  320. 0, /* tp_str */
  321. 0, /* tp_getattro */
  322. 0, /* tp_setattro */
  323. 0, /* tp_as_buffer */
  324. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
  325. ZstdCompressionDict__doc__, /* tp_doc */
  326. 0, /* tp_traverse */
  327. 0, /* tp_clear */
  328. 0, /* tp_richcompare */
  329. 0, /* tp_weaklistoffset */
  330. 0, /* tp_iter */
  331. 0, /* tp_iternext */
  332. ZstdCompressionDict_methods, /* tp_methods */
  333. ZstdCompressionDict_members, /* tp_members */
  334. 0, /* tp_getset */
  335. 0, /* tp_base */
  336. 0, /* tp_dict */
  337. 0, /* tp_descr_get */
  338. 0, /* tp_descr_set */
  339. 0, /* tp_dictoffset */
  340. (initproc)ZstdCompressionDict_init, /* tp_init */
  341. 0, /* tp_alloc */
  342. PyType_GenericNew, /* tp_new */
  343. };
  344. void compressiondict_module_init(PyObject* mod) {
  345. Py_TYPE(&ZstdCompressionDictType) = &PyType_Type;
  346. if (PyType_Ready(&ZstdCompressionDictType) < 0) {
  347. return;
  348. }
  349. Py_INCREF((PyObject*)&ZstdCompressionDictType);
  350. PyModule_AddObject(mod, "ZstdCompressionDict",
  351. (PyObject*)&ZstdCompressionDictType);
  352. }