From 0fa409246b98c636d4dd32553782ca962f70a6f7 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 29 Jun 2017 09:43:00 +0000 Subject: [PATCH] fix bugs --- go/master/c/client.go | 18 ++++++++++++++++-- go/master/client.go | 21 +++++++++++++++------ go/master/client_test.go | 18 ++++++++++++++---- python/paddle/v2/reader/creator.py | 6 ++---- 4 files changed, 47 insertions(+), 16 deletions(-) diff --git a/go/master/c/client.go b/go/master/c/client.go index b88911b858..79e13e4b63 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -13,6 +13,7 @@ typedef int paddle_master_client; import "C" import ( + "io" "sync" "unsafe" @@ -84,14 +85,27 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int return C.PADDLE_MASTER_OK } +// return value: +// 0:ok +// -1:EOF +// -2:error //export paddle_next_record func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { c := get(client) - r := c.NextRecord() - if r == nil { + r, err := c.NextRecord() + if err == io.EOF { // EOF + *record = (*C.uchar)(nullPtr) return -1 } + + if err != nil { + // Error + // TODO: return the type of error? + *record = (*C.uchar)(nullPtr) + return -2 + } + if len(r) == 0 { // Empty record *record = (*C.uchar)(nullPtr) diff --git a/go/master/client.go b/go/master/client.go index fa479338c5..c122d17c8f 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -1,6 +1,7 @@ package master import ( + "io" "os" "time" @@ -17,7 +18,12 @@ type Addresser interface { // Client is the client of the master server. type Client struct { conn *connection.Conn - ch chan []byte + ch chan record +} + +type record struct { + r []byte + err error } // NewClient creates a new Client. @@ -27,7 +33,7 @@ type Client struct { func NewClient(addr Addresser, bufSize int) *Client { c := &Client{} c.conn = connection.New() - c.ch = make(chan []byte, bufSize) + c.ch = make(chan record, bufSize) go c.monitorMaster(addr) go c.getRecords() return c @@ -52,18 +58,20 @@ func (c *Client) getRecords() { s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1) for s.Scan() { - c.ch <- s.Record() + c.ch <- record{s.Record(), nil} } if s.Err() != nil { + c.ch <- record{nil, s.Err()} log.Errorln(err, chunk.Path) } err = f.Close() - c.ch <- nil if err != nil { log.Errorln(err) } + + c.ch <- record{nil, io.EOF} } // We treat a task as finished whenever the last data @@ -133,6 +141,7 @@ func (c *Client) taskFinished(taskID int) error { // // NextRecord will block until the next record is available. It is // thread-safe. -func (c *Client) NextRecord() []byte { - return <-c.ch +func (c *Client) NextRecord() ([]byte, error) { + r := <-c.ch + return r.r, r.err } diff --git a/go/master/client_test.go b/go/master/client_test.go index 85a86761c2..05201941e3 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -2,6 +2,7 @@ package master_test import ( "fmt" + "io" "net" "net/http" "net/rpc" @@ -69,13 +70,22 @@ func TestNextRecord(t *testing.T) { for pass := 0; pass < 50; pass++ { received := make(map[byte]bool) - for i := 0; i < total; i++ { - r := c.NextRecord() + for i := 0; i <= total; i++ { + r, err := c.NextRecord() + if err == io.EOF { + break + } + + if err != nil { + t.Fatal(pass, i, "Read error:", err) + } + if len(r) != 1 { - t.Fatal("Length should be 1.", r) + t.Fatal(pass, i, "Length should be 1.", r) } + if received[r[0]] { - t.Fatal("Received duplicate.", received, r) + t.Fatal(pass, i, "Received duplicate.", received, r) } received[r[0]] = true } diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 3376d7accb..b575f57dc6 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -79,7 +79,6 @@ def recordio_local(paths): return reader - def recordio(paths, addr="", buf_size=100): """ Creates a data reader that outputs record one one by one @@ -90,8 +89,8 @@ def recordio(paths, addr="", buf_size=100): import os import paddle.v2.master.client as cloud - if len(os.environ["KUBERNETES_SERVICE_HOST"]) == 0: - return recordio_local(path) + if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): + return recordio_local(paths) def reader(): c = cloud(addr, buf_size) @@ -106,4 +105,3 @@ def recordio(paths, addr="", buf_size=100): c.close() return reader - -- GitLab