compressiondict.c 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. /**
  2. * Copyright (c) 2016-present, Gregory Szorc
  3. * All rights reserved.
  4. *
  5. * This software may be modified and distributed under the terms
  6. * of the BSD license. See the LICENSE file for details.
  7. */
  8. #include "python-zstandard.h"
  9. extern PyObject *ZstdError;
  10. ZstdCompressionDict *train_dictionary(PyObject *self, PyObject *args,
  11. PyObject *kwargs) {
  12. static char *kwlist[] = {
  13. "dict_size", "samples", "k", "d",
  14. "f", "split_point", "accel", "notifications",
  15. "dict_id", "level", "steps", "threads",
  16. NULL};
  17. size_t capacity;
  18. PyObject *samples;
  19. unsigned k = 0;
  20. unsigned d = 0;
  21. unsigned f = 0;
  22. double splitPoint = 0.0;
  23. unsigned accel = 0;
  24. unsigned notifications = 0;
  25. unsigned dictID = 0;
  26. int level = 0;
  27. unsigned steps = 0;
  28. int threads = 0;
  29. ZDICT_fastCover_params_t params;
  30. Py_ssize_t samplesLen;
  31. Py_ssize_t i;
  32. size_t samplesSize = 0;
  33. void *sampleBuffer = NULL;
  34. size_t *sampleSizes = NULL;
  35. void *sampleOffset;
  36. Py_ssize_t sampleSize;
  37. void *dict = NULL;
  38. size_t zresult;
  39. ZstdCompressionDict *result = NULL;
  40. if (!PyArg_ParseTupleAndKeywords(
  41. args, kwargs, "nO!|IIIdIIIiIi:train_dictionary", kwlist, &capacity,
  42. &PyList_Type, &samples, &k, &d, &f, &splitPoint, &accel,
  43. &notifications, &dictID, &level, &steps, &threads)) {
  44. return NULL;
  45. }
  46. if (threads < 0) {
  47. threads = cpu_count();
  48. }
  49. if (!steps && !threads) {
  50. /* Defaults from ZDICT_trainFromBuffer() */
  51. d = d ? d : 8;
  52. steps = steps ? steps : 4;
  53. level = level ? level : 3;
  54. }
  55. memset(&params, 0, sizeof(params));
  56. params.k = k;
  57. params.d = d;
  58. params.f = f;
  59. params.steps = steps;
  60. params.nbThreads = threads;
  61. params.splitPoint = splitPoint;
  62. params.accel = accel;
  63. params.zParams.compressionLevel = level;
  64. params.zParams.dictID = dictID;
  65. params.zParams.notificationLevel = notifications;
  66. /* Figure out total size of input samples. */
  67. samplesLen = PyList_Size(samples);
  68. for (i = 0; i < samplesLen; i++) {
  69. PyObject *sampleItem = PyList_GET_ITEM(samples, i);
  70. if (!PyBytes_Check(sampleItem)) {
  71. PyErr_SetString(PyExc_ValueError, "samples must be bytes");
  72. return NULL;
  73. }
  74. samplesSize += PyBytes_GET_SIZE(sampleItem);
  75. }
  76. sampleBuffer = PyMem_Malloc(samplesSize);
  77. if (!sampleBuffer) {
  78. PyErr_NoMemory();
  79. goto finally;
  80. }
  81. sampleSizes = PyMem_Malloc(samplesLen * sizeof(size_t));
  82. if (!sampleSizes) {
  83. PyErr_NoMemory();
  84. goto finally;
  85. }
  86. sampleOffset = sampleBuffer;
  87. for (i = 0; i < samplesLen; i++) {
  88. PyObject *sampleItem = PyList_GET_ITEM(samples, i);
  89. sampleSize = PyBytes_GET_SIZE(sampleItem);
  90. sampleSizes[i] = sampleSize;
  91. memcpy(sampleOffset, PyBytes_AS_STRING(sampleItem), sampleSize);
  92. sampleOffset = (char *)sampleOffset + sampleSize;
  93. }
  94. dict = PyMem_Malloc(capacity);
  95. if (!dict) {
  96. PyErr_NoMemory();
  97. goto finally;
  98. }
  99. Py_BEGIN_ALLOW_THREADS zresult = ZDICT_optimizeTrainFromBuffer_fastCover(
  100. dict, capacity, sampleBuffer, sampleSizes, (unsigned)samplesLen,
  101. &params);
  102. Py_END_ALLOW_THREADS
  103. if (ZDICT_isError(zresult)) {
  104. PyMem_Free(dict);
  105. PyErr_Format(ZstdError, "cannot train dict: %s",
  106. ZDICT_getErrorName(zresult));
  107. goto finally;
  108. }
  109. result = PyObject_New(ZstdCompressionDict, ZstdCompressionDictType);
  110. if (!result) {
  111. PyMem_Free(dict);
  112. goto finally;
  113. }
  114. result->dictData = dict;
  115. result->dictSize = zresult;
  116. result->dictType = ZSTD_dct_fullDict;
  117. result->d = params.d;
  118. result->k = params.k;
  119. result->cdict = NULL;
  120. result->ddict = NULL;
  121. finally:
  122. PyMem_Free(sampleBuffer);
  123. PyMem_Free(sampleSizes);
  124. return result;
  125. }
  126. int ensure_ddict(ZstdCompressionDict *dict) {
  127. if (dict->ddict) {
  128. return 0;
  129. }
  130. Py_BEGIN_ALLOW_THREADS dict->ddict = ZSTD_createDDict_advanced(
  131. dict->dictData, dict->dictSize, ZSTD_dlm_byRef, dict->dictType,
  132. ZSTD_defaultCMem);
  133. Py_END_ALLOW_THREADS if (!dict->ddict) {
  134. PyErr_SetString(ZstdError, "could not create decompression dict");
  135. return 1;
  136. }
  137. return 0;
  138. }
  139. static int ZstdCompressionDict_init(ZstdCompressionDict *self, PyObject *args,
  140. PyObject *kwargs) {
  141. static char *kwlist[] = {"data", "dict_type", NULL};
  142. int result = -1;
  143. Py_buffer source;
  144. unsigned dictType = ZSTD_dct_auto;
  145. self->dictData = NULL;
  146. self->dictSize = 0;
  147. self->cdict = NULL;
  148. self->ddict = NULL;
  149. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|I:ZstdCompressionDict",
  150. kwlist, &source, &dictType)) {
  151. return -1;
  152. }
  153. if (dictType != ZSTD_dct_auto && dictType != ZSTD_dct_rawContent &&
  154. dictType != ZSTD_dct_fullDict) {
  155. PyErr_Format(
  156. PyExc_ValueError,
  157. "invalid dictionary load mode: %d; must use DICT_TYPE_* constants",
  158. dictType);
  159. goto finally;
  160. }
  161. self->dictType = dictType;
  162. self->dictData = PyMem_Malloc(source.len);
  163. if (!self->dictData) {
  164. PyErr_NoMemory();
  165. goto finally;
  166. }
  167. memcpy(self->dictData, source.buf, source.len);
  168. self->dictSize = source.len;
  169. result = 0;
  170. finally:
  171. PyBuffer_Release(&source);
  172. return result;
  173. }
  174. static void ZstdCompressionDict_dealloc(ZstdCompressionDict *self) {
  175. if (self->cdict) {
  176. ZSTD_freeCDict(self->cdict);
  177. self->cdict = NULL;
  178. }
  179. if (self->ddict) {
  180. ZSTD_freeDDict(self->ddict);
  181. self->ddict = NULL;
  182. }
  183. if (self->dictData) {
  184. PyMem_Free(self->dictData);
  185. self->dictData = NULL;
  186. }
  187. PyObject_Del(self);
  188. }
  189. static PyObject *
  190. ZstdCompressionDict_precompute_compress(ZstdCompressionDict *self,
  191. PyObject *args, PyObject *kwargs) {
  192. static char *kwlist[] = {"level", "compression_params", NULL};
  193. int level = 0;
  194. ZstdCompressionParametersObject *compressionParams = NULL;
  195. ZSTD_compressionParameters cParams;
  196. size_t zresult;
  197. if (!PyArg_ParseTupleAndKeywords(
  198. args, kwargs, "|iO!:precompute_compress", kwlist, &level,
  199. ZstdCompressionParametersType, &compressionParams)) {
  200. return NULL;
  201. }
  202. if (level && compressionParams) {
  203. PyErr_SetString(PyExc_ValueError,
  204. "must only specify one of level or compression_params");
  205. return NULL;
  206. }
  207. if (!level && !compressionParams) {
  208. PyErr_SetString(PyExc_ValueError,
  209. "must specify one of level or compression_params");
  210. return NULL;
  211. }
  212. if (self->cdict) {
  213. zresult = ZSTD_freeCDict(self->cdict);
  214. self->cdict = NULL;
  215. if (ZSTD_isError(zresult)) {
  216. PyErr_Format(ZstdError, "unable to free CDict: %s",
  217. ZSTD_getErrorName(zresult));
  218. return NULL;
  219. }
  220. }
  221. if (level) {
  222. cParams = ZSTD_getCParams(level, 0, self->dictSize);
  223. }
  224. else {
  225. if (to_cparams(compressionParams, &cParams)) {
  226. return NULL;
  227. }
  228. }
  229. assert(!self->cdict);
  230. self->cdict = ZSTD_createCDict_advanced(self->dictData, self->dictSize,
  231. ZSTD_dlm_byRef, self->dictType,
  232. cParams, ZSTD_defaultCMem);
  233. if (!self->cdict) {
  234. PyErr_SetString(ZstdError, "unable to precompute dictionary");
  235. return NULL;
  236. }
  237. Py_RETURN_NONE;
  238. }
  239. static PyObject *ZstdCompressionDict_dict_id(ZstdCompressionDict *self) {
  240. unsigned dictID = ZDICT_getDictID(self->dictData, self->dictSize);
  241. return PyLong_FromLong(dictID);
  242. }
  243. static PyObject *ZstdCompressionDict_as_bytes(ZstdCompressionDict *self) {
  244. return PyBytes_FromStringAndSize(self->dictData, self->dictSize);
  245. }
  246. static PyMethodDef ZstdCompressionDict_methods[] = {
  247. {"dict_id", (PyCFunction)ZstdCompressionDict_dict_id, METH_NOARGS,
  248. PyDoc_STR("dict_id() -- obtain the numeric dictionary ID")},
  249. {"as_bytes", (PyCFunction)ZstdCompressionDict_as_bytes, METH_NOARGS,
  250. PyDoc_STR("as_bytes() -- obtain the raw bytes constituting the dictionary "
  251. "data")},
  252. {"precompute_compress",
  253. (PyCFunction)ZstdCompressionDict_precompute_compress,
  254. METH_VARARGS | METH_KEYWORDS, NULL},
  255. {NULL, NULL}};
  256. static PyMemberDef ZstdCompressionDict_members[] = {
  257. {"k", T_UINT, offsetof(ZstdCompressionDict, k), READONLY, "segment size"},
  258. {"d", T_UINT, offsetof(ZstdCompressionDict, d), READONLY, "dmer size"},
  259. {NULL}};
  260. static Py_ssize_t ZstdCompressionDict_length(ZstdCompressionDict *self) {
  261. return self->dictSize;
  262. }
  263. PyType_Slot ZstdCompressionDictSlots[] = {
  264. {Py_tp_dealloc, ZstdCompressionDict_dealloc},
  265. {Py_sq_length, ZstdCompressionDict_length},
  266. {Py_tp_methods, ZstdCompressionDict_methods},
  267. {Py_tp_members, ZstdCompressionDict_members},
  268. {Py_tp_init, ZstdCompressionDict_init},
  269. {Py_tp_new, PyType_GenericNew},
  270. {0, NULL},
  271. };
  272. PyType_Spec ZstdCompressionDictSpec = {
  273. "zstd.ZstdCompressionDict",
  274. sizeof(ZstdCompressionDict),
  275. 0,
  276. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
  277. ZstdCompressionDictSlots,
  278. };
  279. PyTypeObject *ZstdCompressionDictType;
  280. void compressiondict_module_init(PyObject *mod) {
  281. ZstdCompressionDictType =
  282. (PyTypeObject *)PyType_FromSpec(&ZstdCompressionDictSpec);
  283. if (PyType_Ready(ZstdCompressionDictType) < 0) {
  284. return;
  285. }
  286. Py_INCREF((PyObject *)ZstdCompressionDictType);
  287. PyModule_AddObject(mod, "ZstdCompressionDict",
  288. (PyObject *)ZstdCompressionDictType);
  289. }