提交 26e661bc 编写于 作者: G gongweibao

fix by helin's comments

上级 af5ac2c4
......@@ -106,7 +106,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// return value:
// 0:ok
// -2:error
// -1:error
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
......@@ -115,7 +115,7 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
// Error
// TODO: return the type of error?
*record = (*C.uchar)(nullPtr)
return -2
return -1
}
if len(r) == 0 {
......
......@@ -28,12 +28,12 @@ class client(object):
# return format: (record, errno)
# errno = 0: ok
# < -1: error
# < 0: error
def next_record(self):
p = ctypes.c_char_p()
ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret)
if size < -1:
if size < 0:
# Error
return None, size
......
......@@ -57,19 +57,20 @@ def text_file(path):
return reader
def recordio_local(paths):
def recordio_local(paths, buf_size=100):
"""
Creates a data reader that outputs record one one by one
from given local recordio fils path.
Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported.
:path: path of recordio files.
:returns: data reader of recordio files.
"""
import recordio as rec
import paddle.v2.reader.decorator as dec
def reader():
for i, path in enumerate(paths):
f = rec.reader(path)
a = ','.join(paths)
f = rec.reader(a)
while True:
r = f.read()
if r is None:
......@@ -77,9 +78,10 @@ def recordio_local(paths):
yield r
f.close()
return reader
return dec.buffered(reader, buf_size)
def recordio(paths, addr="", buf_size=100):
def recordio(paths, buf_size=100):
"""
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
......@@ -92,6 +94,12 @@ def recordio(paths, addr="", buf_size=100):
if "KUBERNETES_SERVICE_HOST" not in os.environ.keys():
return recordio_local(paths)
host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environ.')
addr = os.environ(host)
def reader():
c = cloud(addr, buf_size)
c.set_dataset(paths)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册