zlib.cpp 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. #include "zlib.h"
  2. #include <util/memory/addstorage.h>
  3. #include <util/generic/scope.h>
  4. #include <util/generic/utility.h>
  5. #include <zlib.h>
  6. #include <cstring>
  7. namespace {
  8. static const int opts[] = {
  9. // Auto
  10. 15 + 32,
  11. // ZLib
  12. 15 + 0,
  13. // GZip
  14. 15 + 16,
  15. // Raw
  16. -15};
  17. class TZLibCommon {
  18. public:
  19. inline TZLibCommon() noexcept {
  20. memset(Z(), 0, sizeof(*Z()));
  21. }
  22. inline ~TZLibCommon() = default;
  23. inline const char* GetErrMsg() const noexcept {
  24. return Z()->msg != nullptr ? Z()->msg : "unknown error";
  25. }
  26. inline z_stream* Z() const noexcept {
  27. return (z_stream*)(&Z_);
  28. }
  29. private:
  30. z_stream Z_;
  31. };
  32. static inline ui32 MaxPortion(size_t s) noexcept {
  33. return (ui32)Min<size_t>(Max<ui32>(), s);
  34. }
  35. struct TChunkedZeroCopyInput {
  36. inline TChunkedZeroCopyInput(IZeroCopyInput* in)
  37. : In(in)
  38. , Buf(nullptr)
  39. , Len(0)
  40. {
  41. }
  42. template <class P, class T>
  43. inline bool Next(P** buf, T* len) {
  44. if (!Len) {
  45. Len = In->Next(&Buf);
  46. if (!Len) {
  47. return false;
  48. }
  49. }
  50. const T toread = (T)Min((size_t)Max<T>(), Len);
  51. *len = toread;
  52. *buf = (P*)Buf;
  53. Buf += toread;
  54. Len -= toread;
  55. return true;
  56. }
  57. IZeroCopyInput* In;
  58. const char* Buf;
  59. size_t Len;
  60. };
  61. } // namespace
  62. class TZLibDecompress::TImpl: private TZLibCommon, public TChunkedZeroCopyInput {
  63. public:
  64. inline TImpl(IZeroCopyInput* in, ZLib::StreamType type, TStringBuf dict)
  65. : TChunkedZeroCopyInput(in)
  66. , Dict(dict)
  67. {
  68. if (inflateInit2(Z(), opts[type]) != Z_OK) {
  69. ythrow TZLibDecompressorError() << "can not init inflate engine";
  70. }
  71. if (dict.size() && type == ZLib::Raw) {
  72. SetDict();
  73. }
  74. }
  75. virtual ~TImpl() {
  76. inflateEnd(Z());
  77. }
  78. void SetAllowMultipleStreams(bool allowMultipleStreams) {
  79. AllowMultipleStreams_ = allowMultipleStreams;
  80. }
  81. inline size_t Read(void* buf, size_t size) {
  82. Z()->next_out = (unsigned char*)buf;
  83. Z()->avail_out = size;
  84. while (true) {
  85. if (Z()->avail_in == 0) {
  86. if (!FillInputBuffer()) {
  87. return 0;
  88. }
  89. }
  90. switch (inflate(Z(), Z_SYNC_FLUSH)) {
  91. case Z_NEED_DICT: {
  92. SetDict();
  93. continue;
  94. }
  95. case Z_STREAM_END: {
  96. if (AllowMultipleStreams_) {
  97. if (inflateReset(Z()) != Z_OK) {
  98. ythrow TZLibDecompressorError() << "inflate reset error(" << GetErrMsg() << ")";
  99. }
  100. } else {
  101. return size - Z()->avail_out;
  102. }
  103. [[fallthrough]];
  104. }
  105. case Z_OK: {
  106. const size_t processed = size - Z()->avail_out;
  107. if (processed) {
  108. return processed;
  109. }
  110. break;
  111. }
  112. default:
  113. ythrow TZLibDecompressorError() << "inflate error(" << GetErrMsg() << ")";
  114. }
  115. }
  116. }
  117. private:
  118. inline bool FillInputBuffer() {
  119. return Next(&Z()->next_in, &Z()->avail_in);
  120. }
  121. void SetDict() {
  122. if (inflateSetDictionary(Z(), (const Bytef*)Dict.data(), Dict.size()) != Z_OK) {
  123. ythrow TZLibCompressorError() << "can not set inflate dictionary";
  124. }
  125. }
  126. bool AllowMultipleStreams_ = true;
  127. TStringBuf Dict;
  128. };
  129. namespace {
  130. class TDecompressStream: public IZeroCopyInput, public TZLibDecompress::TImpl, public TAdditionalStorage<TDecompressStream> {
  131. public:
  132. inline TDecompressStream(IInputStream* input, ZLib::StreamType type, TStringBuf dict)
  133. : TZLibDecompress::TImpl(this, type, dict)
  134. , Stream_(input)
  135. {
  136. }
  137. ~TDecompressStream() override = default;
  138. private:
  139. size_t DoNext(const void** ptr, size_t len) override {
  140. void* buf = AdditionalData();
  141. *ptr = buf;
  142. return Stream_->Read(buf, Min(len, AdditionalDataLength()));
  143. }
  144. private:
  145. IInputStream* Stream_;
  146. };
  147. using TZeroCopyDecompress = TZLibDecompress::TImpl;
  148. } // namespace
  149. class TZLibCompress::TImpl: public TAdditionalStorage<TImpl>, private TZLibCommon {
  150. static inline ZLib::StreamType Type(ZLib::StreamType type) {
  151. if (type == ZLib::Auto) {
  152. return ZLib::ZLib;
  153. }
  154. if (type >= ZLib::Invalid) {
  155. ythrow TZLibError() << "invalid compression type: " << static_cast<unsigned long>(type);
  156. }
  157. return type;
  158. }
  159. public:
  160. inline TImpl(const TParams& p)
  161. : Stream_(p.Out)
  162. {
  163. if (deflateInit2(Z(), Min<size_t>(9, p.CompressionLevel), Z_DEFLATED, opts[Type(p.Type)], 8, Z_DEFAULT_STRATEGY)) {
  164. ythrow TZLibCompressorError() << "can not init inflate engine";
  165. }
  166. // Create exactly the same files on all platforms by fixing OS field in the header.
  167. if (p.Type == ZLib::GZip) {
  168. GZHeader_ = MakeHolder<gz_header>();
  169. GZHeader_->os = 3; // UNIX
  170. deflateSetHeader(Z(), GZHeader_.Get());
  171. }
  172. if (p.Dict.size()) {
  173. if (deflateSetDictionary(Z(), (const Bytef*)p.Dict.data(), p.Dict.size())) {
  174. ythrow TZLibCompressorError() << "can not set deflate dictionary";
  175. }
  176. }
  177. Z()->next_out = TmpBuf();
  178. Z()->avail_out = TmpBufLen();
  179. }
  180. inline ~TImpl() {
  181. deflateEnd(Z());
  182. }
  183. inline void Write(const void* buf, size_t size) {
  184. const Bytef* b = (const Bytef*)buf;
  185. const Bytef* e = b + size;
  186. Y_DEFER {
  187. Z()->next_in = nullptr;
  188. Z()->avail_in = 0;
  189. };
  190. do {
  191. b = WritePart(b, e);
  192. } while (b < e);
  193. }
  194. inline const Bytef* WritePart(const Bytef* b, const Bytef* e) {
  195. Z()->next_in = const_cast<Bytef*>(b);
  196. Z()->avail_in = MaxPortion(e - b);
  197. while (Z()->avail_in) {
  198. const int ret = deflate(Z(), Z_NO_FLUSH);
  199. switch (ret) {
  200. case Z_OK:
  201. continue;
  202. case Z_BUF_ERROR:
  203. FlushBuffer();
  204. break;
  205. default:
  206. ythrow TZLibCompressorError() << "deflate error(" << GetErrMsg() << ")";
  207. }
  208. }
  209. return Z()->next_in;
  210. }
  211. inline void Flush() {
  212. int ret = deflate(Z(), Z_SYNC_FLUSH);
  213. while ((ret == Z_OK || ret == Z_BUF_ERROR) && !Z()->avail_out) {
  214. FlushBuffer();
  215. ret = deflate(Z(), Z_SYNC_FLUSH);
  216. }
  217. if (ret != Z_OK && ret != Z_BUF_ERROR) {
  218. ythrow TZLibCompressorError() << "deflate flush error(" << GetErrMsg() << ")";
  219. }
  220. if (Z()->avail_out < TmpBufLen()) {
  221. FlushBuffer();
  222. }
  223. }
  224. inline void FlushBuffer() {
  225. Stream_->Write(TmpBuf(), TmpBufLen() - Z()->avail_out);
  226. Z()->next_out = TmpBuf();
  227. Z()->avail_out = TmpBufLen();
  228. }
  229. inline void Finish() {
  230. int ret = deflate(Z(), Z_FINISH);
  231. while (ret == Z_OK || ret == Z_BUF_ERROR) {
  232. FlushBuffer();
  233. ret = deflate(Z(), Z_FINISH);
  234. }
  235. if (ret == Z_STREAM_END) {
  236. Stream_->Write(TmpBuf(), TmpBufLen() - Z()->avail_out);
  237. } else {
  238. ythrow TZLibCompressorError() << "deflate finish error(" << GetErrMsg() << ")";
  239. }
  240. }
  241. private:
  242. inline unsigned char* TmpBuf() noexcept {
  243. return (unsigned char*)AdditionalData();
  244. }
  245. inline size_t TmpBufLen() const noexcept {
  246. return AdditionalDataLength();
  247. }
  248. private:
  249. IOutputStream* Stream_;
  250. THolder<gz_header> GZHeader_;
  251. };
  252. TZLibDecompress::TZLibDecompress(IZeroCopyInput* input, ZLib::StreamType type, TStringBuf dict)
  253. : Impl_(new TZeroCopyDecompress(input, type, dict))
  254. {
  255. }
  256. TZLibDecompress::TZLibDecompress(IInputStream* input, ZLib::StreamType type, size_t buflen, TStringBuf dict)
  257. : Impl_(new (buflen) TDecompressStream(input, type, dict))
  258. {
  259. }
  260. void TZLibDecompress::SetAllowMultipleStreams(bool allowMultipleStreams) {
  261. Impl_->SetAllowMultipleStreams(allowMultipleStreams);
  262. }
  263. TZLibDecompress::~TZLibDecompress() = default;
  264. size_t TZLibDecompress::DoRead(void* buf, size_t size) {
  265. return Impl_->Read(buf, MaxPortion(size));
  266. }
  267. void TZLibCompress::Init(const TParams& params) {
  268. Y_ENSURE(params.BufLen >= 16, "ZLib buffer too small");
  269. Impl_.Reset(new (params.BufLen) TImpl(params));
  270. }
  271. void TZLibCompress::TDestruct::Destroy(TImpl* impl) {
  272. delete impl;
  273. }
  274. TZLibCompress::~TZLibCompress() {
  275. try {
  276. Finish();
  277. } catch (...) {
  278. // ¯\_(ツ)_/¯
  279. }
  280. }
  281. void TZLibCompress::DoWrite(const void* buf, size_t size) {
  282. if (!Impl_) {
  283. ythrow TZLibCompressorError() << "can not write to finished zlib stream";
  284. }
  285. Impl_->Write(buf, size);
  286. }
  287. void TZLibCompress::DoFlush() {
  288. if (Impl_) {
  289. Impl_->Flush();
  290. }
  291. }
  292. void TZLibCompress::DoFinish() {
  293. THolder<TImpl> impl(Impl_.Release());
  294. if (impl) {
  295. impl->Finish();
  296. }
  297. }
  298. TBufferedZLibDecompress::~TBufferedZLibDecompress() = default;