fuzz_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. //go:build go1.18
  2. // +build go1.18
  3. package flate
  4. import (
  5. "bytes"
  6. "flag"
  7. "io"
  8. "os"
  9. "strconv"
  10. "testing"
  11. "github.com/klauspost/compress/internal/fuzz"
  12. )
  13. // Fuzzing tweaks:
  14. var fuzzStartF = flag.Int("start", HuffmanOnly, "Start fuzzing at this level")
  15. var fuzzEndF = flag.Int("end", BestCompression, "End fuzzing at this level (inclusive)")
  16. var fuzzMaxF = flag.Int("max", 1<<20, "Maximum input size")
  17. var fuzzSLF = flag.Bool("sl", true, "Include stateless encodes")
  18. func TestMain(m *testing.M) {
  19. flag.Parse()
  20. os.Exit(m.Run())
  21. }
  22. func FuzzEncoding(f *testing.F) {
  23. fuzz.AddFromZip(f, "testdata/regression.zip", fuzz.TypeRaw, false)
  24. fuzz.AddFromZip(f, "testdata/fuzz/encode-raw-corpus.zip", fuzz.TypeRaw, testing.Short())
  25. fuzz.AddFromZip(f, "testdata/fuzz/FuzzEncoding.zip", fuzz.TypeGoFuzz, testing.Short())
  26. startFuzz := *fuzzStartF
  27. endFuzz := *fuzzEndF
  28. maxSize := *fuzzMaxF
  29. stateless := *fuzzSLF
  30. decoder := NewReader(nil)
  31. buf := new(bytes.Buffer)
  32. encs := make([]*Writer, endFuzz-startFuzz+1)
  33. for i := range encs {
  34. var err error
  35. encs[i], err = NewWriter(nil, i+startFuzz)
  36. if err != nil {
  37. f.Fatal(err.Error())
  38. }
  39. }
  40. f.Fuzz(func(t *testing.T, data []byte) {
  41. if len(data) > maxSize {
  42. return
  43. }
  44. for level := startFuzz; level <= endFuzz; level++ {
  45. msg := "level " + strconv.Itoa(level) + ":"
  46. buf.Reset()
  47. fw := encs[level-startFuzz]
  48. fw.Reset(buf)
  49. n, err := fw.Write(data)
  50. if n != len(data) {
  51. t.Fatal(msg + "short write")
  52. }
  53. if err != nil {
  54. t.Fatal(msg + err.Error())
  55. }
  56. err = fw.Close()
  57. if err != nil {
  58. t.Fatal(msg + err.Error())
  59. }
  60. decoder.(Resetter).Reset(buf, nil)
  61. data2, err := io.ReadAll(decoder)
  62. if err != nil {
  63. t.Fatal(msg + err.Error())
  64. }
  65. if !bytes.Equal(data, data2) {
  66. t.Fatal(msg + "not equal")
  67. }
  68. // Do it again...
  69. msg = "level " + strconv.Itoa(level) + " (reset):"
  70. buf.Reset()
  71. fw.Reset(buf)
  72. n, err = fw.Write(data)
  73. if n != len(data) {
  74. t.Fatal(msg + "short write")
  75. }
  76. if err != nil {
  77. t.Fatal(msg + err.Error())
  78. }
  79. err = fw.Close()
  80. if err != nil {
  81. t.Fatal(msg + err.Error())
  82. }
  83. decoder.(Resetter).Reset(buf, nil)
  84. data2, err = io.ReadAll(decoder)
  85. if err != nil {
  86. t.Fatal(msg + err.Error())
  87. }
  88. if !bytes.Equal(data, data2) {
  89. t.Fatal(msg + "not equal")
  90. }
  91. }
  92. if !stateless {
  93. return
  94. }
  95. // Split into two and use history...
  96. buf.Reset()
  97. err := StatelessDeflate(buf, data[:len(data)/2], false, nil)
  98. if err != nil {
  99. t.Error(err)
  100. }
  101. // Use top half as dictionary...
  102. dict := data[:len(data)/2]
  103. err = StatelessDeflate(buf, data[len(data)/2:], true, dict)
  104. if err != nil {
  105. t.Error(err)
  106. }
  107. decoder.(Resetter).Reset(buf, nil)
  108. data2, err := io.ReadAll(decoder)
  109. if err != nil {
  110. t.Error(err)
  111. }
  112. if !bytes.Equal(data, data2) {
  113. //fmt.Printf("want:%x\ngot: %x\n", data1, data2)
  114. t.Error("not equal")
  115. }
  116. })
  117. }