huffman_codec.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. #include "huffman_codec.h"
  2. #include <library/cpp/bit_io/bitinput.h>
  3. #include <library/cpp/bit_io/bitoutput.h>
  4. #include <util/generic/algorithm.h>
  5. #include <util/generic/bitops.h>
  6. #include <util/stream/length.h>
  7. #include <util/string/printf.h>
  8. namespace NCodecs {
  9. template <typename T>
  10. struct TCanonicalCmp {
  11. bool operator()(const T& a, const T& b) const {
  12. if (a.CodeLength == b.CodeLength) {
  13. return a.Char < b.Char;
  14. } else {
  15. return a.CodeLength < b.CodeLength;
  16. }
  17. }
  18. };
  19. template <typename T>
  20. struct TByCharCmp {
  21. bool operator()(const T& a, const T& b) const {
  22. return a.Char < b.Char;
  23. }
  24. };
  25. struct TTreeEntry {
  26. static const ui32 InvalidBranch = (ui32)-1;
  27. ui64 Freq = 0;
  28. ui32 Branches[2]{InvalidBranch, InvalidBranch};
  29. ui32 CodeLength = 0;
  30. ui8 Char = 0;
  31. bool Invalid = false;
  32. TTreeEntry() = default;
  33. static bool ByFreq(const TTreeEntry& a, const TTreeEntry& b) {
  34. return a.Freq < b.Freq;
  35. }
  36. static bool ByFreqRev(const TTreeEntry& a, const TTreeEntry& b) {
  37. return a.Freq > b.Freq;
  38. }
  39. };
  40. using TCodeTree = TVector<TTreeEntry>;
  41. void InitTreeByFreqs(TCodeTree& tree, const ui64 freqs[256]) {
  42. tree.reserve(255 * 256 / 2); // worst case - balanced tree
  43. for (ui32 i = 0; i < 256; ++i) {
  44. tree.emplace_back();
  45. tree.back().Char = i;
  46. tree.back().Freq = freqs[i];
  47. }
  48. StableSort(tree.begin(), tree.end(), TTreeEntry::ByFreq);
  49. }
  50. void InitTree(TCodeTree& tree, ISequenceReader* in) {
  51. using namespace NPrivate;
  52. ui64 freqs[256];
  53. Zero(freqs);
  54. TStringBuf r;
  55. while (in->NextRegion(r)) {
  56. for (ui64 i = 0; i < r.size(); ++i)
  57. ++freqs[(ui8)r[i]];
  58. }
  59. InitTreeByFreqs(tree, freqs);
  60. }
  61. void CalculateCodeLengths(TCodeTree& tree) {
  62. Y_ENSURE(tree.size() == 256, " ");
  63. const ui32 firstbranch = tree.size();
  64. ui32 curleaf = 0;
  65. ui32 curbranch = firstbranch;
  66. // building code tree. two priority queues are combined in one.
  67. while (firstbranch - curleaf + tree.size() - curbranch >= 2) {
  68. TTreeEntry e;
  69. for (auto& branche : e.Branches) {
  70. ui32 br;
  71. if (curleaf >= firstbranch)
  72. br = curbranch++;
  73. else if (curbranch >= tree.size())
  74. br = curleaf++;
  75. else if (tree[curleaf].Freq < tree[curbranch].Freq)
  76. br = curleaf++;
  77. else
  78. br = curbranch++;
  79. Y_ENSURE(br < tree.size(), " ");
  80. branche = br;
  81. e.Freq += tree[br].Freq;
  82. }
  83. tree.push_back(e);
  84. PushHeap(tree.begin() + curbranch, tree.end(), TTreeEntry::ByFreqRev);
  85. }
  86. // computing code lengths
  87. for (ui64 i = tree.size() - 1; i >= firstbranch; --i) {
  88. TTreeEntry e = tree[i];
  89. for (auto branche : e.Branches)
  90. tree[branche].CodeLength = e.CodeLength + 1;
  91. }
  92. // chopping off the branches
  93. tree.resize(firstbranch);
  94. Sort(tree.begin(), tree.end(), TCanonicalCmp<TTreeEntry>());
  95. // simplification: we are stripping codes longer than 64 bits
  96. while (!tree.empty() && tree.back().CodeLength > 64)
  97. tree.pop_back();
  98. // will not compress
  99. if (tree.empty())
  100. return;
  101. // special invalid code word
  102. tree.back().Invalid = true;
  103. }
  104. struct TEncoderEntry {
  105. ui64 Code = 0;
  106. ui8 CodeLength = 0;
  107. ui8 Char = 0;
  108. ui8 Invalid = true;
  109. explicit TEncoderEntry(TTreeEntry e)
  110. : CodeLength(e.CodeLength)
  111. , Char(e.Char)
  112. , Invalid(e.Invalid)
  113. {
  114. }
  115. TEncoderEntry() = default;
  116. };
  117. struct TEncoderTable {
  118. TEncoderEntry Entries[256];
  119. void Save(IOutputStream* out) const {
  120. ui16 nval = 0;
  121. for (auto entrie : Entries)
  122. nval += !entrie.Invalid;
  123. ::Save(out, nval);
  124. for (auto entrie : Entries) {
  125. if (!entrie.Invalid) {
  126. ::Save(out, entrie.Char);
  127. ::Save(out, entrie.CodeLength);
  128. }
  129. }
  130. }
  131. void Load(IInputStream* in) {
  132. ui16 nval = 0;
  133. ::Load(in, nval);
  134. for (ui32 i = 0; i < 256; ++i)
  135. Entries[i].Char = i;
  136. for (ui32 i = 0; i < nval; ++i) {
  137. ui8 ch = 0;
  138. ui8 len = 0;
  139. ::Load(in, ch);
  140. ::Load(in, len);
  141. Entries[ch].CodeLength = len;
  142. Entries[ch].Invalid = false;
  143. }
  144. }
  145. };
  146. struct TDecoderEntry {
  147. ui32 NextTable : 10;
  148. ui32 Char : 8;
  149. ui32 Invalid : 1;
  150. ui32 Bad : 1;
  151. TDecoderEntry()
  152. : NextTable()
  153. , Char()
  154. , Invalid()
  155. , Bad()
  156. {
  157. }
  158. };
  159. struct TDecoderTable: public TIntrusiveListItem<TDecoderTable> {
  160. ui64 Length = 0;
  161. ui64 BaseCode = 0;
  162. TDecoderEntry Entries[256];
  163. TDecoderTable() {
  164. Zero(Entries);
  165. }
  166. };
  167. const int CACHE_BITS_COUNT = 16;
  168. class THuffmanCodec::TImpl: public TAtomicRefCount<TImpl> {
  169. TEncoderTable Encoder;
  170. TDecoderTable Decoder[256];
  171. TEncoderEntry Invalid;
  172. ui32 SubTablesNum;
  173. class THuffmanCache {
  174. struct TCacheEntry {
  175. int EndOffset : 24;
  176. int BitsLeft : 8;
  177. };
  178. TVector<char> DecodeCache;
  179. TVector<TCacheEntry> CacheEntries;
  180. const TImpl& Original;
  181. public:
  182. THuffmanCache(const THuffmanCodec::TImpl& encoder);
  183. void Decode(NBitIO::TBitInput& in, TBuffer& out) const;
  184. };
  185. THolder<THuffmanCache> Cache;
  186. public:
  187. TImpl()
  188. : SubTablesNum(1)
  189. {
  190. Invalid.CodeLength = 255;
  191. }
  192. ui8 Encode(TStringBuf in, TBuffer& out) const {
  193. out.Clear();
  194. if (in.empty()) {
  195. return 0;
  196. }
  197. out.Reserve(in.size() * 2);
  198. {
  199. NBitIO::TBitOutputVector<TBuffer> bout(&out);
  200. TStringBuf tin = in;
  201. // data is under compression
  202. bout.Write(1, 1);
  203. for (auto t : tin) {
  204. const TEncoderEntry& ce = Encoder.Entries[(ui8)t];
  205. bout.Write(ce.Code, ce.CodeLength);
  206. if (ce.Invalid) {
  207. bout.Write(t, 8);
  208. }
  209. }
  210. // in canonical huffman coding there cannot be a code having no 0 in the suffix
  211. // and shorter than 8 bits.
  212. bout.Write((ui64)-1, bout.GetByteReminder());
  213. return bout.GetByteReminder();
  214. }
  215. }
  216. void Decode(TStringBuf in, TBuffer& out) const {
  217. out.Clear();
  218. if (in.empty()) {
  219. return;
  220. }
  221. NBitIO::TBitInput bin(in);
  222. ui64 f = 0;
  223. bin.ReadK<1>(f);
  224. // if data is uncompressed
  225. if (!f) {
  226. in.Skip(1);
  227. out.Append(in.data(), in.size());
  228. } else {
  229. out.Reserve(in.size() * 8);
  230. if (Cache.Get()) {
  231. Cache->Decode(bin, out);
  232. } else {
  233. while (ReadNextChar(bin, out)) {
  234. }
  235. }
  236. }
  237. }
  238. Y_FORCE_INLINE int ReadNextChar(NBitIO::TBitInput& bin, TBuffer& out) const {
  239. const TDecoderTable* table = Decoder;
  240. TDecoderEntry e;
  241. int bitsRead = 0;
  242. while (true) {
  243. ui64 code = 0;
  244. if (Y_UNLIKELY(!bin.Read(code, table->Length)))
  245. return 0;
  246. bitsRead += table->Length;
  247. if (Y_UNLIKELY(code < table->BaseCode))
  248. return 0;
  249. code -= table->BaseCode;
  250. if (Y_UNLIKELY(code > 255))
  251. return 0;
  252. e = table->Entries[code];
  253. if (Y_UNLIKELY(e.Bad))
  254. return 0;
  255. if (e.NextTable) {
  256. table = Decoder + e.NextTable;
  257. } else {
  258. if (e.Invalid) {
  259. code = 0;
  260. bin.ReadK<8>(code);
  261. bitsRead += 8;
  262. out.Append((ui8)code);
  263. } else {
  264. out.Append((ui8)e.Char);
  265. }
  266. return bitsRead;
  267. }
  268. }
  269. Y_ENSURE(false, " could not decode input");
  270. return 0;
  271. }
  272. void GenerateEncoder(TCodeTree& tree) {
  273. const ui64 sz = tree.size();
  274. TEncoderEntry lastcode = Encoder.Entries[tree[0].Char] = TEncoderEntry(tree[0]);
  275. for (ui32 i = 1; i < sz; ++i) {
  276. const TTreeEntry& te = tree[i];
  277. TEncoderEntry& e = Encoder.Entries[te.Char];
  278. e = TEncoderEntry(te);
  279. e.Code = (lastcode.Code + 1) << (e.CodeLength - lastcode.CodeLength);
  280. lastcode = e;
  281. e.Code = ReverseBits(e.Code, e.CodeLength);
  282. if (e.Invalid)
  283. Invalid = e;
  284. }
  285. for (auto& e : Encoder.Entries) {
  286. if (e.Invalid)
  287. e = Invalid;
  288. Y_ENSURE(e.CodeLength, " ");
  289. }
  290. }
  291. void RegenerateEncoder() {
  292. for (auto& entrie : Encoder.Entries) {
  293. if (entrie.Invalid)
  294. entrie.CodeLength = Invalid.CodeLength;
  295. }
  296. Sort(Encoder.Entries, Encoder.Entries + 256, TCanonicalCmp<TEncoderEntry>());
  297. TEncoderEntry lastcode = Encoder.Entries[0];
  298. for (ui32 i = 1; i < 256; ++i) {
  299. TEncoderEntry& e = Encoder.Entries[i];
  300. e.Code = (lastcode.Code + 1) << (e.CodeLength - lastcode.CodeLength);
  301. lastcode = e;
  302. e.Code = ReverseBits(e.Code, e.CodeLength);
  303. }
  304. for (auto& entrie : Encoder.Entries) {
  305. if (entrie.Invalid) {
  306. Invalid = entrie;
  307. break;
  308. }
  309. }
  310. Sort(Encoder.Entries, Encoder.Entries + 256, TByCharCmp<TEncoderEntry>());
  311. for (auto& entrie : Encoder.Entries) {
  312. if (entrie.Invalid)
  313. entrie = Invalid;
  314. }
  315. }
  316. void BuildDecoder() {
  317. TEncoderTable enc = Encoder;
  318. Sort(enc.Entries, enc.Entries + 256, TCanonicalCmp<TEncoderEntry>());
  319. TEncoderEntry& e1 = enc.Entries[0];
  320. Decoder[0].BaseCode = e1.Code;
  321. Decoder[0].Length = e1.CodeLength;
  322. for (auto e2 : enc.Entries) {
  323. SetEntry(Decoder, e2.Code, e2.CodeLength, e2);
  324. }
  325. Cache.Reset(new THuffmanCache(*this));
  326. }
  327. void SetEntry(TDecoderTable* t, ui64 code, ui64 len, TEncoderEntry e) {
  328. Y_ENSURE(len >= t->Length, len << " < " << t->Length);
  329. ui64 idx = (code & MaskLowerBits(t->Length)) - t->BaseCode;
  330. TDecoderEntry& d = t->Entries[idx];
  331. if (len == t->Length) {
  332. Y_ENSURE(!d.NextTable, " ");
  333. d.Char = e.Char;
  334. d.Invalid = e.Invalid;
  335. return;
  336. }
  337. if (!d.NextTable) {
  338. Y_ENSURE(SubTablesNum < Y_ARRAY_SIZE(Decoder), " ");
  339. d.NextTable = SubTablesNum++;
  340. TDecoderTable* nt = Decoder + d.NextTable;
  341. nt->Length = Min<ui64>(8, len - t->Length);
  342. nt->BaseCode = (code >> t->Length) & MaskLowerBits(nt->Length);
  343. }
  344. SetEntry(Decoder + d.NextTable, code >> t->Length, len - t->Length, e);
  345. }
  346. void Learn(ISequenceReader* in) {
  347. {
  348. TCodeTree tree;
  349. InitTree(tree, in);
  350. CalculateCodeLengths(tree);
  351. Y_ENSURE(!tree.empty(), " ");
  352. GenerateEncoder(tree);
  353. }
  354. BuildDecoder();
  355. }
  356. void LearnByFreqs(const TArrayRef<std::pair<char, ui64>>& freqs) {
  357. TCodeTree tree;
  358. ui64 freqsArray[256];
  359. Zero(freqsArray);
  360. for (const auto& freq : freqs)
  361. freqsArray[static_cast<ui8>(freq.first)] += freq.second;
  362. InitTreeByFreqs(tree, freqsArray);
  363. CalculateCodeLengths(tree);
  364. Y_ENSURE(!tree.empty(), " ");
  365. GenerateEncoder(tree);
  366. BuildDecoder();
  367. }
  368. void Save(IOutputStream* out) {
  369. ::Save(out, Invalid.CodeLength);
  370. Encoder.Save(out);
  371. }
  372. void Load(IInputStream* in) {
  373. ::Load(in, Invalid.CodeLength);
  374. Encoder.Load(in);
  375. RegenerateEncoder();
  376. BuildDecoder();
  377. }
  378. };
  379. THuffmanCodec::TImpl::THuffmanCache::THuffmanCache(const THuffmanCodec::TImpl& codec)
  380. : Original(codec)
  381. {
  382. CacheEntries.resize(1 << CACHE_BITS_COUNT);
  383. DecodeCache.reserve(CacheEntries.size() * 2);
  384. char buffer[2];
  385. TBuffer decoded;
  386. for (size_t i = 0; i < CacheEntries.size(); i++) {
  387. buffer[1] = i >> 8;
  388. buffer[0] = i;
  389. NBitIO::TBitInput bin(buffer, buffer + sizeof(buffer));
  390. int totalBits = 0;
  391. while (true) {
  392. decoded.Resize(0);
  393. int bits = codec.ReadNextChar(bin, decoded);
  394. if (totalBits + bits > 16 || !bits) {
  395. TCacheEntry e = {static_cast<int>(DecodeCache.size()), 16 - totalBits};
  396. CacheEntries[i] = e;
  397. break;
  398. }
  399. for (TBuffer::TConstIterator it = decoded.Begin(); it != decoded.End(); ++it) {
  400. DecodeCache.push_back(*it);
  401. }
  402. totalBits += bits;
  403. }
  404. }
  405. DecodeCache.push_back(0);
  406. CacheEntries.shrink_to_fit();
  407. DecodeCache.shrink_to_fit();
  408. }
  409. void THuffmanCodec::TImpl::THuffmanCache::Decode(NBitIO::TBitInput& bin, TBuffer& out) const {
  410. int bits = 0;
  411. ui64 code = 0;
  412. while (!bin.Eof()) {
  413. ui64 f = 0;
  414. const int toRead = 16 - bits;
  415. if (toRead > 0 && bin.Read(f, toRead)) {
  416. code = (code >> (16 - bits)) | (f << bits);
  417. code &= 0xFFFF;
  418. TCacheEntry entry = CacheEntries[code];
  419. int start = code > 0 ? CacheEntries[code - 1].EndOffset : 0;
  420. out.Append((const char*)&DecodeCache[start], (const char*)&DecodeCache[entry.EndOffset]);
  421. bits = entry.BitsLeft;
  422. } else { // should never happen until there are exceptions or unaligned input
  423. bin.Back(bits);
  424. if (!Original.ReadNextChar(bin, out))
  425. break;
  426. code = 0;
  427. bits = 0;
  428. }
  429. }
  430. }
  431. THuffmanCodec::THuffmanCodec()
  432. : Impl(new TImpl)
  433. {
  434. MyTraits.NeedsTraining = true;
  435. MyTraits.PreservesPrefixGrouping = true;
  436. MyTraits.PaddingBit = 1;
  437. MyTraits.SizeOnEncodeMultiplier = 2;
  438. MyTraits.SizeOnDecodeMultiplier = 8;
  439. MyTraits.RecommendedSampleSize = 1 << 21;
  440. }
  441. THuffmanCodec::~THuffmanCodec() = default;
  442. ui8 THuffmanCodec::Encode(TStringBuf in, TBuffer& bbb) const {
  443. if (Y_UNLIKELY(!Trained))
  444. ythrow TCodecException() << " not trained";
  445. return Impl->Encode(in, bbb);
  446. }
  447. void THuffmanCodec::Decode(TStringBuf in, TBuffer& bbb) const {
  448. Impl->Decode(in, bbb);
  449. }
  450. void THuffmanCodec::Save(IOutputStream* out) const {
  451. Impl->Save(out);
  452. }
  453. void THuffmanCodec::Load(IInputStream* in) {
  454. Impl->Load(in);
  455. }
  456. void THuffmanCodec::DoLearn(ISequenceReader& in) {
  457. Impl->Learn(&in);
  458. }
  459. void THuffmanCodec::LearnByFreqs(const TArrayRef<std::pair<char, ui64>>& freqs) {
  460. Impl->LearnByFreqs(freqs);
  461. Trained = true;
  462. }
  463. }