提交 633171c2 编写于 作者: H Helin Wang

fix according to comments

上级 0e80dadf
#include "libclient.h"
#include <stdio.h>
//#include "gtest/gtest.h"
#include "libclient.h"
void panic() {
void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
*(void*)0;
printf("test failed.\n");
exit(-1);
}
int main() {
......@@ -35,7 +36,7 @@ retry:
goto retry;
}
} else {
panic();
fail();
}
char content[] = {0x00, 0x11, 0x22};
......@@ -44,25 +45,25 @@ retry:
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, grads, 2)) {
panic();
fail();
}
paddle_parameter* params[2] = {NULL, NULL};
char* names[] = {"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) {
panic();
fail();
}
// get parameters again by reusing the allocated parameter buffers.
if (!paddle_get_params(c, names, params, 2)) {
panic();
fail();
}
paddle_release_param(params[0]);
paddle_release_param(params[1]);
if (!paddle_save_model(c, "/tmp/")) {
panic();
fail();
}
return 0;
......
......@@ -31,11 +31,9 @@ function(ADD_GO_LIBRARY NAME BUILD_TYPE)
# make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle
# in github.
if(NOT EXISTS ${PADDLE_IN_GOPATH})
add_custom_target(copyPaddle
COMMAND ln -s ${PADDLE_DIR} ${PADDLE_IN_GOPATH})
add_dependencies(goGet copyPaddle)
endif()
add_custom_target(copyPaddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH})
add_dependencies(goGet copyPaddle)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
......
......@@ -9,10 +9,8 @@ typedef int writer;
import "C"
import (
"io"
"log"
"os"
"path/filepath"
"strings"
"unsafe"
......@@ -27,84 +25,24 @@ type writer struct {
}
type reader struct {
buffer chan []byte
cancel chan struct{}
scanner *recordio.MultiScanner
}
func read(paths []string, buffer chan<- []byte, cancel chan struct{}) {
var curFile *os.File
var curScanner *recordio.Scanner
var pathIdx int
var nextFile func() bool
nextFile = func() bool {
if pathIdx >= len(paths) {
return false
}
path := paths[pathIdx]
pathIdx++
f, err := os.Open(path)
if err != nil {
return nextFile()
}
idx, err := recordio.LoadIndex(f)
if err != nil {
log.Println(err)
err = f.Close()
if err != nil {
log.Println(err)
}
return nextFile()
}
curFile = f
curScanner = recordio.NewScanner(f, idx, 0, -1)
return true
}
more := nextFile()
if !more {
close(buffer)
return
}
closeFile := func() {
err := curFile.Close()
if err != nil {
log.Println(err)
}
curFile = nil
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
for {
for curScanner.Scan() {
select {
case buffer <- curScanner.Record():
case <-cancel:
close(buffer)
closeFile()
return
}
}
if err := curScanner.Error(); err != nil && err != io.EOF {
log.Println(err)
}
closeFile()
more := nextFile()
if !more {
close(buffer)
return
}
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
//export paddle_new_writer
func paddle_new_writer(path *C.char) C.writer {
//export create_recordio_writer
func create_recordio_writer(path *C.char) C.writer {
p := C.GoString(path)
f, err := os.Create(p)
if err != nil {
......@@ -117,21 +55,8 @@ func paddle_new_writer(path *C.char) C.writer {
return addWriter(writer)
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
//export paddle_writer_write
func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int {
//export write_recordio
func write_recordio(writer C.writer, buf *C.uchar, size C.int) int {
w := getWriter(writer)
b := cArrayToSlice(unsafe.Pointer(buf), int(size))
_, err := w.w.Write(b)
......@@ -143,66 +68,50 @@ func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int {
return 0
}
//export paddle_writer_release
func paddle_writer_release(writer C.writer) {
//export release_recordio
func release_recordio(writer C.writer) {
w := removeWriter(writer)
w.w.Close()
w.f.Close()
}
//export paddle_new_reader
func paddle_new_reader(path *C.char, bufferSize C.int) C.reader {
//export create_recordio_reader
func create_recordio_reader(path *C.char) C.reader {
p := C.GoString(path)
ss := strings.Split(p, ",")
var paths []string
for _, s := range ss {
match, err := filepath.Glob(s)
if err != nil {
log.Printf("error applying glob to %s: %v\n", s, err)
return -1
}
paths = append(paths, match...)
}
if len(paths) == 0 {
log.Println("no valid path provided.", p)
s, err := recordio.NewMultiScanner(strings.Split(p, ","))
if err != nil {
log.Println(err)
return -1
}
buffer := make(chan []byte, int(bufferSize))
cancel := make(chan struct{})
r := &reader{buffer: buffer, cancel: cancel}
go read(paths, buffer, cancel)
r := &reader{scanner: s}
return addReader(r)
}
//export paddle_reader_next_item
func paddle_reader_next_item(reader C.reader, size *C.int) *C.uchar {
//export read_next_item
func read_next_item(reader C.reader, size *C.int) *C.uchar {
r := getReader(reader)
buf, ok := <-r.buffer
if !ok {
// channel closed and empty, reached EOF.
*size = -1
return (*C.uchar)(nullPtr)
}
if r.scanner.Scan() {
buf := r.scanner.Record()
*size = C.int(len(buf))
if len(buf) == 0 {
return (*C.uchar)(nullPtr)
}
if len(buf) == 0 {
// empty item
*size = 0
return (*C.uchar)(nullPtr)
ptr := C.malloc(C.size_t(len(buf)))
C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
return (*C.uchar)(ptr)
}
ptr := C.malloc(C.size_t(len(buf)))
C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
*size = C.int(len(buf))
return (*C.uchar)(ptr)
*size = -1
return (*C.uchar)(nullPtr)
}
//export paddle_reader_release
func paddle_reader_release(reader C.reader) {
//export release_recordio_reader
func release_recordio_reader(reader C.reader) {
r := removeReader(reader)
close(r.cancel)
r.scanner.Close()
}
func main() {} // Required but ignored
......@@ -3,30 +3,55 @@
#include "librecordio.h"
void panic() {
void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
*(void*)0;
printf("test failed.\n");
exit(-1);
}
int main() {
writer w = paddle_new_writer("/tmp/test");
paddle_writer_write(w, "hello", 6);
paddle_writer_write(w, "hi", 3);
paddle_writer_release(w);
writer w = create_recordio_writer("/tmp/test_recordio_0");
write_recordio(w, "hello", 6);
write_recordio(w, "hi", 3);
release_recordio(w);
reader r = paddle_new_reader("/tmp/test", 10);
w = create_recordio_writer("/tmp/test_recordio_1");
write_recordio(w, "dog", 4);
write_recordio(w, "cat", 4);
release_recordio(w);
reader r = create_recordio_reader("/tmp/test_recordio_*");
int size;
unsigned char* item = paddle_reader_next_item(r, &size);
if (!strcmp(item, "hello") || size != 6) {
panic();
unsigned char* item = read_next_item(r, &size);
if (strcmp(item, "hello") || size != 6) {
fail();
}
free(item);
item = read_next_item(r, &size);
if (strcmp(item, "hi") || size != 3) {
fail();
}
free(item);
item = paddle_reader_next_item(r, &size);
if (!strcmp(item, "hi") || size != 2) {
panic();
item = read_next_item(r, &size);
if (strcmp(item, "dog") || size != 4) {
fail();
}
free(item);
paddle_reader_release(r);
item = read_next_item(r, &size);
if (strcmp(item, "cat") || size != 4) {
fail();
}
free(item);
item = read_next_item(r, &size);
if (item != NULL || size != -1) {
fail();
}
release_recordio_reader(r);
}
......@@ -32,7 +32,7 @@ f.Close()
for s.Scan() {
fmt.Println(string(s.Record()))
}
if s.Error() != nil && s.Error() != io.EOF {
if s.Err() != nil {
log.Fatalf("Something wrong with scanning: %v", e)
}
f.Close()
......
package recordio
import (
"fmt"
"os"
"path/filepath"
)
// MultiScanner is a scanner for multiple recordio files.
type MultiScanner struct {
paths []string
curFile *os.File
curScanner *Scanner
pathIdx int
end bool
err error
}
// NewMultiScanner creates a new MultiScanner.
func NewMultiScanner(paths []string) (*MultiScanner, error) {
var ps []string
for _, s := range paths {
match, err := filepath.Glob(s)
if err != nil {
return nil, err
}
ps = append(ps, match...)
}
if len(ps) == 0 {
return nil, fmt.Errorf("no valid path provided: %v", paths)
}
return &MultiScanner{paths: ps}, nil
}
// Scan moves the cursor forward for one record and loads the chunk
// containing the record if not yet.
func (s *MultiScanner) Scan() bool {
if s.err != nil {
return false
}
if s.end {
return false
}
if s.curScanner == nil {
more, err := s.nextFile()
if err != nil {
s.err = err
return false
}
if !more {
s.end = true
return false
}
}
curMore := s.curScanner.Scan()
s.err = s.curScanner.Err()
if s.err != nil {
return curMore
}
if !curMore {
err := s.curFile.Close()
if err != nil {
s.err = err
return false
}
s.curFile = nil
more, err := s.nextFile()
if err != nil {
s.err = err
return false
}
if !more {
s.end = true
return false
}
return s.Scan()
}
return true
}
// Err returns the first non-EOF error that was encountered by the
// Scanner.
func (s *MultiScanner) Err() error {
return s.err
}
// Record returns the record under the current cursor.
func (s *MultiScanner) Record() []byte {
if s.curScanner == nil {
return nil
}
return s.curScanner.Record()
}
// Close release the resources.
func (s *MultiScanner) Close() error {
s.curScanner = nil
if s.curFile != nil {
err := s.curFile.Close()
s.curFile = nil
return err
}
return nil
}
func (s *MultiScanner) nextFile() (bool, error) {
if s.pathIdx >= len(s.paths) {
return false, nil
}
path := s.paths[s.pathIdx]
s.pathIdx++
f, err := os.Open(path)
if err != nil {
return false, err
}
idx, err := LoadIndex(f)
if err != nil {
f.Close()
return false, err
}
s.curFile = f
s.curScanner = NewScanner(f, idx, 0, -1)
return true, nil
}
......@@ -129,7 +129,12 @@ func (s *Scanner) Record() []byte {
return s.chunk.records[ri]
}
// Error returns the error that stopped Scan.
func (s *Scanner) Error() error {
// Err returns the first non-EOF error that was encountered by the
// Scanner.
func (s *Scanner) Err() error {
if s.err == io.EOF {
return nil
}
return s.err
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册