From 26e661bc51e2fac36c3692d748b7db8a950cb370 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 3 Jul 2017 03:05:36 +0000 Subject: [PATCH] fix by helin's comments --- go/master/c/client.go | 4 ++-- python/paddle/v2/master/client.py | 4 ++-- python/paddle/v2/reader/creator.py | 34 ++++++++++++++++++------------ 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/go/master/c/client.go b/go/master/c/client.go index 635688f196b..31f43119745 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -106,7 +106,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int // return value: // 0:ok -// -2:error +// -1:error //export paddle_next_record func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { c := get(client) @@ -115,7 +115,7 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { // Error // TODO: return the type of error? *record = (*C.uchar)(nullPtr) - return -2 + return -1 } if len(r) == 0 { diff --git a/python/paddle/v2/master/client.py b/python/paddle/v2/master/client.py index 6ddb09e4e89..70f9e43c968 100644 --- a/python/paddle/v2/master/client.py +++ b/python/paddle/v2/master/client.py @@ -28,12 +28,12 @@ class client(object): # return format: (record, errno) # errno = 0: ok - # < -1: error + # < 0: error def next_record(self): p = ctypes.c_char_p() ret = ctypes.pointer(p) size = lib.paddle_next_record(self.c, ret) - if size < -1: + if size < 0: # Error return None, size diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 2e8626e565c..20624d5286b 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -57,29 +57,31 @@ def text_file(path): return reader -def recordio_local(paths): +def recordio_local(paths, buf_size=100): """ - Creates a data reader that outputs record one one by one - from given local recordio fils path. + Creates a data reader from given RecordIO file paths separated by ",", + glob pattern is supported. :path: path of recordio files. :returns: data reader of recordio files. """ import recordio as rec + import paddle.v2.reader.decorator as dec def reader(): - for i, path in enumerate(paths): - f = rec.reader(path) - while True: - r = f.read() - if r is None: - break - yield r - f.close() + a = ','.join(paths) + f = rec.reader(a) + while True: + r = f.read() + if r is None: + break + yield r + f.close() + + return dec.buffered(reader, buf_size) - return reader -def recordio(paths, addr="", buf_size=100): +def recordio(paths, buf_size=100): """ Creates a data reader that outputs record one one by one from given local or cloud recordio path. @@ -92,6 +94,12 @@ def recordio(paths, addr="", buf_size=100): if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): return recordio_local(paths) + host_name = "MASTER_SERVICE_HOST" + if host_name not in os.environ.keys(): + raise Exception('not find ' + host_name + ' in environ.') + + addr = os.environ(host) + def reader(): c = cloud(addr, buf_size) c.set_dataset(paths) -- GitLab