提交 0fa40924 编写于 作者: G gongweibao

fix bugs

上级 4874810b
...@@ -13,6 +13,7 @@ typedef int paddle_master_client; ...@@ -13,6 +13,7 @@ typedef int paddle_master_client;
import "C" import "C"
import ( import (
"io"
"sync" "sync"
"unsafe" "unsafe"
...@@ -84,14 +85,27 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -84,14 +85,27 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK return C.PADDLE_MASTER_OK
} }
// return value:
// 0:ok
// -1:EOF
// -2:error
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
r := c.NextRecord() r, err := c.NextRecord()
if r == nil { if err == io.EOF {
// EOF // EOF
*record = (*C.uchar)(nullPtr)
return -1 return -1
} }
if err != nil {
// Error
// TODO: return the type of error?
*record = (*C.uchar)(nullPtr)
return -2
}
if len(r) == 0 { if len(r) == 0 {
// Empty record // Empty record
*record = (*C.uchar)(nullPtr) *record = (*C.uchar)(nullPtr)
......
package master package master
import ( import (
"io"
"os" "os"
"time" "time"
...@@ -17,7 +18,12 @@ type Addresser interface { ...@@ -17,7 +18,12 @@ type Addresser interface {
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan []byte ch chan record
}
type record struct {
r []byte
err error
} }
// NewClient creates a new Client. // NewClient creates a new Client.
...@@ -27,7 +33,7 @@ type Client struct { ...@@ -27,7 +33,7 @@ type Client struct {
func NewClient(addr Addresser, bufSize int) *Client { func NewClient(addr Addresser, bufSize int) *Client {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan []byte, bufSize) c.ch = make(chan record, bufSize)
go c.monitorMaster(addr) go c.monitorMaster(addr)
go c.getRecords() go c.getRecords()
return c return c
...@@ -52,18 +58,20 @@ func (c *Client) getRecords() { ...@@ -52,18 +58,20 @@ func (c *Client) getRecords() {
s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1) s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() { for s.Scan() {
c.ch <- s.Record() c.ch <- record{s.Record(), nil}
} }
if s.Err() != nil { if s.Err() != nil {
c.ch <- record{nil, s.Err()}
log.Errorln(err, chunk.Path) log.Errorln(err, chunk.Path)
} }
err = f.Close() err = f.Close()
c.ch <- nil
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
} }
c.ch <- record{nil, io.EOF}
} }
// We treat a task as finished whenever the last data // We treat a task as finished whenever the last data
...@@ -133,6 +141,7 @@ func (c *Client) taskFinished(taskID int) error { ...@@ -133,6 +141,7 @@ func (c *Client) taskFinished(taskID int) error {
// //
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
// thread-safe. // thread-safe.
func (c *Client) NextRecord() []byte { func (c *Client) NextRecord() ([]byte, error) {
return <-c.ch r := <-c.ch
return r.r, r.err
} }
...@@ -2,6 +2,7 @@ package master_test ...@@ -2,6 +2,7 @@ package master_test
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
...@@ -69,13 +70,22 @@ func TestNextRecord(t *testing.T) { ...@@ -69,13 +70,22 @@ func TestNextRecord(t *testing.T) {
for pass := 0; pass < 50; pass++ { for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool) received := make(map[byte]bool)
for i := 0; i < total; i++ { for i := 0; i <= total; i++ {
r := c.NextRecord() r, err := c.NextRecord()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(pass, i, "Read error:", err)
}
if len(r) != 1 { if len(r) != 1 {
t.Fatal("Length should be 1.", r) t.Fatal(pass, i, "Length should be 1.", r)
} }
if received[r[0]] { if received[r[0]] {
t.Fatal("Received duplicate.", received, r) t.Fatal(pass, i, "Received duplicate.", received, r)
} }
received[r[0]] = true received[r[0]] = true
} }
......
...@@ -79,7 +79,6 @@ def recordio_local(paths): ...@@ -79,7 +79,6 @@ def recordio_local(paths):
return reader return reader
def recordio(paths, addr="", buf_size=100): def recordio(paths, addr="", buf_size=100):
""" """
Creates a data reader that outputs record one one by one Creates a data reader that outputs record one one by one
...@@ -90,8 +89,8 @@ def recordio(paths, addr="", buf_size=100): ...@@ -90,8 +89,8 @@ def recordio(paths, addr="", buf_size=100):
import os import os
import paddle.v2.master.client as cloud import paddle.v2.master.client as cloud
if len(os.environ["KUBERNETES_SERVICE_HOST"]) == 0: if "KUBERNETES_SERVICE_HOST" not in os.environ.keys():
return recordio_local(path) return recordio_local(paths)
def reader(): def reader():
c = cloud(addr, buf_size) c = cloud(addr, buf_size)
...@@ -106,4 +105,3 @@ def recordio(paths, addr="", buf_size=100): ...@@ -106,4 +105,3 @@ def recordio(paths, addr="", buf_size=100):
c.close() c.close()
return reader return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册