seqdec.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. )
  10. type seq struct {
  11. litLen uint32
  12. matchLen uint32
  13. offset uint32
  14. // Codes are stored here for the encoder
  15. // so they only have to be looked up once.
  16. llCode, mlCode, ofCode uint8
  17. }
  18. type seqVals struct {
  19. ll, ml, mo int
  20. }
  21. func (s seq) String() string {
  22. if s.offset <= 3 {
  23. if s.offset == 0 {
  24. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset: INVALID (0)")
  25. }
  26. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset, " (repeat)")
  27. }
  28. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset-3, " (new)")
  29. }
  30. type seqCompMode uint8
  31. const (
  32. compModePredefined seqCompMode = iota
  33. compModeRLE
  34. compModeFSE
  35. compModeRepeat
  36. )
  37. type sequenceDec struct {
  38. // decoder keeps track of the current state and updates it from the bitstream.
  39. fse *fseDecoder
  40. state fseState
  41. repeat bool
  42. }
  43. // init the state of the decoder with input from stream.
  44. func (s *sequenceDec) init(br *bitReader) error {
  45. if s.fse == nil {
  46. return errors.New("sequence decoder not defined")
  47. }
  48. s.state.init(br, s.fse.actualTableLog, s.fse.dt[:1<<s.fse.actualTableLog])
  49. return nil
  50. }
  51. // sequenceDecs contains all 3 sequence decoders and their state.
  52. type sequenceDecs struct {
  53. litLengths sequenceDec
  54. offsets sequenceDec
  55. matchLengths sequenceDec
  56. prevOffset [3]int
  57. dict []byte
  58. literals []byte
  59. out []byte
  60. nSeqs int
  61. br *bitReader
  62. seqSize int
  63. windowSize int
  64. maxBits uint8
  65. maxSyncLen uint64
  66. }
  67. // initialize all 3 decoders from the stream input.
  68. func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) error {
  69. if err := s.litLengths.init(br); err != nil {
  70. return errors.New("litLengths:" + err.Error())
  71. }
  72. if err := s.offsets.init(br); err != nil {
  73. return errors.New("offsets:" + err.Error())
  74. }
  75. if err := s.matchLengths.init(br); err != nil {
  76. return errors.New("matchLengths:" + err.Error())
  77. }
  78. s.br = br
  79. s.prevOffset = hist.recentOffsets
  80. s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits
  81. s.windowSize = hist.windowSize
  82. s.out = out
  83. s.dict = nil
  84. if hist.dict != nil {
  85. s.dict = hist.dict.content
  86. }
  87. return nil
  88. }
  89. func (s *sequenceDecs) freeDecoders() {
  90. if f := s.litLengths.fse; f != nil && !f.preDefined {
  91. fseDecoderPool.Put(f)
  92. s.litLengths.fse = nil
  93. }
  94. if f := s.offsets.fse; f != nil && !f.preDefined {
  95. fseDecoderPool.Put(f)
  96. s.offsets.fse = nil
  97. }
  98. if f := s.matchLengths.fse; f != nil && !f.preDefined {
  99. fseDecoderPool.Put(f)
  100. s.matchLengths.fse = nil
  101. }
  102. }
  103. // execute will execute the decoded sequence with the provided history.
  104. // The sequence must be evaluated before being sent.
  105. func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
  106. if len(s.dict) == 0 {
  107. return s.executeSimple(seqs, hist)
  108. }
  109. // Ensure we have enough output size...
  110. if len(s.out)+s.seqSize > cap(s.out) {
  111. addBytes := s.seqSize + len(s.out)
  112. s.out = append(s.out, make([]byte, addBytes)...)
  113. s.out = s.out[:len(s.out)-addBytes]
  114. }
  115. if debugDecoder {
  116. printf("Execute %d seqs with hist %d, dict %d, literals: %d into %d bytes\n", len(seqs), len(hist), len(s.dict), len(s.literals), s.seqSize)
  117. }
  118. var t = len(s.out)
  119. out := s.out[:t+s.seqSize]
  120. for _, seq := range seqs {
  121. // Add literals
  122. copy(out[t:], s.literals[:seq.ll])
  123. t += seq.ll
  124. s.literals = s.literals[seq.ll:]
  125. // Copy from dictionary...
  126. if seq.mo > t+len(hist) || seq.mo > s.windowSize {
  127. if len(s.dict) == 0 {
  128. return fmt.Errorf("match offset (%d) bigger than current history (%d)", seq.mo, t+len(hist))
  129. }
  130. // we may be in dictionary.
  131. dictO := len(s.dict) - (seq.mo - (t + len(hist)))
  132. if dictO < 0 || dictO >= len(s.dict) {
  133. return fmt.Errorf("match offset (%d) bigger than current history+dict (%d)", seq.mo, t+len(hist)+len(s.dict))
  134. }
  135. end := dictO + seq.ml
  136. if end > len(s.dict) {
  137. n := len(s.dict) - dictO
  138. copy(out[t:], s.dict[dictO:])
  139. t += n
  140. seq.ml -= n
  141. } else {
  142. copy(out[t:], s.dict[dictO:end])
  143. t += end - dictO
  144. continue
  145. }
  146. }
  147. // Copy from history.
  148. if v := seq.mo - t; v > 0 {
  149. // v is the start position in history from end.
  150. start := len(hist) - v
  151. if seq.ml > v {
  152. // Some goes into current block.
  153. // Copy remainder of history
  154. copy(out[t:], hist[start:])
  155. t += v
  156. seq.ml -= v
  157. } else {
  158. copy(out[t:], hist[start:start+seq.ml])
  159. t += seq.ml
  160. continue
  161. }
  162. }
  163. // We must be in current buffer now
  164. if seq.ml > 0 {
  165. start := t - seq.mo
  166. if seq.ml <= t-start {
  167. // No overlap
  168. copy(out[t:], out[start:start+seq.ml])
  169. t += seq.ml
  170. continue
  171. } else {
  172. // Overlapping copy
  173. // Extend destination slice and copy one byte at the time.
  174. src := out[start : start+seq.ml]
  175. dst := out[t:]
  176. dst = dst[:len(src)]
  177. t += len(src)
  178. // Destination is the space we just added.
  179. for i := range src {
  180. dst[i] = src[i]
  181. }
  182. }
  183. }
  184. }
  185. // Add final literals
  186. copy(out[t:], s.literals)
  187. if debugDecoder {
  188. t += len(s.literals)
  189. if t != len(out) {
  190. panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize))
  191. }
  192. }
  193. s.out = out
  194. return nil
  195. }
  196. // decode sequences from the stream with the provided history.
  197. func (s *sequenceDecs) decodeSync(hist []byte) error {
  198. supported, err := s.decodeSyncSimple(hist)
  199. if supported {
  200. return err
  201. }
  202. br := s.br
  203. seqs := s.nSeqs
  204. startSize := len(s.out)
  205. // Grab full sizes tables, to avoid bounds checks.
  206. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize]
  207. llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
  208. out := s.out
  209. maxBlockSize := maxCompressedBlockSize
  210. if s.windowSize < maxBlockSize {
  211. maxBlockSize = s.windowSize
  212. }
  213. if debugDecoder {
  214. println("decodeSync: decoding", seqs, "sequences", br.remain(), "bits remain on stream")
  215. }
  216. for i := seqs - 1; i >= 0; i-- {
  217. if br.overread() {
  218. printf("reading sequence %d, exceeded available data. Overread by %d\n", seqs-i, -br.remain())
  219. return io.ErrUnexpectedEOF
  220. }
  221. var ll, mo, ml int
  222. if br.off > 4+((maxOffsetBits+16+16)>>3) {
  223. // inlined function:
  224. // ll, mo, ml = s.nextFast(br, llState, mlState, ofState)
  225. // Final will not read from stream.
  226. var llB, mlB, moB uint8
  227. ll, llB = llState.final()
  228. ml, mlB = mlState.final()
  229. mo, moB = ofState.final()
  230. // extra bits are stored in reverse order.
  231. br.fillFast()
  232. mo += br.getBits(moB)
  233. if s.maxBits > 32 {
  234. br.fillFast()
  235. }
  236. ml += br.getBits(mlB)
  237. ll += br.getBits(llB)
  238. if moB > 1 {
  239. s.prevOffset[2] = s.prevOffset[1]
  240. s.prevOffset[1] = s.prevOffset[0]
  241. s.prevOffset[0] = mo
  242. } else {
  243. // mo = s.adjustOffset(mo, ll, moB)
  244. // Inlined for rather big speedup
  245. if ll == 0 {
  246. // There is an exception though, when current sequence's literals_length = 0.
  247. // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
  248. // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
  249. mo++
  250. }
  251. if mo == 0 {
  252. mo = s.prevOffset[0]
  253. } else {
  254. var temp int
  255. if mo == 3 {
  256. temp = s.prevOffset[0] - 1
  257. } else {
  258. temp = s.prevOffset[mo]
  259. }
  260. if temp == 0 {
  261. // 0 is not valid; input is corrupted; force offset to 1
  262. println("WARNING: temp was 0")
  263. temp = 1
  264. }
  265. if mo != 1 {
  266. s.prevOffset[2] = s.prevOffset[1]
  267. }
  268. s.prevOffset[1] = s.prevOffset[0]
  269. s.prevOffset[0] = temp
  270. mo = temp
  271. }
  272. }
  273. br.fillFast()
  274. } else {
  275. ll, mo, ml = s.next(br, llState, mlState, ofState)
  276. br.fill()
  277. }
  278. if debugSequences {
  279. println("Seq", seqs-i-1, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml)
  280. }
  281. if ll > len(s.literals) {
  282. return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, len(s.literals))
  283. }
  284. size := ll + ml + len(out)
  285. if size-startSize > maxBlockSize {
  286. return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
  287. }
  288. if size > cap(out) {
  289. // Not enough size, which can happen under high volume block streaming conditions
  290. // but could be if destination slice is too small for sync operations.
  291. // over-allocating here can create a large amount of GC pressure so we try to keep
  292. // it as contained as possible
  293. used := len(out) - startSize
  294. addBytes := 256 + ll + ml + used>>2
  295. // Clamp to max block size.
  296. if used+addBytes > maxBlockSize {
  297. addBytes = maxBlockSize - used
  298. }
  299. out = append(out, make([]byte, addBytes)...)
  300. out = out[:len(out)-addBytes]
  301. }
  302. if ml > maxMatchLen {
  303. return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
  304. }
  305. // Add literals
  306. out = append(out, s.literals[:ll]...)
  307. s.literals = s.literals[ll:]
  308. if mo == 0 && ml > 0 {
  309. return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
  310. }
  311. if mo > len(out)+len(hist) || mo > s.windowSize {
  312. if len(s.dict) == 0 {
  313. return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(out)+len(hist)-startSize)
  314. }
  315. // we may be in dictionary.
  316. dictO := len(s.dict) - (mo - (len(out) + len(hist)))
  317. if dictO < 0 || dictO >= len(s.dict) {
  318. return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(out)+len(hist)-startSize)
  319. }
  320. end := dictO + ml
  321. if end > len(s.dict) {
  322. out = append(out, s.dict[dictO:]...)
  323. ml -= len(s.dict) - dictO
  324. } else {
  325. out = append(out, s.dict[dictO:end]...)
  326. mo = 0
  327. ml = 0
  328. }
  329. }
  330. // Copy from history.
  331. // TODO: Blocks without history could be made to ignore this completely.
  332. if v := mo - len(out); v > 0 {
  333. // v is the start position in history from end.
  334. start := len(hist) - v
  335. if ml > v {
  336. // Some goes into current block.
  337. // Copy remainder of history
  338. out = append(out, hist[start:]...)
  339. ml -= v
  340. } else {
  341. out = append(out, hist[start:start+ml]...)
  342. ml = 0
  343. }
  344. }
  345. // We must be in current buffer now
  346. if ml > 0 {
  347. start := len(out) - mo
  348. if ml <= len(out)-start {
  349. // No overlap
  350. out = append(out, out[start:start+ml]...)
  351. } else {
  352. // Overlapping copy
  353. // Extend destination slice and copy one byte at the time.
  354. out = out[:len(out)+ml]
  355. src := out[start : start+ml]
  356. // Destination is the space we just added.
  357. dst := out[len(out)-ml:]
  358. dst = dst[:len(src)]
  359. for i := range src {
  360. dst[i] = src[i]
  361. }
  362. }
  363. }
  364. if i == 0 {
  365. // This is the last sequence, so we shouldn't update state.
  366. break
  367. }
  368. // Manually inlined, ~ 5-20% faster
  369. // Update all 3 states at once. Approx 20% faster.
  370. nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits()
  371. if nBits == 0 {
  372. llState = llTable[llState.newState()&maxTableMask]
  373. mlState = mlTable[mlState.newState()&maxTableMask]
  374. ofState = ofTable[ofState.newState()&maxTableMask]
  375. } else {
  376. bits := br.get32BitsFast(nBits)
  377. lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
  378. llState = llTable[(llState.newState()+lowBits)&maxTableMask]
  379. lowBits = uint16(bits >> (ofState.nbBits() & 31))
  380. lowBits &= bitMask[mlState.nbBits()&15]
  381. mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask]
  382. lowBits = uint16(bits) & bitMask[ofState.nbBits()&15]
  383. ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask]
  384. }
  385. }
  386. if size := len(s.literals) + len(out) - startSize; size > maxBlockSize {
  387. return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
  388. }
  389. // Add final literals
  390. s.out = append(out, s.literals...)
  391. return br.close()
  392. }
  393. var bitMask [16]uint16
  394. func init() {
  395. for i := range bitMask[:] {
  396. bitMask[i] = uint16((1 << uint(i)) - 1)
  397. }
  398. }
  399. func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
  400. // Final will not read from stream.
  401. ll, llB := llState.final()
  402. ml, mlB := mlState.final()
  403. mo, moB := ofState.final()
  404. // extra bits are stored in reverse order.
  405. br.fill()
  406. if s.maxBits <= 32 {
  407. mo += br.getBits(moB)
  408. ml += br.getBits(mlB)
  409. ll += br.getBits(llB)
  410. } else {
  411. mo += br.getBits(moB)
  412. br.fill()
  413. // matchlength+literal length, max 32 bits
  414. ml += br.getBits(mlB)
  415. ll += br.getBits(llB)
  416. }
  417. mo = s.adjustOffset(mo, ll, moB)
  418. return
  419. }
  420. func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int {
  421. if offsetB > 1 {
  422. s.prevOffset[2] = s.prevOffset[1]
  423. s.prevOffset[1] = s.prevOffset[0]
  424. s.prevOffset[0] = offset
  425. return offset
  426. }
  427. if litLen == 0 {
  428. // There is an exception though, when current sequence's literals_length = 0.
  429. // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
  430. // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
  431. offset++
  432. }
  433. if offset == 0 {
  434. return s.prevOffset[0]
  435. }
  436. var temp int
  437. if offset == 3 {
  438. temp = s.prevOffset[0] - 1
  439. } else {
  440. temp = s.prevOffset[offset]
  441. }
  442. if temp == 0 {
  443. // 0 is not valid; input is corrupted; force offset to 1
  444. println("temp was 0")
  445. temp = 1
  446. }
  447. if offset != 1 {
  448. s.prevOffset[2] = s.prevOffset[1]
  449. }
  450. s.prevOffset[1] = s.prevOffset[0]
  451. s.prevOffset[0] = temp
  452. return temp
  453. }