From 633171c2d3a1f6b8e245844fa4fb254895565da7 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Sat, 27 May 2017 14:49:05 +0000 Subject: [PATCH] fix according to comments --- paddle/go/cclient/test/main.c | 19 ++-- paddle/go/cmake/golang.cmake | 8 +- paddle/go/crecordio/crecordio.go | 169 +++++++---------------------- paddle/go/crecordio/test/test.c | 53 ++++++--- paddle/go/recordio/README.md | 2 +- paddle/go/recordio/multi_reader.go | 140 ++++++++++++++++++++++++ paddle/go/recordio/reader.go | 9 +- 7 files changed, 239 insertions(+), 161 deletions(-) create mode 100644 paddle/go/recordio/multi_reader.go diff --git a/paddle/go/cclient/test/main.c b/paddle/go/cclient/test/main.c index 28e3d03b7a0..abfb32e5603 100644 --- a/paddle/go/cclient/test/main.c +++ b/paddle/go/cclient/test/main.c @@ -1,11 +1,12 @@ -#include "libclient.h" +#include -//#include "gtest/gtest.h" +#include "libclient.h" -void panic() { +void fail() { // TODO(helin): fix: gtest using cmake is not working, using this // hacky way for now. - *(void*)0; + printf("test failed.\n"); + exit(-1); } int main() { @@ -35,7 +36,7 @@ retry: goto retry; } } else { - panic(); + fail(); } char content[] = {0x00, 0x11, 0x22}; @@ -44,25 +45,25 @@ retry: {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; if (!paddle_send_grads(c, grads, 2)) { - panic(); + fail(); } paddle_parameter* params[2] = {NULL, NULL}; char* names[] = {"param_a", "param_b"}; if (!paddle_get_params(c, names, params, 2)) { - panic(); + fail(); } // get parameters again by reusing the allocated parameter buffers. if (!paddle_get_params(c, names, params, 2)) { - panic(); + fail(); } paddle_release_param(params[0]); paddle_release_param(params[1]); if (!paddle_save_model(c, "/tmp/")) { - panic(); + fail(); } return 0; diff --git a/paddle/go/cmake/golang.cmake b/paddle/go/cmake/golang.cmake index caddaae1bf4..0ac17a967bf 100644 --- a/paddle/go/cmake/golang.cmake +++ b/paddle/go/cmake/golang.cmake @@ -31,11 +31,9 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE) # make a symlink that references Paddle inside $GOPATH, so go get # will use the local changes in Paddle rather than checkout Paddle # in github. - if(NOT EXISTS ${PADDLE_IN_GOPATH}) - add_custom_target(copyPaddle - COMMAND ln -s ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) - add_dependencies(goGet copyPaddle) - endif() + add_custom_target(copyPaddle + COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) + add_dependencies(goGet copyPaddle) add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} diff --git a/paddle/go/crecordio/crecordio.go b/paddle/go/crecordio/crecordio.go index cfc15d29a66..e96bb49017b 100644 --- a/paddle/go/crecordio/crecordio.go +++ b/paddle/go/crecordio/crecordio.go @@ -9,10 +9,8 @@ typedef int writer; import "C" import ( - "io" "log" "os" - "path/filepath" "strings" "unsafe" @@ -27,84 +25,24 @@ type writer struct { } type reader struct { - buffer chan []byte - cancel chan struct{} + scanner *recordio.MultiScanner } -func read(paths []string, buffer chan<- []byte, cancel chan struct{}) { - var curFile *os.File - var curScanner *recordio.Scanner - var pathIdx int - - var nextFile func() bool - nextFile = func() bool { - if pathIdx >= len(paths) { - return false - } - - path := paths[pathIdx] - pathIdx++ - f, err := os.Open(path) - if err != nil { - return nextFile() - } - - idx, err := recordio.LoadIndex(f) - if err != nil { - log.Println(err) - err = f.Close() - if err != nil { - log.Println(err) - } - - return nextFile() - } - - curFile = f - curScanner = recordio.NewScanner(f, idx, 0, -1) - return true - } - - more := nextFile() - if !more { - close(buffer) - return - } - - closeFile := func() { - err := curFile.Close() - if err != nil { - log.Println(err) - } - curFile = nil +func cArrayToSlice(p unsafe.Pointer, len int) []byte { + if p == nullPtr { + return nil } - for { - for curScanner.Scan() { - select { - case buffer <- curScanner.Record(): - case <-cancel: - close(buffer) - closeFile() - return - } - } - - if err := curScanner.Error(); err != nil && err != io.EOF { - log.Println(err) - } - - closeFile() - more := nextFile() - if !more { - close(buffer) - return - } - } + // create a Go clice backed by a C array, reference: + // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // + // Go garbage collector will not interact with this data, need + // to be freed properly. + return (*[1 << 30]byte)(p)[:len:len] } -//export paddle_new_writer -func paddle_new_writer(path *C.char) C.writer { +//export create_recordio_writer +func create_recordio_writer(path *C.char) C.writer { p := C.GoString(path) f, err := os.Create(p) if err != nil { @@ -117,21 +55,8 @@ func paddle_new_writer(path *C.char) C.writer { return addWriter(writer) } -func cArrayToSlice(p unsafe.Pointer, len int) []byte { - if p == nullPtr { - return nil - } - - // create a Go clice backed by a C array, reference: - // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices - // - // Go garbage collector will not interact with this data, need - // to be freed properly. - return (*[1 << 30]byte)(p)[:len:len] -} - -//export paddle_writer_write -func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int { +//export write_recordio +func write_recordio(writer C.writer, buf *C.uchar, size C.int) int { w := getWriter(writer) b := cArrayToSlice(unsafe.Pointer(buf), int(size)) _, err := w.w.Write(b) @@ -143,66 +68,50 @@ func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int { return 0 } -//export paddle_writer_release -func paddle_writer_release(writer C.writer) { +//export release_recordio +func release_recordio(writer C.writer) { w := removeWriter(writer) w.w.Close() w.f.Close() } -//export paddle_new_reader -func paddle_new_reader(path *C.char, bufferSize C.int) C.reader { +//export create_recordio_reader +func create_recordio_reader(path *C.char) C.reader { p := C.GoString(path) - ss := strings.Split(p, ",") - var paths []string - for _, s := range ss { - match, err := filepath.Glob(s) - if err != nil { - log.Printf("error applying glob to %s: %v\n", s, err) - return -1 - } - - paths = append(paths, match...) - } - - if len(paths) == 0 { - log.Println("no valid path provided.", p) + s, err := recordio.NewMultiScanner(strings.Split(p, ",")) + if err != nil { + log.Println(err) return -1 } - buffer := make(chan []byte, int(bufferSize)) - cancel := make(chan struct{}) - r := &reader{buffer: buffer, cancel: cancel} - go read(paths, buffer, cancel) + r := &reader{scanner: s} return addReader(r) } -//export paddle_reader_next_item -func paddle_reader_next_item(reader C.reader, size *C.int) *C.uchar { +//export read_next_item +func read_next_item(reader C.reader, size *C.int) *C.uchar { r := getReader(reader) - buf, ok := <-r.buffer - if !ok { - // channel closed and empty, reached EOF. - *size = -1 - return (*C.uchar)(nullPtr) - } + if r.scanner.Scan() { + buf := r.scanner.Record() + *size = C.int(len(buf)) + + if len(buf) == 0 { + return (*C.uchar)(nullPtr) + } - if len(buf) == 0 { - // empty item - *size = 0 - return (*C.uchar)(nullPtr) + ptr := C.malloc(C.size_t(len(buf))) + C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf))) + return (*C.uchar)(ptr) } - ptr := C.malloc(C.size_t(len(buf))) - C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf))) - *size = C.int(len(buf)) - return (*C.uchar)(ptr) + *size = -1 + return (*C.uchar)(nullPtr) } -//export paddle_reader_release -func paddle_reader_release(reader C.reader) { +//export release_recordio_reader +func release_recordio_reader(reader C.reader) { r := removeReader(reader) - close(r.cancel) + r.scanner.Close() } func main() {} // Required but ignored diff --git a/paddle/go/crecordio/test/test.c b/paddle/go/crecordio/test/test.c index 5461a0911f5..54c3773ee94 100644 --- a/paddle/go/crecordio/test/test.c +++ b/paddle/go/crecordio/test/test.c @@ -3,30 +3,55 @@ #include "librecordio.h" -void panic() { +void fail() { // TODO(helin): fix: gtest using cmake is not working, using this // hacky way for now. - *(void*)0; + printf("test failed.\n"); + exit(-1); } int main() { - writer w = paddle_new_writer("/tmp/test"); - paddle_writer_write(w, "hello", 6); - paddle_writer_write(w, "hi", 3); - paddle_writer_release(w); + writer w = create_recordio_writer("/tmp/test_recordio_0"); + write_recordio(w, "hello", 6); + write_recordio(w, "hi", 3); + release_recordio(w); - reader r = paddle_new_reader("/tmp/test", 10); + w = create_recordio_writer("/tmp/test_recordio_1"); + write_recordio(w, "dog", 4); + write_recordio(w, "cat", 4); + release_recordio(w); + + reader r = create_recordio_reader("/tmp/test_recordio_*"); int size; - unsigned char* item = paddle_reader_next_item(r, &size); - if (!strcmp(item, "hello") || size != 6) { - panic(); + unsigned char* item = read_next_item(r, &size); + if (strcmp(item, "hello") || size != 6) { + fail(); + } + + free(item); + + item = read_next_item(r, &size); + if (strcmp(item, "hi") || size != 3) { + fail(); } free(item); - item = paddle_reader_next_item(r, &size); - if (!strcmp(item, "hi") || size != 2) { - panic(); + item = read_next_item(r, &size); + if (strcmp(item, "dog") || size != 4) { + fail(); } free(item); - paddle_reader_release(r); + + item = read_next_item(r, &size); + if (strcmp(item, "cat") || size != 4) { + fail(); + } + free(item); + + item = read_next_item(r, &size); + if (item != NULL || size != -1) { + fail(); + } + + release_recordio_reader(r); } diff --git a/paddle/go/recordio/README.md b/paddle/go/recordio/README.md index fbf568ceba4..50e7e954764 100644 --- a/paddle/go/recordio/README.md +++ b/paddle/go/recordio/README.md @@ -32,7 +32,7 @@ f.Close() for s.Scan() { fmt.Println(string(s.Record())) } - if s.Error() != nil && s.Error() != io.EOF { + if s.Err() != nil { log.Fatalf("Something wrong with scanning: %v", e) } f.Close() diff --git a/paddle/go/recordio/multi_reader.go b/paddle/go/recordio/multi_reader.go new file mode 100644 index 00000000000..07e28342118 --- /dev/null +++ b/paddle/go/recordio/multi_reader.go @@ -0,0 +1,140 @@ +package recordio + +import ( + "fmt" + "os" + "path/filepath" +) + +// MultiScanner is a scanner for multiple recordio files. +type MultiScanner struct { + paths []string + curFile *os.File + curScanner *Scanner + pathIdx int + end bool + err error +} + +// NewMultiScanner creates a new MultiScanner. +func NewMultiScanner(paths []string) (*MultiScanner, error) { + var ps []string + for _, s := range paths { + match, err := filepath.Glob(s) + if err != nil { + return nil, err + } + + ps = append(ps, match...) + } + + if len(ps) == 0 { + return nil, fmt.Errorf("no valid path provided: %v", paths) + } + + return &MultiScanner{paths: ps}, nil +} + +// Scan moves the cursor forward for one record and loads the chunk +// containing the record if not yet. +func (s *MultiScanner) Scan() bool { + if s.err != nil { + return false + } + + if s.end { + return false + } + + if s.curScanner == nil { + more, err := s.nextFile() + if err != nil { + s.err = err + return false + } + + if !more { + s.end = true + return false + } + } + + curMore := s.curScanner.Scan() + s.err = s.curScanner.Err() + + if s.err != nil { + return curMore + } + + if !curMore { + err := s.curFile.Close() + if err != nil { + s.err = err + return false + } + s.curFile = nil + + more, err := s.nextFile() + if err != nil { + s.err = err + return false + } + + if !more { + s.end = true + return false + } + + return s.Scan() + } + return true +} + +// Err returns the first non-EOF error that was encountered by the +// Scanner. +func (s *MultiScanner) Err() error { + return s.err +} + +// Record returns the record under the current cursor. +func (s *MultiScanner) Record() []byte { + if s.curScanner == nil { + return nil + } + + return s.curScanner.Record() +} + +// Close release the resources. +func (s *MultiScanner) Close() error { + s.curScanner = nil + if s.curFile != nil { + err := s.curFile.Close() + s.curFile = nil + return err + } + return nil +} + +func (s *MultiScanner) nextFile() (bool, error) { + if s.pathIdx >= len(s.paths) { + return false, nil + } + + path := s.paths[s.pathIdx] + s.pathIdx++ + f, err := os.Open(path) + if err != nil { + return false, err + } + + idx, err := LoadIndex(f) + if err != nil { + f.Close() + return false, err + } + + s.curFile = f + s.curScanner = NewScanner(f, idx, 0, -1) + return true, nil +} diff --git a/paddle/go/recordio/reader.go b/paddle/go/recordio/reader.go index a12c604f7b2..d00aef7ca99 100644 --- a/paddle/go/recordio/reader.go +++ b/paddle/go/recordio/reader.go @@ -129,7 +129,12 @@ func (s *Scanner) Record() []byte { return s.chunk.records[ri] } -// Error returns the error that stopped Scan. -func (s *Scanner) Error() error { +// Err returns the first non-EOF error that was encountered by the +// Scanner. +func (s *Scanner) Err() error { + if s.err == io.EOF { + return nil + } + return s.err } -- GitLab