cover.c 43 KB


  1. /*
  2. * Copyright (c) Meta Platforms, Inc. and affiliates.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under both the BSD-style license (found in the
  6. * LICENSE file in the root directory of this source tree) and the GPLv2 (found
  7. * in the COPYING file in the root directory of this source tree).
  8. * You may select, at your option, one of the above-listed licenses.
  9. */
  10. /* *****************************************************************************
  11. * Constructs a dictionary using a heuristic based on the following paper:
  12. *
  13. * Liao, Petri, Moffat, Wirth
  14. * Effective Construction of Relative Lempel-Ziv Dictionaries
  15. * Published in WWW 2016.
  16. *
  17. * Adapted from code originally written by @ot (Giuseppe Ottaviano).
  18. ******************************************************************************/
  19. /*-*************************************
  20. * Dependencies
  21. ***************************************/
  22. /* qsort_r is an extension. */
  23. #if defined(__linux) || defined(__linux__) || defined(linux) || defined(__gnu_linux__) || \
  24. defined(__CYGWIN__) || defined(__MSYS__)
  25. #if !defined(_GNU_SOURCE) && !defined(__ANDROID__) /* NDK doesn't ship qsort_r(). */
  26. #define _GNU_SOURCE
  27. #endif
  28. #endif
  29. #include <stdio.h> /* fprintf */
  30. #include <stdlib.h> /* malloc, free, qsort_r */
  31. #include <string.h> /* memset */
  32. #include <time.h> /* clock */
  33. #ifndef ZDICT_STATIC_LINKING_ONLY
  34. # define ZDICT_STATIC_LINKING_ONLY
  35. #endif
  36. #include "../common/mem.h" /* read */
  37. #include "../common/pool.h" /* POOL_ctx */
  38. #include "../common/threading.h" /* ZSTD_pthread_mutex_t */
  39. #include "../common/zstd_internal.h" /* includes zstd.h */
  40. #include "../common/bits.h" /* ZSTD_highbit32 */
  41. #include "../zdict.h"
  42. #include "cover.h"
  43. /*-*************************************
  44. * Constants
  45. ***************************************/
  46. /**
  47. * There are 32bit indexes used to ref samples, so limit samples size to 4GB
  48. * on 64bit builds.
  49. * For 32bit builds we choose 1 GB.
  50. * Most 32bit platforms have 2GB user-mode addressable space and we allocate a large
  51. * contiguous buffer, so 1GB is already a high limit.
  52. */
  53. #define COVER_MAX_SAMPLES_SIZE (sizeof(size_t) == 8 ? ((unsigned)-1) : ((unsigned)1 GB))
  54. #define COVER_DEFAULT_SPLITPOINT 1.0
  55. /*-*************************************
  56. * Console display
  57. ***************************************/
  58. #ifndef LOCALDISPLAYLEVEL
  59. static int g_displayLevel = 0;
  60. #endif
  61. #undef DISPLAY
  62. #define DISPLAY(...) \
  63. { \
  64. fprintf(stderr, __VA_ARGS__); \
  65. fflush(stderr); \
  66. }
  67. #undef LOCALDISPLAYLEVEL
  68. #define LOCALDISPLAYLEVEL(displayLevel, l, ...) \
  69. if (displayLevel >= l) { \
  70. DISPLAY(__VA_ARGS__); \
  71. } /* 0 : no display; 1: errors; 2: default; 3: details; 4: debug */
  72. #undef DISPLAYLEVEL
  73. #define DISPLAYLEVEL(l, ...) LOCALDISPLAYLEVEL(g_displayLevel, l, __VA_ARGS__)
  74. #ifndef LOCALDISPLAYUPDATE
  75. static const clock_t g_refreshRate = CLOCKS_PER_SEC * 15 / 100;
  76. static clock_t g_time = 0;
  77. #endif
  78. #undef LOCALDISPLAYUPDATE
  79. #define LOCALDISPLAYUPDATE(displayLevel, l, ...) \
  80. if (displayLevel >= l) { \
  81. if ((clock() - g_time > g_refreshRate) || (displayLevel >= 4)) { \
  82. g_time = clock(); \
  83. DISPLAY(__VA_ARGS__); \
  84. } \
  85. }
  86. #undef DISPLAYUPDATE
  87. #define DISPLAYUPDATE(l, ...) LOCALDISPLAYUPDATE(g_displayLevel, l, __VA_ARGS__)
  88. /*-*************************************
  89. * Hash table
  90. ***************************************
  91. * A small specialized hash map for storing activeDmers.
  92. * The map does not resize, so if it becomes full it will loop forever.
  93. * Thus, the map must be large enough to store every value.
  94. * The map implements linear probing and keeps its load less than 0.5.
  95. */
  96. #define MAP_EMPTY_VALUE ((U32)-1)
  97. typedef struct COVER_map_pair_t_s {
  98. U32 key;
  99. U32 value;
  100. } COVER_map_pair_t;
  101. typedef struct COVER_map_s {
  102. COVER_map_pair_t *data;
  103. U32 sizeLog;
  104. U32 size;
  105. U32 sizeMask;
  106. } COVER_map_t;
  107. /**
  108. * Clear the map.
  109. */
  110. static void COVER_map_clear(COVER_map_t *map) {
  111. memset(map->data, MAP_EMPTY_VALUE, map->size * sizeof(COVER_map_pair_t));
  112. }
  113. /**
  114. * Initializes a map of the given size.
  115. * Returns 1 on success and 0 on failure.
  116. * The map must be destroyed with COVER_map_destroy().
  117. * The map is only guaranteed to be large enough to hold size elements.
  118. */
  119. static int COVER_map_init(COVER_map_t *map, U32 size) {
  120. map->sizeLog = ZSTD_highbit32(size) + 2;
  121. map->size = (U32)1 << map->sizeLog;
  122. map->sizeMask = map->size - 1;
  123. map->data = (COVER_map_pair_t *)malloc(map->size * sizeof(COVER_map_pair_t));
  124. if (!map->data) {
  125. map->sizeLog = 0;
  126. map->size = 0;
  127. return 0;
  128. }
  129. COVER_map_clear(map);
  130. return 1;
  131. }
  132. /**
  133. * Internal hash function
  134. */
  135. static const U32 COVER_prime4bytes = 2654435761U;
  136. static U32 COVER_map_hash(COVER_map_t *map, U32 key) {
  137. return (key * COVER_prime4bytes) >> (32 - map->sizeLog);
  138. }
  139. /**
  140. * Helper function that returns the index that a key should be placed into.
  141. */
  142. static U32 COVER_map_index(COVER_map_t *map, U32 key) {
  143. const U32 hash = COVER_map_hash(map, key);
  144. U32 i;
  145. for (i = hash;; i = (i + 1) & map->sizeMask) {
  146. COVER_map_pair_t *pos = &map->data[i];
  147. if (pos->value == MAP_EMPTY_VALUE) {
  148. return i;
  149. }
  150. if (pos->key == key) {
  151. return i;
  152. }
  153. }
  154. }
  155. /**
  156. * Returns the pointer to the value for key.
  157. * If key is not in the map, it is inserted and the value is set to 0.
  158. * The map must not be full.
  159. */
  160. static U32 *COVER_map_at(COVER_map_t *map, U32 key) {
  161. COVER_map_pair_t *pos = &map->data[COVER_map_index(map, key)];
  162. if (pos->value == MAP_EMPTY_VALUE) {
  163. pos->key = key;
  164. pos->value = 0;
  165. }
  166. return &pos->value;
  167. }
  168. /**
  169. * Deletes key from the map if present.
  170. */
  171. static void COVER_map_remove(COVER_map_t *map, U32 key) {
  172. U32 i = COVER_map_index(map, key);
  173. COVER_map_pair_t *del = &map->data[i];
  174. U32 shift = 1;
  175. if (del->value == MAP_EMPTY_VALUE) {
  176. return;
  177. }
  178. for (i = (i + 1) & map->sizeMask;; i = (i + 1) & map->sizeMask) {
  179. COVER_map_pair_t *const pos = &map->data[i];
  180. /* If the position is empty we are done */
  181. if (pos->value == MAP_EMPTY_VALUE) {
  182. del->value = MAP_EMPTY_VALUE;
  183. return;
  184. }
  185. /* If pos can be moved to del do so */
  186. if (((i - COVER_map_hash(map, pos->key)) & map->sizeMask) >= shift) {
  187. del->key = pos->key;
  188. del->value = pos->value;
  189. del = pos;
  190. shift = 1;
  191. } else {
  192. ++shift;
  193. }
  194. }
  195. }
  196. /**
  197. * Destroys a map that is inited with COVER_map_init().
  198. */
  199. static void COVER_map_destroy(COVER_map_t *map) {
  200. if (map->data) {
  201. free(map->data);
  202. }
  203. map->data = NULL;
  204. map->size = 0;
  205. }
  206. /*-*************************************
  207. * Context
  208. ***************************************/
  209. typedef struct {
  210. const BYTE *samples;
  211. size_t *offsets;
  212. const size_t *samplesSizes;
  213. size_t nbSamples;
  214. size_t nbTrainSamples;
  215. size_t nbTestSamples;
  216. U32 *suffix;
  217. size_t suffixSize;
  218. U32 *freqs;
  219. U32 *dmerAt;
  220. unsigned d;
  221. } COVER_ctx_t;
  222. #if !defined(_GNU_SOURCE) && !defined(__APPLE__) && !defined(_MSC_VER)
  223. /* C90 only offers qsort() that needs a global context. */
  224. static COVER_ctx_t *g_coverCtx = NULL;
  225. #endif
  226. /*-*************************************
  227. * Helper functions
  228. ***************************************/
  229. /**
  230. * Returns the sum of the sample sizes.
  231. */
  232. size_t COVER_sum(const size_t *samplesSizes, unsigned nbSamples) {
  233. size_t sum = 0;
  234. unsigned i;
  235. for (i = 0; i < nbSamples; ++i) {
  236. sum += samplesSizes[i];
  237. }
  238. return sum;
  239. }
  240. /**
  241. * Returns -1 if the dmer at lp is less than the dmer at rp.
  242. * Return 0 if the dmers at lp and rp are equal.
  243. * Returns 1 if the dmer at lp is greater than the dmer at rp.
  244. */
  245. static int COVER_cmp(COVER_ctx_t *ctx, const void *lp, const void *rp) {
  246. U32 const lhs = *(U32 const *)lp;
  247. U32 const rhs = *(U32 const *)rp;
  248. return memcmp(ctx->samples + lhs, ctx->samples + rhs, ctx->d);
  249. }
  250. /**
  251. * Faster version for d <= 8.
  252. */
  253. static int COVER_cmp8(COVER_ctx_t *ctx, const void *lp, const void *rp) {
  254. U64 const mask = (ctx->d == 8) ? (U64)-1 : (((U64)1 << (8 * ctx->d)) - 1);
  255. U64 const lhs = MEM_readLE64(ctx->samples + *(U32 const *)lp) & mask;
  256. U64 const rhs = MEM_readLE64(ctx->samples + *(U32 const *)rp) & mask;
  257. if (lhs < rhs) {
  258. return -1;
  259. }
  260. return (lhs > rhs);
  261. }
  262. /**
  263. * Same as COVER_cmp() except ties are broken by pointer value
  264. */
  265. #if (defined(_WIN32) && defined(_MSC_VER)) || defined(__APPLE__)
  266. static int WIN_CDECL COVER_strict_cmp(void* g_coverCtx, const void* lp, const void* rp) {
  267. #elif defined(_GNU_SOURCE)
  268. static int COVER_strict_cmp(const void *lp, const void *rp, void *g_coverCtx) {
  269. #else /* C90 fallback.*/
  270. static int COVER_strict_cmp(const void *lp, const void *rp) {
  271. #endif
  272. int result = COVER_cmp((COVER_ctx_t*)g_coverCtx, lp, rp);
  273. if (result == 0) {
  274. result = lp < rp ? -1 : 1;
  275. }
  276. return result;
  277. }
  278. /**
  279. * Faster version for d <= 8.
  280. */
  281. #if (defined(_WIN32) && defined(_MSC_VER)) || defined(__APPLE__)
  282. static int WIN_CDECL COVER_strict_cmp8(void* g_coverCtx, const void* lp, const void* rp) {
  283. #elif defined(_GNU_SOURCE)
  284. static int COVER_strict_cmp8(const void *lp, const void *rp, void *g_coverCtx) {
  285. #else /* C90 fallback.*/
  286. static int COVER_strict_cmp8(const void *lp, const void *rp) {
  287. #endif
  288. int result = COVER_cmp8((COVER_ctx_t*)g_coverCtx, lp, rp);
  289. if (result == 0) {
  290. result = lp < rp ? -1 : 1;
  291. }
  292. return result;
  293. }
  294. /**
  295. * Abstract away divergence of qsort_r() parameters.
  296. * Hopefully when C11 become the norm, we will be able
  297. * to clean it up.
  298. */
  299. static void stableSort(COVER_ctx_t *ctx) {
  300. #if defined(__APPLE__)
  301. qsort_r(ctx->suffix, ctx->suffixSize, sizeof(U32),
  302. ctx,
  303. (ctx->d <= 8 ? &COVER_strict_cmp8 : &COVER_strict_cmp));
  304. #elif defined(_GNU_SOURCE)
  305. qsort_r(ctx->suffix, ctx->suffixSize, sizeof(U32),
  306. (ctx->d <= 8 ? &COVER_strict_cmp8 : &COVER_strict_cmp),
  307. ctx);
  308. #elif defined(_WIN32) && defined(_MSC_VER)
  309. qsort_s(ctx->suffix, ctx->suffixSize, sizeof(U32),
  310. (ctx->d <= 8 ? &COVER_strict_cmp8 : &COVER_strict_cmp),
  311. ctx);
  312. #elif defined(__OpenBSD__)
  313. g_coverCtx = ctx;
  314. mergesort(ctx->suffix, ctx->suffixSize, sizeof(U32),
  315. (ctx->d <= 8 ? &COVER_strict_cmp8 : &COVER_strict_cmp));
  316. #else /* C90 fallback.*/
  317. g_coverCtx = ctx;
  318. /* TODO(cavalcanti): implement a reentrant qsort() when is not available. */
  319. qsort(ctx->suffix, ctx->suffixSize, sizeof(U32),
  320. (ctx->d <= 8 ? &COVER_strict_cmp8 : &COVER_strict_cmp));
  321. #endif
  322. }
  323. /**
  324. * Returns the first pointer in [first, last) whose element does not compare
  325. * less than value. If no such element exists it returns last.
  326. */
  327. static const size_t *COVER_lower_bound(const size_t* first, const size_t* last,
  328. size_t value) {
  329. size_t count = (size_t)(last - first);
  330. assert(last >= first);
  331. while (count != 0) {
  332. size_t step = count / 2;
  333. const size_t *ptr = first;
  334. ptr += step;
  335. if (*ptr < value) {
  336. first = ++ptr;
  337. count -= step + 1;
  338. } else {
  339. count = step;
  340. }
  341. }
  342. return first;
  343. }
  344. /**
  345. * Generic groupBy function.
  346. * Groups an array sorted by cmp into groups with equivalent values.
  347. * Calls grp for each group.
  348. */
  349. static void
  350. COVER_groupBy(const void *data, size_t count, size_t size, COVER_ctx_t *ctx,
  351. int (*cmp)(COVER_ctx_t *, const void *, const void *),
  352. void (*grp)(COVER_ctx_t *, const void *, const void *)) {
  353. const BYTE *ptr = (const BYTE *)data;
  354. size_t num = 0;
  355. while (num < count) {
  356. const BYTE *grpEnd = ptr + size;
  357. ++num;
  358. while (num < count && cmp(ctx, ptr, grpEnd) == 0) {
  359. grpEnd += size;
  360. ++num;
  361. }
  362. grp(ctx, ptr, grpEnd);
  363. ptr = grpEnd;
  364. }
  365. }
  366. /*-*************************************
  367. * Cover functions
  368. ***************************************/
  369. /**
  370. * Called on each group of positions with the same dmer.
  371. * Counts the frequency of each dmer and saves it in the suffix array.
  372. * Fills `ctx->dmerAt`.
  373. */
  374. static void COVER_group(COVER_ctx_t *ctx, const void *group,
  375. const void *groupEnd) {
  376. /* The group consists of all the positions with the same first d bytes. */
  377. const U32 *grpPtr = (const U32 *)group;
  378. const U32 *grpEnd = (const U32 *)groupEnd;
  379. /* The dmerId is how we will reference this dmer.
  380. * This allows us to map the whole dmer space to a much smaller space, the
  381. * size of the suffix array.
  382. */
  383. const U32 dmerId = (U32)(grpPtr - ctx->suffix);
  384. /* Count the number of samples this dmer shows up in */
  385. U32 freq = 0;
  386. /* Details */
  387. const size_t *curOffsetPtr = ctx->offsets;
  388. const size_t *offsetsEnd = ctx->offsets + ctx->nbSamples;
  389. /* Once *grpPtr >= curSampleEnd this occurrence of the dmer is in a
  390. * different sample than the last.
  391. */
  392. size_t curSampleEnd = ctx->offsets[0];
  393. for (; grpPtr != grpEnd; ++grpPtr) {
  394. /* Save the dmerId for this position so we can get back to it. */
  395. ctx->dmerAt[*grpPtr] = dmerId;
  396. /* Dictionaries only help for the first reference to the dmer.
  397. * After that zstd can reference the match from the previous reference.
  398. * So only count each dmer once for each sample it is in.
  399. */
  400. if (*grpPtr < curSampleEnd) {
  401. continue;
  402. }
  403. freq += 1;
  404. /* Binary search to find the end of the sample *grpPtr is in.
  405. * In the common case that grpPtr + 1 == grpEnd we can skip the binary
  406. * search because the loop is over.
  407. */
  408. if (grpPtr + 1 != grpEnd) {
  409. const size_t *sampleEndPtr =
  410. COVER_lower_bound(curOffsetPtr, offsetsEnd, *grpPtr);
  411. curSampleEnd = *sampleEndPtr;
  412. curOffsetPtr = sampleEndPtr + 1;
  413. }
  414. }
  415. /* At this point we are never going to look at this segment of the suffix
  416. * array again. We take advantage of this fact to save memory.
  417. * We store the frequency of the dmer in the first position of the group,
  418. * which is dmerId.
  419. */
  420. ctx->suffix[dmerId] = freq;
  421. }
  422. /**
  423. * Selects the best segment in an epoch.
  424. * Segments of are scored according to the function:
  425. *
  426. * Let F(d) be the frequency of dmer d.
  427. * Let S_i be the dmer at position i of segment S which has length k.
  428. *
  429. * Score(S) = F(S_1) + F(S_2) + ... + F(S_{k-d+1})
  430. *
  431. * Once the dmer d is in the dictionary we set F(d) = 0.
  432. */
  433. static COVER_segment_t COVER_selectSegment(const COVER_ctx_t *ctx, U32 *freqs,
  434. COVER_map_t *activeDmers, U32 begin,
  435. U32 end,
  436. ZDICT_cover_params_t parameters) {
  437. /* Constants */
  438. const U32 k = parameters.k;
  439. const U32 d = parameters.d;
  440. const U32 dmersInK = k - d + 1;
  441. /* Try each segment (activeSegment) and save the best (bestSegment) */
  442. COVER_segment_t bestSegment = {0, 0, 0};
  443. COVER_segment_t activeSegment;
  444. /* Reset the activeDmers in the segment */
  445. COVER_map_clear(activeDmers);
  446. /* The activeSegment starts at the beginning of the epoch. */
  447. activeSegment.begin = begin;
  448. activeSegment.end = begin;
  449. activeSegment.score = 0;
  450. /* Slide the activeSegment through the whole epoch.
  451. * Save the best segment in bestSegment.
  452. */
  453. while (activeSegment.end < end) {
  454. /* The dmerId for the dmer at the next position */
  455. U32 newDmer = ctx->dmerAt[activeSegment.end];
  456. /* The entry in activeDmers for this dmerId */
  457. U32 *newDmerOcc = COVER_map_at(activeDmers, newDmer);
  458. /* If the dmer isn't already present in the segment add its score. */
  459. if (*newDmerOcc == 0) {
  460. /* The paper suggest using the L-0.5 norm, but experiments show that it
  461. * doesn't help.
  462. */
  463. activeSegment.score += freqs[newDmer];
  464. }
  465. /* Add the dmer to the segment */
  466. activeSegment.end += 1;
  467. *newDmerOcc += 1;
  468. /* If the window is now too large, drop the first position */
  469. if (activeSegment.end - activeSegment.begin == dmersInK + 1) {
  470. U32 delDmer = ctx->dmerAt[activeSegment.begin];
  471. U32 *delDmerOcc = COVER_map_at(activeDmers, delDmer);
  472. activeSegment.begin += 1;
  473. *delDmerOcc -= 1;
  474. /* If this is the last occurrence of the dmer, subtract its score */
  475. if (*delDmerOcc == 0) {
  476. COVER_map_remove(activeDmers, delDmer);
  477. activeSegment.score -= freqs[delDmer];
  478. }
  479. }
  480. /* If this segment is the best so far save it */
  481. if (activeSegment.score > bestSegment.score) {
  482. bestSegment = activeSegment;
  483. }
  484. }
  485. {
  486. /* Trim off the zero frequency head and tail from the segment. */
  487. U32 newBegin = bestSegment.end;
  488. U32 newEnd = bestSegment.begin;
  489. U32 pos;
  490. for (pos = bestSegment.begin; pos != bestSegment.end; ++pos) {
  491. U32 freq = freqs[ctx->dmerAt[pos]];
  492. if (freq != 0) {
  493. newBegin = MIN(newBegin, pos);
  494. newEnd = pos + 1;
  495. }
  496. }
  497. bestSegment.begin = newBegin;
  498. bestSegment.end = newEnd;
  499. }
  500. {
  501. /* Zero out the frequency of each dmer covered by the chosen segment. */
  502. U32 pos;
  503. for (pos = bestSegment.begin; pos != bestSegment.end; ++pos) {
  504. freqs[ctx->dmerAt[pos]] = 0;
  505. }
  506. }
  507. return bestSegment;
  508. }
  509. /**
  510. * Check the validity of the parameters.
  511. * Returns non-zero if the parameters are valid and 0 otherwise.
  512. */
  513. static int COVER_checkParameters(ZDICT_cover_params_t parameters,
  514. size_t maxDictSize) {
  515. /* k and d are required parameters */
  516. if (parameters.d == 0 || parameters.k == 0) {
  517. return 0;
  518. }
  519. /* k <= maxDictSize */
  520. if (parameters.k > maxDictSize) {
  521. return 0;
  522. }
  523. /* d <= k */
  524. if (parameters.d > parameters.k) {
  525. return 0;
  526. }
  527. /* 0 < splitPoint <= 1 */
  528. if (parameters.splitPoint <= 0 || parameters.splitPoint > 1){
  529. return 0;
  530. }
  531. return 1;
  532. }
  533. /**
  534. * Clean up a context initialized with `COVER_ctx_init()`.
  535. */
  536. static void COVER_ctx_destroy(COVER_ctx_t *ctx) {
  537. if (!ctx) {
  538. return;
  539. }
  540. if (ctx->suffix) {
  541. free(ctx->suffix);
  542. ctx->suffix = NULL;
  543. }
  544. if (ctx->freqs) {
  545. free(ctx->freqs);
  546. ctx->freqs = NULL;
  547. }
  548. if (ctx->dmerAt) {
  549. free(ctx->dmerAt);
  550. ctx->dmerAt = NULL;
  551. }
  552. if (ctx->offsets) {
  553. free(ctx->offsets);
  554. ctx->offsets = NULL;
  555. }
  556. }
  557. /**
  558. * Prepare a context for dictionary building.
  559. * The context is only dependent on the parameter `d` and can be used multiple
  560. * times.
  561. * Returns 0 on success or error code on error.
  562. * The context must be destroyed with `COVER_ctx_destroy()`.
  563. */
  564. static size_t COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer,
  565. const size_t *samplesSizes, unsigned nbSamples,
  566. unsigned d, double splitPoint)
  567. {
  568. const BYTE *const samples = (const BYTE *)samplesBuffer;
  569. const size_t totalSamplesSize = COVER_sum(samplesSizes, nbSamples);
  570. /* Split samples into testing and training sets */
  571. const unsigned nbTrainSamples = splitPoint < 1.0 ? (unsigned)((double)nbSamples * splitPoint) : nbSamples;
  572. const unsigned nbTestSamples = splitPoint < 1.0 ? nbSamples - nbTrainSamples : nbSamples;
  573. const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples) : totalSamplesSize;
  574. const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes + nbTrainSamples, nbTestSamples) : totalSamplesSize;
  575. /* Checks */
  576. if (totalSamplesSize < MAX(d, sizeof(U64)) ||
  577. totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) {
  578. DISPLAYLEVEL(1, "Total samples size is too large (%u MB), maximum size is %u MB\n",
  579. (unsigned)(totalSamplesSize>>20), (COVER_MAX_SAMPLES_SIZE >> 20));
  580. return ERROR(srcSize_wrong);
  581. }
  582. /* Check if there are at least 5 training samples */
  583. if (nbTrainSamples < 5) {
  584. DISPLAYLEVEL(1, "Total number of training samples is %u and is invalid.", nbTrainSamples);
  585. return ERROR(srcSize_wrong);
  586. }
  587. /* Check if there's testing sample */
  588. if (nbTestSamples < 1) {
  589. DISPLAYLEVEL(1, "Total number of testing samples is %u and is invalid.", nbTestSamples);
  590. return ERROR(srcSize_wrong);
  591. }
  592. /* Zero the context */
  593. memset(ctx, 0, sizeof(*ctx));
  594. DISPLAYLEVEL(2, "Training on %u samples of total size %u\n", nbTrainSamples,
  595. (unsigned)trainingSamplesSize);
  596. DISPLAYLEVEL(2, "Testing on %u samples of total size %u\n", nbTestSamples,
  597. (unsigned)testSamplesSize);
  598. ctx->samples = samples;
  599. ctx->samplesSizes = samplesSizes;
  600. ctx->nbSamples = nbSamples;
  601. ctx->nbTrainSamples = nbTrainSamples;
  602. ctx->nbTestSamples = nbTestSamples;
  603. /* Partial suffix array */
  604. ctx->suffixSize = trainingSamplesSize - MAX(d, sizeof(U64)) + 1;
  605. ctx->suffix = (U32 *)malloc(ctx->suffixSize * sizeof(U32));
  606. /* Maps index to the dmerID */
  607. ctx->dmerAt = (U32 *)malloc(ctx->suffixSize * sizeof(U32));
  608. /* The offsets of each file */
  609. ctx->offsets = (size_t *)malloc((nbSamples + 1) * sizeof(size_t));
  610. if (!ctx->suffix || !ctx->dmerAt || !ctx->offsets) {
  611. DISPLAYLEVEL(1, "Failed to allocate scratch buffers\n");
  612. COVER_ctx_destroy(ctx);
  613. return ERROR(memory_allocation);
  614. }
  615. ctx->freqs = NULL;
  616. ctx->d = d;
  617. /* Fill offsets from the samplesSizes */
  618. {
  619. U32 i;
  620. ctx->offsets[0] = 0;
  621. for (i = 1; i <= nbSamples; ++i) {
  622. ctx->offsets[i] = ctx->offsets[i - 1] + samplesSizes[i - 1];
  623. }
  624. }
  625. DISPLAYLEVEL(2, "Constructing partial suffix array\n");
  626. {
  627. /* suffix is a partial suffix array.
  628. * It only sorts suffixes by their first parameters.d bytes.
  629. * The sort is stable, so each dmer group is sorted by position in input.
  630. */
  631. U32 i;
  632. for (i = 0; i < ctx->suffixSize; ++i) {
  633. ctx->suffix[i] = i;
  634. }
  635. stableSort(ctx);
  636. }
  637. DISPLAYLEVEL(2, "Computing frequencies\n");
  638. /* For each dmer group (group of positions with the same first d bytes):
  639. * 1. For each position we set dmerAt[position] = dmerID. The dmerID is
  640. * (groupBeginPtr - suffix). This allows us to go from position to
  641. * dmerID so we can look up values in freq.
  642. * 2. We calculate how many samples the dmer occurs in and save it in
  643. * freqs[dmerId].
  644. */
  645. COVER_groupBy(ctx->suffix, ctx->suffixSize, sizeof(U32), ctx,
  646. (ctx->d <= 8 ? &COVER_cmp8 : &COVER_cmp), &COVER_group);
  647. ctx->freqs = ctx->suffix;
  648. ctx->suffix = NULL;
  649. return 0;
  650. }
  651. void COVER_warnOnSmallCorpus(size_t maxDictSize, size_t nbDmers, int displayLevel)
  652. {
  653. const double ratio = (double)nbDmers / (double)maxDictSize;
  654. if (ratio >= 10) {
  655. return;
  656. }
  657. LOCALDISPLAYLEVEL(displayLevel, 1,
  658. "WARNING: The maximum dictionary size %u is too large "
  659. "compared to the source size %u! "
  660. "size(source)/size(dictionary) = %f, but it should be >= "
  661. "10! This may lead to a subpar dictionary! We recommend "
  662. "training on sources at least 10x, and preferably 100x "
  663. "the size of the dictionary! \n", (U32)maxDictSize,
  664. (U32)nbDmers, ratio);
  665. }
  666. COVER_epoch_info_t COVER_computeEpochs(U32 maxDictSize,
  667. U32 nbDmers, U32 k, U32 passes)
  668. {
  669. const U32 minEpochSize = k * 10;
  670. COVER_epoch_info_t epochs;
  671. epochs.num = MAX(1, maxDictSize / k / passes);
  672. epochs.size = nbDmers / epochs.num;
  673. if (epochs.size >= minEpochSize) {
  674. assert(epochs.size * epochs.num <= nbDmers);
  675. return epochs;
  676. }
  677. epochs.size = MIN(minEpochSize, nbDmers);
  678. epochs.num = nbDmers / epochs.size;
  679. assert(epochs.size * epochs.num <= nbDmers);
  680. return epochs;
  681. }
  682. /**
  683. * Given the prepared context build the dictionary.
  684. */
  685. static size_t COVER_buildDictionary(const COVER_ctx_t *ctx, U32 *freqs,
  686. COVER_map_t *activeDmers, void *dictBuffer,
  687. size_t dictBufferCapacity,
  688. ZDICT_cover_params_t parameters) {
  689. BYTE *const dict = (BYTE *)dictBuffer;
  690. size_t tail = dictBufferCapacity;
  691. /* Divide the data into epochs. We will select one segment from each epoch. */
  692. const COVER_epoch_info_t epochs = COVER_computeEpochs(
  693. (U32)dictBufferCapacity, (U32)ctx->suffixSize, parameters.k, 4);
  694. const size_t maxZeroScoreRun = MAX(10, MIN(100, epochs.num >> 3));
  695. size_t zeroScoreRun = 0;
  696. size_t epoch;
  697. DISPLAYLEVEL(2, "Breaking content into %u epochs of size %u\n",
  698. (U32)epochs.num, (U32)epochs.size);
  699. /* Loop through the epochs until there are no more segments or the dictionary
  700. * is full.
  701. */
  702. for (epoch = 0; tail > 0; epoch = (epoch + 1) % epochs.num) {
  703. const U32 epochBegin = (U32)(epoch * epochs.size);
  704. const U32 epochEnd = epochBegin + epochs.size;
  705. size_t segmentSize;
  706. /* Select a segment */
  707. COVER_segment_t segment = COVER_selectSegment(
  708. ctx, freqs, activeDmers, epochBegin, epochEnd, parameters);
  709. /* If the segment covers no dmers, then we are out of content.
  710. * There may be new content in other epochs, for continue for some time.
  711. */
  712. if (segment.score == 0) {
  713. if (++zeroScoreRun >= maxZeroScoreRun) {
  714. break;
  715. }
  716. continue;
  717. }
  718. zeroScoreRun = 0;
  719. /* Trim the segment if necessary and if it is too small then we are done */
  720. segmentSize = MIN(segment.end - segment.begin + parameters.d - 1, tail);
  721. if (segmentSize < parameters.d) {
  722. break;
  723. }
  724. /* We fill the dictionary from the back to allow the best segments to be
  725. * referenced with the smallest offsets.
  726. */
  727. tail -= segmentSize;
  728. memcpy(dict + tail, ctx->samples + segment.begin, segmentSize);
  729. DISPLAYUPDATE(
  730. 2, "\r%u%% ",
  731. (unsigned)(((dictBufferCapacity - tail) * 100) / dictBufferCapacity));
  732. }
  733. DISPLAYLEVEL(2, "\r%79s\r", "");
  734. return tail;
  735. }
  736. ZDICTLIB_STATIC_API size_t ZDICT_trainFromBuffer_cover(
  737. void *dictBuffer, size_t dictBufferCapacity,
  738. const void *samplesBuffer, const size_t *samplesSizes, unsigned nbSamples,
  739. ZDICT_cover_params_t parameters)
  740. {
  741. BYTE* const dict = (BYTE*)dictBuffer;
  742. COVER_ctx_t ctx;
  743. COVER_map_t activeDmers;
  744. parameters.splitPoint = 1.0;
  745. /* Initialize global data */
  746. g_displayLevel = (int)parameters.zParams.notificationLevel;
  747. /* Checks */
  748. if (!COVER_checkParameters(parameters, dictBufferCapacity)) {
  749. DISPLAYLEVEL(1, "Cover parameters incorrect\n");
  750. return ERROR(parameter_outOfBound);
  751. }
  752. if (nbSamples == 0) {
  753. DISPLAYLEVEL(1, "Cover must have at least one input file\n");
  754. return ERROR(srcSize_wrong);
  755. }
  756. if (dictBufferCapacity < ZDICT_DICTSIZE_MIN) {
  757. DISPLAYLEVEL(1, "dictBufferCapacity must be at least %u\n",
  758. ZDICT_DICTSIZE_MIN);
  759. return ERROR(dstSize_tooSmall);
  760. }
  761. /* Initialize context and activeDmers */
  762. {
  763. size_t const initVal = COVER_ctx_init(&ctx, samplesBuffer, samplesSizes, nbSamples,
  764. parameters.d, parameters.splitPoint);
  765. if (ZSTD_isError(initVal)) {
  766. return initVal;
  767. }
  768. }
  769. COVER_warnOnSmallCorpus(dictBufferCapacity, ctx.suffixSize, g_displayLevel);
  770. if (!COVER_map_init(&activeDmers, parameters.k - parameters.d + 1)) {
  771. DISPLAYLEVEL(1, "Failed to allocate dmer map: out of memory\n");
  772. COVER_ctx_destroy(&ctx);
  773. return ERROR(memory_allocation);
  774. }
  775. DISPLAYLEVEL(2, "Building dictionary\n");
  776. {
  777. const size_t tail =
  778. COVER_buildDictionary(&ctx, ctx.freqs, &activeDmers, dictBuffer,
  779. dictBufferCapacity, parameters);
  780. const size_t dictionarySize = ZDICT_finalizeDictionary(
  781. dict, dictBufferCapacity, dict + tail, dictBufferCapacity - tail,
  782. samplesBuffer, samplesSizes, nbSamples, parameters.zParams);
  783. if (!ZSTD_isError(dictionarySize)) {
  784. DISPLAYLEVEL(2, "Constructed dictionary of size %u\n",
  785. (unsigned)dictionarySize);
  786. }
  787. COVER_ctx_destroy(&ctx);
  788. COVER_map_destroy(&activeDmers);
  789. return dictionarySize;
  790. }
  791. }
  792. size_t COVER_checkTotalCompressedSize(const ZDICT_cover_params_t parameters,
  793. const size_t *samplesSizes, const BYTE *samples,
  794. size_t *offsets,
  795. size_t nbTrainSamples, size_t nbSamples,
  796. BYTE *const dict, size_t dictBufferCapacity) {
  797. size_t totalCompressedSize = ERROR(GENERIC);
  798. /* Pointers */
  799. ZSTD_CCtx *cctx;
  800. ZSTD_CDict *cdict;
  801. void *dst;
  802. /* Local variables */
  803. size_t dstCapacity;
  804. size_t i;
  805. /* Allocate dst with enough space to compress the maximum sized sample */
  806. {
  807. size_t maxSampleSize = 0;
  808. i = parameters.splitPoint < 1.0 ? nbTrainSamples : 0;
  809. for (; i < nbSamples; ++i) {
  810. maxSampleSize = MAX(samplesSizes[i], maxSampleSize);
  811. }
  812. dstCapacity = ZSTD_compressBound(maxSampleSize);
  813. dst = malloc(dstCapacity);
  814. }
  815. /* Create the cctx and cdict */
  816. cctx = ZSTD_createCCtx();
  817. cdict = ZSTD_createCDict(dict, dictBufferCapacity,
  818. parameters.zParams.compressionLevel);
  819. if (!dst || !cctx || !cdict) {
  820. goto _compressCleanup;
  821. }
  822. /* Compress each sample and sum their sizes (or error) */
  823. totalCompressedSize = dictBufferCapacity;
  824. i = parameters.splitPoint < 1.0 ? nbTrainSamples : 0;
  825. for (; i < nbSamples; ++i) {
  826. const size_t size = ZSTD_compress_usingCDict(
  827. cctx, dst, dstCapacity, samples + offsets[i],
  828. samplesSizes[i], cdict);
  829. if (ZSTD_isError(size)) {
  830. totalCompressedSize = size;
  831. goto _compressCleanup;
  832. }
  833. totalCompressedSize += size;
  834. }
  835. _compressCleanup:
  836. ZSTD_freeCCtx(cctx);
  837. ZSTD_freeCDict(cdict);
  838. if (dst) {
  839. free(dst);
  840. }
  841. return totalCompressedSize;
  842. }
  843. /**
  844. * Initialize the `COVER_best_t`.
  845. */
  846. void COVER_best_init(COVER_best_t *best) {
  847. if (best==NULL) return; /* compatible with init on NULL */
  848. (void)ZSTD_pthread_mutex_init(&best->mutex, NULL);
  849. (void)ZSTD_pthread_cond_init(&best->cond, NULL);
  850. best->liveJobs = 0;
  851. best->dict = NULL;
  852. best->dictSize = 0;
  853. best->compressedSize = (size_t)-1;
  854. memset(&best->parameters, 0, sizeof(best->parameters));
  855. }
  856. /**
  857. * Wait until liveJobs == 0.
  858. */
  859. void COVER_best_wait(COVER_best_t *best) {
  860. if (!best) {
  861. return;
  862. }
  863. ZSTD_pthread_mutex_lock(&best->mutex);
  864. while (best->liveJobs != 0) {
  865. ZSTD_pthread_cond_wait(&best->cond, &best->mutex);
  866. }
  867. ZSTD_pthread_mutex_unlock(&best->mutex);
  868. }
  869. /**
  870. * Call COVER_best_wait() and then destroy the COVER_best_t.
  871. */
  872. void COVER_best_destroy(COVER_best_t *best) {
  873. if (!best) {
  874. return;
  875. }
  876. COVER_best_wait(best);
  877. if (best->dict) {
  878. free(best->dict);
  879. }
  880. ZSTD_pthread_mutex_destroy(&best->mutex);
  881. ZSTD_pthread_cond_destroy(&best->cond);
  882. }
  883. /**
  884. * Called when a thread is about to be launched.
  885. * Increments liveJobs.
  886. */
  887. void COVER_best_start(COVER_best_t *best) {
  888. if (!best) {
  889. return;
  890. }
  891. ZSTD_pthread_mutex_lock(&best->mutex);
  892. ++best->liveJobs;
  893. ZSTD_pthread_mutex_unlock(&best->mutex);
  894. }
  895. /**
  896. * Called when a thread finishes executing, both on error or success.
  897. * Decrements liveJobs and signals any waiting threads if liveJobs == 0.
  898. * If this dictionary is the best so far save it and its parameters.
  899. */
  900. void COVER_best_finish(COVER_best_t* best,
  901. ZDICT_cover_params_t parameters,
  902. COVER_dictSelection_t selection)
  903. {
  904. void* dict = selection.dictContent;
  905. size_t compressedSize = selection.totalCompressedSize;
  906. size_t dictSize = selection.dictSize;
  907. if (!best) {
  908. return;
  909. }
  910. {
  911. size_t liveJobs;
  912. ZSTD_pthread_mutex_lock(&best->mutex);
  913. --best->liveJobs;
  914. liveJobs = best->liveJobs;
  915. /* If the new dictionary is better */
  916. if (compressedSize < best->compressedSize) {
  917. /* Allocate space if necessary */
  918. if (!best->dict || best->dictSize < dictSize) {
  919. if (best->dict) {
  920. free(best->dict);
  921. }
  922. best->dict = malloc(dictSize);
  923. if (!best->dict) {
  924. best->compressedSize = ERROR(GENERIC);
  925. best->dictSize = 0;
  926. ZSTD_pthread_cond_signal(&best->cond);
  927. ZSTD_pthread_mutex_unlock(&best->mutex);
  928. return;
  929. }
  930. }
  931. /* Save the dictionary, parameters, and size */
  932. if (dict) {
  933. memcpy(best->dict, dict, dictSize);
  934. best->dictSize = dictSize;
  935. best->parameters = parameters;
  936. best->compressedSize = compressedSize;
  937. }
  938. }
  939. if (liveJobs == 0) {
  940. ZSTD_pthread_cond_broadcast(&best->cond);
  941. }
  942. ZSTD_pthread_mutex_unlock(&best->mutex);
  943. }
  944. }
  945. static COVER_dictSelection_t setDictSelection(BYTE* buf, size_t s, size_t csz)
  946. {
  947. COVER_dictSelection_t ds;
  948. ds.dictContent = buf;
  949. ds.dictSize = s;
  950. ds.totalCompressedSize = csz;
  951. return ds;
  952. }
  953. COVER_dictSelection_t COVER_dictSelectionError(size_t error) {
  954. return setDictSelection(NULL, 0, error);
  955. }
  956. unsigned COVER_dictSelectionIsError(COVER_dictSelection_t selection) {
  957. return (ZSTD_isError(selection.totalCompressedSize) || !selection.dictContent);
  958. }
  959. void COVER_dictSelectionFree(COVER_dictSelection_t selection){
  960. free(selection.dictContent);
  961. }
  962. COVER_dictSelection_t COVER_selectDict(BYTE* customDictContent, size_t dictBufferCapacity,
  963. size_t dictContentSize, const BYTE* samplesBuffer, const size_t* samplesSizes, unsigned nbFinalizeSamples,
  964. size_t nbCheckSamples, size_t nbSamples, ZDICT_cover_params_t params, size_t* offsets, size_t totalCompressedSize) {
  965. size_t largestDict = 0;
  966. size_t largestCompressed = 0;
  967. BYTE* customDictContentEnd = customDictContent + dictContentSize;
  968. BYTE* largestDictbuffer = (BYTE*)malloc(dictBufferCapacity);
  969. BYTE* candidateDictBuffer = (BYTE*)malloc(dictBufferCapacity);
  970. double regressionTolerance = ((double)params.shrinkDictMaxRegression / 100.0) + 1.00;
  971. if (!largestDictbuffer || !candidateDictBuffer) {
  972. free(largestDictbuffer);
  973. free(candidateDictBuffer);
  974. return COVER_dictSelectionError(dictContentSize);
  975. }
  976. /* Initial dictionary size and compressed size */
  977. memcpy(largestDictbuffer, customDictContent, dictContentSize);
  978. dictContentSize = ZDICT_finalizeDictionary(
  979. largestDictbuffer, dictBufferCapacity, customDictContent, dictContentSize,
  980. samplesBuffer, samplesSizes, nbFinalizeSamples, params.zParams);
  981. if (ZDICT_isError(dictContentSize)) {
  982. free(largestDictbuffer);
  983. free(candidateDictBuffer);
  984. return COVER_dictSelectionError(dictContentSize);
  985. }
  986. totalCompressedSize = COVER_checkTotalCompressedSize(params, samplesSizes,
  987. samplesBuffer, offsets,
  988. nbCheckSamples, nbSamples,
  989. largestDictbuffer, dictContentSize);
  990. if (ZSTD_isError(totalCompressedSize)) {
  991. free(largestDictbuffer);
  992. free(candidateDictBuffer);
  993. return COVER_dictSelectionError(totalCompressedSize);
  994. }
  995. if (params.shrinkDict == 0) {
  996. free(candidateDictBuffer);
  997. return setDictSelection(largestDictbuffer, dictContentSize, totalCompressedSize);
  998. }
  999. largestDict = dictContentSize;
  1000. largestCompressed = totalCompressedSize;
  1001. dictContentSize = ZDICT_DICTSIZE_MIN;
  1002. /* Largest dict is initially at least ZDICT_DICTSIZE_MIN */
  1003. while (dictContentSize < largestDict) {
  1004. memcpy(candidateDictBuffer, largestDictbuffer, largestDict);
  1005. dictContentSize = ZDICT_finalizeDictionary(
  1006. candidateDictBuffer, dictBufferCapacity, customDictContentEnd - dictContentSize, dictContentSize,
  1007. samplesBuffer, samplesSizes, nbFinalizeSamples, params.zParams);
  1008. if (ZDICT_isError(dictContentSize)) {
  1009. free(largestDictbuffer);
  1010. free(candidateDictBuffer);
  1011. return COVER_dictSelectionError(dictContentSize);
  1012. }
  1013. totalCompressedSize = COVER_checkTotalCompressedSize(params, samplesSizes,
  1014. samplesBuffer, offsets,
  1015. nbCheckSamples, nbSamples,
  1016. candidateDictBuffer, dictContentSize);
  1017. if (ZSTD_isError(totalCompressedSize)) {
  1018. free(largestDictbuffer);
  1019. free(candidateDictBuffer);
  1020. return COVER_dictSelectionError(totalCompressedSize);
  1021. }
  1022. if ((double)totalCompressedSize <= (double)largestCompressed * regressionTolerance) {
  1023. free(largestDictbuffer);
  1024. return setDictSelection( candidateDictBuffer, dictContentSize, totalCompressedSize );
  1025. }
  1026. dictContentSize *= 2;
  1027. }
  1028. dictContentSize = largestDict;
  1029. totalCompressedSize = largestCompressed;
  1030. free(candidateDictBuffer);
  1031. return setDictSelection( largestDictbuffer, dictContentSize, totalCompressedSize );
  1032. }
  1033. /**
  1034. * Parameters for COVER_tryParameters().
  1035. */
  1036. typedef struct COVER_tryParameters_data_s {
  1037. const COVER_ctx_t *ctx;
  1038. COVER_best_t *best;
  1039. size_t dictBufferCapacity;
  1040. ZDICT_cover_params_t parameters;
  1041. } COVER_tryParameters_data_t;
  1042. /**
  1043. * Tries a set of parameters and updates the COVER_best_t with the results.
  1044. * This function is thread safe if zstd is compiled with multithreaded support.
  1045. * It takes its parameters as an *OWNING* opaque pointer to support threading.
  1046. */
  1047. static void COVER_tryParameters(void *opaque)
  1048. {
  1049. /* Save parameters as local variables */
  1050. COVER_tryParameters_data_t *const data = (COVER_tryParameters_data_t*)opaque;
  1051. const COVER_ctx_t *const ctx = data->ctx;
  1052. const ZDICT_cover_params_t parameters = data->parameters;
  1053. size_t dictBufferCapacity = data->dictBufferCapacity;
  1054. size_t totalCompressedSize = ERROR(GENERIC);
  1055. /* Allocate space for hash table, dict, and freqs */
  1056. COVER_map_t activeDmers;
  1057. BYTE* const dict = (BYTE*)malloc(dictBufferCapacity);
  1058. COVER_dictSelection_t selection = COVER_dictSelectionError(ERROR(GENERIC));
  1059. U32* const freqs = (U32*)malloc(ctx->suffixSize * sizeof(U32));
  1060. if (!COVER_map_init(&activeDmers, parameters.k - parameters.d + 1)) {
  1061. DISPLAYLEVEL(1, "Failed to allocate dmer map: out of memory\n");
  1062. goto _cleanup;
  1063. }
  1064. if (!dict || !freqs) {
  1065. DISPLAYLEVEL(1, "Failed to allocate buffers: out of memory\n");
  1066. goto _cleanup;
  1067. }
  1068. /* Copy the frequencies because we need to modify them */
  1069. memcpy(freqs, ctx->freqs, ctx->suffixSize * sizeof(U32));
  1070. /* Build the dictionary */
  1071. {
  1072. const size_t tail = COVER_buildDictionary(ctx, freqs, &activeDmers, dict,
  1073. dictBufferCapacity, parameters);
  1074. selection = COVER_selectDict(dict + tail, dictBufferCapacity, dictBufferCapacity - tail,
  1075. ctx->samples, ctx->samplesSizes, (unsigned)ctx->nbTrainSamples, ctx->nbTrainSamples, ctx->nbSamples, parameters, ctx->offsets,
  1076. totalCompressedSize);
  1077. if (COVER_dictSelectionIsError(selection)) {
  1078. DISPLAYLEVEL(1, "Failed to select dictionary\n");
  1079. goto _cleanup;
  1080. }
  1081. }
  1082. _cleanup:
  1083. free(dict);
  1084. COVER_best_finish(data->best, parameters, selection);
  1085. free(data);
  1086. COVER_map_destroy(&activeDmers);
  1087. COVER_dictSelectionFree(selection);
  1088. free(freqs);
  1089. }
  1090. ZDICTLIB_STATIC_API size_t ZDICT_optimizeTrainFromBuffer_cover(
  1091. void* dictBuffer, size_t dictBufferCapacity, const void* samplesBuffer,
  1092. const size_t* samplesSizes, unsigned nbSamples,
  1093. ZDICT_cover_params_t* parameters)
  1094. {
  1095. /* constants */
  1096. const unsigned nbThreads = parameters->nbThreads;
  1097. const double splitPoint =
  1098. parameters->splitPoint <= 0.0 ? COVER_DEFAULT_SPLITPOINT : parameters->splitPoint;
  1099. const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d;
  1100. const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d;
  1101. const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k;
  1102. const unsigned kMaxK = parameters->k == 0 ? 2000 : parameters->k;
  1103. const unsigned kSteps = parameters->steps == 0 ? 40 : parameters->steps;
  1104. const unsigned kStepSize = MAX((kMaxK - kMinK) / kSteps, 1);
  1105. const unsigned kIterations =
  1106. (1 + (kMaxD - kMinD) / 2) * (1 + (kMaxK - kMinK) / kStepSize);
  1107. const unsigned shrinkDict = 0;
  1108. /* Local variables */
  1109. const int displayLevel = parameters->zParams.notificationLevel;
  1110. unsigned iteration = 1;
  1111. unsigned d;
  1112. unsigned k;
  1113. COVER_best_t best;
  1114. POOL_ctx *pool = NULL;
  1115. int warned = 0;
  1116. /* Checks */
  1117. if (splitPoint <= 0 || splitPoint > 1) {
  1118. LOCALDISPLAYLEVEL(displayLevel, 1, "Incorrect parameters\n");
  1119. return ERROR(parameter_outOfBound);
  1120. }
  1121. if (kMinK < kMaxD || kMaxK < kMinK) {
  1122. LOCALDISPLAYLEVEL(displayLevel, 1, "Incorrect parameters\n");
  1123. return ERROR(parameter_outOfBound);
  1124. }
  1125. if (nbSamples == 0) {
  1126. DISPLAYLEVEL(1, "Cover must have at least one input file\n");
  1127. return ERROR(srcSize_wrong);
  1128. }
  1129. if (dictBufferCapacity < ZDICT_DICTSIZE_MIN) {
  1130. DISPLAYLEVEL(1, "dictBufferCapacity must be at least %u\n",
  1131. ZDICT_DICTSIZE_MIN);
  1132. return ERROR(dstSize_tooSmall);
  1133. }
  1134. if (nbThreads > 1) {
  1135. pool = POOL_create(nbThreads, 1);
  1136. if (!pool) {
  1137. return ERROR(memory_allocation);
  1138. }
  1139. }
  1140. /* Initialization */
  1141. COVER_best_init(&best);
  1142. /* Turn down global display level to clean up display at level 2 and below */
  1143. g_displayLevel = displayLevel == 0 ? 0 : displayLevel - 1;
  1144. /* Loop through d first because each new value needs a new context */
  1145. LOCALDISPLAYLEVEL(displayLevel, 2, "Trying %u different sets of parameters\n",
  1146. kIterations);
  1147. for (d = kMinD; d <= kMaxD; d += 2) {
  1148. /* Initialize the context for this value of d */
  1149. COVER_ctx_t ctx;
  1150. LOCALDISPLAYLEVEL(displayLevel, 3, "d=%u\n", d);
  1151. {
  1152. const size_t initVal = COVER_ctx_init(&ctx, samplesBuffer, samplesSizes, nbSamples, d, splitPoint);
  1153. if (ZSTD_isError(initVal)) {
  1154. LOCALDISPLAYLEVEL(displayLevel, 1, "Failed to initialize context\n");
  1155. COVER_best_destroy(&best);
  1156. POOL_free(pool);
  1157. return initVal;
  1158. }
  1159. }
  1160. if (!warned) {
  1161. COVER_warnOnSmallCorpus(dictBufferCapacity, ctx.suffixSize, displayLevel);
  1162. warned = 1;
  1163. }
  1164. /* Loop through k reusing the same context */
  1165. for (k = kMinK; k <= kMaxK; k += kStepSize) {
  1166. /* Prepare the arguments */
  1167. COVER_tryParameters_data_t *data = (COVER_tryParameters_data_t *)malloc(
  1168. sizeof(COVER_tryParameters_data_t));
  1169. LOCALDISPLAYLEVEL(displayLevel, 3, "k=%u\n", k);
  1170. if (!data) {
  1171. LOCALDISPLAYLEVEL(displayLevel, 1, "Failed to allocate parameters\n");
  1172. COVER_best_destroy(&best);
  1173. COVER_ctx_destroy(&ctx);
  1174. POOL_free(pool);
  1175. return ERROR(memory_allocation);
  1176. }
  1177. data->ctx = &ctx;
  1178. data->best = &best;
  1179. data->dictBufferCapacity = dictBufferCapacity;
  1180. data->parameters = *parameters;
  1181. data->parameters.k = k;
  1182. data->parameters.d = d;
  1183. data->parameters.splitPoint = splitPoint;
  1184. data->parameters.steps = kSteps;
  1185. data->parameters.shrinkDict = shrinkDict;
  1186. data->parameters.zParams.notificationLevel = g_displayLevel;
  1187. /* Check the parameters */
  1188. if (!COVER_checkParameters(data->parameters, dictBufferCapacity)) {
  1189. DISPLAYLEVEL(1, "Cover parameters incorrect\n");
  1190. free(data);
  1191. continue;
  1192. }
  1193. /* Call the function and pass ownership of data to it */
  1194. COVER_best_start(&best);
  1195. if (pool) {
  1196. POOL_add(pool, &COVER_tryParameters, data);
  1197. } else {
  1198. COVER_tryParameters(data);
  1199. }
  1200. /* Print status */
  1201. LOCALDISPLAYUPDATE(displayLevel, 2, "\r%u%% ",
  1202. (unsigned)((iteration * 100) / kIterations));
  1203. ++iteration;
  1204. }
  1205. COVER_best_wait(&best);
  1206. COVER_ctx_destroy(&ctx);
  1207. }
  1208. LOCALDISPLAYLEVEL(displayLevel, 2, "\r%79s\r", "");
  1209. /* Fill the output buffer and parameters with output of the best parameters */
  1210. {
  1211. const size_t dictSize = best.dictSize;
  1212. if (ZSTD_isError(best.compressedSize)) {
  1213. const size_t compressedSize = best.compressedSize;
  1214. COVER_best_destroy(&best);
  1215. POOL_free(pool);
  1216. return compressedSize;
  1217. }
  1218. *parameters = best.parameters;
  1219. memcpy(dictBuffer, best.dict, dictSize);
  1220. COVER_best_destroy(&best);
  1221. POOL_free(pool);
  1222. return dictSize;
  1223. }
  1224. }