write_parquet_test.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package schema
  2. import (
  3. "fmt"
  4. "github.com/parquet-go/parquet-go"
  5. "github.com/parquet-go/parquet-go/compress/zstd"
  6. "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
  7. "io"
  8. "os"
  9. "testing"
  10. )
  11. func TestWriteReadParquet(t *testing.T) {
  12. // create a schema_pb.RecordType
  13. recordType := RecordTypeBegin().
  14. WithField("ID", TypeInt64).
  15. WithField("CreatedAt", TypeInt64).
  16. WithRecordField("Person",
  17. RecordTypeBegin().
  18. WithField("zName", TypeString).
  19. WithField("emails", ListOf(TypeString)).
  20. RecordTypeEnd()).
  21. WithField("Company", TypeString).
  22. WithRecordField("Address",
  23. RecordTypeBegin().
  24. WithField("Street", TypeString).
  25. WithField("City", TypeString).
  26. RecordTypeEnd()).
  27. RecordTypeEnd()
  28. fmt.Printf("RecordType: %v\n", recordType)
  29. // create a parquet schema
  30. parquetSchema, err := ToParquetSchema("example", recordType)
  31. if err != nil {
  32. t.Fatalf("ToParquetSchema failed: %v", err)
  33. }
  34. fmt.Printf("ParquetSchema: %v\n", parquetSchema)
  35. fmt.Printf("Go Type: %+v\n", parquetSchema.GoType())
  36. filename := "example.parquet"
  37. count := 3
  38. testWritingParquetFile(t, count, filename, parquetSchema, recordType)
  39. total := testReadingParquetFile(t, filename, parquetSchema, recordType)
  40. if total != count {
  41. t.Fatalf("total != 128*1024: %v", total)
  42. }
  43. if err = os.Remove(filename); err != nil {
  44. t.Fatalf("os.Remove failed: %v", err)
  45. }
  46. }
  47. func testWritingParquetFile(t *testing.T, count int, filename string, parquetSchema *parquet.Schema, recordType *schema_pb.RecordType) {
  48. parquetLevels, err := ToParquetLevels(recordType)
  49. if err != nil {
  50. t.Fatalf("ToParquetLevels failed: %v", err)
  51. }
  52. // create a parquet file
  53. file, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0664)
  54. if err != nil {
  55. t.Fatalf("os.Open failed: %v", err)
  56. }
  57. defer file.Close()
  58. writer := parquet.NewWriter(file, parquetSchema, parquet.Compression(&zstd.Codec{Level: zstd.DefaultLevel}))
  59. rowBuilder := parquet.NewRowBuilder(parquetSchema)
  60. for i := 0; i < count; i++ {
  61. rowBuilder.Reset()
  62. // generate random data
  63. recordValue := RecordBegin().
  64. SetInt64("ID", 1+int64(i)).
  65. SetInt64("CreatedAt", 2+2*int64(i)).
  66. SetRecord("Person",
  67. RecordBegin().
  68. SetString("zName", fmt.Sprintf("john_%d", i)).
  69. SetStringList("emails",
  70. fmt.Sprintf("john_%d@a.com", i),
  71. fmt.Sprintf("john_%d@b.com", i),
  72. fmt.Sprintf("john_%d@c.com", i),
  73. fmt.Sprintf("john_%d@d.com", i),
  74. fmt.Sprintf("john_%d@e.com", i)).
  75. RecordEnd()).
  76. SetString("Company", fmt.Sprintf("company_%d", i)).
  77. RecordEnd()
  78. AddRecordValue(rowBuilder, recordType, parquetLevels, recordValue)
  79. if count < 10 {
  80. fmt.Printf("RecordValue: %v\n", recordValue)
  81. }
  82. row := rowBuilder.Row()
  83. if count < 10 {
  84. fmt.Printf("Row: %+v\n", row)
  85. }
  86. if err != nil {
  87. t.Fatalf("rowBuilder.Build failed: %v", err)
  88. }
  89. if _, err = writer.WriteRows([]parquet.Row{row}); err != nil {
  90. t.Fatalf("writer.Write failed: %v", err)
  91. }
  92. }
  93. if err = writer.Close(); err != nil {
  94. t.Fatalf("writer.WriteStop failed: %v", err)
  95. }
  96. }
  97. func testReadingParquetFile(t *testing.T, filename string, parquetSchema *parquet.Schema, recordType *schema_pb.RecordType) (total int) {
  98. parquetLevels, err := ToParquetLevels(recordType)
  99. if err != nil {
  100. t.Fatalf("ToParquetLevels failed: %v", err)
  101. }
  102. // read the parquet file
  103. file, err := os.Open(filename)
  104. if err != nil {
  105. t.Fatalf("os.Open failed: %v", err)
  106. }
  107. defer file.Close()
  108. reader := parquet.NewReader(file, parquetSchema)
  109. rows := make([]parquet.Row, 128)
  110. for {
  111. rowCount, err := reader.ReadRows(rows)
  112. if err != nil {
  113. if err == io.EOF {
  114. break
  115. }
  116. t.Fatalf("reader.Read failed: %v", err)
  117. }
  118. for i := 0; i < rowCount; i++ {
  119. row := rows[i]
  120. // convert parquet row to schema_pb.RecordValue
  121. recordValue, err := ToRecordValue(recordType, parquetLevels, row)
  122. if err != nil {
  123. t.Fatalf("ToRecordValue failed: %v", err)
  124. }
  125. if rowCount < 10 {
  126. fmt.Printf("RecordValue: %v\n", recordValue)
  127. }
  128. }
  129. total += rowCount
  130. }
  131. fmt.Printf("total: %v\n", total)
  132. return
  133. }