huffman_codec.cpp 17 KB


  1. #include "huffman_codec.h"
  2. #include <library/cpp/bit_io/bitinput.h>
  3. #include <library/cpp/bit_io/bitoutput.h>
  4. #include <util/generic/algorithm.h>
  5. #include <util/generic/bitops.h>
  6. #include <util/stream/buffer.h>
  7. #include <util/stream/length.h>
  8. #include <util/string/printf.h>
  9. namespace NCodecs {
  10. template <typename T>
  11. struct TCanonicalCmp {
  12. bool operator()(const T& a, const T& b) const {
  13. if (a.CodeLength == b.CodeLength) {
  14. return a.Char < b.Char;
  15. } else {
  16. return a.CodeLength < b.CodeLength;
  17. }
  18. }
  19. };
  20. template <typename T>
  21. struct TByCharCmp {
  22. bool operator()(const T& a, const T& b) const {
  23. return a.Char < b.Char;
  24. }
  25. };
  26. struct TTreeEntry {
  27. static const ui32 InvalidBranch = (ui32)-1;
  28. ui64 Freq = 0;
  29. ui32 Branches[2]{InvalidBranch, InvalidBranch};
  30. ui32 CodeLength = 0;
  31. ui8 Char = 0;
  32. bool Invalid = false;
  33. TTreeEntry() = default;
  34. static bool ByFreq(const TTreeEntry& a, const TTreeEntry& b) {
  35. return a.Freq < b.Freq;
  36. }
  37. static bool ByFreqRev(const TTreeEntry& a, const TTreeEntry& b) {
  38. return a.Freq > b.Freq;
  39. }
  40. };
  41. using TCodeTree = TVector<TTreeEntry>;
  42. void InitTreeByFreqs(TCodeTree& tree, const ui64 freqs[256]) {
  43. tree.reserve(255 * 256 / 2); // worst case - balanced tree
  44. for (ui32 i = 0; i < 256; ++i) {
  45. tree.emplace_back();
  46. tree.back().Char = i;
  47. tree.back().Freq = freqs[i];
  48. }
  49. StableSort(tree.begin(), tree.end(), TTreeEntry::ByFreq);
  50. }
  51. void InitTree(TCodeTree& tree, ISequenceReader* in) {
  52. using namespace NPrivate;
  53. ui64 freqs[256];
  54. Zero(freqs);
  55. TStringBuf r;
  56. while (in->NextRegion(r)) {
  57. for (ui64 i = 0; i < r.size(); ++i)
  58. ++freqs[(ui8)r[i]];
  59. }
  60. InitTreeByFreqs(tree, freqs);
  61. }
  62. void CalculateCodeLengths(TCodeTree& tree) {
  63. Y_ENSURE(tree.size() == 256, " ");
  64. const ui32 firstbranch = tree.size();
  65. ui32 curleaf = 0;
  66. ui32 curbranch = firstbranch;
  67. // building code tree. two priority queues are combined in one.
  68. while (firstbranch - curleaf + tree.size() - curbranch >= 2) {
  69. TTreeEntry e;
  70. for (auto& branche : e.Branches) {
  71. ui32 br;
  72. if (curleaf >= firstbranch)
  73. br = curbranch++;
  74. else if (curbranch >= tree.size())
  75. br = curleaf++;
  76. else if (tree[curleaf].Freq < tree[curbranch].Freq)
  77. br = curleaf++;
  78. else
  79. br = curbranch++;
  80. Y_ENSURE(br < tree.size(), " ");
  81. branche = br;
  82. e.Freq += tree[br].Freq;
  83. }
  84. tree.push_back(e);
  85. PushHeap(tree.begin() + curbranch, tree.end(), TTreeEntry::ByFreqRev);
  86. }
  87. // computing code lengths
  88. for (ui64 i = tree.size() - 1; i >= firstbranch; --i) {
  89. TTreeEntry e = tree[i];
  90. for (auto branche : e.Branches)
  91. tree[branche].CodeLength = e.CodeLength + 1;
  92. }
  93. // chopping off the branches
  94. tree.resize(firstbranch);
  95. Sort(tree.begin(), tree.end(), TCanonicalCmp<TTreeEntry>());
  96. // simplification: we are stripping codes longer than 64 bits
  97. while (!tree.empty() && tree.back().CodeLength > 64)
  98. tree.pop_back();
  99. // will not compress
  100. if (tree.empty())
  101. return;
  102. // special invalid code word
  103. tree.back().Invalid = true;
  104. }
  105. struct TEncoderEntry {
  106. ui64 Code = 0;
  107. ui8 CodeLength = 0;
  108. ui8 Char = 0;
  109. ui8 Invalid = true;
  110. explicit TEncoderEntry(TTreeEntry e)
  111. : CodeLength(e.CodeLength)
  112. , Char(e.Char)
  113. , Invalid(e.Invalid)
  114. {
  115. }
  116. TEncoderEntry() = default;
  117. };
  118. struct TEncoderTable {
  119. TEncoderEntry Entries[256];
  120. void Save(IOutputStream* out) const {
  121. ui16 nval = 0;
  122. for (auto entrie : Entries)
  123. nval += !entrie.Invalid;
  124. ::Save(out, nval);
  125. for (auto entrie : Entries) {
  126. if (!entrie.Invalid) {
  127. ::Save(out, entrie.Char);
  128. ::Save(out, entrie.CodeLength);
  129. }
  130. }
  131. }
  132. void Load(IInputStream* in) {
  133. ui16 nval = 0;
  134. ::Load(in, nval);
  135. for (ui32 i = 0; i < 256; ++i)
  136. Entries[i].Char = i;
  137. for (ui32 i = 0; i < nval; ++i) {
  138. ui8 ch = 0;
  139. ui8 len = 0;
  140. ::Load(in, ch);
  141. ::Load(in, len);
  142. Entries[ch].CodeLength = len;
  143. Entries[ch].Invalid = false;
  144. }
  145. }
  146. };
  147. struct TDecoderEntry {
  148. ui32 NextTable : 10;
  149. ui32 Char : 8;
  150. ui32 Invalid : 1;
  151. ui32 Bad : 1;
  152. TDecoderEntry()
  153. : NextTable()
  154. , Char()
  155. , Invalid()
  156. , Bad()
  157. {
  158. }
  159. };
  160. struct TDecoderTable: public TIntrusiveListItem<TDecoderTable> {
  161. ui64 Length = 0;
  162. ui64 BaseCode = 0;
  163. TDecoderEntry Entries[256];
  164. TDecoderTable() {
  165. Zero(Entries);
  166. }
  167. };
  168. const int CACHE_BITS_COUNT = 16;
  169. class THuffmanCodec::TImpl: public TAtomicRefCount<TImpl> {
  170. TEncoderTable Encoder;
  171. TDecoderTable Decoder[256];
  172. TEncoderEntry Invalid;
  173. ui32 SubTablesNum;
  174. class THuffmanCache {
  175. struct TCacheEntry {
  176. int EndOffset : 24;
  177. int BitsLeft : 8;
  178. };
  179. TVector<char> DecodeCache;
  180. TVector<TCacheEntry> CacheEntries;
  181. const TImpl& Original;
  182. public:
  183. THuffmanCache(const THuffmanCodec::TImpl& encoder);
  184. void Decode(NBitIO::TBitInput& in, TBuffer& out) const;
  185. };
  186. THolder<THuffmanCache> Cache;
  187. public:
  188. TImpl()
  189. : SubTablesNum(1)
  190. {
  191. Invalid.CodeLength = 255;
  192. }
  193. ui8 Encode(TStringBuf in, TBuffer& out) const {
  194. out.Clear();
  195. if (in.empty()) {
  196. return 0;
  197. }
  198. out.Reserve(in.size() * 2);
  199. {
  200. NBitIO::TBitOutputVector<TBuffer> bout(&out);
  201. TStringBuf tin = in;
  202. // data is under compression
  203. bout.Write(1, 1);
  204. for (auto t : tin) {
  205. const TEncoderEntry& ce = Encoder.Entries[(ui8)t];
  206. bout.Write(ce.Code, ce.CodeLength);
  207. if (ce.Invalid) {
  208. bout.Write(t, 8);
  209. }
  210. }
  211. // in canonical huffman coding there cannot be a code having no 0 in the suffix
  212. // and shorter than 8 bits.
  213. bout.Write((ui64)-1, bout.GetByteReminder());
  214. return bout.GetByteReminder();
  215. }
  216. }
  217. void Decode(TStringBuf in, TBuffer& out) const {
  218. out.Clear();
  219. if (in.empty()) {
  220. return;
  221. }
  222. NBitIO::TBitInput bin(in);
  223. ui64 f = 0;
  224. bin.ReadK<1>(f);
  225. // if data is uncompressed
  226. if (!f) {
  227. in.Skip(1);
  228. out.Append(in.data(), in.size());
  229. } else {
  230. out.Reserve(in.size() * 8);
  231. if (Cache.Get()) {
  232. Cache->Decode(bin, out);
  233. } else {
  234. while (ReadNextChar(bin, out)) {
  235. }
  236. }
  237. }
  238. }
  239. Y_FORCE_INLINE int ReadNextChar(NBitIO::TBitInput& bin, TBuffer& out) const {
  240. const TDecoderTable* table = Decoder;
  241. TDecoderEntry e;
  242. int bitsRead = 0;
  243. while (true) {
  244. ui64 code = 0;
  245. if (Y_UNLIKELY(!bin.Read(code, table->Length)))
  246. return 0;
  247. bitsRead += table->Length;
  248. if (Y_UNLIKELY(code < table->BaseCode))
  249. return 0;
  250. code -= table->BaseCode;
  251. if (Y_UNLIKELY(code > 255))
  252. return 0;
  253. e = table->Entries[code];
  254. if (Y_UNLIKELY(e.Bad))
  255. return 0;
  256. if (e.NextTable) {
  257. table = Decoder + e.NextTable;
  258. } else {
  259. if (e.Invalid) {
  260. code = 0;
  261. bin.ReadK<8>(code);
  262. bitsRead += 8;
  263. out.Append((ui8)code);
  264. } else {
  265. out.Append((ui8)e.Char);
  266. }
  267. return bitsRead;
  268. }
  269. }
  270. Y_ENSURE(false, " could not decode input");
  271. return 0;
  272. }
  273. void GenerateEncoder(TCodeTree& tree) {
  274. const ui64 sz = tree.size();
  275. TEncoderEntry lastcode = Encoder.Entries[tree[0].Char] = TEncoderEntry(tree[0]);
  276. for (ui32 i = 1; i < sz; ++i) {
  277. const TTreeEntry& te = tree[i];
  278. TEncoderEntry& e = Encoder.Entries[te.Char];
  279. e = TEncoderEntry(te);
  280. e.Code = (lastcode.Code + 1) << (e.CodeLength - lastcode.CodeLength);
  281. lastcode = e;
  282. e.Code = ReverseBits(e.Code, e.CodeLength);
  283. if (e.Invalid)
  284. Invalid = e;
  285. }
  286. for (auto& e : Encoder.Entries) {
  287. if (e.Invalid)
  288. e = Invalid;
  289. Y_ENSURE(e.CodeLength, " ");
  290. }
  291. }
  292. void RegenerateEncoder() {
  293. for (auto& entrie : Encoder.Entries) {
  294. if (entrie.Invalid)
  295. entrie.CodeLength = Invalid.CodeLength;
  296. }
  297. Sort(Encoder.Entries, Encoder.Entries + 256, TCanonicalCmp<TEncoderEntry>());
  298. TEncoderEntry lastcode = Encoder.Entries[0];
  299. for (ui32 i = 1; i < 256; ++i) {
  300. TEncoderEntry& e = Encoder.Entries[i];
  301. e.Code = (lastcode.Code + 1) << (e.CodeLength - lastcode.CodeLength);
  302. lastcode = e;
  303. e.Code = ReverseBits(e.Code, e.CodeLength);
  304. }
  305. for (auto& entrie : Encoder.Entries) {
  306. if (entrie.Invalid) {
  307. Invalid = entrie;
  308. break;
  309. }
  310. }
  311. Sort(Encoder.Entries, Encoder.Entries + 256, TByCharCmp<TEncoderEntry>());
  312. for (auto& entrie : Encoder.Entries) {
  313. if (entrie.Invalid)
  314. entrie = Invalid;
  315. }
  316. }
  317. void BuildDecoder() {
  318. TEncoderTable enc = Encoder;
  319. Sort(enc.Entries, enc.Entries + 256, TCanonicalCmp<TEncoderEntry>());
  320. TEncoderEntry& e1 = enc.Entries[0];
  321. Decoder[0].BaseCode = e1.Code;
  322. Decoder[0].Length = e1.CodeLength;
  323. for (auto e2 : enc.Entries) {
  324. SetEntry(Decoder, e2.Code, e2.CodeLength, e2);
  325. }
  326. Cache.Reset(new THuffmanCache(*this));
  327. }
  328. void SetEntry(TDecoderTable* t, ui64 code, ui64 len, TEncoderEntry e) {
  329. Y_ENSURE(len >= t->Length, len << " < " << t->Length);
  330. ui64 idx = (code & MaskLowerBits(t->Length)) - t->BaseCode;
  331. TDecoderEntry& d = t->Entries[idx];
  332. if (len == t->Length) {
  333. Y_ENSURE(!d.NextTable, " ");
  334. d.Char = e.Char;
  335. d.Invalid = e.Invalid;
  336. return;
  337. }
  338. if (!d.NextTable) {
  339. Y_ENSURE(SubTablesNum < Y_ARRAY_SIZE(Decoder), " ");
  340. d.NextTable = SubTablesNum++;
  341. TDecoderTable* nt = Decoder + d.NextTable;
  342. nt->Length = Min<ui64>(8, len - t->Length);
  343. nt->BaseCode = (code >> t->Length) & MaskLowerBits(nt->Length);
  344. }
  345. SetEntry(Decoder + d.NextTable, code >> t->Length, len - t->Length, e);
  346. }
  347. void Learn(ISequenceReader* in) {
  348. {
  349. TCodeTree tree;
  350. InitTree(tree, in);
  351. CalculateCodeLengths(tree);
  352. Y_ENSURE(!tree.empty(), " ");
  353. GenerateEncoder(tree);
  354. }
  355. BuildDecoder();
  356. }
  357. void LearnByFreqs(const TArrayRef<std::pair<char, ui64>>& freqs) {
  358. TCodeTree tree;
  359. ui64 freqsArray[256];
  360. Zero(freqsArray);
  361. for (const auto& freq : freqs)
  362. freqsArray[static_cast<ui8>(freq.first)] += freq.second;
  363. InitTreeByFreqs(tree, freqsArray);
  364. CalculateCodeLengths(tree);
  365. Y_ENSURE(!tree.empty(), " ");
  366. GenerateEncoder(tree);
  367. BuildDecoder();
  368. }
  369. void Save(IOutputStream* out) {
  370. ::Save(out, Invalid.CodeLength);
  371. Encoder.Save(out);
  372. }
  373. void Load(IInputStream* in) {
  374. ::Load(in, Invalid.CodeLength);
  375. Encoder.Load(in);
  376. RegenerateEncoder();
  377. BuildDecoder();
  378. }
  379. };
  380. THuffmanCodec::TImpl::THuffmanCache::THuffmanCache(const THuffmanCodec::TImpl& codec)
  381. : Original(codec)
  382. {
  383. CacheEntries.resize(1 << CACHE_BITS_COUNT);
  384. DecodeCache.reserve(CacheEntries.size() * 2);
  385. char buffer[2];
  386. TBuffer decoded;
  387. for (size_t i = 0; i < CacheEntries.size(); i++) {
  388. buffer[1] = i >> 8;
  389. buffer[0] = i;
  390. NBitIO::TBitInput bin(buffer, buffer + sizeof(buffer));
  391. int totalBits = 0;
  392. while (true) {
  393. decoded.Resize(0);
  394. int bits = codec.ReadNextChar(bin, decoded);
  395. if (totalBits + bits > 16 || !bits) {
  396. TCacheEntry e = {static_cast<int>(DecodeCache.size()), 16 - totalBits};
  397. CacheEntries[i] = e;
  398. break;
  399. }
  400. for (TBuffer::TConstIterator it = decoded.Begin(); it != decoded.End(); ++it) {
  401. DecodeCache.push_back(*it);
  402. }
  403. totalBits += bits;
  404. }
  405. }
  406. DecodeCache.push_back(0);
  407. CacheEntries.shrink_to_fit();
  408. DecodeCache.shrink_to_fit();
  409. }
  410. void THuffmanCodec::TImpl::THuffmanCache::Decode(NBitIO::TBitInput& bin, TBuffer& out) const {
  411. int bits = 0;
  412. ui64 code = 0;
  413. while (!bin.Eof()) {
  414. ui64 f = 0;
  415. const int toRead = 16 - bits;
  416. if (toRead > 0 && bin.Read(f, toRead)) {
  417. code = (code >> (16 - bits)) | (f << bits);
  418. code &= 0xFFFF;
  419. TCacheEntry entry = CacheEntries[code];
  420. int start = code > 0 ? CacheEntries[code - 1].EndOffset : 0;
  421. out.Append((const char*)&DecodeCache[start], (const char*)&DecodeCache[entry.EndOffset]);
  422. bits = entry.BitsLeft;
  423. } else { // should never happen until there are exceptions or unaligned input
  424. bin.Back(bits);
  425. if (!Original.ReadNextChar(bin, out))
  426. break;
  427. code = 0;
  428. bits = 0;
  429. }
  430. }
  431. }
  432. THuffmanCodec::THuffmanCodec()
  433. : Impl(new TImpl)
  434. {
  435. MyTraits.NeedsTraining = true;
  436. MyTraits.PreservesPrefixGrouping = true;
  437. MyTraits.PaddingBit = 1;
  438. MyTraits.SizeOnEncodeMultiplier = 2;
  439. MyTraits.SizeOnDecodeMultiplier = 8;
  440. MyTraits.RecommendedSampleSize = 1 << 21;
  441. }
  442. THuffmanCodec::~THuffmanCodec() = default;
  443. ui8 THuffmanCodec::Encode(TStringBuf in, TBuffer& bbb) const {
  444. if (Y_UNLIKELY(!Trained))
  445. ythrow TCodecException() << " not trained";
  446. return Impl->Encode(in, bbb);
  447. }
  448. void THuffmanCodec::Decode(TStringBuf in, TBuffer& bbb) const {
  449. Impl->Decode(in, bbb);
  450. }
  451. void THuffmanCodec::Save(IOutputStream* out) const {
  452. Impl->Save(out);
  453. }
  454. void THuffmanCodec::Load(IInputStream* in) {
  455. Impl->Load(in);
  456. }
  457. void THuffmanCodec::DoLearn(ISequenceReader& in) {
  458. Impl->Learn(&in);
  459. }
  460. void THuffmanCodec::LearnByFreqs(const TArrayRef<std::pair<char, ui64>>& freqs) {
  461. Impl->LearnByFreqs(freqs);
  462. Trained = true;
  463. }
  464. }