gorilla.cc 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. // SPDX-License-Identifier: GPL-3.0-or-later
  2. #include "gorilla.h"
  3. #include <cassert>
  4. #include <climits>
  5. #include <cstdio>
  6. #include <cstring>
  7. using std::size_t;
  8. template <typename T>
  9. static constexpr size_t bit_size() noexcept
  10. {
  11. static_assert((sizeof(T) * CHAR_BIT) == 32 || (sizeof(T) * CHAR_BIT) == 64,
  12. "Word size should be 32 or 64 bits.");
  13. return (sizeof(T) * CHAR_BIT);
  14. }
  15. /*
  16. * Low-level bitstream operations, allowing us to read/write individual bits.
  17. */
  18. template<typename Word>
  19. struct bit_stream_t {
  20. Word *buffer;
  21. size_t capacity;
  22. size_t position;
  23. };
  24. template<typename Word>
  25. static bit_stream_t<Word> bit_stream_new(Word *buffer, Word capacity) {
  26. bit_stream_t<Word> bs;
  27. bs.buffer = buffer;
  28. bs.capacity = capacity * bit_size<Word>();
  29. bs.position = 0;
  30. return bs;
  31. }
  32. template<typename Word>
  33. static bool bit_stream_write(bit_stream_t<Word> *bs, Word value, size_t nbits) {
  34. assert(nbits > 0 && nbits <= bit_size<Word>());
  35. assert(bs->capacity >= (bs->position + nbits));
  36. if (bs->position + nbits > bs->capacity) {
  37. return false;
  38. }
  39. const size_t index = bs->position / bit_size<Word>();
  40. const size_t offset = bs->position % bit_size<Word>();
  41. bs->position += nbits;
  42. if (offset == 0) {
  43. bs->buffer[index] = value;
  44. } else {
  45. const size_t remaining_bits = bit_size<Word>() - offset;
  46. // write the lower part of the value
  47. const Word low_bits_mask = ((Word) 1 << remaining_bits) - 1;
  48. const Word lowest_bits_in_value = value & low_bits_mask;
  49. bs->buffer[index] |= (lowest_bits_in_value << offset);
  50. if (nbits > remaining_bits) {
  51. // write the upper part of the value
  52. const Word high_bits_mask = ~low_bits_mask;
  53. const Word highest_bits_in_value = (value & high_bits_mask) >> (remaining_bits);
  54. bs->buffer[index + 1] = highest_bits_in_value;
  55. }
  56. }
  57. return true;
  58. }
  59. template<typename Word>
  60. static bool bit_stream_read(bit_stream_t<Word> *bs, Word *value, size_t nbits) {
  61. assert(nbits > 0 && nbits <= bit_size<Word>());
  62. assert(bs->capacity >= (bs->position + nbits));
  63. if (bs->position + nbits > bs->capacity) {
  64. return false;
  65. }
  66. const size_t index = bs->position / bit_size<Word>();
  67. const size_t offset = bs->position % bit_size<Word>();
  68. bs->position += nbits;
  69. if (offset == 0) {
  70. *value = (nbits == bit_size<Word>()) ?
  71. bs->buffer[index] :
  72. bs->buffer[index] & (((Word) 1 << nbits) - 1);
  73. } else {
  74. const size_t remaining_bits = bit_size<Word>() - offset;
  75. // extract the lower part of the value
  76. if (nbits < remaining_bits) {
  77. *value = (bs->buffer[index] >> offset) & (((Word) 1 << nbits) - 1);
  78. } else {
  79. *value = (bs->buffer[index] >> offset) & (((Word) 1 << remaining_bits) - 1);
  80. nbits -= remaining_bits;
  81. *value |= (bs->buffer[index + 1] & (((Word) 1 << nbits) - 1)) << remaining_bits;
  82. }
  83. }
  84. return true;
  85. }
  86. /*
  87. * High-level Gorilla codec implementation
  88. */
  89. template<typename Word>
  90. struct bit_code_t {
  91. bit_stream_t<Word> bs;
  92. Word entries;
  93. Word prev_number;
  94. Word prev_xor;
  95. Word prev_xor_lzc;
  96. };
  97. template<typename Word>
  98. static void bit_code_init(bit_code_t<Word> *bc, Word *buffer, Word capacity) {
  99. bc->bs = bit_stream_new(buffer, capacity);
  100. bc->entries = 0;
  101. bc->prev_number = 0;
  102. bc->prev_xor = 0;
  103. bc->prev_xor_lzc = 0;
  104. // reserved two words:
  105. // Buffer[0] -> number of entries written
  106. // Buffer[1] -> number of bits written
  107. bc->bs.position += 2 * bit_size<Word>();
  108. }
  109. template<typename Word>
  110. static bool bit_code_read(bit_code_t<Word> *bc, Word *number) {
  111. bit_stream_t<Word> *bs = &bc->bs;
  112. bc->entries++;
  113. // read the first number
  114. if (bc->entries == 1) {
  115. bool ok = bit_stream_read(bs, number, bit_size<Word>());
  116. bc->prev_number = *number;
  117. return ok;
  118. }
  119. // process same-number bit
  120. Word is_same_number;
  121. if (!bit_stream_read(bs, &is_same_number, 1)) {
  122. return false;
  123. }
  124. if (is_same_number) {
  125. *number = bc->prev_number;
  126. return true;
  127. }
  128. // proceess same-xor-lzc bit
  129. Word xor_lzc = bc->prev_xor_lzc;
  130. Word same_xor_lzc;
  131. if (!bit_stream_read(bs, &same_xor_lzc, 1)) {
  132. return false;
  133. }
  134. if (!same_xor_lzc) {
  135. if (!bit_stream_read(bs, &xor_lzc, (bit_size<Word>() == 32) ? 5 : 6)) {
  136. return false;
  137. }
  138. }
  139. // process the non-lzc suffix
  140. Word xor_value = 0;
  141. if (!bit_stream_read(bs, &xor_value, bit_size<Word>() - xor_lzc)) {
  142. return false;
  143. }
  144. *number = (bc->prev_number ^ xor_value);
  145. bc->prev_number = *number;
  146. bc->prev_xor_lzc = xor_lzc;
  147. bc->prev_xor = xor_value;
  148. return true;
  149. }
  150. template<typename Word>
  151. static bool bit_code_write(bit_code_t<Word> *bc, const Word number) {
  152. bit_stream_t<Word> *bs = &bc->bs;
  153. Word position = bs->position;
  154. bc->entries++;
  155. // this is the first number we are writing
  156. if (bc->entries == 1) {
  157. bc->prev_number = number;
  158. return bit_stream_write(bs, number, bit_size<Word>());
  159. }
  160. // write true/false based on whether we got the same number or not.
  161. if (number == bc->prev_number) {
  162. return bit_stream_write(bs, static_cast<Word>(1), 1);
  163. } else {
  164. if (bit_stream_write(bs, static_cast<Word>(0), 1) == false) {
  165. return false;
  166. }
  167. }
  168. // otherwise:
  169. // - compute the non-zero xor
  170. // - find its leading-zero count
  171. Word xor_value = bc->prev_number ^ number;
  172. // FIXME: Use SFINAE
  173. Word xor_lzc = (bit_size<Word>() == 32) ? __builtin_clz(xor_value) : __builtin_clzll(xor_value);
  174. Word is_xor_lzc_same = (xor_lzc == bc->prev_xor_lzc) ? 1 : 0;
  175. if (is_xor_lzc_same) {
  176. // xor-lzc is same
  177. if (bit_stream_write(bs, static_cast<Word>(1), 1) == false) {
  178. goto RET_FALSE;
  179. }
  180. } else {
  181. // xor-lzc is different
  182. if (bit_stream_write(bs, static_cast<Word>(0), 1) == false) {
  183. goto RET_FALSE;
  184. }
  185. if (bit_stream_write(bs, xor_lzc, (bit_size<Word>() == 32) ? 5 : 6) == false) {
  186. goto RET_FALSE;
  187. }
  188. }
  189. // write the bits of the XOR value without the LZC prefix
  190. if (bit_stream_write(bs, xor_value, bit_size<Word>() - xor_lzc) == false) {
  191. goto RET_FALSE;
  192. }
  193. bc->prev_number = number;
  194. bc->prev_xor_lzc = xor_lzc;
  195. return true;
  196. RET_FALSE:
  197. bc->bs.position = position;
  198. return false;
  199. }
  200. // only valid for writers
  201. template<typename Word>
  202. static bool bit_code_flush(bit_code_t<Word> *bc) {
  203. bit_stream_t<Word> *bs = &bc->bs;
  204. Word num_entries_written = bc->entries;
  205. Word num_bits_written = bs->position;
  206. // we want to write these at the beginning
  207. bs->position = 0;
  208. if (!bit_stream_write(bs, num_entries_written, bit_size<Word>())) {
  209. return false;
  210. }
  211. if (!bit_stream_write(bs, num_bits_written, bit_size<Word>())) {
  212. return false;
  213. }
  214. bs->position = num_bits_written;
  215. return true;
  216. }
  217. // only valid for readers
  218. template<typename Word>
  219. static bool bit_code_info(bit_code_t<Word> *bc, Word *num_entries_written,
  220. Word *num_bits_written) {
  221. bit_stream_t<Word> *bs = &bc->bs;
  222. assert(bs->position == 2 * bit_size<Word>());
  223. if (bs->capacity < (2 * bit_size<Word>())) {
  224. return false;
  225. }
  226. if (num_entries_written) {
  227. *num_entries_written = bs->buffer[0];
  228. }
  229. if (num_bits_written) {
  230. *num_bits_written = bs->buffer[1];
  231. }
  232. return true;
  233. }
  234. template<typename Word>
  235. static size_t gorilla_encode(Word *dst, Word dst_len, const Word *src, Word src_len) {
  236. bit_code_t<Word> bcw;
  237. bit_code_init(&bcw, dst, dst_len);
  238. for (size_t i = 0; i != src_len; i++) {
  239. if (!bit_code_write(&bcw, src[i]))
  240. return 0;
  241. }
  242. if (!bit_code_flush(&bcw))
  243. return 0;
  244. return src_len;
  245. }
  246. template<typename Word>
  247. static size_t gorilla_decode(Word *dst, Word dst_len, const Word *src, Word src_len) {
  248. bit_code_t<Word> bcr;
  249. bit_code_init(&bcr, (Word *) src, src_len);
  250. Word num_entries;
  251. if (!bit_code_info(&bcr, &num_entries, (Word *) NULL)) {
  252. return 0;
  253. }
  254. if (num_entries > dst_len) {
  255. return 0;
  256. }
  257. for (size_t i = 0; i != num_entries; i++) {
  258. if (!bit_code_read(&bcr, &dst[i]))
  259. return 0;
  260. }
  261. return num_entries;
  262. }
  263. /*
  264. * Low-level public API
  265. */
  266. // 32-bit API
  267. void bit_code_writer_u32_init(bit_code_writer_u32_t *bcw, uint32_t *buffer, uint32_t capacity) {
  268. bit_code_t<uint32_t> *bc = (bit_code_t<uint32_t> *) bcw;
  269. bit_code_init(bc, buffer, capacity);
  270. }
  271. bool bit_code_writer_u32_write(bit_code_writer_u32_t *bcw, const uint32_t number) {
  272. bit_code_t<uint32_t> *bc = (bit_code_t<uint32_t> *) bcw;
  273. return bit_code_write(bc, number);
  274. }
  275. bool bit_code_writer_u32_flush(bit_code_writer_u32_t *bcw) {
  276. bit_code_t<uint32_t> *bc = (bit_code_t<uint32_t> *) bcw;
  277. return bit_code_flush(bc);
  278. }
  279. void bit_code_reader_u32_init(bit_code_reader_u32_t *bcr, uint32_t *buffer, uint32_t capacity) {
  280. bit_code_t<uint32_t> *bc = (bit_code_t<uint32_t> *) bcr;
  281. bit_code_init(bc, buffer, capacity);
  282. }
  283. bool bit_code_reader_u32_read(bit_code_reader_u32_t *bcr, uint32_t *number) {
  284. bit_code_t<uint32_t> *bc = (bit_code_t<uint32_t> *) bcr;
  285. return bit_code_read(bc, number);
  286. }
  287. bool bit_code_reader_u32_info(bit_code_reader_u32_t *bcr, uint32_t *num_entries_written,
  288. uint32_t *num_bits_written) {
  289. bit_code_t<uint32_t> *bc = (bit_code_t<uint32_t> *) bcr;
  290. return bit_code_info(bc, num_entries_written, num_bits_written);
  291. }
  292. // 64-bit API
  293. void bit_code_writer_u64_init(bit_code_writer_u64_t *bcw, uint64_t *buffer, uint64_t capacity) {
  294. bit_code_t<uint64_t> *bc = (bit_code_t<uint64_t> *) bcw;
  295. bit_code_init(bc, buffer, capacity);
  296. }
  297. bool bit_code_writer_u64_write(bit_code_writer_u64_t *bcw, const uint64_t number) {
  298. bit_code_t<uint64_t> *bc = (bit_code_t<uint64_t> *) bcw;
  299. return bit_code_write(bc, number);
  300. }
  301. bool bit_code_writer_u64_flush(bit_code_writer_u64_t *bcw) {
  302. bit_code_t<uint64_t> *bc = (bit_code_t<uint64_t> *) bcw;
  303. return bit_code_flush(bc);
  304. }
  305. void bit_code_reader_u64_init(bit_code_reader_u64_t *bcr, uint64_t *buffer, uint64_t capacity) {
  306. bit_code_t<uint64_t> *bc = (bit_code_t<uint64_t> *) bcr;
  307. bit_code_init(bc, buffer, capacity);
  308. }
  309. bool bit_code_reader_u64_read(bit_code_reader_u64_t *bcr, uint64_t *number) {
  310. bit_code_t<uint64_t> *bc = (bit_code_t<uint64_t> *) bcr;
  311. return bit_code_read(bc, number);
  312. }
  313. bool bit_code_reader_u64_info(bit_code_reader_u64_t *bcr, uint64_t *num_entries_written,
  314. uint64_t *num_bits_written) {
  315. bit_code_t<uint64_t> *bc = (bit_code_t<uint64_t> *) bcr;
  316. return bit_code_info(bc, num_entries_written, num_bits_written);
  317. }
  318. /*
  319. * High-level public API
  320. */
  321. // 32-bit API
  322. size_t gorilla_encode_u32(uint32_t *dst, size_t dst_len, const uint32_t *src, size_t src_len) {
  323. return gorilla_encode(dst, (uint32_t) dst_len, src, (uint32_t) src_len);
  324. }
  325. size_t gorilla_decode_u32(uint32_t *dst, size_t dst_len, const uint32_t *src, size_t src_len) {
  326. return gorilla_decode(dst, (uint32_t) dst_len, src, (uint32_t) src_len);
  327. }
  328. // 64-bit API
  329. size_t gorilla_encode_u64(uint64_t *dst, size_t dst_len, const uint64_t *src, size_t src_len) {
  330. return gorilla_encode(dst, (uint64_t) dst_len, src, (uint64_t) src_len);
  331. }
  332. size_t gorilla_decode_u64(uint64_t *dst, size_t dst_len, const uint64_t *src, size_t src_len) {
  333. return gorilla_decode(dst, (uint64_t) dst_len, src, (uint64_t) src_len);
  334. }
  335. /*
  336. * Internal code used for fuzzing the library
  337. */
  338. #ifdef ENABLE_FUZZER
  339. #include <vector>
  340. template<typename Word>
  341. static std::vector<Word> random_vector(const uint8_t *data, size_t size) {
  342. std::vector<Word> V;
  343. V.reserve(1024);
  344. while (size >= sizeof(Word)) {
  345. size -= sizeof(Word);
  346. Word w;
  347. memcpy(&w, &data[size], sizeof(Word));
  348. V.push_back(w);
  349. }
  350. return V;
  351. }
  352. template<typename Word>
  353. static void check_equal_buffers(Word *lhs, Word lhs_size, Word *rhs, Word rhs_size) {
  354. assert((lhs_size == rhs_size) && "Buffers have different size.");
  355. for (size_t i = 0; i != lhs_size; i++) {
  356. assert((lhs[i] == rhs[i]) && "Buffers differ");
  357. }
  358. }
  359. extern "C" int LLVMFuzzerTestOneInput(const uint8_t *Data, size_t Size) {
  360. // 32-bit tests
  361. {
  362. if (Size < 4)
  363. return 0;
  364. std::vector<uint32_t> RandomData = random_vector<uint32_t>(Data, Size);
  365. std::vector<uint32_t> EncodedData(10 * RandomData.capacity(), 0);
  366. std::vector<uint32_t> DecodedData(10 * RandomData.capacity(), 0);
  367. size_t num_entries_written = gorilla_encode_u32(EncodedData.data(), EncodedData.size(),
  368. RandomData.data(), RandomData.size());
  369. size_t num_entries_read = gorilla_decode_u32(DecodedData.data(), DecodedData.size(),
  370. EncodedData.data(), EncodedData.size());
  371. assert(num_entries_written == num_entries_read);
  372. check_equal_buffers(RandomData.data(), (uint32_t) RandomData.size(),
  373. DecodedData.data(), (uint32_t) RandomData.size());
  374. }
  375. // 64-bit tests
  376. {
  377. if (Size < 8)
  378. return 0;
  379. std::vector<uint64_t> RandomData = random_vector<uint64_t>(Data, Size);
  380. std::vector<uint64_t> EncodedData(10 * RandomData.capacity(), 0);
  381. std::vector<uint64_t> DecodedData(10 * RandomData.capacity(), 0);
  382. size_t num_entries_written = gorilla_encode_u64(EncodedData.data(), EncodedData.size(),
  383. RandomData.data(), RandomData.size());
  384. size_t num_entries_read = gorilla_decode_u64(DecodedData.data(), DecodedData.size(),
  385. EncodedData.data(), EncodedData.size());
  386. assert(num_entries_written == num_entries_read);
  387. check_equal_buffers(RandomData.data(), (uint64_t) RandomData.size(),
  388. DecodedData.data(), (uint64_t) RandomData.size());
  389. }
  390. return 0;
  391. }
  392. #endif /* ENABLE_FUZZER */
  393. #ifdef ENABLE_BENCHMARK
  394. #include <benchmark/benchmark.h>
  395. #include <random>
  396. static size_t NumItems = 1024;
  397. static void BM_EncodeU32Numbers(benchmark::State& state) {
  398. std::random_device rd;
  399. std::mt19937 mt(rd());
  400. std::uniform_int_distribution<uint32_t> dist(0x0, 0x0000FFFF);
  401. std::vector<uint32_t> RandomData;
  402. for (size_t idx = 0; idx != NumItems; idx++) {
  403. RandomData.push_back(dist(mt));
  404. }
  405. std::vector<uint32_t> EncodedData(10 * RandomData.capacity(), 0);
  406. for (auto _ : state) {
  407. benchmark::DoNotOptimize(
  408. gorilla_encode_u32(EncodedData.data(), EncodedData.size(),
  409. RandomData.data(), RandomData.size())
  410. );
  411. benchmark::ClobberMemory();
  412. }
  413. state.SetItemsProcessed(NumItems * state.iterations());
  414. state.SetBytesProcessed(NumItems * state.iterations() * sizeof(uint32_t));
  415. }
  416. BENCHMARK(BM_EncodeU32Numbers);
  417. static void BM_DecodeU32Numbers(benchmark::State& state) {
  418. std::random_device rd;
  419. std::mt19937 mt(rd());
  420. std::uniform_int_distribution<uint32_t> dist(0x0, 0xFFFFFFFF);
  421. std::vector<uint32_t> RandomData;
  422. for (size_t idx = 0; idx != NumItems; idx++) {
  423. RandomData.push_back(dist(mt));
  424. }
  425. std::vector<uint32_t> EncodedData(10 * RandomData.capacity(), 0);
  426. std::vector<uint32_t> DecodedData(10 * RandomData.capacity(), 0);
  427. gorilla_encode_u32(EncodedData.data(), EncodedData.size(),
  428. RandomData.data(), RandomData.size());
  429. for (auto _ : state) {
  430. benchmark::DoNotOptimize(
  431. gorilla_decode_u32(DecodedData.data(), DecodedData.size(),
  432. EncodedData.data(), EncodedData.size())
  433. );
  434. benchmark::ClobberMemory();
  435. }
  436. state.SetItemsProcessed(NumItems * state.iterations());
  437. state.SetBytesProcessed(NumItems * state.iterations() * sizeof(uint32_t));
  438. }
  439. // Register the function as a benchmark
  440. BENCHMARK(BM_DecodeU32Numbers);
  441. static void BM_EncodeU64Numbers(benchmark::State& state) {
  442. std::random_device rd;
  443. std::mt19937 mt(rd());
  444. std::uniform_int_distribution<uint64_t> dist(0x0, 0x0000FFFF);
  445. std::vector<uint64_t> RandomData;
  446. for (size_t idx = 0; idx != 1024; idx++) {
  447. RandomData.push_back(dist(mt));
  448. }
  449. std::vector<uint64_t> EncodedData(10 * RandomData.capacity(), 0);
  450. for (auto _ : state) {
  451. benchmark::DoNotOptimize(
  452. gorilla_encode_u64(EncodedData.data(), EncodedData.size(),
  453. RandomData.data(), RandomData.size())
  454. );
  455. benchmark::ClobberMemory();
  456. }
  457. state.SetItemsProcessed(NumItems * state.iterations());
  458. state.SetBytesProcessed(NumItems * state.iterations() * sizeof(uint64_t));
  459. }
  460. BENCHMARK(BM_EncodeU64Numbers);
  461. static void BM_DecodeU64Numbers(benchmark::State& state) {
  462. std::random_device rd;
  463. std::mt19937 mt(rd());
  464. std::uniform_int_distribution<uint64_t> dist(0x0, 0xFFFFFFFF);
  465. std::vector<uint64_t> RandomData;
  466. for (size_t idx = 0; idx != 1024; idx++) {
  467. RandomData.push_back(dist(mt));
  468. }
  469. std::vector<uint64_t> EncodedData(10 * RandomData.capacity(), 0);
  470. std::vector<uint64_t> DecodedData(10 * RandomData.capacity(), 0);
  471. gorilla_encode_u64(EncodedData.data(), EncodedData.size(),
  472. RandomData.data(), RandomData.size());
  473. for (auto _ : state) {
  474. benchmark::DoNotOptimize(
  475. gorilla_decode_u64(DecodedData.data(), DecodedData.size(),
  476. EncodedData.data(), EncodedData.size())
  477. );
  478. benchmark::ClobberMemory();
  479. }
  480. state.SetItemsProcessed(NumItems * state.iterations());
  481. state.SetBytesProcessed(NumItems * state.iterations() * sizeof(uint64_t));
  482. }
  483. // Register the function as a benchmark
  484. BENCHMARK(BM_DecodeU64Numbers);
  485. #endif /* ENABLE_BENCHMARK */