提交 633171c2 编写于 作者: H Helin Wang

fix according to comments

上级 0e80dadf
#include "libclient.h" #include <stdio.h>
//#include "gtest/gtest.h" #include "libclient.h"
void panic() { void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this // TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now. // hacky way for now.
*(void*)0; printf("test failed.\n");
exit(-1);
} }
int main() { int main() {
...@@ -35,7 +36,7 @@ retry: ...@@ -35,7 +36,7 @@ retry:
goto retry; goto retry;
} }
} else { } else {
panic(); fail();
} }
char content[] = {0x00, 0x11, 0x22}; char content[] = {0x00, 0x11, 0x22};
...@@ -44,25 +45,25 @@ retry: ...@@ -44,25 +45,25 @@ retry:
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, grads, 2)) { if (!paddle_send_grads(c, grads, 2)) {
panic(); fail();
} }
paddle_parameter* params[2] = {NULL, NULL}; paddle_parameter* params[2] = {NULL, NULL};
char* names[] = {"param_a", "param_b"}; char* names[] = {"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) { if (!paddle_get_params(c, names, params, 2)) {
panic(); fail();
} }
// get parameters again by reusing the allocated parameter buffers. // get parameters again by reusing the allocated parameter buffers.
if (!paddle_get_params(c, names, params, 2)) { if (!paddle_get_params(c, names, params, 2)) {
panic(); fail();
} }
paddle_release_param(params[0]); paddle_release_param(params[0]);
paddle_release_param(params[1]); paddle_release_param(params[1]);
if (!paddle_save_model(c, "/tmp/")) { if (!paddle_save_model(c, "/tmp/")) {
panic(); fail();
} }
return 0; return 0;
......
...@@ -31,11 +31,9 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE) ...@@ -31,11 +31,9 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE)
# make a symlink that references Paddle inside $GOPATH, so go get # make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle # will use the local changes in Paddle rather than checkout Paddle
# in github. # in github.
if(NOT EXISTS ${PADDLE_IN_GOPATH})
add_custom_target(copyPaddle add_custom_target(copyPaddle
COMMAND ln -s ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH})
add_dependencies(goGet copyPaddle) add_dependencies(goGet copyPaddle)
endif()
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
......
...@@ -9,10 +9,8 @@ typedef int writer; ...@@ -9,10 +9,8 @@ typedef int writer;
import "C" import "C"
import ( import (
"io"
"log" "log"
"os" "os"
"path/filepath"
"strings" "strings"
"unsafe" "unsafe"
...@@ -27,84 +25,24 @@ type writer struct { ...@@ -27,84 +25,24 @@ type writer struct {
} }
type reader struct { type reader struct {
buffer chan []byte scanner *recordio.MultiScanner
cancel chan struct{}
} }
func read(paths []string, buffer chan<- []byte, cancel chan struct{}) { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
var curFile *os.File if p == nullPtr {
var curScanner *recordio.Scanner return nil
var pathIdx int
var nextFile func() bool
nextFile = func() bool {
if pathIdx >= len(paths) {
return false
}
path := paths[pathIdx]
pathIdx++
f, err := os.Open(path)
if err != nil {
return nextFile()
}
idx, err := recordio.LoadIndex(f)
if err != nil {
log.Println(err)
err = f.Close()
if err != nil {
log.Println(err)
}
return nextFile()
}
curFile = f
curScanner = recordio.NewScanner(f, idx, 0, -1)
return true
}
more := nextFile()
if !more {
close(buffer)
return
}
closeFile := func() {
err := curFile.Close()
if err != nil {
log.Println(err)
}
curFile = nil
}
for {
for curScanner.Scan() {
select {
case buffer <- curScanner.Record():
case <-cancel:
close(buffer)
closeFile()
return
}
}
if err := curScanner.Error(); err != nil && err != io.EOF {
log.Println(err)
} }
closeFile() // create a Go clice backed by a C array, reference:
more := nextFile() // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
if !more { //
close(buffer) // Go garbage collector will not interact with this data, need
return // to be freed properly.
} return (*[1 << 30]byte)(p)[:len:len]
}
} }
//export paddle_new_writer //export create_recordio_writer
func paddle_new_writer(path *C.char) C.writer { func create_recordio_writer(path *C.char) C.writer {
p := C.GoString(path) p := C.GoString(path)
f, err := os.Create(p) f, err := os.Create(p)
if err != nil { if err != nil {
...@@ -117,21 +55,8 @@ func paddle_new_writer(path *C.char) C.writer { ...@@ -117,21 +55,8 @@ func paddle_new_writer(path *C.char) C.writer {
return addWriter(writer) return addWriter(writer)
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { //export write_recordio
if p == nullPtr { func write_recordio(writer C.writer, buf *C.uchar, size C.int) int {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
//export paddle_writer_write
func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int {
w := getWriter(writer) w := getWriter(writer)
b := cArrayToSlice(unsafe.Pointer(buf), int(size)) b := cArrayToSlice(unsafe.Pointer(buf), int(size))
_, err := w.w.Write(b) _, err := w.w.Write(b)
...@@ -143,66 +68,50 @@ func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int { ...@@ -143,66 +68,50 @@ func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int {
return 0 return 0
} }
//export paddle_writer_release //export release_recordio
func paddle_writer_release(writer C.writer) { func release_recordio(writer C.writer) {
w := removeWriter(writer) w := removeWriter(writer)
w.w.Close() w.w.Close()
w.f.Close() w.f.Close()
} }
//export paddle_new_reader //export create_recordio_reader
func paddle_new_reader(path *C.char, bufferSize C.int) C.reader { func create_recordio_reader(path *C.char) C.reader {
p := C.GoString(path) p := C.GoString(path)
ss := strings.Split(p, ",") s, err := recordio.NewMultiScanner(strings.Split(p, ","))
var paths []string
for _, s := range ss {
match, err := filepath.Glob(s)
if err != nil { if err != nil {
log.Printf("error applying glob to %s: %v\n", s, err) log.Println(err)
return -1
}
paths = append(paths, match...)
}
if len(paths) == 0 {
log.Println("no valid path provided.", p)
return -1 return -1
} }
buffer := make(chan []byte, int(bufferSize)) r := &reader{scanner: s}
cancel := make(chan struct{})
r := &reader{buffer: buffer, cancel: cancel}
go read(paths, buffer, cancel)
return addReader(r) return addReader(r)
} }
//export paddle_reader_next_item //export read_next_item
func paddle_reader_next_item(reader C.reader, size *C.int) *C.uchar { func read_next_item(reader C.reader, size *C.int) *C.uchar {
r := getReader(reader) r := getReader(reader)
buf, ok := <-r.buffer if r.scanner.Scan() {
if !ok { buf := r.scanner.Record()
// channel closed and empty, reached EOF. *size = C.int(len(buf))
*size = -1
return (*C.uchar)(nullPtr)
}
if len(buf) == 0 { if len(buf) == 0 {
// empty item
*size = 0
return (*C.uchar)(nullPtr) return (*C.uchar)(nullPtr)
} }
ptr := C.malloc(C.size_t(len(buf))) ptr := C.malloc(C.size_t(len(buf)))
C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf))) C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
*size = C.int(len(buf))
return (*C.uchar)(ptr) return (*C.uchar)(ptr)
}
*size = -1
return (*C.uchar)(nullPtr)
} }
//export paddle_reader_release //export release_recordio_reader
func paddle_reader_release(reader C.reader) { func release_recordio_reader(reader C.reader) {
r := removeReader(reader) r := removeReader(reader)
close(r.cancel) r.scanner.Close()
} }
func main() {} // Required but ignored func main() {} // Required but ignored
...@@ -3,30 +3,55 @@ ...@@ -3,30 +3,55 @@
#include "librecordio.h" #include "librecordio.h"
void panic() { void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this // TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now. // hacky way for now.
*(void*)0; printf("test failed.\n");
exit(-1);
} }
int main() { int main() {
writer w = paddle_new_writer("/tmp/test"); writer w = create_recordio_writer("/tmp/test_recordio_0");
paddle_writer_write(w, "hello", 6); write_recordio(w, "hello", 6);
paddle_writer_write(w, "hi", 3); write_recordio(w, "hi", 3);
paddle_writer_release(w); release_recordio(w);
reader r = paddle_new_reader("/tmp/test", 10); w = create_recordio_writer("/tmp/test_recordio_1");
write_recordio(w, "dog", 4);
write_recordio(w, "cat", 4);
release_recordio(w);
reader r = create_recordio_reader("/tmp/test_recordio_*");
int size; int size;
unsigned char* item = paddle_reader_next_item(r, &size); unsigned char* item = read_next_item(r, &size);
if (!strcmp(item, "hello") || size != 6) { if (strcmp(item, "hello") || size != 6) {
panic(); fail();
}
free(item);
item = read_next_item(r, &size);
if (strcmp(item, "hi") || size != 3) {
fail();
} }
free(item); free(item);
item = paddle_reader_next_item(r, &size); item = read_next_item(r, &size);
if (!strcmp(item, "hi") || size != 2) { if (strcmp(item, "dog") || size != 4) {
panic(); fail();
} }
free(item); free(item);
paddle_reader_release(r);
item = read_next_item(r, &size);
if (strcmp(item, "cat") || size != 4) {
fail();
}
free(item);
item = read_next_item(r, &size);
if (item != NULL || size != -1) {
fail();
}
release_recordio_reader(r);
} }
...@@ -32,7 +32,7 @@ f.Close() ...@@ -32,7 +32,7 @@ f.Close()
for s.Scan() { for s.Scan() {
fmt.Println(string(s.Record())) fmt.Println(string(s.Record()))
} }
if s.Error() != nil && s.Error() != io.EOF { if s.Err() != nil {
log.Fatalf("Something wrong with scanning: %v", e) log.Fatalf("Something wrong with scanning: %v", e)
} }
f.Close() f.Close()
......
package recordio
import (
"fmt"
"os"
"path/filepath"
)
// MultiScanner is a scanner for multiple recordio files.
type MultiScanner struct {
paths []string
curFile *os.File
curScanner *Scanner
pathIdx int
end bool
err error
}
// NewMultiScanner creates a new MultiScanner.
func NewMultiScanner(paths []string) (*MultiScanner, error) {
var ps []string
for _, s := range paths {
match, err := filepath.Glob(s)
if err != nil {
return nil, err
}
ps = append(ps, match...)
}
if len(ps) == 0 {
return nil, fmt.Errorf("no valid path provided: %v", paths)
}
return &MultiScanner{paths: ps}, nil
}
// Scan moves the cursor forward for one record and loads the chunk
// containing the record if not yet.
func (s *MultiScanner) Scan() bool {
if s.err != nil {
return false
}
if s.end {
return false
}
if s.curScanner == nil {
more, err := s.nextFile()
if err != nil {
s.err = err
return false
}
if !more {
s.end = true
return false
}
}
curMore := s.curScanner.Scan()
s.err = s.curScanner.Err()
if s.err != nil {
return curMore
}
if !curMore {
err := s.curFile.Close()
if err != nil {
s.err = err
return false
}
s.curFile = nil
more, err := s.nextFile()
if err != nil {
s.err = err
return false
}
if !more {
s.end = true
return false
}
return s.Scan()
}
return true
}
// Err returns the first non-EOF error that was encountered by the
// Scanner.
func (s *MultiScanner) Err() error {
return s.err
}
// Record returns the record under the current cursor.
func (s *MultiScanner) Record() []byte {
if s.curScanner == nil {
return nil
}
return s.curScanner.Record()
}
// Close release the resources.
func (s *MultiScanner) Close() error {
s.curScanner = nil
if s.curFile != nil {
err := s.curFile.Close()
s.curFile = nil
return err
}
return nil
}
func (s *MultiScanner) nextFile() (bool, error) {
if s.pathIdx >= len(s.paths) {
return false, nil
}
path := s.paths[s.pathIdx]
s.pathIdx++
f, err := os.Open(path)
if err != nil {
return false, err
}
idx, err := LoadIndex(f)
if err != nil {
f.Close()
return false, err
}
s.curFile = f
s.curScanner = NewScanner(f, idx, 0, -1)
return true, nil
}
...@@ -129,7 +129,12 @@ func (s *Scanner) Record() []byte { ...@@ -129,7 +129,12 @@ func (s *Scanner) Record() []byte {
return s.chunk.records[ri] return s.chunk.records[ri]
} }
// Error returns the error that stopped Scan. // Err returns the first non-EOF error that was encountered by the
func (s *Scanner) Error() error { // Scanner.
func (s *Scanner) Err() error {
if s.err == io.EOF {
return nil
}
return s.err return s.err
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册