decompress.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. package fse
  2. import (
  3. "errors"
  4. "fmt"
  5. )
  6. const (
  7. tablelogAbsoluteMax = 15
  8. )
  9. // Decompress a block of data.
  10. // You can provide a scratch buffer to avoid allocations.
  11. // If nil is provided a temporary one will be allocated.
  12. // It is possible, but by no way guaranteed that corrupt data will
  13. // return an error.
  14. // It is up to the caller to verify integrity of the returned data.
  15. // Use a predefined Scrach to set maximum acceptable output size.
  16. func Decompress(b []byte, s *Scratch) ([]byte, error) {
  17. s, err := s.prepare(b)
  18. if err != nil {
  19. return nil, err
  20. }
  21. s.Out = s.Out[:0]
  22. err = s.readNCount()
  23. if err != nil {
  24. return nil, err
  25. }
  26. err = s.buildDtable()
  27. if err != nil {
  28. return nil, err
  29. }
  30. err = s.decompress()
  31. if err != nil {
  32. return nil, err
  33. }
  34. return s.Out, nil
  35. }
  36. // readNCount will read the symbol distribution so decoding tables can be constructed.
  37. func (s *Scratch) readNCount() error {
  38. var (
  39. charnum uint16
  40. previous0 bool
  41. b = &s.br
  42. )
  43. iend := b.remain()
  44. if iend < 4 {
  45. return errors.New("input too small")
  46. }
  47. bitStream := b.Uint32()
  48. nbBits := uint((bitStream & 0xF) + minTablelog) // extract tableLog
  49. if nbBits > tablelogAbsoluteMax {
  50. return errors.New("tableLog too large")
  51. }
  52. bitStream >>= 4
  53. bitCount := uint(4)
  54. s.actualTableLog = uint8(nbBits)
  55. remaining := int32((1 << nbBits) + 1)
  56. threshold := int32(1 << nbBits)
  57. gotTotal := int32(0)
  58. nbBits++
  59. for remaining > 1 {
  60. if previous0 {
  61. n0 := charnum
  62. for (bitStream & 0xFFFF) == 0xFFFF {
  63. n0 += 24
  64. if b.off < iend-5 {
  65. b.advance(2)
  66. bitStream = b.Uint32() >> bitCount
  67. } else {
  68. bitStream >>= 16
  69. bitCount += 16
  70. }
  71. }
  72. for (bitStream & 3) == 3 {
  73. n0 += 3
  74. bitStream >>= 2
  75. bitCount += 2
  76. }
  77. n0 += uint16(bitStream & 3)
  78. bitCount += 2
  79. if n0 > maxSymbolValue {
  80. return errors.New("maxSymbolValue too small")
  81. }
  82. for charnum < n0 {
  83. s.norm[charnum&0xff] = 0
  84. charnum++
  85. }
  86. if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 {
  87. b.advance(bitCount >> 3)
  88. bitCount &= 7
  89. bitStream = b.Uint32() >> bitCount
  90. } else {
  91. bitStream >>= 2
  92. }
  93. }
  94. max := (2*(threshold) - 1) - (remaining)
  95. var count int32
  96. if (int32(bitStream) & (threshold - 1)) < max {
  97. count = int32(bitStream) & (threshold - 1)
  98. bitCount += nbBits - 1
  99. } else {
  100. count = int32(bitStream) & (2*threshold - 1)
  101. if count >= threshold {
  102. count -= max
  103. }
  104. bitCount += nbBits
  105. }
  106. count-- // extra accuracy
  107. if count < 0 {
  108. // -1 means +1
  109. remaining += count
  110. gotTotal -= count
  111. } else {
  112. remaining -= count
  113. gotTotal += count
  114. }
  115. s.norm[charnum&0xff] = int16(count)
  116. charnum++
  117. previous0 = count == 0
  118. for remaining < threshold {
  119. nbBits--
  120. threshold >>= 1
  121. }
  122. if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 {
  123. b.advance(bitCount >> 3)
  124. bitCount &= 7
  125. } else {
  126. bitCount -= (uint)(8 * (len(b.b) - 4 - b.off))
  127. b.off = len(b.b) - 4
  128. }
  129. bitStream = b.Uint32() >> (bitCount & 31)
  130. }
  131. s.symbolLen = charnum
  132. if s.symbolLen <= 1 {
  133. return fmt.Errorf("symbolLen (%d) too small", s.symbolLen)
  134. }
  135. if s.symbolLen > maxSymbolValue+1 {
  136. return fmt.Errorf("symbolLen (%d) too big", s.symbolLen)
  137. }
  138. if remaining != 1 {
  139. return fmt.Errorf("corruption detected (remaining %d != 1)", remaining)
  140. }
  141. if bitCount > 32 {
  142. return fmt.Errorf("corruption detected (bitCount %d > 32)", bitCount)
  143. }
  144. if gotTotal != 1<<s.actualTableLog {
  145. return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<<s.actualTableLog)
  146. }
  147. b.advance((bitCount + 7) >> 3)
  148. return nil
  149. }
  150. // decSymbol contains information about a state entry,
  151. // Including the state offset base, the output symbol and
  152. // the number of bits to read for the low part of the destination state.
  153. type decSymbol struct {
  154. newState uint16
  155. symbol uint8
  156. nbBits uint8
  157. }
  158. // allocDtable will allocate decoding tables if they are not big enough.
  159. func (s *Scratch) allocDtable() {
  160. tableSize := 1 << s.actualTableLog
  161. if cap(s.decTable) < tableSize {
  162. s.decTable = make([]decSymbol, tableSize)
  163. }
  164. s.decTable = s.decTable[:tableSize]
  165. if cap(s.ct.tableSymbol) < 256 {
  166. s.ct.tableSymbol = make([]byte, 256)
  167. }
  168. s.ct.tableSymbol = s.ct.tableSymbol[:256]
  169. if cap(s.ct.stateTable) < 256 {
  170. s.ct.stateTable = make([]uint16, 256)
  171. }
  172. s.ct.stateTable = s.ct.stateTable[:256]
  173. }
  174. // buildDtable will build the decoding table.
  175. func (s *Scratch) buildDtable() error {
  176. tableSize := uint32(1 << s.actualTableLog)
  177. highThreshold := tableSize - 1
  178. s.allocDtable()
  179. symbolNext := s.ct.stateTable[:256]
  180. // Init, lay down lowprob symbols
  181. s.zeroBits = false
  182. {
  183. largeLimit := int16(1 << (s.actualTableLog - 1))
  184. for i, v := range s.norm[:s.symbolLen] {
  185. if v == -1 {
  186. s.decTable[highThreshold].symbol = uint8(i)
  187. highThreshold--
  188. symbolNext[i] = 1
  189. } else {
  190. if v >= largeLimit {
  191. s.zeroBits = true
  192. }
  193. symbolNext[i] = uint16(v)
  194. }
  195. }
  196. }
  197. // Spread symbols
  198. {
  199. tableMask := tableSize - 1
  200. step := tableStep(tableSize)
  201. position := uint32(0)
  202. for ss, v := range s.norm[:s.symbolLen] {
  203. for i := 0; i < int(v); i++ {
  204. s.decTable[position].symbol = uint8(ss)
  205. position = (position + step) & tableMask
  206. for position > highThreshold {
  207. // lowprob area
  208. position = (position + step) & tableMask
  209. }
  210. }
  211. }
  212. if position != 0 {
  213. // position must reach all cells once, otherwise normalizedCounter is incorrect
  214. return errors.New("corrupted input (position != 0)")
  215. }
  216. }
  217. // Build Decoding table
  218. {
  219. tableSize := uint16(1 << s.actualTableLog)
  220. for u, v := range s.decTable {
  221. symbol := v.symbol
  222. nextState := symbolNext[symbol]
  223. symbolNext[symbol] = nextState + 1
  224. nBits := s.actualTableLog - byte(highBits(uint32(nextState)))
  225. s.decTable[u].nbBits = nBits
  226. newState := (nextState << nBits) - tableSize
  227. if newState >= tableSize {
  228. return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize)
  229. }
  230. if newState == uint16(u) && nBits == 0 {
  231. // Seems weird that this is possible with nbits > 0.
  232. return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u)
  233. }
  234. s.decTable[u].newState = newState
  235. }
  236. }
  237. return nil
  238. }
  239. // decompress will decompress the bitstream.
  240. // If the buffer is over-read an error is returned.
  241. func (s *Scratch) decompress() error {
  242. br := &s.bits
  243. if err := br.init(s.br.unread()); err != nil {
  244. return err
  245. }
  246. var s1, s2 decoder
  247. // Initialize and decode first state and symbol.
  248. s1.init(br, s.decTable, s.actualTableLog)
  249. s2.init(br, s.decTable, s.actualTableLog)
  250. // Use temp table to avoid bound checks/append penalty.
  251. var tmp = s.ct.tableSymbol[:256]
  252. var off uint8
  253. // Main part
  254. if !s.zeroBits {
  255. for br.off >= 8 {
  256. br.fillFast()
  257. tmp[off+0] = s1.nextFast()
  258. tmp[off+1] = s2.nextFast()
  259. br.fillFast()
  260. tmp[off+2] = s1.nextFast()
  261. tmp[off+3] = s2.nextFast()
  262. off += 4
  263. // When off is 0, we have overflowed and should write.
  264. if off == 0 {
  265. s.Out = append(s.Out, tmp...)
  266. if len(s.Out) >= s.DecompressLimit {
  267. return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit)
  268. }
  269. }
  270. }
  271. } else {
  272. for br.off >= 8 {
  273. br.fillFast()
  274. tmp[off+0] = s1.next()
  275. tmp[off+1] = s2.next()
  276. br.fillFast()
  277. tmp[off+2] = s1.next()
  278. tmp[off+3] = s2.next()
  279. off += 4
  280. if off == 0 {
  281. s.Out = append(s.Out, tmp...)
  282. // When off is 0, we have overflowed and should write.
  283. if len(s.Out) >= s.DecompressLimit {
  284. return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit)
  285. }
  286. }
  287. }
  288. }
  289. s.Out = append(s.Out, tmp[:off]...)
  290. // Final bits, a bit more expensive check
  291. for {
  292. if s1.finished() {
  293. s.Out = append(s.Out, s1.final(), s2.final())
  294. break
  295. }
  296. br.fill()
  297. s.Out = append(s.Out, s1.next())
  298. if s2.finished() {
  299. s.Out = append(s.Out, s2.final(), s1.final())
  300. break
  301. }
  302. s.Out = append(s.Out, s2.next())
  303. if len(s.Out) >= s.DecompressLimit {
  304. return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit)
  305. }
  306. }
  307. return br.close()
  308. }
  309. // decoder keeps track of the current state and updates it from the bitstream.
  310. type decoder struct {
  311. state uint16
  312. br *bitReader
  313. dt []decSymbol
  314. }
  315. // init will initialize the decoder and read the first state from the stream.
  316. func (d *decoder) init(in *bitReader, dt []decSymbol, tableLog uint8) {
  317. d.dt = dt
  318. d.br = in
  319. d.state = in.getBits(tableLog)
  320. }
  321. // next returns the next symbol and sets the next state.
  322. // At least tablelog bits must be available in the bit reader.
  323. func (d *decoder) next() uint8 {
  324. n := &d.dt[d.state]
  325. lowBits := d.br.getBits(n.nbBits)
  326. d.state = n.newState + lowBits
  327. return n.symbol
  328. }
  329. // finished returns true if all bits have been read from the bitstream
  330. // and the next state would require reading bits from the input.
  331. func (d *decoder) finished() bool {
  332. return d.br.finished() && d.dt[d.state].nbBits > 0
  333. }
  334. // final returns the current state symbol without decoding the next.
  335. func (d *decoder) final() uint8 {
  336. return d.dt[d.state].symbol
  337. }
  338. // nextFast returns the next symbol and sets the next state.
  339. // This can only be used if no symbols are 0 bits.
  340. // At least tablelog bits must be available in the bit reader.
  341. func (d *decoder) nextFast() uint8 {
  342. n := d.dt[d.state]
  343. lowBits := d.br.getBitsFast(n.nbBits)
  344. d.state = n.newState + lowBits
  345. return n.symbol
  346. }