#include "zstd.h" #include #include #define ZSTD_STATIC_LINKING_ONLY #include namespace { inline void CheckError(const char* op, size_t code) { if (::ZSTD_isError(code)) { ythrow yexception() << op << TStringBuf(" zstd error: ") << ::ZSTD_getErrorName(code); } } struct DestroyZCStream { static void Destroy(::ZSTD_CStream* p) noexcept { ::ZSTD_freeCStream(p); } }; struct DestroyZDStream { static void Destroy(::ZSTD_DStream* p) noexcept { ::ZSTD_freeDStream(p); } }; } class TZstdCompress::TImpl { public: TImpl(IOutputStream* slave, int quality) : Slave_(slave) , ZCtx_(::ZSTD_createCStream()) , Buffer_(::ZSTD_CStreamOutSize()) // do reserve { Y_ENSURE(nullptr != ZCtx_.Get(), "Failed to allocate ZSTD_CStream"); Y_ENSURE(0 != Buffer_.Capacity(), "ZSTD_CStreamOutSize was too small"); CheckError("init", ZSTD_initCStream(ZCtx_.Get(), quality)); } void Write(const void* buffer, size_t size) { ::ZSTD_inBuffer zIn{buffer, size, 0}; auto zOut = OutBuffer(); while (0 != zIn.size) { CheckError("compress", ::ZSTD_compressStream(ZCtx_.Get(), &zOut, &zIn)); DoWrite(zOut); // forget about the data we already compressed zIn.src = static_cast(zIn.src) + zIn.pos; zIn.size -= zIn.pos; zIn.pos = 0; } } void Flush() { auto zOut = OutBuffer(); CheckError("flush", ::ZSTD_flushStream(ZCtx_.Get(), &zOut)); DoWrite(zOut); } void Finish() { auto zOut = OutBuffer(); size_t returnCode; do { returnCode = ::ZSTD_endStream(ZCtx_.Get(), &zOut); CheckError("finish", returnCode); DoWrite(zOut); } while (0 != returnCode); // zero means there is no more bytes to flush } private: ::ZSTD_outBuffer OutBuffer() { return {Buffer_.Data(), Buffer_.Capacity(), 0}; } void DoWrite(::ZSTD_outBuffer& buffer) { Slave_->Write(buffer.dst, buffer.pos); buffer.pos = 0; } private: IOutputStream* Slave_; THolder<::ZSTD_CStream, DestroyZCStream> ZCtx_; TBuffer Buffer_; }; TZstdCompress::TZstdCompress(IOutputStream* slave, int quality) : Impl_(new TImpl(slave, quality)) { } TZstdCompress::~TZstdCompress() { try { Finish(); } catch (...) { } } void TZstdCompress::DoWrite(const void* buffer, size_t size) { Y_ENSURE(Impl_, "Cannot use stream after finish."); Impl_->Write(buffer, size); } void TZstdCompress::DoFlush() { Y_ENSURE(Impl_, "Cannot use stream after finish."); Impl_->Flush(); } void TZstdCompress::DoFinish() { // Finish should be idempotent if (Impl_) { auto impl = std::move(Impl_); impl->Finish(); } } //////////////////////////////////////////////////////////////////////////////// class TZstdDecompress::TImpl { public: TImpl(IInputStream* slave, size_t bufferSize) : Slave_(slave) , ZCtx_(::ZSTD_createDStream()) , Buffer_(bufferSize) // do reserve , Offset_(0) { Y_ENSURE(nullptr != ZCtx_.Get(), "Failed to allocate ZSTD_DStream"); Y_ENSURE(0 != Buffer_.Capacity(), "Buffer size was too small"); } size_t Read(void* buffer, size_t size) { Y_ASSERT(size > 0); ::ZSTD_outBuffer zOut{buffer, size, 0}; ::ZSTD_inBuffer zIn{Buffer_.Data(), Buffer_.Size(), Offset_}; size_t returnCode = 0; while (zOut.pos != zOut.size) { if (zIn.pos == zIn.size) { zIn.size = Slave_->Read(Buffer_.Data(), Buffer_.Capacity()); Buffer_.Resize(zIn.size); zIn.pos = Offset_ = 0; if (0 == zIn.size) { // end of stream, need to check that there is no uncompleted blocks Y_ENSURE(0 == returnCode, "Incomplete block"); break; } } returnCode = ::ZSTD_decompressStream(ZCtx_.Get(), &zOut, &zIn); CheckError("decompress", returnCode); if (0 == returnCode) { // The frame is over, prepare to (maybe) start a new frame ZSTD_initDStream(ZCtx_.Get()); } } Offset_ = zIn.pos; return zOut.pos; } private: IInputStream* Slave_; THolder<::ZSTD_DStream, DestroyZDStream> ZCtx_; TBuffer Buffer_; size_t Offset_; }; TZstdDecompress::TZstdDecompress(IInputStream* slave, size_t bufferSize) : Impl_(new TImpl(slave, bufferSize)) { } TZstdDecompress::~TZstdDecompress() = default; size_t TZstdDecompress::DoRead(void* buffer, size_t size) { return Impl_->Read(buffer, size); }