提交 26e661bc 编写于 作者: G gongweibao

fix by helin's comments

上级 af5ac2c4
...@@ -106,7 +106,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -106,7 +106,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// return value: // return value:
// 0:ok // 0:ok
// -2:error // -1:error
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
...@@ -115,7 +115,7 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { ...@@ -115,7 +115,7 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
// Error // Error
// TODO: return the type of error? // TODO: return the type of error?
*record = (*C.uchar)(nullPtr) *record = (*C.uchar)(nullPtr)
return -2 return -1
} }
if len(r) == 0 { if len(r) == 0 {
......
...@@ -28,12 +28,12 @@ class client(object): ...@@ -28,12 +28,12 @@ class client(object):
# return format: (record, errno) # return format: (record, errno)
# errno = 0: ok # errno = 0: ok
# < -1: error # < 0: error
def next_record(self): def next_record(self):
p = ctypes.c_char_p() p = ctypes.c_char_p()
ret = ctypes.pointer(p) ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret) size = lib.paddle_next_record(self.c, ret)
if size < -1: if size < 0:
# Error # Error
return None, size return None, size
......
...@@ -57,19 +57,20 @@ def text_file(path): ...@@ -57,19 +57,20 @@ def text_file(path):
return reader return reader
def recordio_local(paths): def recordio_local(paths, buf_size=100):
""" """
Creates a data reader that outputs record one one by one Creates a data reader from given RecordIO file paths separated by ",",
from given local recordio fils path. glob pattern is supported.
:path: path of recordio files. :path: path of recordio files.
:returns: data reader of recordio files. :returns: data reader of recordio files.
""" """
import recordio as rec import recordio as rec
import paddle.v2.reader.decorator as dec
def reader(): def reader():
for i, path in enumerate(paths): a = ','.join(paths)
f = rec.reader(path) f = rec.reader(a)
while True: while True:
r = f.read() r = f.read()
if r is None: if r is None:
...@@ -77,9 +78,10 @@ def recordio_local(paths): ...@@ -77,9 +78,10 @@ def recordio_local(paths):
yield r yield r
f.close() f.close()
return reader return dec.buffered(reader, buf_size)
def recordio(paths, addr="", buf_size=100): def recordio(paths, buf_size=100):
""" """
Creates a data reader that outputs record one one by one Creates a data reader that outputs record one one by one
from given local or cloud recordio path. from given local or cloud recordio path.
...@@ -92,6 +94,12 @@ def recordio(paths, addr="", buf_size=100): ...@@ -92,6 +94,12 @@ def recordio(paths, addr="", buf_size=100):
if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): if "KUBERNETES_SERVICE_HOST" not in os.environ.keys():
return recordio_local(paths) return recordio_local(paths)
host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environ.')
addr = os.environ(host)
def reader(): def reader():
c = cloud(addr, buf_size) c = cloud(addr, buf_size)
c.set_dataset(paths) c.set_dataset(paths)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册