zstd.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #include "zstd.h"
  2. #include <util/generic/buffer.h>
  3. #include <util/generic/yexception.h>
  4. #define ZSTD_STATIC_LINKING_ONLY
  5. #include <contrib/libs/zstd/include/zstd.h>
  6. namespace {
  7. inline void CheckError(const char* op, size_t code) {
  8. if (::ZSTD_isError(code)) {
  9. ythrow yexception() << op << TStringBuf(" zstd error: ") << ::ZSTD_getErrorName(code);
  10. }
  11. }
  12. struct DestroyZCStream {
  13. static void Destroy(::ZSTD_CStream* p) noexcept {
  14. ::ZSTD_freeCStream(p);
  15. }
  16. };
  17. struct DestroyZDStream {
  18. static void Destroy(::ZSTD_DStream* p) noexcept {
  19. ::ZSTD_freeDStream(p);
  20. }
  21. };
  22. }
  23. class TZstdCompress::TImpl {
  24. public:
  25. TImpl(IOutputStream* slave, int quality)
  26. : Slave_(slave)
  27. , ZCtx_(::ZSTD_createCStream())
  28. , Buffer_(::ZSTD_CStreamOutSize()) // do reserve
  29. {
  30. Y_ENSURE(nullptr != ZCtx_.Get(), "Failed to allocate ZSTD_CStream");
  31. Y_ENSURE(0 != Buffer_.Capacity(), "ZSTD_CStreamOutSize was too small");
  32. CheckError("init", ZSTD_initCStream(ZCtx_.Get(), quality));
  33. }
  34. void Write(const void* buffer, size_t size) {
  35. ::ZSTD_inBuffer zIn{buffer, size, 0};
  36. auto zOut = OutBuffer();
  37. while (0 != zIn.size) {
  38. CheckError("compress", ::ZSTD_compressStream(ZCtx_.Get(), &zOut, &zIn));
  39. DoWrite(zOut);
  40. // forget about the data we already compressed
  41. zIn.src = static_cast<const unsigned char*>(zIn.src) + zIn.pos;
  42. zIn.size -= zIn.pos;
  43. zIn.pos = 0;
  44. }
  45. }
  46. void Flush() {
  47. auto zOut = OutBuffer();
  48. CheckError("flush", ::ZSTD_flushStream(ZCtx_.Get(), &zOut));
  49. DoWrite(zOut);
  50. }
  51. void Finish() {
  52. auto zOut = OutBuffer();
  53. size_t returnCode;
  54. do {
  55. returnCode = ::ZSTD_endStream(ZCtx_.Get(), &zOut);
  56. CheckError("finish", returnCode);
  57. DoWrite(zOut);
  58. } while (0 != returnCode); // zero means there is no more bytes to flush
  59. }
  60. private:
  61. ::ZSTD_outBuffer OutBuffer() {
  62. return {Buffer_.Data(), Buffer_.Capacity(), 0};
  63. }
  64. void DoWrite(::ZSTD_outBuffer& buffer) {
  65. Slave_->Write(buffer.dst, buffer.pos);
  66. buffer.pos = 0;
  67. }
  68. private:
  69. IOutputStream* Slave_;
  70. THolder<::ZSTD_CStream, DestroyZCStream> ZCtx_;
  71. TBuffer Buffer_;
  72. };
  73. TZstdCompress::TZstdCompress(IOutputStream* slave, int quality)
  74. : Impl_(new TImpl(slave, quality)) {
  75. }
  76. TZstdCompress::~TZstdCompress() {
  77. try {
  78. Finish();
  79. } catch (...) {
  80. }
  81. }
  82. void TZstdCompress::DoWrite(const void* buffer, size_t size) {
  83. Y_ENSURE(Impl_, "Cannot use stream after finish.");
  84. Impl_->Write(buffer, size);
  85. }
  86. void TZstdCompress::DoFlush() {
  87. Y_ENSURE(Impl_, "Cannot use stream after finish.");
  88. Impl_->Flush();
  89. }
  90. void TZstdCompress::DoFinish() {
  91. // Finish should be idempotent
  92. if (Impl_) {
  93. auto impl = std::move(Impl_);
  94. impl->Finish();
  95. }
  96. }
  97. ////////////////////////////////////////////////////////////////////////////////
  98. class TZstdDecompress::TImpl {
  99. public:
  100. TImpl(IInputStream* slave, size_t bufferSize)
  101. : Slave_(slave)
  102. , ZCtx_(::ZSTD_createDStream())
  103. , Buffer_(bufferSize) // do reserve
  104. , Offset_(0)
  105. {
  106. Y_ENSURE(nullptr != ZCtx_.Get(), "Failed to allocate ZSTD_DStream");
  107. Y_ENSURE(0 != Buffer_.Capacity(), "Buffer size was too small");
  108. }
  109. size_t Read(void* buffer, size_t size) {
  110. Y_ASSERT(size > 0);
  111. ::ZSTD_outBuffer zOut{buffer, size, 0};
  112. ::ZSTD_inBuffer zIn{Buffer_.Data(), Buffer_.Size(), Offset_};
  113. size_t returnCode = 0;
  114. while (zOut.pos != zOut.size) {
  115. if (zIn.pos == zIn.size) {
  116. zIn.size = Slave_->Read(Buffer_.Data(), Buffer_.Capacity());
  117. Buffer_.Resize(zIn.size);
  118. zIn.pos = Offset_ = 0;
  119. if (0 == zIn.size) {
  120. // end of stream, need to check that there is no uncompleted blocks
  121. Y_ENSURE(0 == returnCode, "Incomplete block");
  122. break;
  123. }
  124. }
  125. returnCode = ::ZSTD_decompressStream(ZCtx_.Get(), &zOut, &zIn);
  126. CheckError("decompress", returnCode);
  127. if (0 == returnCode) {
  128. // The frame is over, prepare to (maybe) start a new frame
  129. ZSTD_initDStream(ZCtx_.Get());
  130. }
  131. }
  132. Offset_ = zIn.pos;
  133. return zOut.pos;
  134. }
  135. private:
  136. IInputStream* Slave_;
  137. THolder<::ZSTD_DStream, DestroyZDStream> ZCtx_;
  138. TBuffer Buffer_;
  139. size_t Offset_;
  140. };
  141. TZstdDecompress::TZstdDecompress(IInputStream* slave, size_t bufferSize)
  142. : Impl_(new TImpl(slave, bufferSize)) {
  143. }
  144. TZstdDecompress::~TZstdDecompress() = default;
  145. size_t TZstdDecompress::DoRead(void* buffer, size_t size) {
  146. return Impl_->Read(buffer, size);
  147. }