123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624 |
- // Copyright 2019+ Klaus Post. All rights reserved.
- // License information can be found in the LICENSE file.
- // Based on work by Yann Collet, released under BSD License.
- package zstd
- import (
- "crypto/rand"
- "fmt"
- "io"
- "math"
- rdebug "runtime/debug"
- "sync"
- "github.com/klauspost/compress/zstd/internal/xxhash"
- )
- // Encoder provides encoding to Zstandard.
- // An Encoder can be used for either compressing a stream via the
- // io.WriteCloser interface supported by the Encoder or as multiple independent
- // tasks via the EncodeAll function.
- // Smaller encodes are encouraged to use the EncodeAll function.
- // Use NewWriter to create a new instance.
- type Encoder struct {
- o encoderOptions
- encoders chan encoder
- state encoderState
- init sync.Once
- }
- type encoder interface {
- Encode(blk *blockEnc, src []byte)
- EncodeNoHist(blk *blockEnc, src []byte)
- Block() *blockEnc
- CRC() *xxhash.Digest
- AppendCRC([]byte) []byte
- WindowSize(size int64) int32
- UseBlock(*blockEnc)
- Reset(d *dict, singleBlock bool)
- }
- type encoderState struct {
- w io.Writer
- filling []byte
- current []byte
- previous []byte
- encoder encoder
- writing *blockEnc
- err error
- writeErr error
- nWritten int64
- nInput int64
- frameContentSize int64
- headerWritten bool
- eofWritten bool
- fullFrameWritten bool
- // This waitgroup indicates an encode is running.
- wg sync.WaitGroup
- // This waitgroup indicates we have a block encoding/writing.
- wWg sync.WaitGroup
- }
- // NewWriter will create a new Zstandard encoder.
- // If the encoder will be used for encoding blocks a nil writer can be used.
- func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
- initPredefined()
- var e Encoder
- e.o.setDefault()
- for _, o := range opts {
- err := o(&e.o)
- if err != nil {
- return nil, err
- }
- }
- if w != nil {
- e.Reset(w)
- }
- return &e, nil
- }
- func (e *Encoder) initialize() {
- if e.o.concurrent == 0 {
- e.o.setDefault()
- }
- e.encoders = make(chan encoder, e.o.concurrent)
- for i := 0; i < e.o.concurrent; i++ {
- enc := e.o.encoder()
- e.encoders <- enc
- }
- }
- // Reset will re-initialize the writer and new writes will encode to the supplied writer
- // as a new, independent stream.
- func (e *Encoder) Reset(w io.Writer) {
- s := &e.state
- s.wg.Wait()
- s.wWg.Wait()
- if cap(s.filling) == 0 {
- s.filling = make([]byte, 0, e.o.blockSize)
- }
- if e.o.concurrent > 1 {
- if cap(s.current) == 0 {
- s.current = make([]byte, 0, e.o.blockSize)
- }
- if cap(s.previous) == 0 {
- s.previous = make([]byte, 0, e.o.blockSize)
- }
- s.current = s.current[:0]
- s.previous = s.previous[:0]
- if s.writing == nil {
- s.writing = &blockEnc{lowMem: e.o.lowMem}
- s.writing.init()
- }
- s.writing.initNewEncode()
- }
- if s.encoder == nil {
- s.encoder = e.o.encoder()
- }
- s.filling = s.filling[:0]
- s.encoder.Reset(e.o.dict, false)
- s.headerWritten = false
- s.eofWritten = false
- s.fullFrameWritten = false
- s.w = w
- s.err = nil
- s.nWritten = 0
- s.nInput = 0
- s.writeErr = nil
- s.frameContentSize = 0
- }
- // ResetContentSize will reset and set a content size for the next stream.
- // If the bytes written does not match the size given an error will be returned
- // when calling Close().
- // This is removed when Reset is called.
- // Sizes <= 0 results in no content size set.
- func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
- e.Reset(w)
- if size >= 0 {
- e.state.frameContentSize = size
- }
- }
- // Write data to the encoder.
- // Input data will be buffered and as the buffer fills up
- // content will be compressed and written to the output.
- // When done writing, use Close to flush the remaining output
- // and write CRC if requested.
- func (e *Encoder) Write(p []byte) (n int, err error) {
- s := &e.state
- for len(p) > 0 {
- if len(p)+len(s.filling) < e.o.blockSize {
- if e.o.crc {
- _, _ = s.encoder.CRC().Write(p)
- }
- s.filling = append(s.filling, p...)
- return n + len(p), nil
- }
- add := p
- if len(p)+len(s.filling) > e.o.blockSize {
- add = add[:e.o.blockSize-len(s.filling)]
- }
- if e.o.crc {
- _, _ = s.encoder.CRC().Write(add)
- }
- s.filling = append(s.filling, add...)
- p = p[len(add):]
- n += len(add)
- if len(s.filling) < e.o.blockSize {
- return n, nil
- }
- err := e.nextBlock(false)
- if err != nil {
- return n, err
- }
- if debugAsserts && len(s.filling) > 0 {
- panic(len(s.filling))
- }
- }
- return n, nil
- }
- // nextBlock will synchronize and start compressing input in e.state.filling.
- // If an error has occurred during encoding it will be returned.
- func (e *Encoder) nextBlock(final bool) error {
- s := &e.state
- // Wait for current block.
- s.wg.Wait()
- if s.err != nil {
- return s.err
- }
- if len(s.filling) > e.o.blockSize {
- return fmt.Errorf("block > maxStoreBlockSize")
- }
- if !s.headerWritten {
- // If we have a single block encode, do a sync compression.
- if final && len(s.filling) == 0 && !e.o.fullZero {
- s.headerWritten = true
- s.fullFrameWritten = true
- s.eofWritten = true
- return nil
- }
- if final && len(s.filling) > 0 {
- s.current = e.EncodeAll(s.filling, s.current[:0])
- var n2 int
- n2, s.err = s.w.Write(s.current)
- if s.err != nil {
- return s.err
- }
- s.nWritten += int64(n2)
- s.nInput += int64(len(s.filling))
- s.current = s.current[:0]
- s.filling = s.filling[:0]
- s.headerWritten = true
- s.fullFrameWritten = true
- s.eofWritten = true
- return nil
- }
- var tmp [maxHeaderSize]byte
- fh := frameHeader{
- ContentSize: uint64(s.frameContentSize),
- WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
- SingleSegment: false,
- Checksum: e.o.crc,
- DictID: e.o.dict.ID(),
- }
- dst, err := fh.appendTo(tmp[:0])
- if err != nil {
- return err
- }
- s.headerWritten = true
- s.wWg.Wait()
- var n2 int
- n2, s.err = s.w.Write(dst)
- if s.err != nil {
- return s.err
- }
- s.nWritten += int64(n2)
- }
- if s.eofWritten {
- // Ensure we only write it once.
- final = false
- }
- if len(s.filling) == 0 {
- // Final block, but no data.
- if final {
- enc := s.encoder
- blk := enc.Block()
- blk.reset(nil)
- blk.last = true
- blk.encodeRaw(nil)
- s.wWg.Wait()
- _, s.err = s.w.Write(blk.output)
- s.nWritten += int64(len(blk.output))
- s.eofWritten = true
- }
- return s.err
- }
- // SYNC:
- if e.o.concurrent == 1 {
- src := s.filling
- s.nInput += int64(len(s.filling))
- if debugEncoder {
- println("Adding sync block,", len(src), "bytes, final:", final)
- }
- enc := s.encoder
- blk := enc.Block()
- blk.reset(nil)
- enc.Encode(blk, src)
- blk.last = final
- if final {
- s.eofWritten = true
- }
- s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
- if s.err != nil {
- return s.err
- }
- _, s.err = s.w.Write(blk.output)
- s.nWritten += int64(len(blk.output))
- s.filling = s.filling[:0]
- return s.err
- }
- // Move blocks forward.
- s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
- s.nInput += int64(len(s.current))
- s.wg.Add(1)
- go func(src []byte) {
- if debugEncoder {
- println("Adding block,", len(src), "bytes, final:", final)
- }
- defer func() {
- if r := recover(); r != nil {
- s.err = fmt.Errorf("panic while encoding: %v", r)
- rdebug.PrintStack()
- }
- s.wg.Done()
- }()
- enc := s.encoder
- blk := enc.Block()
- enc.Encode(blk, src)
- blk.last = final
- if final {
- s.eofWritten = true
- }
- // Wait for pending writes.
- s.wWg.Wait()
- if s.writeErr != nil {
- s.err = s.writeErr
- return
- }
- // Transfer encoders from previous write block.
- blk.swapEncoders(s.writing)
- // Transfer recent offsets to next.
- enc.UseBlock(s.writing)
- s.writing = blk
- s.wWg.Add(1)
- go func() {
- defer func() {
- if r := recover(); r != nil {
- s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
- rdebug.PrintStack()
- }
- s.wWg.Done()
- }()
- s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
- if s.writeErr != nil {
- return
- }
- _, s.writeErr = s.w.Write(blk.output)
- s.nWritten += int64(len(blk.output))
- }()
- }(s.current)
- return nil
- }
- // ReadFrom reads data from r until EOF or error.
- // The return value n is the number of bytes read.
- // Any error except io.EOF encountered during the read is also returned.
- //
- // The Copy function uses ReaderFrom if available.
- func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
- if debugEncoder {
- println("Using ReadFrom")
- }
- // Flush any current writes.
- if len(e.state.filling) > 0 {
- if err := e.nextBlock(false); err != nil {
- return 0, err
- }
- }
- e.state.filling = e.state.filling[:e.o.blockSize]
- src := e.state.filling
- for {
- n2, err := r.Read(src)
- if e.o.crc {
- _, _ = e.state.encoder.CRC().Write(src[:n2])
- }
- // src is now the unfilled part...
- src = src[n2:]
- n += int64(n2)
- switch err {
- case io.EOF:
- e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
- if debugEncoder {
- println("ReadFrom: got EOF final block:", len(e.state.filling))
- }
- return n, nil
- case nil:
- default:
- if debugEncoder {
- println("ReadFrom: got error:", err)
- }
- e.state.err = err
- return n, err
- }
- if len(src) > 0 {
- if debugEncoder {
- println("ReadFrom: got space left in source:", len(src))
- }
- continue
- }
- err = e.nextBlock(false)
- if err != nil {
- return n, err
- }
- e.state.filling = e.state.filling[:e.o.blockSize]
- src = e.state.filling
- }
- }
- // Flush will send the currently written data to output
- // and block until everything has been written.
- // This should only be used on rare occasions where pushing the currently queued data is critical.
- func (e *Encoder) Flush() error {
- s := &e.state
- if len(s.filling) > 0 {
- err := e.nextBlock(false)
- if err != nil {
- return err
- }
- }
- s.wg.Wait()
- s.wWg.Wait()
- if s.err != nil {
- return s.err
- }
- return s.writeErr
- }
- // Close will flush the final output and close the stream.
- // The function will block until everything has been written.
- // The Encoder can still be re-used after calling this.
- func (e *Encoder) Close() error {
- s := &e.state
- if s.encoder == nil {
- return nil
- }
- err := e.nextBlock(true)
- if err != nil {
- return err
- }
- if s.frameContentSize > 0 {
- if s.nInput != s.frameContentSize {
- return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
- }
- }
- if e.state.fullFrameWritten {
- return s.err
- }
- s.wg.Wait()
- s.wWg.Wait()
- if s.err != nil {
- return s.err
- }
- if s.writeErr != nil {
- return s.writeErr
- }
- // Write CRC
- if e.o.crc && s.err == nil {
- // heap alloc.
- var tmp [4]byte
- _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
- s.nWritten += 4
- }
- // Add padding with content from crypto/rand.Reader
- if s.err == nil && e.o.pad > 0 {
- add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
- frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
- if err != nil {
- return err
- }
- _, s.err = s.w.Write(frame)
- }
- return s.err
- }
- // EncodeAll will encode all input in src and append it to dst.
- // This function can be called concurrently, but each call will only run on a single goroutine.
- // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
- // Encoded blocks can be concatenated and the result will be the combined input stream.
- // Data compressed with EncodeAll can be decoded with the Decoder,
- // using either a stream or DecodeAll.
- func (e *Encoder) EncodeAll(src, dst []byte) []byte {
- if len(src) == 0 {
- if e.o.fullZero {
- // Add frame header.
- fh := frameHeader{
- ContentSize: 0,
- WindowSize: MinWindowSize,
- SingleSegment: true,
- // Adding a checksum would be a waste of space.
- Checksum: false,
- DictID: 0,
- }
- dst, _ = fh.appendTo(dst)
- // Write raw block as last one only.
- var blk blockHeader
- blk.setSize(0)
- blk.setType(blockTypeRaw)
- blk.setLast(true)
- dst = blk.appendTo(dst)
- }
- return dst
- }
- e.init.Do(e.initialize)
- enc := <-e.encoders
- defer func() {
- // Release encoder reference to last block.
- // If a non-single block is needed the encoder will reset again.
- e.encoders <- enc
- }()
- // Use single segments when above minimum window and below window size.
- single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
- if e.o.single != nil {
- single = *e.o.single
- }
- fh := frameHeader{
- ContentSize: uint64(len(src)),
- WindowSize: uint32(enc.WindowSize(int64(len(src)))),
- SingleSegment: single,
- Checksum: e.o.crc,
- DictID: e.o.dict.ID(),
- }
- // If less than 1MB, allocate a buffer up front.
- if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
- dst = make([]byte, 0, len(src))
- }
- dst, err := fh.appendTo(dst)
- if err != nil {
- panic(err)
- }
- // If we can do everything in one block, prefer that.
- if len(src) <= e.o.blockSize {
- enc.Reset(e.o.dict, true)
- // Slightly faster with no history and everything in one block.
- if e.o.crc {
- _, _ = enc.CRC().Write(src)
- }
- blk := enc.Block()
- blk.last = true
- if e.o.dict == nil {
- enc.EncodeNoHist(blk, src)
- } else {
- enc.Encode(blk, src)
- }
- // If we got the exact same number of literals as input,
- // assume the literals cannot be compressed.
- oldout := blk.output
- // Output directly to dst
- blk.output = dst
- err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
- if err != nil {
- panic(err)
- }
- dst = blk.output
- blk.output = oldout
- } else {
- enc.Reset(e.o.dict, false)
- blk := enc.Block()
- for len(src) > 0 {
- todo := src
- if len(todo) > e.o.blockSize {
- todo = todo[:e.o.blockSize]
- }
- src = src[len(todo):]
- if e.o.crc {
- _, _ = enc.CRC().Write(todo)
- }
- blk.pushOffsets()
- enc.Encode(blk, todo)
- if len(src) == 0 {
- blk.last = true
- }
- err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
- if err != nil {
- panic(err)
- }
- dst = append(dst, blk.output...)
- blk.reset(nil)
- }
- }
- if e.o.crc {
- dst = enc.AppendCRC(dst)
- }
- // Add padding with content from crypto/rand.Reader
- if e.o.pad > 0 {
- add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
- dst, err = skippableFrame(dst, add, rand.Reader)
- if err != nil {
- panic(err)
- }
- }
- return dst
- }
- // MaxEncodedSize returns the expected maximum
- // size of an encoded block or stream.
- func (e *Encoder) MaxEncodedSize(size int) int {
- frameHeader := 4 + 2 // magic + frame header & window descriptor
- if e.o.dict != nil {
- frameHeader += 4
- }
- // Frame content size:
- if size < 256 {
- frameHeader++
- } else if size < 65536+256 {
- frameHeader += 2
- } else if size < math.MaxInt32 {
- frameHeader += 4
- } else {
- frameHeader += 8
- }
- // Final crc
- if e.o.crc {
- frameHeader += 4
- }
- // Max overhead is 3 bytes/block.
- // There cannot be 0 blocks.
- blocks := (size + e.o.blockSize) / e.o.blockSize
- // Combine, add padding.
- maxSz := frameHeader + 3*blocks + size
- if e.o.pad > 1 {
- maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
- }
- return maxSz
- }
|