dict_test.go 14 KB


  1. package zstd
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "os"
  7. "strings"
  8. "testing"
  9. "github.com/klauspost/compress/zip"
  10. )
  11. func TestDecoder_SmallDict(t *testing.T) {
  12. // All files have CRC
  13. zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
  14. dicts := readDicts(t, zr)
  15. dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
  16. if err != nil {
  17. t.Fatal(err)
  18. return
  19. }
  20. defer dec.Close()
  21. for _, tt := range zr.File {
  22. if !strings.HasSuffix(tt.Name, ".zst") {
  23. continue
  24. }
  25. t.Run("decodeall-"+tt.Name, func(t *testing.T) {
  26. r, err := tt.Open()
  27. if err != nil {
  28. t.Fatal(err)
  29. }
  30. defer r.Close()
  31. in, err := io.ReadAll(r)
  32. if err != nil {
  33. t.Fatal(err)
  34. }
  35. got, err := dec.DecodeAll(in, nil)
  36. if err != nil {
  37. t.Fatal(err)
  38. }
  39. _, err = dec.DecodeAll(in, got[:0])
  40. if err != nil {
  41. t.Fatal(err)
  42. }
  43. })
  44. }
  45. }
  46. func TestEncoder_SmallDict(t *testing.T) {
  47. // All files have CRC
  48. zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
  49. var dicts [][]byte
  50. var encs []*Encoder
  51. var noDictEncs []*Encoder
  52. var encNames []string
  53. for _, tt := range zr.File {
  54. if !strings.HasSuffix(tt.Name, ".dict") {
  55. continue
  56. }
  57. func() {
  58. r, err := tt.Open()
  59. if err != nil {
  60. t.Fatal(err)
  61. }
  62. defer r.Close()
  63. in, err := io.ReadAll(r)
  64. if err != nil {
  65. t.Fatal(err)
  66. }
  67. dicts = append(dicts, in)
  68. for level := SpeedFastest; level < speedLast; level++ {
  69. if isRaceTest && level >= SpeedBestCompression {
  70. break
  71. }
  72. enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderDict(in), WithEncoderLevel(level), WithWindowSize(1<<17))
  73. if err != nil {
  74. t.Fatal(err)
  75. }
  76. encs = append(encs, enc)
  77. encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
  78. enc, err = NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithWindowSize(1<<17))
  79. if err != nil {
  80. t.Fatal(err)
  81. }
  82. noDictEncs = append(noDictEncs, enc)
  83. }
  84. }()
  85. }
  86. dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
  87. if err != nil {
  88. t.Fatal(err)
  89. return
  90. }
  91. defer dec.Close()
  92. for i, tt := range zr.File {
  93. if testing.Short() && i > 100 {
  94. break
  95. }
  96. if !strings.HasSuffix(tt.Name, ".zst") {
  97. continue
  98. }
  99. r, err := tt.Open()
  100. if err != nil {
  101. t.Fatal(err)
  102. }
  103. defer r.Close()
  104. in, err := io.ReadAll(r)
  105. if err != nil {
  106. t.Fatal(err)
  107. }
  108. decoded, err := dec.DecodeAll(in, nil)
  109. if err != nil {
  110. t.Fatal(err)
  111. }
  112. if testing.Short() && len(decoded) > 1000 {
  113. continue
  114. }
  115. t.Run("encodeall-"+tt.Name, func(t *testing.T) {
  116. // Attempt to compress with all dicts
  117. var b []byte
  118. var tmp []byte
  119. for i := range encs {
  120. i := i
  121. t.Run(encNames[i], func(t *testing.T) {
  122. b = encs[i].EncodeAll(decoded, b[:0])
  123. tmp, err = dec.DecodeAll(in, tmp[:0])
  124. if err != nil {
  125. t.Fatal(err)
  126. }
  127. if !bytes.Equal(tmp, decoded) {
  128. t.Fatal("output mismatch")
  129. }
  130. tmp = noDictEncs[i].EncodeAll(decoded, tmp[:0])
  131. if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
  132. t.Log("reference:", len(in), "no dict:", len(tmp), "with dict:", len(b), "SAVED:", len(tmp)-len(b))
  133. // Check that we reduced this significantly
  134. if len(b) > 250 {
  135. t.Error("output was bigger than expected")
  136. }
  137. }
  138. })
  139. }
  140. })
  141. t.Run("stream-"+tt.Name, func(t *testing.T) {
  142. // Attempt to compress with all dicts
  143. var tmp []byte
  144. for i := range encs {
  145. i := i
  146. enc := encs[i]
  147. t.Run(encNames[i], func(t *testing.T) {
  148. var buf bytes.Buffer
  149. enc.ResetContentSize(&buf, int64(len(decoded)))
  150. _, err := enc.Write(decoded)
  151. if err != nil {
  152. t.Fatal(err)
  153. }
  154. err = enc.Close()
  155. if err != nil {
  156. t.Fatal(err)
  157. }
  158. tmp, err = dec.DecodeAll(buf.Bytes(), tmp[:0])
  159. if err != nil {
  160. t.Fatal(err)
  161. }
  162. if !bytes.Equal(tmp, decoded) {
  163. t.Fatal("output mismatch")
  164. }
  165. var buf2 bytes.Buffer
  166. noDictEncs[i].Reset(&buf2)
  167. noDictEncs[i].Write(decoded)
  168. noDictEncs[i].Close()
  169. if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
  170. t.Log("reference:", len(in), "no dict:", buf2.Len(), "with dict:", buf.Len(), "SAVED:", buf2.Len()-buf.Len())
  171. // Check that we reduced this significantly
  172. if buf.Len() > 250 {
  173. t.Error("output was bigger than expected")
  174. }
  175. }
  176. })
  177. }
  178. })
  179. }
  180. }
  181. func TestEncoder_SmallDictFresh(t *testing.T) {
  182. // All files have CRC
  183. zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
  184. var dicts [][]byte
  185. var encs []func() *Encoder
  186. var noDictEncs []*Encoder
  187. var encNames []string
  188. for _, tt := range zr.File {
  189. if !strings.HasSuffix(tt.Name, ".dict") {
  190. continue
  191. }
  192. func() {
  193. r, err := tt.Open()
  194. if err != nil {
  195. t.Fatal(err)
  196. }
  197. defer r.Close()
  198. in, err := io.ReadAll(r)
  199. if err != nil {
  200. t.Fatal(err)
  201. }
  202. dicts = append(dicts, in)
  203. for level := SpeedFastest; level < speedLast; level++ {
  204. if isRaceTest && level >= SpeedBestCompression {
  205. break
  206. }
  207. level := level
  208. encs = append(encs, func() *Encoder {
  209. enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderDict(in), WithEncoderLevel(level), WithWindowSize(1<<17))
  210. if err != nil {
  211. t.Fatal(err)
  212. }
  213. return enc
  214. })
  215. encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
  216. enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithWindowSize(1<<17))
  217. if err != nil {
  218. t.Fatal(err)
  219. }
  220. noDictEncs = append(noDictEncs, enc)
  221. }
  222. }()
  223. }
  224. dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
  225. if err != nil {
  226. t.Fatal(err)
  227. return
  228. }
  229. defer dec.Close()
  230. for i, tt := range zr.File {
  231. if testing.Short() && i > 100 {
  232. break
  233. }
  234. if !strings.HasSuffix(tt.Name, ".zst") {
  235. continue
  236. }
  237. r, err := tt.Open()
  238. if err != nil {
  239. t.Fatal(err)
  240. }
  241. defer r.Close()
  242. in, err := io.ReadAll(r)
  243. if err != nil {
  244. t.Fatal(err)
  245. }
  246. decoded, err := dec.DecodeAll(in, nil)
  247. if err != nil {
  248. t.Fatal(err)
  249. }
  250. if testing.Short() && len(decoded) > 1000 {
  251. continue
  252. }
  253. t.Run("encodeall-"+tt.Name, func(t *testing.T) {
  254. // Attempt to compress with all dicts
  255. var b []byte
  256. var tmp []byte
  257. for i := range encs {
  258. i := i
  259. t.Run(encNames[i], func(t *testing.T) {
  260. enc := encs[i]()
  261. defer enc.Close()
  262. b = enc.EncodeAll(decoded, b[:0])
  263. tmp, err = dec.DecodeAll(in, tmp[:0])
  264. if err != nil {
  265. t.Fatal(err)
  266. }
  267. if !bytes.Equal(tmp, decoded) {
  268. t.Fatal("output mismatch")
  269. }
  270. tmp = noDictEncs[i].EncodeAll(decoded, tmp[:0])
  271. if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
  272. t.Log("reference:", len(in), "no dict:", len(tmp), "with dict:", len(b), "SAVED:", len(tmp)-len(b))
  273. // Check that we reduced this significantly
  274. if len(b) > 250 {
  275. t.Error("output was bigger than expected")
  276. }
  277. }
  278. })
  279. }
  280. })
  281. t.Run("stream-"+tt.Name, func(t *testing.T) {
  282. // Attempt to compress with all dicts
  283. var tmp []byte
  284. for i := range encs {
  285. i := i
  286. t.Run(encNames[i], func(t *testing.T) {
  287. enc := encs[i]()
  288. defer enc.Close()
  289. var buf bytes.Buffer
  290. enc.ResetContentSize(&buf, int64(len(decoded)))
  291. _, err := enc.Write(decoded)
  292. if err != nil {
  293. t.Fatal(err)
  294. }
  295. err = enc.Close()
  296. if err != nil {
  297. t.Fatal(err)
  298. }
  299. tmp, err = dec.DecodeAll(buf.Bytes(), tmp[:0])
  300. if err != nil {
  301. t.Fatal(err)
  302. }
  303. if !bytes.Equal(tmp, decoded) {
  304. t.Fatal("output mismatch")
  305. }
  306. var buf2 bytes.Buffer
  307. noDictEncs[i].Reset(&buf2)
  308. noDictEncs[i].Write(decoded)
  309. noDictEncs[i].Close()
  310. if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
  311. t.Log("reference:", len(in), "no dict:", buf2.Len(), "with dict:", buf.Len(), "SAVED:", buf2.Len()-buf.Len())
  312. // Check that we reduced this significantly
  313. if buf.Len() > 250 {
  314. t.Error("output was bigger than expected")
  315. }
  316. }
  317. })
  318. }
  319. })
  320. }
  321. }
  322. func benchmarkEncodeAllLimitedBySize(b *testing.B, lowerLimit int, upperLimit int) {
  323. zr := testCreateZipReader("testdata/dict-tests-small.zip", b)
  324. t := testing.TB(b)
  325. var dicts [][]byte
  326. var encs []*Encoder
  327. var encNames []string
  328. for _, tt := range zr.File {
  329. if !strings.HasSuffix(tt.Name, ".dict") {
  330. continue
  331. }
  332. func() {
  333. r, err := tt.Open()
  334. if err != nil {
  335. t.Fatal(err)
  336. }
  337. defer r.Close()
  338. in, err := io.ReadAll(r)
  339. if err != nil {
  340. t.Fatal(err)
  341. }
  342. dicts = append(dicts, in)
  343. for level := SpeedFastest; level < speedLast; level++ {
  344. enc, err := NewWriter(nil, WithEncoderDict(in), WithEncoderLevel(level))
  345. if err != nil {
  346. t.Fatal(err)
  347. }
  348. encs = append(encs, enc)
  349. encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
  350. }
  351. }()
  352. }
  353. const nPer = int(speedLast - SpeedFastest)
  354. dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
  355. if err != nil {
  356. t.Fatal(err)
  357. return
  358. }
  359. defer dec.Close()
  360. tested := make(map[int]struct{})
  361. for j, tt := range zr.File {
  362. if !strings.HasSuffix(tt.Name, ".zst") {
  363. continue
  364. }
  365. r, err := tt.Open()
  366. if err != nil {
  367. t.Fatal(err)
  368. }
  369. defer r.Close()
  370. in, err := io.ReadAll(r)
  371. if err != nil {
  372. t.Fatal(err)
  373. }
  374. decoded, err := dec.DecodeAll(in, nil)
  375. if err != nil {
  376. t.Fatal(err)
  377. }
  378. // Only test each size once
  379. if _, ok := tested[len(decoded)]; ok {
  380. continue
  381. }
  382. tested[len(decoded)] = struct{}{}
  383. if len(decoded) < lowerLimit {
  384. continue
  385. }
  386. if upperLimit > 0 && len(decoded) > upperLimit {
  387. continue
  388. }
  389. for i := range encs {
  390. // Only do 1 dict (4 encoders) for now.
  391. if i == nPer-1 {
  392. break
  393. }
  394. // Attempt to compress with all dicts
  395. encIdx := (i + j*nPer) % len(encs)
  396. enc := encs[encIdx]
  397. b.Run(fmt.Sprintf("length-%d-%s", len(decoded), encNames[encIdx]), func(b *testing.B) {
  398. b.RunParallel(func(pb *testing.PB) {
  399. dst := make([]byte, 0, len(decoded)+10)
  400. b.SetBytes(int64(len(decoded)))
  401. b.ResetTimer()
  402. b.ReportAllocs()
  403. for pb.Next() {
  404. dst = enc.EncodeAll(decoded, dst[:0])
  405. }
  406. })
  407. })
  408. }
  409. }
  410. }
  411. func BenchmarkEncodeAllDict0_1024(b *testing.B) {
  412. benchmarkEncodeAllLimitedBySize(b, 0, 1024)
  413. }
  414. func BenchmarkEncodeAllDict1024_8192(b *testing.B) {
  415. benchmarkEncodeAllLimitedBySize(b, 1024, 8192)
  416. }
  417. func BenchmarkEncodeAllDict8192_16384(b *testing.B) {
  418. benchmarkEncodeAllLimitedBySize(b, 8192, 16384)
  419. }
  420. func BenchmarkEncodeAllDict16384_65536(b *testing.B) {
  421. benchmarkEncodeAllLimitedBySize(b, 16384, 65536)
  422. }
  423. func BenchmarkEncodeAllDict65536_0(b *testing.B) {
  424. benchmarkEncodeAllLimitedBySize(b, 65536, 0)
  425. }
  426. func TestDecoder_MoreDicts(t *testing.T) {
  427. // All files have CRC
  428. // https://files.klauspost.com/compress/zstd-dict-tests.zip
  429. fn := "testdata/zstd-dict-tests.zip"
  430. data, err := os.ReadFile(fn)
  431. if err != nil {
  432. t.Skip("extended dict test not found.")
  433. }
  434. zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
  435. if err != nil {
  436. t.Fatal(err)
  437. }
  438. var dicts [][]byte
  439. for _, tt := range zr.File {
  440. if !strings.HasSuffix(tt.Name, ".dict") {
  441. continue
  442. }
  443. func() {
  444. r, err := tt.Open()
  445. if err != nil {
  446. t.Fatal(err)
  447. }
  448. defer r.Close()
  449. in, err := io.ReadAll(r)
  450. if err != nil {
  451. t.Fatal(err)
  452. }
  453. dicts = append(dicts, in)
  454. }()
  455. }
  456. dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
  457. if err != nil {
  458. t.Fatal(err)
  459. return
  460. }
  461. defer dec.Close()
  462. for i, tt := range zr.File {
  463. if !strings.HasSuffix(tt.Name, ".zst") {
  464. continue
  465. }
  466. if testing.Short() && i > 50 {
  467. continue
  468. }
  469. t.Run("decodeall-"+tt.Name, func(t *testing.T) {
  470. r, err := tt.Open()
  471. if err != nil {
  472. t.Fatal(err)
  473. }
  474. defer r.Close()
  475. in, err := io.ReadAll(r)
  476. if err != nil {
  477. t.Fatal(err)
  478. }
  479. got, err := dec.DecodeAll(in, nil)
  480. if err != nil {
  481. t.Fatal(err)
  482. }
  483. _, err = dec.DecodeAll(in, got[:0])
  484. if err != nil {
  485. t.Fatal(err)
  486. }
  487. })
  488. }
  489. }
  490. func TestDecoder_MoreDicts2(t *testing.T) {
  491. // All files have CRC
  492. // https://files.klauspost.com/compress/zstd-dict-tests.zip
  493. fn := "testdata/zstd-dict-tests.zip"
  494. data, err := os.ReadFile(fn)
  495. if err != nil {
  496. t.Skip("extended dict test not found.")
  497. }
  498. zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
  499. if err != nil {
  500. t.Fatal(err)
  501. }
  502. var dicts [][]byte
  503. for _, tt := range zr.File {
  504. if !strings.HasSuffix(tt.Name, ".dict") {
  505. continue
  506. }
  507. func() {
  508. r, err := tt.Open()
  509. if err != nil {
  510. t.Fatal(err)
  511. }
  512. defer r.Close()
  513. in, err := io.ReadAll(r)
  514. if err != nil {
  515. t.Fatal(err)
  516. }
  517. dicts = append(dicts, in)
  518. }()
  519. }
  520. dec, err := NewReader(nil, WithDecoderConcurrency(2), WithDecoderDicts(dicts...))
  521. if err != nil {
  522. t.Fatal(err)
  523. return
  524. }
  525. defer dec.Close()
  526. for i, tt := range zr.File {
  527. if !strings.HasSuffix(tt.Name, ".zst") {
  528. continue
  529. }
  530. if testing.Short() && i > 50 {
  531. continue
  532. }
  533. t.Run("decodeall-"+tt.Name, func(t *testing.T) {
  534. r, err := tt.Open()
  535. if err != nil {
  536. t.Fatal(err)
  537. }
  538. defer r.Close()
  539. in, err := io.ReadAll(r)
  540. if err != nil {
  541. t.Fatal(err)
  542. }
  543. got, err := dec.DecodeAll(in, nil)
  544. if err != nil {
  545. t.Fatal(err)
  546. }
  547. _, err = dec.DecodeAll(in, got[:0])
  548. if err != nil {
  549. t.Fatal(err)
  550. }
  551. })
  552. }
  553. }
  554. func readDicts(tb testing.TB, zr *zip.Reader) [][]byte {
  555. var dicts [][]byte
  556. for _, tt := range zr.File {
  557. if !strings.HasSuffix(tt.Name, ".dict") {
  558. continue
  559. }
  560. func() {
  561. r, err := tt.Open()
  562. if err != nil {
  563. tb.Fatal(err)
  564. }
  565. defer r.Close()
  566. in, err := io.ReadAll(r)
  567. if err != nil {
  568. tb.Fatal(err)
  569. }
  570. dicts = append(dicts, in)
  571. }()
  572. }
  573. return dicts
  574. }
  575. // Test decoding of zstd --patch-from output.
  576. func TestDecoderRawDict(t *testing.T) {
  577. t.Parallel()
  578. dict, err := os.ReadFile("testdata/delta/source.txt")
  579. if err != nil {
  580. t.Fatal(err)
  581. }
  582. delta, err := os.Open("testdata/delta/target.txt.zst")
  583. if err != nil {
  584. t.Fatal(err)
  585. }
  586. defer delta.Close()
  587. dec, err := NewReader(delta, WithDecoderDictRaw(0, dict))
  588. if err != nil {
  589. t.Fatal(err)
  590. }
  591. out, err := io.ReadAll(dec)
  592. if err != nil {
  593. t.Fatal(err)
  594. }
  595. ref, err := os.ReadFile("testdata/delta/target.txt")
  596. if err != nil {
  597. t.Fatal(err)
  598. }
  599. if !bytes.Equal(out, ref) {
  600. t.Errorf("mismatch: got %q, wanted %q", out, ref)
  601. }
  602. }