dict_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. // Copyright (c) 2023+ Klaus Post. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package s2
  5. import (
  6. "archive/tar"
  7. "bytes"
  8. "compress/gzip"
  9. "fmt"
  10. "io"
  11. "math/rand"
  12. "os"
  13. "testing"
  14. "github.com/klauspost/compress/internal/fuzz"
  15. "github.com/klauspost/compress/zstd"
  16. )
  17. func TestDict(t *testing.T) {
  18. rng := rand.New(rand.NewSource(1))
  19. data := make([]byte, 128<<10)
  20. for i := range data {
  21. data[i] = uint8(rng.Intn(256))
  22. }
  23. // Should match the first 64K
  24. d := NewDict(append([]byte{0}, data[:65536]...))
  25. encoded := make([]byte, MaxEncodedLen(len(data)))
  26. res := encodeBlockDictGo(encoded, data, d)
  27. if res == 0 || res > len(data)-65500 {
  28. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  29. }
  30. encoded = encoded[:res]
  31. t.Log("saved", len(data)-res, "bytes")
  32. decoded := make([]byte, len(data))
  33. res = s2DecodeDict(decoded, encoded, d)
  34. if res != 0 {
  35. t.Fatalf("got result: %d", res)
  36. }
  37. if !bytes.Equal(decoded, data) {
  38. //os.WriteFile("decoded.bin", decoded, os.ModePerm)
  39. //os.WriteFile("original.bin", data, os.ModePerm)
  40. t.Fatal("decoded mismatch")
  41. }
  42. // Add dict that will produce a full match 5000 chars into the input.
  43. d = NewDict(append([]byte{0}, data[5000:65536+5000]...))
  44. encoded = make([]byte, MaxEncodedLen(len(data)))
  45. res = encodeBlockDictGo(encoded, data, d)
  46. if res == 0 || res > len(data)-65500 {
  47. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  48. }
  49. encoded = encoded[:res]
  50. t.Log("saved", len(data)-res, "bytes")
  51. decoded = make([]byte, len(data))
  52. res = s2DecodeDict(decoded, encoded, d)
  53. if res != 0 {
  54. t.Fatalf("got result: %d", res)
  55. }
  56. if !bytes.Equal(decoded, data) {
  57. //os.WriteFile("decoded.bin", decoded, os.ModePerm)
  58. //os.WriteFile("original.bin", data, os.ModePerm)
  59. t.Fatal("decoded mismatch")
  60. }
  61. // generate copies
  62. for i := 1; i < len(data); {
  63. n := rng.Intn(32) + 4
  64. off := rng.Intn(len(data) - n)
  65. copy(data[i:], data[off:off+n])
  66. i += n
  67. }
  68. dict := make([]byte, 65536)
  69. for i := 1; i < len(dict); {
  70. n := rng.Intn(32) + 4
  71. off := rng.Intn(65536 - n)
  72. copy(dict[i:], data[off:off+n])
  73. i += n
  74. }
  75. d = NewDict(dict)
  76. encoded = make([]byte, MaxEncodedLen(len(data)))
  77. res = encodeBlockDictGo(encoded, data, d)
  78. if res == 0 || res > len(data)-20000 {
  79. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  80. }
  81. encoded = encoded[:res]
  82. t.Log("saved", len(data)-res, "bytes")
  83. decoded = make([]byte, len(data))
  84. res = s2DecodeDict(decoded, encoded, d)
  85. if res != 0 {
  86. t.Fatalf("got result: %d", res)
  87. }
  88. if !bytes.Equal(decoded, data) {
  89. os.WriteFile("decoded.bin", decoded, os.ModePerm)
  90. os.WriteFile("original.bin", data, os.ModePerm)
  91. t.Fatal("decoded mismatch")
  92. }
  93. }
  94. func TestDictBetter(t *testing.T) {
  95. rng := rand.New(rand.NewSource(1))
  96. data := make([]byte, 128<<10)
  97. for i := range data {
  98. data[i] = uint8(rng.Intn(256))
  99. }
  100. // Should match the first 64K
  101. d := NewDict(append([]byte{0}, data[:65536]...))
  102. encoded := make([]byte, MaxEncodedLen(len(data)))
  103. res := encodeBlockBetterDict(encoded, data, d)
  104. if res == 0 || res > len(data)-65500 {
  105. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  106. }
  107. encoded = encoded[:res]
  108. t.Log("saved", len(data)-res, "bytes")
  109. decoded := make([]byte, len(data))
  110. res = s2DecodeDict(decoded, encoded, d)
  111. if res != 0 {
  112. t.Fatalf("got result: %d", res)
  113. }
  114. if !bytes.Equal(decoded, data) {
  115. //os.WriteFile("decoded.bin", decoded, os.ModePerm)
  116. //os.WriteFile("original.bin", data, os.ModePerm)
  117. t.Fatal("decoded mismatch")
  118. }
  119. // Add dict that will produce a full match 5000 chars into the input.
  120. d = NewDict(append([]byte{0}, data[5000:65536+5000]...))
  121. encoded = make([]byte, MaxEncodedLen(len(data)))
  122. res = encodeBlockBetterDict(encoded, data, d)
  123. if res == 0 || res > len(data)-65500 {
  124. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  125. }
  126. encoded = encoded[:res]
  127. t.Log("saved", len(data)-res, "bytes")
  128. decoded = make([]byte, len(data))
  129. res = s2DecodeDict(decoded, encoded, d)
  130. if res != 0 {
  131. t.Fatalf("got result: %d", res)
  132. }
  133. if !bytes.Equal(decoded, data) {
  134. //os.WriteFile("decoded.bin", decoded, os.ModePerm)
  135. //os.WriteFile("original.bin", data, os.ModePerm)
  136. t.Fatal("decoded mismatch")
  137. }
  138. // generate copies
  139. for i := 1; i < len(data); {
  140. n := rng.Intn(32) + 4
  141. off := rng.Intn(len(data) - n)
  142. copy(data[i:], data[off:off+n])
  143. i += n
  144. }
  145. dict := make([]byte, 65536)
  146. for i := 1; i < len(dict); {
  147. n := rng.Intn(32) + 4
  148. off := rng.Intn(65536 - n)
  149. copy(dict[i:], data[off:off+n])
  150. i += n
  151. }
  152. d = NewDict(dict)
  153. encoded = make([]byte, MaxEncodedLen(len(data)))
  154. res = encodeBlockBetterDict(encoded, data, d)
  155. if res == 0 || res > len(data)-20000 {
  156. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  157. }
  158. encoded = encoded[:res]
  159. t.Log("saved", len(data)-res, "bytes")
  160. decoded = make([]byte, len(data))
  161. res = s2DecodeDict(decoded, encoded, d)
  162. if res != 0 {
  163. t.Fatalf("got result: %d", res)
  164. }
  165. if !bytes.Equal(decoded, data) {
  166. os.WriteFile("decoded.bin", decoded, os.ModePerm)
  167. os.WriteFile("original.bin", data, os.ModePerm)
  168. t.Fatal("decoded mismatch")
  169. }
  170. }
  171. func TestDictBest(t *testing.T) {
  172. rng := rand.New(rand.NewSource(1))
  173. data := make([]byte, 128<<10)
  174. for i := range data {
  175. data[i] = uint8(rng.Intn(256))
  176. }
  177. // Should match the first 64K
  178. d := NewDict(append([]byte{0}, data[:65536]...))
  179. encoded := make([]byte, MaxEncodedLen(len(data)))
  180. res := encodeBlockBest(encoded, data, d)
  181. if res == 0 || res > len(data)-65500 {
  182. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  183. }
  184. encoded = encoded[:res]
  185. t.Log("saved", len(data)-res, "bytes")
  186. decoded := make([]byte, len(data))
  187. res = s2DecodeDict(decoded, encoded, d)
  188. if res != 0 {
  189. t.Fatalf("got result: %d", res)
  190. }
  191. if !bytes.Equal(decoded, data) {
  192. //os.WriteFile("decoded.bin", decoded, os.ModePerm)
  193. //os.WriteFile("original.bin", data, os.ModePerm)
  194. t.Fatal("decoded mismatch")
  195. }
  196. // Add dict that will produce a full match 5000 chars into the input.
  197. d = NewDict(append([]byte{0}, data[5000:65536+5000]...))
  198. encoded = make([]byte, MaxEncodedLen(len(data)))
  199. res = encodeBlockBest(encoded, data, d)
  200. if res == 0 || res > len(data)-65500 {
  201. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  202. }
  203. encoded = encoded[:res]
  204. t.Log("saved", len(data)-res, "bytes")
  205. decoded = make([]byte, len(data))
  206. res = s2DecodeDict(decoded, encoded, d)
  207. if res != 0 {
  208. t.Fatalf("got result: %d", res)
  209. }
  210. if !bytes.Equal(decoded, data) {
  211. //os.WriteFile("decoded.bin", decoded, os.ModePerm)
  212. //os.WriteFile("original.bin", data, os.ModePerm)
  213. t.Fatal("decoded mismatch")
  214. }
  215. // generate copies
  216. for i := 1; i < len(data); {
  217. n := rng.Intn(32) + 4
  218. off := rng.Intn(len(data) - n)
  219. copy(data[i:], data[off:off+n])
  220. i += n
  221. }
  222. dict := make([]byte, 65536)
  223. for i := 1; i < len(dict); {
  224. n := rng.Intn(32) + 4
  225. off := rng.Intn(65536 - n)
  226. copy(dict[i:], data[off:off+n])
  227. i += n
  228. }
  229. d = NewDict(dict)
  230. encoded = make([]byte, MaxEncodedLen(len(data)))
  231. res = encodeBlockBest(encoded, data, d)
  232. if res == 0 || res > len(data)-20000 {
  233. t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res)
  234. }
  235. encoded = encoded[:res]
  236. t.Log("saved", len(data)-res, "bytes")
  237. decoded = make([]byte, len(data))
  238. res = s2DecodeDict(decoded, encoded, d)
  239. if res != 0 {
  240. t.Fatalf("got result: %d", res)
  241. }
  242. if !bytes.Equal(decoded, data) {
  243. os.WriteFile("decoded.bin", decoded, os.ModePerm)
  244. os.WriteFile("original.bin", data, os.ModePerm)
  245. t.Fatal("decoded mismatch")
  246. }
  247. }
  248. func TestDictBetter2(t *testing.T) {
  249. // Should match the first 64K
  250. data := []byte("10 bananas which were brown were added")
  251. d := NewDict(append([]byte{6}, []byte("Yesterday 25 bananas were added to Benjamins brown bag")...))
  252. encoded := make([]byte, MaxEncodedLen(len(data)))
  253. res := encodeBlockBetterDict(encoded, data, d)
  254. encoded = encoded[:res]
  255. t.Log("saved", len(data)-res, "bytes")
  256. t.Log(string(encoded))
  257. decoded := make([]byte, len(data))
  258. res = s2DecodeDict(decoded, encoded, d)
  259. if res != 0 {
  260. t.Fatalf("got result: %d", res)
  261. }
  262. if !bytes.Equal(decoded, data) {
  263. //os.WriteFile("decoded.bin", decoded, os.ModePerm)
  264. //os.WriteFile("original.bin", data, os.ModePerm)
  265. t.Fatal("decoded mismatch")
  266. }
  267. }
  268. func TestDictBest2(t *testing.T) {
  269. // Should match the first 64K
  270. data := []byte("10 bananas which were brown were added")
  271. d := NewDict(append([]byte{6}, []byte("Yesterday 25 bananas were added to Benjamins brown bag")...))
  272. encoded := make([]byte, MaxEncodedLen(len(data)))
  273. res := encodeBlockBest(encoded, data, d)
  274. encoded = encoded[:res]
  275. t.Log("saved", len(data)-res, "bytes")
  276. t.Log(string(encoded))
  277. decoded := make([]byte, len(data))
  278. res = s2DecodeDict(decoded, encoded, d)
  279. if res != 0 {
  280. t.Fatalf("got result: %d", res)
  281. }
  282. if !bytes.Equal(decoded, data) {
  283. //os.WriteFile("decoded.bin", decoded, os.ModePerm)
  284. //os.WriteFile("original.bin", data, os.ModePerm)
  285. t.Fatal("decoded mismatch")
  286. }
  287. }
  288. func TestDictSize(t *testing.T) {
  289. //f, err := os.Open("testdata/xlmeta.tar.s2")
  290. //f, err := os.Open("testdata/broken.tar.s2")
  291. f, err := os.Open("testdata/github_users_sample_set.tar.s2")
  292. //f, err := os.Open("testdata/gofiles2.tar.s2")
  293. //f, err := os.Open("testdata/gosrc.tar.s2")
  294. if err != nil {
  295. t.Skip(err)
  296. }
  297. stream := NewReader(f)
  298. in := tar.NewReader(stream)
  299. //rawDict, err := os.ReadFile("testdata/godict.dictator")
  300. rawDict, err := os.ReadFile("testdata/gofiles.dict")
  301. //rawDict, err := os.ReadFile("testdata/gosrc2.dict")
  302. //rawDict, err := os.ReadFile("testdata/td.dict")
  303. //rawDict, err := os.ReadFile("testdata/users.dict")
  304. //rawDict, err := os.ReadFile("testdata/xlmeta.dict")
  305. if err != nil {
  306. t.Fatal(err)
  307. }
  308. lidx := -1
  309. if di, err := zstd.InspectDictionary(rawDict); err == nil {
  310. rawDict = di.Content()
  311. lidx = len(rawDict) - di.Offsets()[0]
  312. } else {
  313. t.Errorf("Loading dict: %v", err)
  314. return
  315. }
  316. searchFor := ""
  317. if false {
  318. searchFor = "// Copyright 2022"
  319. }
  320. d := MakeDict(rawDict, []byte(searchFor))
  321. if d == nil {
  322. t.Fatal("no dict", lidx)
  323. }
  324. var totalIn int
  325. var totalOut int
  326. var totalCount int
  327. for {
  328. h, err := in.Next()
  329. if err != nil {
  330. break
  331. }
  332. if h.Size == 0 {
  333. continue
  334. }
  335. data := make([]byte, 65536)
  336. t.Run(h.Name, func(t *testing.T) {
  337. if int(h.Size) < 65536 {
  338. data = data[:h.Size]
  339. } else {
  340. data = data[:65536]
  341. }
  342. _, err := io.ReadFull(in, data)
  343. if err != nil {
  344. t.Skip()
  345. }
  346. if d == nil {
  347. // Use first file as dict
  348. d = MakeDict(data, nil)
  349. }
  350. // encode
  351. encoded := make([]byte, MaxEncodedLen(len(data)))
  352. totalIn += len(data)
  353. totalCount++
  354. //res := encodeBlockBest(encoded, data, nil)
  355. res := encodeBlockBest(encoded, data, d)
  356. //res := encodeBlockBetterDict(encoded, data, d)
  357. //res := encodeBlockBetterGo(encoded, data)
  358. //res := encodeBlockDictGo(encoded, data, d)
  359. // res := encodeBlockGo(encoded, data)
  360. if res == 0 {
  361. totalOut += len(data)
  362. return
  363. }
  364. totalOut += res
  365. encoded = encoded[:res]
  366. //t.Log("encoded", len(data), "->", res, "saved", len(data)-res, "bytes")
  367. decoded := make([]byte, len(data))
  368. res = s2DecodeDict(decoded, encoded, d)
  369. if res != 0 {
  370. t.Fatalf("got result: %d", res)
  371. }
  372. if !bytes.Equal(decoded, data) {
  373. os.WriteFile("decoded.bin", decoded, os.ModePerm)
  374. os.WriteFile("original.bin", data, os.ModePerm)
  375. t.Fatal("decoded mismatch")
  376. }
  377. })
  378. }
  379. fmt.Printf("%d files, %d -> %d (%.2f%%) - %.02f bytes saved/file\n", totalCount, totalIn, totalOut, float64(totalOut*100)/float64(totalIn), float64(totalIn-totalOut)/float64(totalCount))
  380. }
  381. func FuzzDictBlocks(f *testing.F) {
  382. fuzz.AddFromZip(f, "testdata/enc_regressions.zip", fuzz.TypeRaw, false)
  383. fuzz.AddFromZip(f, "testdata/fuzz/block-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
  384. fuzz.AddFromZip(f, "testdata/fuzz/block-corpus-enc.zip", fuzz.TypeGoFuzz, testing.Short())
  385. // Fuzzing tweaks:
  386. const (
  387. // Max input size:
  388. maxSize = 8 << 20
  389. )
  390. file, err := os.Open("testdata/s2-dict.bin.gz")
  391. if err != nil {
  392. f.Fatal(err)
  393. }
  394. gzr, err := gzip.NewReader(file)
  395. if err != nil {
  396. f.Fatal(err)
  397. }
  398. dictBytes, err := io.ReadAll(gzr)
  399. if err != nil {
  400. f.Fatal(err)
  401. }
  402. dict := NewDict(dictBytes)
  403. if dict == nil {
  404. f.Fatal("invalid dict")
  405. }
  406. f.Fuzz(func(t *testing.T, data []byte) {
  407. if len(data) > maxSize {
  408. return
  409. }
  410. writeDst := make([]byte, MaxEncodedLen(len(data)), MaxEncodedLen(len(data))+4)
  411. writeDst = append(writeDst, 1, 2, 3, 4)
  412. defer func() {
  413. got := writeDst[MaxEncodedLen(len(data)):]
  414. want := []byte{1, 2, 3, 4}
  415. if !bytes.Equal(got, want) {
  416. t.Fatalf("want %v, got %v - dest modified outside cap", want, got)
  417. }
  418. }()
  419. compDst := writeDst[:MaxEncodedLen(len(data)):MaxEncodedLen(len(data))] // Hard cap
  420. decDst := make([]byte, len(data))
  421. comp := dict.Encode(compDst, data)
  422. decoded, err := dict.Decode(decDst, comp)
  423. if err != nil {
  424. t.Error(err)
  425. return
  426. }
  427. if !bytes.Equal(data, decoded) {
  428. t.Error("block decoder mismatch")
  429. return
  430. }
  431. if mel := MaxEncodedLen(len(data)); len(comp) > mel {
  432. t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp)))
  433. return
  434. }
  435. comp = dict.EncodeBetter(compDst, data)
  436. decoded, err = dict.Decode(decDst, comp)
  437. if err != nil {
  438. t.Error(err)
  439. return
  440. }
  441. if !bytes.Equal(data, decoded) {
  442. t.Error("block decoder mismatch")
  443. return
  444. }
  445. if mel := MaxEncodedLen(len(data)); len(comp) > mel {
  446. t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp)))
  447. return
  448. }
  449. comp = dict.EncodeBest(compDst, data)
  450. decoded, err = dict.Decode(decDst, comp)
  451. if err != nil {
  452. t.Error(err)
  453. return
  454. }
  455. if !bytes.Equal(data, decoded) {
  456. t.Error("block decoder mismatch")
  457. return
  458. }
  459. if mel := MaxEncodedLen(len(data)); len(comp) > mel {
  460. t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp)))
  461. return
  462. }
  463. })
  464. }