From c6b8c13721e5a05e5c8787546c4b870fa32b54ef Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 24 May 2017 20:16:08 -0400 Subject: [PATCH] move recordio to Paddle from github.com/wangkuiyi/recordio --- paddle/go/recordio/README.md | 36 ++++ paddle/go/recordio/chunk.go | 181 +++++++++++++++++++ paddle/go/recordio/header.go | 59 ++++++ paddle/go/recordio/reader.go | 135 ++++++++++++++ paddle/go/recordio/recordio_internal_test.go | 90 +++++++++ paddle/go/recordio/recordio_test.go | 81 +++++++++ paddle/go/recordio/writer.go | 60 ++++++ 7 files changed, 642 insertions(+) create mode 100644 paddle/go/recordio/README.md create mode 100644 paddle/go/recordio/chunk.go create mode 100644 paddle/go/recordio/header.go create mode 100644 paddle/go/recordio/reader.go create mode 100644 paddle/go/recordio/recordio_internal_test.go create mode 100644 paddle/go/recordio/recordio_test.go create mode 100644 paddle/go/recordio/writer.go diff --git a/paddle/go/recordio/README.md b/paddle/go/recordio/README.md new file mode 100644 index 000000000..8b0b9308b --- /dev/null +++ b/paddle/go/recordio/README.md @@ -0,0 +1,36 @@ +# RecordIO + +## Write + +```go +f, e := os.Create("a_file.recordio") +w := recordio.NewWriter(f) +w.Write([]byte("Hello")) +w.Write([]byte("World!")) +w.Close() +``` + +## Read + +1. Load chunk index: + + ```go + f, e := os.Open("a_file.recordio") + idx, e := recordio.LoadIndex(f) + fmt.Println("Total records: ", idx.Len()) + ``` + +2. Create one or more scanner to read a range of records. The + following example reads the range + [1, 3), i.e., the second and the third records: + + ```go + f, e := os.Open("a_file.recordio") + s := recrodio.NewScanner(f, idx, 1, 3) + for s.Scan() { + fmt.Println(string(s.Record())) + } + if s.Err() != nil && s.Err() != io.EOF { + log.Fatalf("Something wrong with scanning: %v", e) + } + ``` diff --git a/paddle/go/recordio/chunk.go b/paddle/go/recordio/chunk.go new file mode 100644 index 000000000..4e983ab72 --- /dev/null +++ b/paddle/go/recordio/chunk.go @@ -0,0 +1,181 @@ +package recordio + +import ( + "bytes" + "compress/gzip" + "encoding/binary" + "fmt" + "hash/crc32" + "io" + + "github.com/golang/snappy" +) + +// A Chunk contains the Header and optionally compressed records. To +// create a chunk, just use ch := &Chunk{}. +type Chunk struct { + records [][]byte + numBytes int // sum of record lengths. +} + +func (ch *Chunk) add(record []byte) { + ch.records = append(ch.records, record) + ch.numBytes += len(record) +} + +// dump the chunk into w, and clears the chunk and makes it ready for +// the next add invocation. +func (ch *Chunk) dump(w io.Writer, compressorIndex int) error { + // NOTE: don't check ch.numBytes instead, because empty + // records are allowed. + if len(ch.records) == 0 { + return nil + } + + // Write raw records and their lengths into data buffer. + var data bytes.Buffer + + for _, r := range ch.records { + var rs [4]byte + binary.LittleEndian.PutUint32(rs[:], uint32(len(r))) + + if _, e := data.Write(rs[:]); e != nil { + return fmt.Errorf("Failed to write record length: %v", e) + } + + if _, e := data.Write(r); e != nil { + return fmt.Errorf("Failed to write record: %v", e) + } + } + + compressed, e := compressData(&data, compressorIndex) + if e != nil { + return e + } + + // Write chunk header and compressed data. + hdr := &Header{ + checkSum: crc32.ChecksumIEEE(compressed.Bytes()), + compressor: uint32(compressorIndex), + compressedSize: uint32(compressed.Len()), + numRecords: uint32(len(ch.records)), + } + + if _, e := hdr.write(w); e != nil { + return fmt.Errorf("Failed to write chunk header: %v", e) + } + + if _, e := w.Write(compressed.Bytes()); e != nil { + return fmt.Errorf("Failed to write chunk data: %v", e) + } + + // Clear the current chunk. + ch.records = nil + ch.numBytes = 0 + + return nil +} + +type noopCompressor struct { + *bytes.Buffer +} + +func (c *noopCompressor) Close() error { + return nil +} + +func compressData(src io.Reader, compressorIndex int) (*bytes.Buffer, error) { + compressed := new(bytes.Buffer) + var compressor io.WriteCloser + + switch compressorIndex { + case NoCompression: + compressor = &noopCompressor{compressed} + case Snappy: + compressor = snappy.NewBufferedWriter(compressed) + case Gzip: + compressor = gzip.NewWriter(compressed) + default: + return nil, fmt.Errorf("Unknown compression algorithm: %d", compressorIndex) + } + + if _, e := io.Copy(compressor, src); e != nil { + return nil, fmt.Errorf("Failed to compress chunk data: %v", e) + } + compressor.Close() + + return compressed, nil +} + +// parse the specified chunk from r. +func parseChunk(r io.ReadSeeker, chunkOffset int64) (*Chunk, error) { + var e error + var hdr *Header + + if _, e = r.Seek(chunkOffset, io.SeekStart); e != nil { + return nil, fmt.Errorf("Failed to seek chunk: %v", e) + } + + hdr, e = parseHeader(r) + if e != nil { + return nil, fmt.Errorf("Failed to parse chunk header: %v", e) + } + + var buf bytes.Buffer + if _, e = io.CopyN(&buf, r, int64(hdr.compressedSize)); e != nil { + return nil, fmt.Errorf("Failed to read chunk data: %v", e) + } + + if hdr.checkSum != crc32.ChecksumIEEE(buf.Bytes()) { + return nil, fmt.Errorf("Checksum checking failed.") + } + + deflated, e := deflateData(&buf, int(hdr.compressor)) + if e != nil { + return nil, e + } + + ch := &Chunk{} + for i := 0; i < int(hdr.numRecords); i++ { + var rs [4]byte + if _, e = deflated.Read(rs[:]); e != nil { + return nil, fmt.Errorf("Failed to read record length: %v", e) + } + + r := make([]byte, binary.LittleEndian.Uint32(rs[:])) + if _, e = deflated.Read(r); e != nil { + return nil, fmt.Errorf("Failed to read a record: %v", e) + } + + ch.records = append(ch.records, r) + ch.numBytes += len(r) + } + + return ch, nil +} + +func deflateData(src io.Reader, compressorIndex int) (*bytes.Buffer, error) { + var e error + var deflator io.Reader + + switch compressorIndex { + case NoCompression: + deflator = src + case Snappy: + deflator = snappy.NewReader(src) + case Gzip: + deflator, e = gzip.NewReader(src) + if e != nil { + return nil, fmt.Errorf("Failed to create gzip reader: %v", e) + } + default: + return nil, fmt.Errorf("Unknown compression algorithm: %d", compressorIndex) + } + + deflated := new(bytes.Buffer) + if _, e = io.Copy(deflated, deflator); e != nil { + return nil, fmt.Errorf("Failed to deflate chunk data: %v", e) + } + + return deflated, nil +} diff --git a/paddle/go/recordio/header.go b/paddle/go/recordio/header.go new file mode 100644 index 000000000..d3aefae36 --- /dev/null +++ b/paddle/go/recordio/header.go @@ -0,0 +1,59 @@ +package recordio + +import ( + "encoding/binary" + "fmt" + "io" +) + +const ( + // NoCompression means writing raw chunk data into files. + // With other choices, chunks are compressed before written. + NoCompression = iota + // Snappy had been the default compressing algorithm widely + // used in Google. It compromises between speech and + // compression ratio. + Snappy + // Gzip is a well-known compression algorithm. It is + // recommmended only you are looking for compression ratio. + Gzip + + magicNumber uint32 = 0x01020304 + defaultCompressor = Snappy +) + +// Header is the metadata of Chunk. +type Header struct { + checkSum uint32 + compressor uint32 + compressedSize uint32 + numRecords uint32 +} + +func (c *Header) write(w io.Writer) (int, error) { + var buf [20]byte + binary.LittleEndian.PutUint32(buf[0:4], magicNumber) + binary.LittleEndian.PutUint32(buf[4:8], c.checkSum) + binary.LittleEndian.PutUint32(buf[8:12], c.compressor) + binary.LittleEndian.PutUint32(buf[12:16], c.compressedSize) + binary.LittleEndian.PutUint32(buf[16:20], c.numRecords) + return w.Write(buf[:]) +} + +func parseHeader(r io.Reader) (*Header, error) { + var buf [20]byte + if _, e := r.Read(buf[:]); e != nil { + return nil, e + } + + if v := binary.LittleEndian.Uint32(buf[0:4]); v != magicNumber { + return nil, fmt.Errorf("Failed to parse magic number") + } + + return &Header{ + checkSum: binary.LittleEndian.Uint32(buf[4:8]), + compressor: binary.LittleEndian.Uint32(buf[8:12]), + compressedSize: binary.LittleEndian.Uint32(buf[12:16]), + numRecords: binary.LittleEndian.Uint32(buf[16:20]), + }, nil +} diff --git a/paddle/go/recordio/reader.go b/paddle/go/recordio/reader.go new file mode 100644 index 000000000..a12c604f7 --- /dev/null +++ b/paddle/go/recordio/reader.go @@ -0,0 +1,135 @@ +package recordio + +import "io" + +// Index consists offsets and sizes of the consequetive chunks in a RecordIO file. +type Index struct { + chunkOffsets []int64 + chunkLens []uint32 + numRecords int // the number of all records in a file. + chunkRecords []int // the number of records in chunks. +} + +// LoadIndex scans the file and parse chunkOffsets, chunkLens, and len. +func LoadIndex(r io.ReadSeeker) (*Index, error) { + f := &Index{} + offset := int64(0) + var e error + var hdr *Header + + for { + hdr, e = parseHeader(r) + if e != nil { + break + } + + f.chunkOffsets = append(f.chunkOffsets, offset) + f.chunkLens = append(f.chunkLens, hdr.numRecords) + f.chunkRecords = append(f.chunkRecords, int(hdr.numRecords)) + f.numRecords += int(hdr.numRecords) + + offset, e = r.Seek(int64(hdr.compressedSize), io.SeekCurrent) + if e != nil { + break + } + } + + if e == io.EOF { + return f, nil + } + return nil, e +} + +// NumRecords returns the total number of records in a RecordIO file. +func (r *Index) NumRecords() int { + return r.numRecords +} + +// NumChunks returns the total number of chunks in a RecordIO file. +func (r *Index) NumChunks() int { + return len(r.chunkLens) +} + +// ChunkIndex return the Index of i-th Chunk. +func (r *Index) ChunkIndex(i int) *Index { + idx := &Index{} + idx.chunkOffsets = []int64{r.chunkOffsets[i]} + idx.chunkLens = []uint32{r.chunkLens[i]} + idx.chunkRecords = []int{r.chunkRecords[i]} + idx.numRecords = idx.chunkRecords[0] + return idx +} + +// Locate returns the index of chunk that contains the given record, +// and the record index within the chunk. It returns (-1, -1) if the +// record is out of range. +func (r *Index) Locate(recordIndex int) (int, int) { + sum := 0 + for i, l := range r.chunkLens { + sum += int(l) + if recordIndex < sum { + return i, recordIndex - sum + int(l) + } + } + return -1, -1 +} + +// Scanner scans records in a specified range within [0, numRecords). +type Scanner struct { + reader io.ReadSeeker + index *Index + start, end, cur int + chunkIndex int + chunk *Chunk + err error +} + +// NewScanner creates a scanner that sequencially reads records in the +// range [start, start+len). If start < 0, it scans from the +// beginning. If len < 0, it scans till the end of file. +func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { + if start < 0 { + start = 0 + } + if len < 0 || start+len >= index.NumRecords() { + len = index.NumRecords() - start + } + + return &Scanner{ + reader: r, + index: index, + start: start, + end: start + len, + cur: start - 1, // The intial status required by Scan. + chunkIndex: -1, + chunk: &Chunk{}, + } +} + +// Scan moves the cursor forward for one record and loads the chunk +// containing the record if not yet. +func (s *Scanner) Scan() bool { + s.cur++ + + if s.cur >= s.end { + s.err = io.EOF + } else { + if ci, _ := s.index.Locate(s.cur); s.chunkIndex != ci { + s.chunkIndex = ci + s.chunk, s.err = parseChunk(s.reader, s.index.chunkOffsets[ci]) + } + } + + return s.err == nil +} + +// Record returns the record under the current cursor. +func (s *Scanner) Record() []byte { + _, ri := s.index.Locate(s.cur) + return s.chunk.records[ri] +} + +// Error returns the error that stopped Scan. +func (s *Scanner) Error() error { + return s.err +} diff --git a/paddle/go/recordio/recordio_internal_test.go b/paddle/go/recordio/recordio_internal_test.go new file mode 100644 index 000000000..e0f7dd040 --- /dev/null +++ b/paddle/go/recordio/recordio_internal_test.go @@ -0,0 +1,90 @@ +package recordio + +import ( + "bytes" + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestChunkHead(t *testing.T) { + assert := assert.New(t) + + c := &Header{ + checkSum: 123, + compressor: 456, + compressedSize: 789, + } + + var buf bytes.Buffer + _, e := c.write(&buf) + assert.Nil(e) + + cc, e := parseHeader(&buf) + assert.Nil(e) + assert.Equal(c, cc) +} + +func TestWriteAndRead(t *testing.T) { + assert := assert.New(t) + + data := []string{ + "12345", + "1234", + "12"} + + var buf bytes.Buffer + w := NewWriter(&buf, 10, NoCompression) // use a small maxChunkSize. + + n, e := w.Write([]byte(data[0])) // not exceed chunk size. + assert.Nil(e) + assert.Equal(5, n) + + n, e = w.Write([]byte(data[1])) // not exceed chunk size. + assert.Nil(e) + assert.Equal(4, n) + + n, e = w.Write([]byte(data[2])) // exeeds chunk size, dump and create a new chunk. + assert.Nil(e) + assert.Equal(n, 2) + + assert.Nil(w.Close()) // flush the second chunk. + assert.Nil(w.Writer) + + n, e = w.Write([]byte("anything")) // not effective after close. + assert.NotNil(e) + assert.Equal(n, 0) + + idx, e := LoadIndex(bytes.NewReader(buf.Bytes())) + assert.Nil(e) + assert.Equal([]uint32{2, 1}, idx.chunkLens) + assert.Equal( + []int64{0, + int64(4 + // magic number + unsafe.Sizeof(Header{}) + + 5 + // first record + 4 + // second record + 2*4)}, // two record legnths + idx.chunkOffsets) + + s := NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) + i := 0 + for s.Scan() { + assert.Equal(data[i], string(s.Record())) + i++ + } +} + +func TestWriteEmptyFile(t *testing.T) { + assert := assert.New(t) + + var buf bytes.Buffer + w := NewWriter(&buf, 10, NoCompression) // use a small maxChunkSize. + assert.Nil(w.Close()) + assert.Equal(0, buf.Len()) + + idx, e := LoadIndex(bytes.NewReader(buf.Bytes())) + assert.Nil(e) + assert.Equal(0, idx.NumRecords()) +} diff --git a/paddle/go/recordio/recordio_test.go b/paddle/go/recordio/recordio_test.go new file mode 100644 index 000000000..8bf1b020a --- /dev/null +++ b/paddle/go/recordio/recordio_test.go @@ -0,0 +1,81 @@ +package recordio_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/PaddlePaddle/Paddle/paddle/go/recordio" +) + +func TestWriteRead(t *testing.T) { + const total = 1000 + var buf bytes.Buffer + w := recordio.NewWriter(&buf, 0, -1) + for i := 0; i < total; i++ { + _, err := w.Write(make([]byte, i)) + if err != nil { + t.Fatal(err) + } + } + w.Close() + + idx, err := recordio.LoadIndex(bytes.NewReader(buf.Bytes())) + if err != nil { + t.Fatal(err) + } + + if idx.NumRecords() != total { + t.Fatal("num record does not match:", idx.NumRecords(), total) + } + + s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) + i := 0 + for s.Scan() { + if !reflect.DeepEqual(s.Record(), make([]byte, i)) { + t.Fatal("not equal:", len(s.Record()), len(make([]byte, i))) + } + i++ + } + + if i != total { + t.Fatal("total count not match:", i, total) + } +} + +func TestChunkIndex(t *testing.T) { + const total = 1000 + var buf bytes.Buffer + w := recordio.NewWriter(&buf, 0, -1) + for i := 0; i < total; i++ { + _, err := w.Write(make([]byte, i)) + if err != nil { + t.Fatal(err) + } + } + w.Close() + + idx, err := recordio.LoadIndex(bytes.NewReader(buf.Bytes())) + if err != nil { + t.Fatal(err) + } + + if idx.NumChunks() != total { + t.Fatal("unexpected chunk num:", idx.NumChunks(), total) + } + + for i := 0; i < total; i++ { + newIdx := idx.ChunkIndex(i) + s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1) + j := 0 + for s.Scan() { + if !reflect.DeepEqual(s.Record(), make([]byte, i)) { + t.Fatal("not equal:", len(s.Record()), len(make([]byte, i))) + } + j++ + } + if j != 1 { + t.Fatal("unexpected record per chunk:", j) + } + } +} diff --git a/paddle/go/recordio/writer.go b/paddle/go/recordio/writer.go new file mode 100644 index 000000000..39112e518 --- /dev/null +++ b/paddle/go/recordio/writer.go @@ -0,0 +1,60 @@ +package recordio + +import ( + "fmt" + "io" +) + +const ( + defaultMaxChunkSize = 32 * 1024 * 1024 +) + +// Writer creates a RecordIO file. +type Writer struct { + io.Writer // Set to nil to mark a closed writer. + chunk *Chunk + maxChunkSize int // total records size, excluding metadata, before compression. + compressor int +} + +// NewWriter creates a RecordIO file writer. Each chunk is compressed +// using the deflate algorithm given compression level. Note that +// level 0 means no compression and -1 means default compression. +func NewWriter(w io.Writer, maxChunkSize, compressor int) *Writer { + if maxChunkSize < 0 { + maxChunkSize = defaultMaxChunkSize + } + + if compressor < 0 { + compressor = defaultCompressor + } + + return &Writer{ + Writer: w, + chunk: &Chunk{}, + maxChunkSize: maxChunkSize, + compressor: compressor} +} + +// Writes a record. It returns an error if Close has been called. +func (w *Writer) Write(record []byte) (int, error) { + if w.Writer == nil { + return 0, fmt.Errorf("Cannot write since writer had been closed") + } + + if w.chunk.numBytes+len(record) > w.maxChunkSize { + if e := w.chunk.dump(w.Writer, w.compressor); e != nil { + return 0, e + } + } + + w.chunk.add(record) + return len(record), nil +} + +// Close flushes the current chunk and makes the writer invalid. +func (w *Writer) Close() error { + e := w.chunk.dump(w.Writer, w.compressor) + w.Writer = nil + return e +} -- GitLab