encoder.go 15 KB


  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "fmt"
  8. "io"
  9. "math"
  10. rdebug "runtime/debug"
  11. "sync"
  12. "github.com/klauspost/compress/zstd/internal/xxhash"
  13. )
  14. // Encoder provides encoding to Zstandard.
  15. // An Encoder can be used for either compressing a stream via the
  16. // io.WriteCloser interface supported by the Encoder or as multiple independent
  17. // tasks via the EncodeAll function.
  18. // Smaller encodes are encouraged to use the EncodeAll function.
  19. // Use NewWriter to create a new instance.
  20. type Encoder struct {
  21. o encoderOptions
  22. encoders chan encoder
  23. state encoderState
  24. init sync.Once
  25. }
  26. type encoder interface {
  27. Encode(blk *blockEnc, src []byte)
  28. EncodeNoHist(blk *blockEnc, src []byte)
  29. Block() *blockEnc
  30. CRC() *xxhash.Digest
  31. AppendCRC([]byte) []byte
  32. WindowSize(size int64) int32
  33. UseBlock(*blockEnc)
  34. Reset(d *dict, singleBlock bool)
  35. }
  36. type encoderState struct {
  37. w io.Writer
  38. filling []byte
  39. current []byte
  40. previous []byte
  41. encoder encoder
  42. writing *blockEnc
  43. err error
  44. writeErr error
  45. nWritten int64
  46. nInput int64
  47. frameContentSize int64
  48. headerWritten bool
  49. eofWritten bool
  50. fullFrameWritten bool
  51. // This waitgroup indicates an encode is running.
  52. wg sync.WaitGroup
  53. // This waitgroup indicates we have a block encoding/writing.
  54. wWg sync.WaitGroup
  55. }
  56. // NewWriter will create a new Zstandard encoder.
  57. // If the encoder will be used for encoding blocks a nil writer can be used.
  58. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  59. initPredefined()
  60. var e Encoder
  61. e.o.setDefault()
  62. for _, o := range opts {
  63. err := o(&e.o)
  64. if err != nil {
  65. return nil, err
  66. }
  67. }
  68. if w != nil {
  69. e.Reset(w)
  70. }
  71. return &e, nil
  72. }
  73. func (e *Encoder) initialize() {
  74. if e.o.concurrent == 0 {
  75. e.o.setDefault()
  76. }
  77. e.encoders = make(chan encoder, e.o.concurrent)
  78. for i := 0; i < e.o.concurrent; i++ {
  79. enc := e.o.encoder()
  80. e.encoders <- enc
  81. }
  82. }
  83. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  84. // as a new, independent stream.
  85. func (e *Encoder) Reset(w io.Writer) {
  86. s := &e.state
  87. s.wg.Wait()
  88. s.wWg.Wait()
  89. if cap(s.filling) == 0 {
  90. s.filling = make([]byte, 0, e.o.blockSize)
  91. }
  92. if e.o.concurrent > 1 {
  93. if cap(s.current) == 0 {
  94. s.current = make([]byte, 0, e.o.blockSize)
  95. }
  96. if cap(s.previous) == 0 {
  97. s.previous = make([]byte, 0, e.o.blockSize)
  98. }
  99. s.current = s.current[:0]
  100. s.previous = s.previous[:0]
  101. if s.writing == nil {
  102. s.writing = &blockEnc{lowMem: e.o.lowMem}
  103. s.writing.init()
  104. }
  105. s.writing.initNewEncode()
  106. }
  107. if s.encoder == nil {
  108. s.encoder = e.o.encoder()
  109. }
  110. s.filling = s.filling[:0]
  111. s.encoder.Reset(e.o.dict, false)
  112. s.headerWritten = false
  113. s.eofWritten = false
  114. s.fullFrameWritten = false
  115. s.w = w
  116. s.err = nil
  117. s.nWritten = 0
  118. s.nInput = 0
  119. s.writeErr = nil
  120. s.frameContentSize = 0
  121. }
  122. // ResetContentSize will reset and set a content size for the next stream.
  123. // If the bytes written does not match the size given an error will be returned
  124. // when calling Close().
  125. // This is removed when Reset is called.
  126. // Sizes <= 0 results in no content size set.
  127. func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
  128. e.Reset(w)
  129. if size >= 0 {
  130. e.state.frameContentSize = size
  131. }
  132. }
  133. // Write data to the encoder.
  134. // Input data will be buffered and as the buffer fills up
  135. // content will be compressed and written to the output.
  136. // When done writing, use Close to flush the remaining output
  137. // and write CRC if requested.
  138. func (e *Encoder) Write(p []byte) (n int, err error) {
  139. s := &e.state
  140. for len(p) > 0 {
  141. if len(p)+len(s.filling) < e.o.blockSize {
  142. if e.o.crc {
  143. _, _ = s.encoder.CRC().Write(p)
  144. }
  145. s.filling = append(s.filling, p...)
  146. return n + len(p), nil
  147. }
  148. add := p
  149. if len(p)+len(s.filling) > e.o.blockSize {
  150. add = add[:e.o.blockSize-len(s.filling)]
  151. }
  152. if e.o.crc {
  153. _, _ = s.encoder.CRC().Write(add)
  154. }
  155. s.filling = append(s.filling, add...)
  156. p = p[len(add):]
  157. n += len(add)
  158. if len(s.filling) < e.o.blockSize {
  159. return n, nil
  160. }
  161. err := e.nextBlock(false)
  162. if err != nil {
  163. return n, err
  164. }
  165. if debugAsserts && len(s.filling) > 0 {
  166. panic(len(s.filling))
  167. }
  168. }
  169. return n, nil
  170. }
  171. // nextBlock will synchronize and start compressing input in e.state.filling.
  172. // If an error has occurred during encoding it will be returned.
  173. func (e *Encoder) nextBlock(final bool) error {
  174. s := &e.state
  175. // Wait for current block.
  176. s.wg.Wait()
  177. if s.err != nil {
  178. return s.err
  179. }
  180. if len(s.filling) > e.o.blockSize {
  181. return fmt.Errorf("block > maxStoreBlockSize")
  182. }
  183. if !s.headerWritten {
  184. // If we have a single block encode, do a sync compression.
  185. if final && len(s.filling) == 0 && !e.o.fullZero {
  186. s.headerWritten = true
  187. s.fullFrameWritten = true
  188. s.eofWritten = true
  189. return nil
  190. }
  191. if final && len(s.filling) > 0 {
  192. s.current = e.EncodeAll(s.filling, s.current[:0])
  193. var n2 int
  194. n2, s.err = s.w.Write(s.current)
  195. if s.err != nil {
  196. return s.err
  197. }
  198. s.nWritten += int64(n2)
  199. s.nInput += int64(len(s.filling))
  200. s.current = s.current[:0]
  201. s.filling = s.filling[:0]
  202. s.headerWritten = true
  203. s.fullFrameWritten = true
  204. s.eofWritten = true
  205. return nil
  206. }
  207. var tmp [maxHeaderSize]byte
  208. fh := frameHeader{
  209. ContentSize: uint64(s.frameContentSize),
  210. WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
  211. SingleSegment: false,
  212. Checksum: e.o.crc,
  213. DictID: e.o.dict.ID(),
  214. }
  215. dst, err := fh.appendTo(tmp[:0])
  216. if err != nil {
  217. return err
  218. }
  219. s.headerWritten = true
  220. s.wWg.Wait()
  221. var n2 int
  222. n2, s.err = s.w.Write(dst)
  223. if s.err != nil {
  224. return s.err
  225. }
  226. s.nWritten += int64(n2)
  227. }
  228. if s.eofWritten {
  229. // Ensure we only write it once.
  230. final = false
  231. }
  232. if len(s.filling) == 0 {
  233. // Final block, but no data.
  234. if final {
  235. enc := s.encoder
  236. blk := enc.Block()
  237. blk.reset(nil)
  238. blk.last = true
  239. blk.encodeRaw(nil)
  240. s.wWg.Wait()
  241. _, s.err = s.w.Write(blk.output)
  242. s.nWritten += int64(len(blk.output))
  243. s.eofWritten = true
  244. }
  245. return s.err
  246. }
  247. // SYNC:
  248. if e.o.concurrent == 1 {
  249. src := s.filling
  250. s.nInput += int64(len(s.filling))
  251. if debugEncoder {
  252. println("Adding sync block,", len(src), "bytes, final:", final)
  253. }
  254. enc := s.encoder
  255. blk := enc.Block()
  256. blk.reset(nil)
  257. enc.Encode(blk, src)
  258. blk.last = final
  259. if final {
  260. s.eofWritten = true
  261. }
  262. s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  263. if s.err != nil {
  264. return s.err
  265. }
  266. _, s.err = s.w.Write(blk.output)
  267. s.nWritten += int64(len(blk.output))
  268. s.filling = s.filling[:0]
  269. return s.err
  270. }
  271. // Move blocks forward.
  272. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  273. s.nInput += int64(len(s.current))
  274. s.wg.Add(1)
  275. go func(src []byte) {
  276. if debugEncoder {
  277. println("Adding block,", len(src), "bytes, final:", final)
  278. }
  279. defer func() {
  280. if r := recover(); r != nil {
  281. s.err = fmt.Errorf("panic while encoding: %v", r)
  282. rdebug.PrintStack()
  283. }
  284. s.wg.Done()
  285. }()
  286. enc := s.encoder
  287. blk := enc.Block()
  288. enc.Encode(blk, src)
  289. blk.last = final
  290. if final {
  291. s.eofWritten = true
  292. }
  293. // Wait for pending writes.
  294. s.wWg.Wait()
  295. if s.writeErr != nil {
  296. s.err = s.writeErr
  297. return
  298. }
  299. // Transfer encoders from previous write block.
  300. blk.swapEncoders(s.writing)
  301. // Transfer recent offsets to next.
  302. enc.UseBlock(s.writing)
  303. s.writing = blk
  304. s.wWg.Add(1)
  305. go func() {
  306. defer func() {
  307. if r := recover(); r != nil {
  308. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  309. rdebug.PrintStack()
  310. }
  311. s.wWg.Done()
  312. }()
  313. s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  314. if s.writeErr != nil {
  315. return
  316. }
  317. _, s.writeErr = s.w.Write(blk.output)
  318. s.nWritten += int64(len(blk.output))
  319. }()
  320. }(s.current)
  321. return nil
  322. }
  323. // ReadFrom reads data from r until EOF or error.
  324. // The return value n is the number of bytes read.
  325. // Any error except io.EOF encountered during the read is also returned.
  326. //
  327. // The Copy function uses ReaderFrom if available.
  328. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  329. if debugEncoder {
  330. println("Using ReadFrom")
  331. }
  332. // Flush any current writes.
  333. if len(e.state.filling) > 0 {
  334. if err := e.nextBlock(false); err != nil {
  335. return 0, err
  336. }
  337. }
  338. e.state.filling = e.state.filling[:e.o.blockSize]
  339. src := e.state.filling
  340. for {
  341. n2, err := r.Read(src)
  342. if e.o.crc {
  343. _, _ = e.state.encoder.CRC().Write(src[:n2])
  344. }
  345. // src is now the unfilled part...
  346. src = src[n2:]
  347. n += int64(n2)
  348. switch err {
  349. case io.EOF:
  350. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  351. if debugEncoder {
  352. println("ReadFrom: got EOF final block:", len(e.state.filling))
  353. }
  354. return n, nil
  355. case nil:
  356. default:
  357. if debugEncoder {
  358. println("ReadFrom: got error:", err)
  359. }
  360. e.state.err = err
  361. return n, err
  362. }
  363. if len(src) > 0 {
  364. if debugEncoder {
  365. println("ReadFrom: got space left in source:", len(src))
  366. }
  367. continue
  368. }
  369. err = e.nextBlock(false)
  370. if err != nil {
  371. return n, err
  372. }
  373. e.state.filling = e.state.filling[:e.o.blockSize]
  374. src = e.state.filling
  375. }
  376. }
  377. // Flush will send the currently written data to output
  378. // and block until everything has been written.
  379. // This should only be used on rare occasions where pushing the currently queued data is critical.
  380. func (e *Encoder) Flush() error {
  381. s := &e.state
  382. if len(s.filling) > 0 {
  383. err := e.nextBlock(false)
  384. if err != nil {
  385. return err
  386. }
  387. }
  388. s.wg.Wait()
  389. s.wWg.Wait()
  390. if s.err != nil {
  391. return s.err
  392. }
  393. return s.writeErr
  394. }
  395. // Close will flush the final output and close the stream.
  396. // The function will block until everything has been written.
  397. // The Encoder can still be re-used after calling this.
  398. func (e *Encoder) Close() error {
  399. s := &e.state
  400. if s.encoder == nil {
  401. return nil
  402. }
  403. err := e.nextBlock(true)
  404. if err != nil {
  405. return err
  406. }
  407. if s.frameContentSize > 0 {
  408. if s.nInput != s.frameContentSize {
  409. return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
  410. }
  411. }
  412. if e.state.fullFrameWritten {
  413. return s.err
  414. }
  415. s.wg.Wait()
  416. s.wWg.Wait()
  417. if s.err != nil {
  418. return s.err
  419. }
  420. if s.writeErr != nil {
  421. return s.writeErr
  422. }
  423. // Write CRC
  424. if e.o.crc && s.err == nil {
  425. // heap alloc.
  426. var tmp [4]byte
  427. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  428. s.nWritten += 4
  429. }
  430. // Add padding with content from crypto/rand.Reader
  431. if s.err == nil && e.o.pad > 0 {
  432. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  433. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  434. if err != nil {
  435. return err
  436. }
  437. _, s.err = s.w.Write(frame)
  438. }
  439. return s.err
  440. }
  441. // EncodeAll will encode all input in src and append it to dst.
  442. // This function can be called concurrently, but each call will only run on a single goroutine.
  443. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  444. // Encoded blocks can be concatenated and the result will be the combined input stream.
  445. // Data compressed with EncodeAll can be decoded with the Decoder,
  446. // using either a stream or DecodeAll.
  447. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  448. if len(src) == 0 {
  449. if e.o.fullZero {
  450. // Add frame header.
  451. fh := frameHeader{
  452. ContentSize: 0,
  453. WindowSize: MinWindowSize,
  454. SingleSegment: true,
  455. // Adding a checksum would be a waste of space.
  456. Checksum: false,
  457. DictID: 0,
  458. }
  459. dst, _ = fh.appendTo(dst)
  460. // Write raw block as last one only.
  461. var blk blockHeader
  462. blk.setSize(0)
  463. blk.setType(blockTypeRaw)
  464. blk.setLast(true)
  465. dst = blk.appendTo(dst)
  466. }
  467. return dst
  468. }
  469. e.init.Do(e.initialize)
  470. enc := <-e.encoders
  471. defer func() {
  472. // Release encoder reference to last block.
  473. // If a non-single block is needed the encoder will reset again.
  474. e.encoders <- enc
  475. }()
  476. // Use single segments when above minimum window and below window size.
  477. single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
  478. if e.o.single != nil {
  479. single = *e.o.single
  480. }
  481. fh := frameHeader{
  482. ContentSize: uint64(len(src)),
  483. WindowSize: uint32(enc.WindowSize(int64(len(src)))),
  484. SingleSegment: single,
  485. Checksum: e.o.crc,
  486. DictID: e.o.dict.ID(),
  487. }
  488. // If less than 1MB, allocate a buffer up front.
  489. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
  490. dst = make([]byte, 0, len(src))
  491. }
  492. dst, err := fh.appendTo(dst)
  493. if err != nil {
  494. panic(err)
  495. }
  496. // If we can do everything in one block, prefer that.
  497. if len(src) <= e.o.blockSize {
  498. enc.Reset(e.o.dict, true)
  499. // Slightly faster with no history and everything in one block.
  500. if e.o.crc {
  501. _, _ = enc.CRC().Write(src)
  502. }
  503. blk := enc.Block()
  504. blk.last = true
  505. if e.o.dict == nil {
  506. enc.EncodeNoHist(blk, src)
  507. } else {
  508. enc.Encode(blk, src)
  509. }
  510. // If we got the exact same number of literals as input,
  511. // assume the literals cannot be compressed.
  512. oldout := blk.output
  513. // Output directly to dst
  514. blk.output = dst
  515. err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  516. if err != nil {
  517. panic(err)
  518. }
  519. dst = blk.output
  520. blk.output = oldout
  521. } else {
  522. enc.Reset(e.o.dict, false)
  523. blk := enc.Block()
  524. for len(src) > 0 {
  525. todo := src
  526. if len(todo) > e.o.blockSize {
  527. todo = todo[:e.o.blockSize]
  528. }
  529. src = src[len(todo):]
  530. if e.o.crc {
  531. _, _ = enc.CRC().Write(todo)
  532. }
  533. blk.pushOffsets()
  534. enc.Encode(blk, todo)
  535. if len(src) == 0 {
  536. blk.last = true
  537. }
  538. err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
  539. if err != nil {
  540. panic(err)
  541. }
  542. dst = append(dst, blk.output...)
  543. blk.reset(nil)
  544. }
  545. }
  546. if e.o.crc {
  547. dst = enc.AppendCRC(dst)
  548. }
  549. // Add padding with content from crypto/rand.Reader
  550. if e.o.pad > 0 {
  551. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  552. dst, err = skippableFrame(dst, add, rand.Reader)
  553. if err != nil {
  554. panic(err)
  555. }
  556. }
  557. return dst
  558. }
  559. // MaxEncodedSize returns the expected maximum
  560. // size of an encoded block or stream.
  561. func (e *Encoder) MaxEncodedSize(size int) int {
  562. frameHeader := 4 + 2 // magic + frame header & window descriptor
  563. if e.o.dict != nil {
  564. frameHeader += 4
  565. }
  566. // Frame content size:
  567. if size < 256 {
  568. frameHeader++
  569. } else if size < 65536+256 {
  570. frameHeader += 2
  571. } else if size < math.MaxInt32 {
  572. frameHeader += 4
  573. } else {
  574. frameHeader += 8
  575. }
  576. // Final crc
  577. if e.o.crc {
  578. frameHeader += 4
  579. }
  580. // Max overhead is 3 bytes/block.
  581. // There cannot be 0 blocks.
  582. blocks := (size + e.o.blockSize) / e.o.blockSize
  583. // Combine, add padding.
  584. maxSz := frameHeader + 3*blocks + size
  585. if e.o.pad > 1 {
  586. maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
  587. }
  588. return maxSz
  589. }