diff --git a/go/master/c/client.go b/go/master/c/client.go index 9e35e986002c0ae3b7593150ece96dba29a1521b..31f431197454c2ec6a25624d37b60876d99f0087 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -104,11 +104,22 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int return C.PADDLE_MASTER_OK } +// return value: +// 0:ok +// -1:error //export paddle_next_record func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { c := get(client) - r := c.NextRecord() + r, err := c.NextRecord() + if err != nil { + // Error + // TODO: return the type of error? + *record = (*C.uchar)(nullPtr) + return -1 + } + if len(r) == 0 { + // Empty record *record = (*C.uchar)(nullPtr) return 0 } diff --git a/go/master/client.go b/go/master/client.go index d3bea49d0a8166420e83478076cc7bc81e48598d..05383f1bf40c0e2b1f974e802ee8fd6aac109b00 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -11,7 +11,12 @@ import ( // Client is the client of the master server. type Client struct { conn *connection.Conn - ch chan []byte + ch chan record +} + +type record struct { + r []byte + err error } // NewClient creates a new Client. @@ -21,7 +26,7 @@ type Client struct { func NewClient(addrCh <-chan string, bufSize int) *Client { c := &Client{} c.conn = connection.New() - c.ch = make(chan []byte, bufSize) + c.ch = make(chan record, bufSize) go c.monitorMaster(addrCh) go c.getRecords() return c @@ -46,10 +51,11 @@ func (c *Client) getRecords() { s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1) for s.Scan() { - c.ch <- s.Record() + c.ch <- record{s.Record(), nil} } if s.Err() != nil { + c.ch <- record{nil, s.Err()} log.Errorln(err, chunk.Path) } @@ -116,6 +122,7 @@ func (c *Client) taskFinished(taskID int) error { // // NextRecord will block until the next record is available. It is // thread-safe. -func (c *Client) NextRecord() []byte { - return <-c.ch +func (c *Client) NextRecord() ([]byte, error) { + r := <-c.ch + return r.r, r.err } diff --git a/go/master/client_test.go b/go/master/client_test.go index c00aeebfd5d1fef6de4a8c67bf7f998a42ee863b..6666d3860c412daa8711fbfa2d729a261b3fd887 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -68,12 +68,17 @@ func TestNextRecord(t *testing.T) { for pass := 0; pass < 50; pass++ { received := make(map[byte]bool) for i := 0; i < total; i++ { - r := c.NextRecord() + r, err := c.NextRecord() + if err != nil { + t.Fatal(pass, i, "Read error:", err) + } + if len(r) != 1 { - t.Fatal("Length should be 1.", r) + t.Fatal(pass, i, "Length should be 1.", r) } + if received[r[0]] { - t.Fatal("Received duplicate.", received, r) + t.Fatal(pass, i, "Received duplicate.", received, r) } received[r[0]] = true } diff --git a/python/paddle/v2/master/client.py b/python/paddle/v2/master/client.py index de8e9bb88e1064e41a80e0ef7838e307089a1331..70f9e43c9683033233d48a750668771a4c7ba045 100644 --- a/python/paddle/v2/master/client.py +++ b/python/paddle/v2/master/client.py @@ -26,14 +26,22 @@ class client(object): holder[idx] = c_ptr lib.paddle_set_dataset(self.c, holder, len(paths)) + # return format: (record, errno) + # errno = 0: ok + # < 0: error def next_record(self): p = ctypes.c_char_p() ret = ctypes.pointer(p) size = lib.paddle_next_record(self.c, ret) + if size < 0: + # Error + return None, size + if size == 0: # Empty record - return "" + return "", 0 + record = ret.contents.value[:size] # Memory created from C should be freed. lib.mem_free(ret.contents) - return record + return record, 0 diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 9f888b16d6b2fbf457ee4f4fe94fcb51b6f37fc9..61b5cc134fba875955bdbfddc2bb1e083241940d 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -57,17 +57,20 @@ def text_file(path): return reader -def recordio(path): +def recordio_local(paths, buf_size=100): """ - Creates a data reader that outputs record one one by one from given recordio file - :path: path of recordio file - :returns: data reader of recordio file + Creates a data reader from given RecordIO file paths separated by ",", + glob pattern is supported. + :path: path of recordio files. + :returns: data reader of recordio files. """ import recordio as rec + import paddle.v2.reader.decorator as dec def reader(): - f = rec.reader(path) + a = ','.join(paths) + f = rec.reader(a) while True: r = f.read() if r is None: @@ -75,4 +78,38 @@ def recordio(path): yield r f.close() + return dec.buffered(reader, buf_size) + + +def recordio(paths, buf_size=100): + """ + Creates a data reader that outputs record one one by one + from given local or cloud recordio path. + :path: path of recordio files. + :returns: data reader of recordio files. + """ + import os + import paddle.v2.master.client as cloud + + if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): + return recordio_local(paths) + + host_name = "MASTER_SERVICE_HOST" + if host_name not in os.environ.keys(): + raise Exception('not find ' + host_name + ' in environ.') + + addr = os.environ(host) + + def reader(): + c = cloud(addr, buf_size) + c.set_dataset(paths) + + while True: + r, err = client.next_record() + if err < 0: + break + yield r + + c.close() + return reader diff --git a/python/paddle/v2/reader/tests/creator_test.py b/python/paddle/v2/reader/tests/creator_test.py index ba4f558874a0155d276fcb0e0d2d9258f0903f0e..b42d273ecfe6c4bc5706ec52617960b83496d70d 100644 --- a/python/paddle/v2/reader/tests/creator_test.py +++ b/python/paddle/v2/reader/tests/creator_test.py @@ -38,7 +38,7 @@ class TestRecordIO(unittest.TestCase): def test_recordio(self): path = os.path.join( os.path.dirname(__file__), "test_recordio_creator.dat") - reader = paddle.v2.reader.creator.recordio(path) + reader = paddle.v2.reader.creator.recordio([path]) for idx, r in enumerate(reader()): self.assertSequenceEqual(r, str(idx)) diff --git a/python/setup.py.in b/python/setup.py.in index 78423614a6b4777daa7a1b5ef11f1df985065600..eeffbfe80e3b4b5473b06f7372addd9870034e77 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -29,7 +29,7 @@ setup(name='paddle', description='Parallel Distributed Deep Learning', install_requires=setup_requires, packages=packages, - package_data={'paddle.v2.master': ['${paddle_master_LIB_NAME}'], }, + package_data={'paddle.v2.master': ['libpaddle_master.so'], }, package_dir={ '': '${CMAKE_CURRENT_SOURCE_DIR}', # The paddle.v2.framework.proto will be generated while compiling.