huff.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #pragma once
  2. #include <util/system/defaults.h>
  3. #include <util/generic/yexception.h>
  4. #include <util/generic/ptr.h>
  5. #include <util/generic/vector.h>
  6. #include <util/generic/algorithm.h>
  7. #include <utility>
  8. #include <queue>
  9. #include "compressor.h"
  10. namespace NCompProto {
  11. template <size_t CacheSize, typename TEntry>
  12. struct TCache {
  13. ui32 CacheKey[CacheSize];
  14. TEntry CacheVal[CacheSize];
  15. size_t Hits;
  16. size_t Misses;
  17. ui32 Hash(ui32 key) {
  18. return key % CacheSize;
  19. }
  20. TCache() {
  21. Hits = 0;
  22. Misses = 0;
  23. Clear();
  24. }
  25. void Clear() {
  26. for (size_t i = 0; i < CacheSize; ++i) {
  27. ui32 j = 0;
  28. for (; Hash(j) == i; ++j)
  29. ;
  30. CacheKey[i] = j;
  31. }
  32. }
  33. };
  34. struct TCode {
  35. i64 Probability;
  36. ui32 Start;
  37. ui32 Bits;
  38. ui32 Prefix;
  39. ui32 PrefLength;
  40. TCode(i64 probability = 0, ui32 start = 0, ui32 bits = 0)
  41. : Probability(probability)
  42. , Start(start)
  43. , Bits(bits)
  44. {
  45. }
  46. bool operator<(const TCode& code) const {
  47. return Probability < code.Probability;
  48. }
  49. bool operator>(const TCode& code) const {
  50. return Probability > code.Probability;
  51. }
  52. };
  53. struct TAccum {
  54. struct TTable {
  55. TAutoPtr<TTable> Tables[16];
  56. i64 Counts[16];
  57. TTable(const TTable& other) {
  58. for (size_t i = 0; i < 16; ++i) {
  59. Counts[i] = other.Counts[i];
  60. if (other.Tables[i].Get()) {
  61. Tables[i].Reset(new TTable(*other.Tables[i].Get()));
  62. }
  63. }
  64. }
  65. TTable() {
  66. for (auto& count : Counts)
  67. count = 0;
  68. }
  69. i64 GetCellCount(size_t i) {
  70. i64 count = Counts[i];
  71. if (Tables[i].Get()) {
  72. for (size_t j = 0; j < 16; ++j) {
  73. count += Tables[i]->GetCellCount(j);
  74. }
  75. }
  76. return count;
  77. }
  78. i64 GetCount() {
  79. i64 count = 0;
  80. for (size_t j = 0; j < 16; ++j) {
  81. count += GetCellCount(j);
  82. }
  83. return count;
  84. }
  85. void GenerateFreqs(TVector<std::pair<i64, TCode>>& codes, int depth, int termDepth, ui32 code, i64 cnt) {
  86. if (depth == termDepth) {
  87. for (size_t i = 0; i < 16; ++i) {
  88. i64 iCount = GetCellCount(i);
  89. if (Tables[i].Get()) {
  90. Counts[i] = iCount;
  91. Tables[i].Reset(nullptr);
  92. }
  93. if (iCount > cnt || (termDepth == 0 && iCount > 0)) {
  94. std::pair<i64, TCode> codep;
  95. codep.first = iCount;
  96. codep.second.Probability = iCount;
  97. codep.second.Start = code + (i << (28 - depth));
  98. codep.second.Bits = 28 - depth;
  99. codes.push_back(codep);
  100. Counts[i] = 0;
  101. }
  102. }
  103. }
  104. for (size_t i = 0; i < 16; ++i) {
  105. if (Tables[i].Get()) {
  106. Tables[i]->GenerateFreqs(codes, depth + 4, termDepth, code + (i << (28 - depth)), cnt);
  107. }
  108. }
  109. }
  110. };
  111. TTable Root;
  112. int TableCount;
  113. i64 Total;
  114. ui64 Max;
  115. TAccum() {
  116. TableCount = 0;
  117. Total = 0;
  118. Max = 0;
  119. }
  120. void GenerateFreqs(TVector<std::pair<i64, TCode>>& codes, int mul) const {
  121. TTable root(Root);
  122. for (int i = 28; i > 0; i -= 4) {
  123. root.GenerateFreqs(codes, 0, i, 0, Total / mul);
  124. }
  125. i64 iCount = root.GetCount();
  126. if (iCount == 0)
  127. return;
  128. std::pair<i64, TCode> codep;
  129. codep.first = iCount;
  130. codep.second.Probability = iCount;
  131. codep.second.Start = 0;
  132. ui32 bits = 0;
  133. while (1) {
  134. if ((1ULL << bits) > Max)
  135. break;
  136. ++bits;
  137. }
  138. codep.second.Bits = bits;
  139. codes.push_back(codep);
  140. }
  141. TCache<256, i64*> Cache;
  142. void AddMap(ui32 value, i64 weight = 1) {
  143. ui32 index = Cache.Hash(value);
  144. if (Cache.CacheKey[index] == value) {
  145. Cache.CacheVal[index][0] += weight;
  146. return;
  147. }
  148. TTable* root = &Root;
  149. for (size_t i = 0; i < 15; ++i) {
  150. ui32 index2 = (value >> (28 - i * 4)) & 0xf;
  151. if (!root->Tables[index2].Get()) {
  152. if (TableCount < 1024) {
  153. ++TableCount;
  154. root->Tables[index2].Reset(new TTable);
  155. } else {
  156. Cache.CacheKey[index2] = value;
  157. Cache.CacheVal[index2] = &root->Counts[index2];
  158. root->Counts[index2] += weight;
  159. return;
  160. }
  161. }
  162. root = root->Tables[index2].Get();
  163. }
  164. Cache.CacheKey[index] = value;
  165. Cache.CacheVal[index] = &root->Counts[value & 0xf];
  166. root->Counts[value & 0xf] += weight;
  167. }
  168. void Add(ui32 value, i64 weight = 1) {
  169. Max = ::Max(Max, (ui64)value);
  170. Total += weight;
  171. AddMap(value, weight);
  172. }
  173. };
  174. struct THuffNode {
  175. i64 Weight;
  176. i64 Priority;
  177. THuffNode* Nodes[2];
  178. TCode* Code;
  179. THuffNode(i64 weight, i64 priority, TCode* code)
  180. : Weight(weight)
  181. , Priority(priority)
  182. , Code(code)
  183. {
  184. Nodes[0] = nullptr;
  185. Nodes[1] = nullptr;
  186. }
  187. void BuildPrefixes(ui32 depth, ui32 prefix) {
  188. if (Code) {
  189. Code->Prefix = prefix;
  190. Code->PrefLength = depth;
  191. return;
  192. }
  193. Nodes[0]->BuildPrefixes(depth + 1, prefix + (0UL << depth));
  194. Nodes[1]->BuildPrefixes(depth + 1, prefix + (1UL << depth));
  195. }
  196. i64 Iterate(size_t depth) const {
  197. if (Code) {
  198. return (depth + Code->Bits) * Code->Probability;
  199. }
  200. return Nodes[0]->Iterate(depth + 1) + Nodes[1]->Iterate(depth + 1);
  201. }
  202. size_t Depth() const {
  203. if (Code) {
  204. return 0;
  205. }
  206. return Max(Nodes[0]->Depth(), Nodes[1]->Depth()) + 1;
  207. }
  208. };
  209. struct THLess {
  210. bool operator()(const THuffNode* a, const THuffNode* b) {
  211. if (a->Weight > b->Weight)
  212. return 1;
  213. if (a->Weight == b->Weight && a->Priority > b->Priority)
  214. return 1;
  215. return 0;
  216. }
  217. };
  218. inline i64 BuildHuff(TVector<TCode>& codes) {
  219. TVector<TSimpleSharedPtr<THuffNode>> hold;
  220. std::priority_queue<THuffNode*, TVector<THuffNode*>, THLess> nodes;
  221. i64 ret = 0;
  222. int priority = 0;
  223. for (size_t i = 0; i < codes.size(); ++i) {
  224. TSimpleSharedPtr<THuffNode> node(new THuffNode(codes[i].Probability, priority++, &codes[i]));
  225. hold.push_back(node);
  226. nodes.push(node.Get());
  227. }
  228. while (nodes.size() > 1) {
  229. THuffNode* nodea = nodes.top();
  230. nodes.pop();
  231. THuffNode* nodeb = nodes.top();
  232. nodes.pop();
  233. TSimpleSharedPtr<THuffNode> node(new THuffNode(nodea->Weight + nodeb->Weight, priority++, nullptr));
  234. node->Nodes[0] = nodea;
  235. node->Nodes[1] = nodeb;
  236. hold.push_back(node);
  237. nodes.push(node.Get());
  238. }
  239. if (nodes.size()) {
  240. THuffNode* node = nodes.top();
  241. node->BuildPrefixes(0, 0);
  242. ret = node->Iterate(0);
  243. }
  244. return ret;
  245. }
  246. struct TCoderEntry {
  247. ui32 MinValue;
  248. ui16 Prefix;
  249. ui8 PrefixBits;
  250. ui8 AllBits;
  251. ui64 MaxValue() const {
  252. return MinValue + (1ULL << (AllBits - PrefixBits));
  253. }
  254. };
  255. inline i64 Analyze(const TAccum& acc, TVector<TCoderEntry>& retCodes) {
  256. i64 ret;
  257. for (int k = 256; k > 0; --k) {
  258. retCodes.clear();
  259. TVector<std::pair<i64, TCode>> pairs;
  260. acc.GenerateFreqs(pairs, k);
  261. TVector<TCode> codes;
  262. for (size_t i = 0; i < pairs.size(); ++i) {
  263. codes.push_back(pairs[i].second);
  264. }
  265. StableSort(codes.begin(), codes.end(), std::greater<TCode>());
  266. ret = BuildHuff(codes);
  267. bool valid = true;
  268. for (size_t i = 0; i < codes.size(); ++i) {
  269. TCoderEntry code;
  270. code.MinValue = codes[i].Start;
  271. code.Prefix = codes[i].Prefix;
  272. code.PrefixBits = codes[i].PrefLength;
  273. if (code.PrefixBits > 6)
  274. valid = false;
  275. code.AllBits = code.PrefixBits + codes[i].Bits;
  276. retCodes.push_back(code);
  277. }
  278. if (valid)
  279. return ret;
  280. }
  281. return ret;
  282. }
  283. struct TComparer {
  284. bool operator()(const TCoderEntry& e0, const TCoderEntry& e1) const {
  285. return e0.AllBits < e1.AllBits;
  286. }
  287. };
  288. struct TCoder {
  289. TVector<TCoderEntry> Entries;
  290. void Normalize() {
  291. TComparer comp;
  292. StableSort(Entries.begin(), Entries.end(), comp);
  293. }
  294. TCoder() {
  295. InitDefault();
  296. }
  297. void InitDefault() {
  298. ui64 cum = 0;
  299. Cache.Clear();
  300. Entries.clear();
  301. ui16 b = 1;
  302. for (ui16 i = 0; i < 40; ++i) {
  303. ui16 bits = Min(b, (ui16)(32));
  304. b = (b * 16) / 10 + 1;
  305. if (b > 32)
  306. b = 32;
  307. TCoderEntry entry;
  308. entry.PrefixBits = i + 1;
  309. entry.AllBits = entry.PrefixBits + bits;
  310. entry.MinValue = (ui32)Min(cum, (ui64)(ui32)(-1));
  311. cum += (1ULL << bits);
  312. entry.Prefix = ((1UL << i) - 1);
  313. Entries.push_back(entry);
  314. if (cum > (ui32)(-1)) {
  315. return;
  316. }
  317. }
  318. }
  319. TCache<1024, TCoderEntry> Cache;
  320. ui64 RealCode(ui32 value, const TCoderEntry& entry, size_t& length) {
  321. length = entry.AllBits;
  322. return (ui64(value - entry.MinValue) << entry.PrefixBits) + entry.Prefix;
  323. }
  324. bool Empty() const {
  325. return Entries.empty();
  326. }
  327. const TCoderEntry& GetEntry(ui32 code, ui8& id) const {
  328. for (size_t i = 0; i < Entries.size(); ++i) {
  329. const TCoderEntry& entry = Entries[i];
  330. ui32 prefMask = (1UL << entry.PrefixBits) - 1UL;
  331. if (entry.Prefix == (code & prefMask)) {
  332. id = ui8(i);
  333. return entry;
  334. }
  335. }
  336. ythrow yexception() << "bad entry";
  337. return Entries[0];
  338. }
  339. ui64 Code(ui32 entry, size_t& length) {
  340. ui32 index = Cache.Hash(entry);
  341. if (Cache.CacheKey[index] == entry) {
  342. ++Cache.Hits;
  343. return RealCode(entry, Cache.CacheVal[index], length);
  344. }
  345. ++Cache.Misses;
  346. for (size_t i = 0; i < Entries.size(); ++i) {
  347. if (entry >= Entries[i].MinValue && entry < Entries[i].MaxValue()) {
  348. Cache.CacheKey[index] = entry;
  349. Cache.CacheVal[index] = Entries[i];
  350. return RealCode(entry, Cache.CacheVal[index], length);
  351. }
  352. }
  353. ythrow yexception() << "bad huff tree";
  354. return 0;
  355. }
  356. };
  357. }