提交 0fa40924 编写于 作者: G gongweibao

fix bugs

上级 4874810b
......@@ -13,6 +13,7 @@ typedef int paddle_master_client;
import "C"
import (
"io"
"sync"
"unsafe"
......@@ -84,14 +85,27 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK
}
// return value:
// 0:ok
// -1:EOF
// -2:error
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r := c.NextRecord()
if r == nil {
r, err := c.NextRecord()
if err == io.EOF {
// EOF
*record = (*C.uchar)(nullPtr)
return -1
}
if err != nil {
// Error
// TODO: return the type of error?
*record = (*C.uchar)(nullPtr)
return -2
}
if len(r) == 0 {
// Empty record
*record = (*C.uchar)(nullPtr)
......
package master
import (
"io"
"os"
"time"
......@@ -17,7 +18,12 @@ type Addresser interface {
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan []byte
ch chan record
}
type record struct {
r []byte
err error
}
// NewClient creates a new Client.
......@@ -27,7 +33,7 @@ type Client struct {
func NewClient(addr Addresser, bufSize int) *Client {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan []byte, bufSize)
c.ch = make(chan record, bufSize)
go c.monitorMaster(addr)
go c.getRecords()
return c
......@@ -52,18 +58,20 @@ func (c *Client) getRecords() {
s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() {
c.ch <- s.Record()
c.ch <- record{s.Record(), nil}
}
if s.Err() != nil {
c.ch <- record{nil, s.Err()}
log.Errorln(err, chunk.Path)
}
err = f.Close()
c.ch <- nil
if err != nil {
log.Errorln(err)
}
c.ch <- record{nil, io.EOF}
}
// We treat a task as finished whenever the last data
......@@ -133,6 +141,7 @@ func (c *Client) taskFinished(taskID int) error {
//
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() []byte {
return <-c.ch
func (c *Client) NextRecord() ([]byte, error) {
r := <-c.ch
return r.r, r.err
}
......@@ -2,6 +2,7 @@ package master_test
import (
"fmt"
"io"
"net"
"net/http"
"net/rpc"
......@@ -69,13 +70,22 @@ func TestNextRecord(t *testing.T) {
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r := c.NextRecord()
for i := 0; i <= total; i++ {
r, err := c.NextRecord()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(pass, i, "Read error:", err)
}
if len(r) != 1 {
t.Fatal("Length should be 1.", r)
t.Fatal(pass, i, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal("Received duplicate.", received, r)
t.Fatal(pass, i, "Received duplicate.", received, r)
}
received[r[0]] = true
}
......
......@@ -79,7 +79,6 @@ def recordio_local(paths):
return reader
def recordio(paths, addr="", buf_size=100):
"""
Creates a data reader that outputs record one one by one
......@@ -90,8 +89,8 @@ def recordio(paths, addr="", buf_size=100):
import os
import paddle.v2.master.client as cloud
if len(os.environ["KUBERNETES_SERVICE_HOST"]) == 0:
return recordio_local(path)
if "KUBERNETES_SERVICE_HOST" not in os.environ.keys():
return recordio_local(paths)
def reader():
c = cloud(addr, buf_size)
......@@ -106,4 +105,3 @@ def recordio(paths, addr="", buf_size=100):
c.close()
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册