From fc3d03142582dcd673cc97fb3b0239bac59815f4 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 29 Jun 2017 09:38:25 +0800 Subject: [PATCH] first add --- go/master/c/client.go | 5 ++ go/master/client.go | 3 +- python/paddle/v2/master/client.py | 3 ++ python/paddle/v2/reader/creator.py | 49 ++++++++++++++----- python/paddle/v2/reader/tests/creator_test.py | 2 +- 5 files changed, 49 insertions(+), 13 deletions(-) diff --git a/go/master/c/client.go b/go/master/c/client.go index b186474dc3..b88911b858 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -88,7 +88,12 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { c := get(client) r := c.NextRecord() + if r == nil { + // EOF + return -1 + } if len(r) == 0 { + // Empty record *record = (*C.uchar)(nullPtr) return 0 } diff --git a/go/master/client.go b/go/master/client.go index 8451820c19..4f8df5ba66 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -60,6 +60,7 @@ func (c *Client) getRecords() { } err = f.Close() + c.ch <- nil if err != nil { log.Errorln(err) } @@ -112,7 +113,7 @@ func (c *Client) monitorMaster(addr Addresser) { // // SetDataset can be call multiple times from different nodes. But // only the first call will be honored. -func (c *Client) SetDataset(globPaths []string) error { +func (c *Client) SetDataset(globPaths ...string) error { return c.conn.Call("Service.SetDataset", globPaths, nil) } diff --git a/python/paddle/v2/master/client.py b/python/paddle/v2/master/client.py index de8e9bb88e..9fd3ef0860 100644 --- a/python/paddle/v2/master/client.py +++ b/python/paddle/v2/master/client.py @@ -30,6 +30,9 @@ class client(object): p = ctypes.c_char_p() ret = ctypes.pointer(p) size = lib.paddle_next_record(self.c, ret) + if size < 0: + # EOF + return None if size == 0: # Empty record return "" diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 9f888b16d6..669867fd10 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -57,22 +57,49 @@ def text_file(path): return reader -def recordio(path): +def recordio_local(paths): """ - Creates a data reader that outputs record one one by one from given recordio file - :path: path of recordio file - :returns: data reader of recordio file + Creates a data reader that outputs record one one by one + from given local recordio fils path. + :path: path of recordio files. + :returns: data reader of recordio files. """ import recordio as rec def reader(): - f = rec.reader(path) - while True: - r = f.read() - if r is None: - break - yield r - f.close() + for i, path in enumerate(paths): + f = rec.reader(path) + while True: + r = f.read() + if r is None: + break + yield r + f.close() return reader + + +def recordio(paths, addr="", buf_size=100): + """ + Creates a data reader that outputs record one one by one + from given local or cloud recordio path. + :path: path of recordio files. + :returns: data reader of recordio files. + """ + import os + import paddle.v2.master.client as cloud + + if len(os.environ["KUBERNETES_SERVICE_HOST"]) == 0: + return recordio_local(path) + + c = cloud(addr, buf_size) + c.set_dataset(paths) + + while True: + r = client.next_record() + if r is None: + break + yield r + + c.close() diff --git a/python/paddle/v2/reader/tests/creator_test.py b/python/paddle/v2/reader/tests/creator_test.py index ba4f558874..b42d273ecf 100644 --- a/python/paddle/v2/reader/tests/creator_test.py +++ b/python/paddle/v2/reader/tests/creator_test.py @@ -38,7 +38,7 @@ class TestRecordIO(unittest.TestCase): def test_recordio(self): path = os.path.join( os.path.dirname(__file__), "test_recordio_creator.dat") - reader = paddle.v2.reader.creator.recordio(path) + reader = paddle.v2.reader.creator.recordio([path]) for idx, r in enumerate(reader()): self.assertSequenceEqual(r, str(idx)) -- GitLab