decode.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. // Copyright 2011 The Snappy-Go Authors. All rights reserved.
  2. // Copyright (c) 2019 Klaus Post. All rights reserved.
  3. // Use of this source code is governed by a BSD-style
  4. // license that can be found in the LICENSE file.
  5. package s2
  6. import (
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "strconv"
  11. )
  12. var (
  13. // ErrCorrupt reports that the input is invalid.
  14. ErrCorrupt = errors.New("s2: corrupt input")
  15. // ErrCRC reports that the input failed CRC validation (streams only)
  16. ErrCRC = errors.New("s2: corrupt input, crc mismatch")
  17. // ErrTooLarge reports that the uncompressed length is too large.
  18. ErrTooLarge = errors.New("s2: decoded block is too large")
  19. // ErrUnsupported reports that the input isn't supported.
  20. ErrUnsupported = errors.New("s2: unsupported input")
  21. )
  22. // DecodedLen returns the length of the decoded block.
  23. func DecodedLen(src []byte) (int, error) {
  24. v, _, err := decodedLen(src)
  25. return v, err
  26. }
  27. // decodedLen returns the length of the decoded block and the number of bytes
  28. // that the length header occupied.
  29. func decodedLen(src []byte) (blockLen, headerLen int, err error) {
  30. v, n := binary.Uvarint(src)
  31. if n <= 0 || v > 0xffffffff {
  32. return 0, 0, ErrCorrupt
  33. }
  34. const wordSize = 32 << (^uint(0) >> 32 & 1)
  35. if wordSize == 32 && v > 0x7fffffff {
  36. return 0, 0, ErrTooLarge
  37. }
  38. return int(v), n, nil
  39. }
  40. const (
  41. decodeErrCodeCorrupt = 1
  42. )
  43. // Decode returns the decoded form of src. The returned slice may be a sub-
  44. // slice of dst if dst was large enough to hold the entire decoded block.
  45. // Otherwise, a newly allocated slice will be returned.
  46. //
  47. // The dst and src must not overlap. It is valid to pass a nil dst.
  48. func Decode(dst, src []byte) ([]byte, error) {
  49. dLen, s, err := decodedLen(src)
  50. if err != nil {
  51. return nil, err
  52. }
  53. if dLen <= cap(dst) {
  54. dst = dst[:dLen]
  55. } else {
  56. dst = make([]byte, dLen)
  57. }
  58. if s2Decode(dst, src[s:]) != 0 {
  59. return nil, ErrCorrupt
  60. }
  61. return dst, nil
  62. }
  63. // s2DecodeDict writes the decoding of src to dst. It assumes that the varint-encoded
  64. // length of the decompressed bytes has already been read, and that len(dst)
  65. // equals that length.
  66. //
  67. // It returns 0 on success or a decodeErrCodeXxx error code on failure.
  68. func s2DecodeDict(dst, src []byte, dict *Dict) int {
  69. if dict == nil {
  70. return s2Decode(dst, src)
  71. }
  72. const debug = false
  73. const debugErrs = debug
  74. if debug {
  75. fmt.Println("Starting decode, dst len:", len(dst))
  76. }
  77. var d, s, length int
  78. offset := len(dict.dict) - dict.repeat
  79. // As long as we can read at least 5 bytes...
  80. for s < len(src)-5 {
  81. // Removing bounds checks is SLOWER, when if doing
  82. // in := src[s:s+5]
  83. // Checked on Go 1.18
  84. switch src[s] & 0x03 {
  85. case tagLiteral:
  86. x := uint32(src[s] >> 2)
  87. switch {
  88. case x < 60:
  89. s++
  90. case x == 60:
  91. s += 2
  92. x = uint32(src[s-1])
  93. case x == 61:
  94. in := src[s : s+3]
  95. x = uint32(in[1]) | uint32(in[2])<<8
  96. s += 3
  97. case x == 62:
  98. in := src[s : s+4]
  99. // Load as 32 bit and shift down.
  100. x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
  101. x >>= 8
  102. s += 4
  103. case x == 63:
  104. in := src[s : s+5]
  105. x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24
  106. s += 5
  107. }
  108. length = int(x) + 1
  109. if debug {
  110. fmt.Println("literals, length:", length, "d-after:", d+length)
  111. }
  112. if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
  113. if debugErrs {
  114. fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
  115. }
  116. return decodeErrCodeCorrupt
  117. }
  118. copy(dst[d:], src[s:s+length])
  119. d += length
  120. s += length
  121. continue
  122. case tagCopy1:
  123. s += 2
  124. toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
  125. length = int(src[s-2]) >> 2 & 0x7
  126. if toffset == 0 {
  127. if debug {
  128. fmt.Print("(repeat) ")
  129. }
  130. // keep last offset
  131. switch length {
  132. case 5:
  133. length = int(src[s]) + 4
  134. s += 1
  135. case 6:
  136. in := src[s : s+2]
  137. length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8)
  138. s += 2
  139. case 7:
  140. in := src[s : s+3]
  141. length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16)
  142. s += 3
  143. default: // 0-> 4
  144. }
  145. } else {
  146. offset = toffset
  147. }
  148. length += 4
  149. case tagCopy2:
  150. in := src[s : s+3]
  151. offset = int(uint32(in[1]) | uint32(in[2])<<8)
  152. length = 1 + int(in[0])>>2
  153. s += 3
  154. case tagCopy4:
  155. in := src[s : s+5]
  156. offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24)
  157. length = 1 + int(in[0])>>2
  158. s += 5
  159. }
  160. if offset <= 0 || length > len(dst)-d {
  161. if debugErrs {
  162. fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
  163. }
  164. return decodeErrCodeCorrupt
  165. }
  166. // copy from dict
  167. if d < offset {
  168. if d > MaxDictSrcOffset {
  169. if debugErrs {
  170. fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
  171. }
  172. return decodeErrCodeCorrupt
  173. }
  174. startOff := len(dict.dict) - offset + d
  175. if startOff < 0 || startOff+length > len(dict.dict) {
  176. if debugErrs {
  177. fmt.Printf("offset (%d) + length (%d) bigger than dict (%d)\n", offset, length, len(dict.dict))
  178. }
  179. return decodeErrCodeCorrupt
  180. }
  181. if debug {
  182. fmt.Println("dict copy, length:", length, "offset:", offset, "d-after:", d+length, "dict start offset:", startOff)
  183. }
  184. copy(dst[d:d+length], dict.dict[startOff:])
  185. d += length
  186. continue
  187. }
  188. if debug {
  189. fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
  190. }
  191. // Copy from an earlier sub-slice of dst to a later sub-slice.
  192. // If no overlap, use the built-in copy:
  193. if offset > length {
  194. copy(dst[d:d+length], dst[d-offset:])
  195. d += length
  196. continue
  197. }
  198. // Unlike the built-in copy function, this byte-by-byte copy always runs
  199. // forwards, even if the slices overlap. Conceptually, this is:
  200. //
  201. // d += forwardCopy(dst[d:d+length], dst[d-offset:])
  202. //
  203. // We align the slices into a and b and show the compiler they are the same size.
  204. // This allows the loop to run without bounds checks.
  205. a := dst[d : d+length]
  206. b := dst[d-offset:]
  207. b = b[:len(a)]
  208. for i := range a {
  209. a[i] = b[i]
  210. }
  211. d += length
  212. }
  213. // Remaining with extra checks...
  214. for s < len(src) {
  215. switch src[s] & 0x03 {
  216. case tagLiteral:
  217. x := uint32(src[s] >> 2)
  218. switch {
  219. case x < 60:
  220. s++
  221. case x == 60:
  222. s += 2
  223. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  224. if debugErrs {
  225. fmt.Println("src went oob")
  226. }
  227. return decodeErrCodeCorrupt
  228. }
  229. x = uint32(src[s-1])
  230. case x == 61:
  231. s += 3
  232. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  233. if debugErrs {
  234. fmt.Println("src went oob")
  235. }
  236. return decodeErrCodeCorrupt
  237. }
  238. x = uint32(src[s-2]) | uint32(src[s-1])<<8
  239. case x == 62:
  240. s += 4
  241. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  242. if debugErrs {
  243. fmt.Println("src went oob")
  244. }
  245. return decodeErrCodeCorrupt
  246. }
  247. x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
  248. case x == 63:
  249. s += 5
  250. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  251. if debugErrs {
  252. fmt.Println("src went oob")
  253. }
  254. return decodeErrCodeCorrupt
  255. }
  256. x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
  257. }
  258. length = int(x) + 1
  259. if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
  260. if debugErrs {
  261. fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
  262. }
  263. return decodeErrCodeCorrupt
  264. }
  265. if debug {
  266. fmt.Println("literals, length:", length, "d-after:", d+length)
  267. }
  268. copy(dst[d:], src[s:s+length])
  269. d += length
  270. s += length
  271. continue
  272. case tagCopy1:
  273. s += 2
  274. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  275. if debugErrs {
  276. fmt.Println("src went oob")
  277. }
  278. return decodeErrCodeCorrupt
  279. }
  280. length = int(src[s-2]) >> 2 & 0x7
  281. toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
  282. if toffset == 0 {
  283. if debug {
  284. fmt.Print("(repeat) ")
  285. }
  286. // keep last offset
  287. switch length {
  288. case 5:
  289. s += 1
  290. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  291. if debugErrs {
  292. fmt.Println("src went oob")
  293. }
  294. return decodeErrCodeCorrupt
  295. }
  296. length = int(uint32(src[s-1])) + 4
  297. case 6:
  298. s += 2
  299. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  300. if debugErrs {
  301. fmt.Println("src went oob")
  302. }
  303. return decodeErrCodeCorrupt
  304. }
  305. length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8)
  306. case 7:
  307. s += 3
  308. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  309. if debugErrs {
  310. fmt.Println("src went oob")
  311. }
  312. return decodeErrCodeCorrupt
  313. }
  314. length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16)
  315. default: // 0-> 4
  316. }
  317. } else {
  318. offset = toffset
  319. }
  320. length += 4
  321. case tagCopy2:
  322. s += 3
  323. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  324. if debugErrs {
  325. fmt.Println("src went oob")
  326. }
  327. return decodeErrCodeCorrupt
  328. }
  329. length = 1 + int(src[s-3])>>2
  330. offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
  331. case tagCopy4:
  332. s += 5
  333. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  334. if debugErrs {
  335. fmt.Println("src went oob")
  336. }
  337. return decodeErrCodeCorrupt
  338. }
  339. length = 1 + int(src[s-5])>>2
  340. offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
  341. }
  342. if offset <= 0 || length > len(dst)-d {
  343. if debugErrs {
  344. fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
  345. }
  346. return decodeErrCodeCorrupt
  347. }
  348. // copy from dict
  349. if d < offset {
  350. if d > MaxDictSrcOffset {
  351. if debugErrs {
  352. fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
  353. }
  354. return decodeErrCodeCorrupt
  355. }
  356. rOff := len(dict.dict) - (offset - d)
  357. if debug {
  358. fmt.Println("starting dict entry from dict offset", len(dict.dict)-rOff)
  359. }
  360. if rOff+length > len(dict.dict) {
  361. if debugErrs {
  362. fmt.Println("err: END offset", rOff+length, "bigger than dict", len(dict.dict), "dict offset:", rOff, "length:", length)
  363. }
  364. return decodeErrCodeCorrupt
  365. }
  366. if rOff < 0 {
  367. if debugErrs {
  368. fmt.Println("err: START offset", rOff, "less than 0", len(dict.dict), "dict offset:", rOff, "length:", length)
  369. }
  370. return decodeErrCodeCorrupt
  371. }
  372. copy(dst[d:d+length], dict.dict[rOff:])
  373. d += length
  374. continue
  375. }
  376. if debug {
  377. fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
  378. }
  379. // Copy from an earlier sub-slice of dst to a later sub-slice.
  380. // If no overlap, use the built-in copy:
  381. if offset > length {
  382. copy(dst[d:d+length], dst[d-offset:])
  383. d += length
  384. continue
  385. }
  386. // Unlike the built-in copy function, this byte-by-byte copy always runs
  387. // forwards, even if the slices overlap. Conceptually, this is:
  388. //
  389. // d += forwardCopy(dst[d:d+length], dst[d-offset:])
  390. //
  391. // We align the slices into a and b and show the compiler they are the same size.
  392. // This allows the loop to run without bounds checks.
  393. a := dst[d : d+length]
  394. b := dst[d-offset:]
  395. b = b[:len(a)]
  396. for i := range a {
  397. a[i] = b[i]
  398. }
  399. d += length
  400. }
  401. if d != len(dst) {
  402. if debugErrs {
  403. fmt.Println("wanted length", len(dst), "got", d)
  404. }
  405. return decodeErrCodeCorrupt
  406. }
  407. return 0
  408. }