diff --git a/paddle/go/crecordio/crecordio.go b/paddle/go/crecordio/crecordio.go index e96bb49017b7d72eb14ec65be13222596e4e3280..33f97de8cf73be01313e0a8b27815489dbaceb64 100644 --- a/paddle/go/crecordio/crecordio.go +++ b/paddle/go/crecordio/crecordio.go @@ -25,7 +25,7 @@ type writer struct { } type reader struct { - scanner *recordio.MultiScanner + scanner *recordio.Scanner } func cArrayToSlice(p unsafe.Pointer, len int) []byte { @@ -55,21 +55,21 @@ func create_recordio_writer(path *C.char) C.writer { return addWriter(writer) } -//export write_recordio -func write_recordio(writer C.writer, buf *C.uchar, size C.int) int { +//export recordio_write +func recordio_write(writer C.writer, buf *C.uchar, size C.int) C.int { w := getWriter(writer) b := cArrayToSlice(unsafe.Pointer(buf), int(size)) - _, err := w.w.Write(b) + c, err := w.w.Write(b) if err != nil { log.Println(err) return -1 } - return 0 + return C.int(c) } -//export release_recordio -func release_recordio(writer C.writer) { +//export release_recordio_writer +func release_recordio_writer(writer C.writer) { w := removeWriter(writer) w.w.Close() w.f.Close() @@ -78,7 +78,7 @@ func release_recordio(writer C.writer) { //export create_recordio_reader func create_recordio_reader(path *C.char) C.reader { p := C.GoString(path) - s, err := recordio.NewMultiScanner(strings.Split(p, ",")) + s, err := recordio.NewScanner(strings.Split(p, ",")...) if err != nil { log.Println(err) return -1 @@ -88,24 +88,23 @@ func create_recordio_reader(path *C.char) C.reader { return addReader(r) } -//export read_next_item -func read_next_item(reader C.reader, size *C.int) *C.uchar { +//export recordio_read +func recordio_read(reader C.reader, record **C.uchar) C.int { r := getReader(reader) if r.scanner.Scan() { buf := r.scanner.Record() - *size = C.int(len(buf)) - if len(buf) == 0 { - return (*C.uchar)(nullPtr) + *record = (*C.uchar)(nullPtr) + return 0 } - ptr := C.malloc(C.size_t(len(buf))) - C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf))) - return (*C.uchar)(ptr) + size := C.int(len(buf)) + *record = (*C.uchar)(C.malloc(C.size_t(len(buf)))) + C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&buf[0]), C.size_t(len(buf))) + return size } - *size = -1 - return (*C.uchar)(nullPtr) + return -1 } //export release_recordio_reader diff --git a/paddle/go/crecordio/test/test.c b/paddle/go/crecordio/test/test.c index 54c3773ee94336f7f4e713ff3571e35275465fbf..b25536a9d76a8654cf1b15075c76887495e1d9bd 100644 --- a/paddle/go/crecordio/test/test.c +++ b/paddle/go/crecordio/test/test.c @@ -12,44 +12,43 @@ void fail() { int main() { writer w = create_recordio_writer("/tmp/test_recordio_0"); - write_recordio(w, "hello", 6); - write_recordio(w, "hi", 3); - release_recordio(w); + recordio_write(w, "hello", 6); + recordio_write(w, "hi", 3); + release_recordio_writer(w); w = create_recordio_writer("/tmp/test_recordio_1"); - write_recordio(w, "dog", 4); - write_recordio(w, "cat", 4); - release_recordio(w); + recordio_write(w, "dog", 4); + recordio_write(w, "cat", 4); + release_recordio_writer(w); reader r = create_recordio_reader("/tmp/test_recordio_*"); - int size; - unsigned char* item = read_next_item(r, &size); + unsigned char* item = NULL; + int size = recordio_read(r, &item); if (strcmp(item, "hello") || size != 6) { fail(); } - free(item); - item = read_next_item(r, &size); + size = recordio_read(r, &item); if (strcmp(item, "hi") || size != 3) { fail(); } free(item); - item = read_next_item(r, &size); + size = recordio_read(r, &item); if (strcmp(item, "dog") || size != 4) { fail(); } free(item); - item = read_next_item(r, &size); + size = recordio_read(r, &item); if (strcmp(item, "cat") || size != 4) { fail(); } free(item); - item = read_next_item(r, &size); - if (item != NULL || size != -1) { + size = recordio_read(r, &item); + if (size != -1) { fail(); } diff --git a/paddle/go/recordio/reader.go b/paddle/go/recordio/range_scanner.go similarity index 88% rename from paddle/go/recordio/reader.go rename to paddle/go/recordio/range_scanner.go index d00aef7ca991e79bf5c53f97ebe2b0da0b45386e..46e2eee68c7b7fc6bb1b69f60a75fd85cfe85576 100644 --- a/paddle/go/recordio/reader.go +++ b/paddle/go/recordio/range_scanner.go @@ -74,8 +74,8 @@ func (r *Index) Locate(recordIndex int) (int, int) { return -1, -1 } -// Scanner scans records in a specified range within [0, numRecords). -type Scanner struct { +// RangeScanner scans records in a specified range within [0, numRecords). +type RangeScanner struct { reader io.ReadSeeker index *Index start, end, cur int @@ -84,10 +84,10 @@ type Scanner struct { err error } -// NewScanner creates a scanner that sequencially reads records in the +// NewRangeScanner creates a scanner that sequencially reads records in the // range [start, start+len). If start < 0, it scans from the // beginning. If len < 0, it scans till the end of file. -func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { +func NewRangeScanner(r io.ReadSeeker, index *Index, start, len int) *RangeScanner { if start < 0 { start = 0 } @@ -95,7 +95,7 @@ func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { len = index.NumRecords() - start } - return &Scanner{ + return &RangeScanner{ reader: r, index: index, start: start, @@ -108,7 +108,7 @@ func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { // Scan moves the cursor forward for one record and loads the chunk // containing the record if not yet. -func (s *Scanner) Scan() bool { +func (s *RangeScanner) Scan() bool { s.cur++ if s.cur >= s.end { @@ -124,14 +124,14 @@ func (s *Scanner) Scan() bool { } // Record returns the record under the current cursor. -func (s *Scanner) Record() []byte { +func (s *RangeScanner) Record() []byte { _, ri := s.index.Locate(s.cur) return s.chunk.records[ri] } // Err returns the first non-EOF error that was encountered by the // Scanner. -func (s *Scanner) Err() error { +func (s *RangeScanner) Err() error { if s.err == io.EOF { return nil } diff --git a/paddle/go/recordio/recordio_internal_test.go b/paddle/go/recordio/recordio_internal_test.go index e0f7dd0407caaf38e8113660239d1a0c6eb8afa1..30e317925d8c95e64a42bd8ac5a1dd43b95ee81d 100644 --- a/paddle/go/recordio/recordio_internal_test.go +++ b/paddle/go/recordio/recordio_internal_test.go @@ -68,7 +68,7 @@ func TestWriteAndRead(t *testing.T) { 2*4)}, // two record legnths idx.chunkOffsets) - s := NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) + s := NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) i := 0 for s.Scan() { assert.Equal(data[i], string(s.Record())) diff --git a/paddle/go/recordio/recordio_test.go b/paddle/go/recordio/recordio_test.go index 8bf1b020ab75ca66c12b713526e010756c364217..ab117d2050e6ac18ef63021081a684f259299803 100644 --- a/paddle/go/recordio/recordio_test.go +++ b/paddle/go/recordio/recordio_test.go @@ -29,7 +29,7 @@ func TestWriteRead(t *testing.T) { t.Fatal("num record does not match:", idx.NumRecords(), total) } - s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) + s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) i := 0 for s.Scan() { if !reflect.DeepEqual(s.Record(), make([]byte, i)) { @@ -66,7 +66,7 @@ func TestChunkIndex(t *testing.T) { for i := 0; i < total; i++ { newIdx := idx.ChunkIndex(i) - s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1) + s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1) j := 0 for s.Scan() { if !reflect.DeepEqual(s.Record(), make([]byte, i)) { diff --git a/paddle/go/recordio/multi_reader.go b/paddle/go/recordio/scanner.go similarity index 77% rename from paddle/go/recordio/multi_reader.go rename to paddle/go/recordio/scanner.go index 07e28342118d866595902f17bd5dd76d0a91e1d9..865228ff651c6eee2cf1fa05ec38a4964394b6dc 100644 --- a/paddle/go/recordio/multi_reader.go +++ b/paddle/go/recordio/scanner.go @@ -6,18 +6,18 @@ import ( "path/filepath" ) -// MultiScanner is a scanner for multiple recordio files. -type MultiScanner struct { +// Scanner is a scanner for multiple recordio files. +type Scanner struct { paths []string curFile *os.File - curScanner *Scanner + curScanner *RangeScanner pathIdx int end bool err error } -// NewMultiScanner creates a new MultiScanner. -func NewMultiScanner(paths []string) (*MultiScanner, error) { +// NewScanner creates a new Scanner. +func NewScanner(paths ...string) (*Scanner, error) { var ps []string for _, s := range paths { match, err := filepath.Glob(s) @@ -32,12 +32,12 @@ func NewMultiScanner(paths []string) (*MultiScanner, error) { return nil, fmt.Errorf("no valid path provided: %v", paths) } - return &MultiScanner{paths: ps}, nil + return &Scanner{paths: ps}, nil } // Scan moves the cursor forward for one record and loads the chunk // containing the record if not yet. -func (s *MultiScanner) Scan() bool { +func (s *Scanner) Scan() bool { if s.err != nil { return false } @@ -92,12 +92,12 @@ func (s *MultiScanner) Scan() bool { // Err returns the first non-EOF error that was encountered by the // Scanner. -func (s *MultiScanner) Err() error { +func (s *Scanner) Err() error { return s.err } // Record returns the record under the current cursor. -func (s *MultiScanner) Record() []byte { +func (s *Scanner) Record() []byte { if s.curScanner == nil { return nil } @@ -106,7 +106,7 @@ func (s *MultiScanner) Record() []byte { } // Close release the resources. -func (s *MultiScanner) Close() error { +func (s *Scanner) Close() error { s.curScanner = nil if s.curFile != nil { err := s.curFile.Close() @@ -116,7 +116,7 @@ func (s *MultiScanner) Close() error { return nil } -func (s *MultiScanner) nextFile() (bool, error) { +func (s *Scanner) nextFile() (bool, error) { if s.pathIdx >= len(s.paths) { return false, nil } @@ -135,6 +135,6 @@ func (s *MultiScanner) nextFile() (bool, error) { } s.curFile = f - s.curScanner = NewScanner(f, idx, 0, -1) + s.curScanner = NewRangeScanner(f, idx, 0, -1) return true, nil }