提交 fc3d0314 编写于 作者: G gongweibao

first add

上级 1a0fdb9e
...@@ -88,7 +88,12 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -88,7 +88,12 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
r := c.NextRecord() r := c.NextRecord()
if r == nil {
// EOF
return -1
}
if len(r) == 0 { if len(r) == 0 {
// Empty record
*record = (*C.uchar)(nullPtr) *record = (*C.uchar)(nullPtr)
return 0 return 0
} }
......
...@@ -60,6 +60,7 @@ func (c *Client) getRecords() { ...@@ -60,6 +60,7 @@ func (c *Client) getRecords() {
} }
err = f.Close() err = f.Close()
c.ch <- nil
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
} }
...@@ -112,7 +113,7 @@ func (c *Client) monitorMaster(addr Addresser) { ...@@ -112,7 +113,7 @@ func (c *Client) monitorMaster(addr Addresser) {
// //
// SetDataset can be call multiple times from different nodes. But // SetDataset can be call multiple times from different nodes. But
// only the first call will be honored. // only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error { func (c *Client) SetDataset(globPaths ...string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil) return c.conn.Call("Service.SetDataset", globPaths, nil)
} }
......
...@@ -30,6 +30,9 @@ class client(object): ...@@ -30,6 +30,9 @@ class client(object):
p = ctypes.c_char_p() p = ctypes.c_char_p()
ret = ctypes.pointer(p) ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret) size = lib.paddle_next_record(self.c, ret)
if size < 0:
# EOF
return None
if size == 0: if size == 0:
# Empty record # Empty record
return "" return ""
......
...@@ -57,22 +57,49 @@ def text_file(path): ...@@ -57,22 +57,49 @@ def text_file(path):
return reader return reader
def recordio(path): def recordio_local(paths):
""" """
Creates a data reader that outputs record one one by one from given recordio file Creates a data reader that outputs record one one by one
:path: path of recordio file from given local recordio fils path.
:returns: data reader of recordio file :path: path of recordio files.
:returns: data reader of recordio files.
""" """
import recordio as rec import recordio as rec
def reader(): def reader():
f = rec.reader(path) for i, path in enumerate(paths):
while True: f = rec.reader(path)
r = f.read() while True:
if r is None: r = f.read()
break if r is None:
yield r break
f.close() yield r
f.close()
return reader return reader
def recordio(paths, addr="", buf_size=100):
"""
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
:path: path of recordio files.
:returns: data reader of recordio files.
"""
import os
import paddle.v2.master.client as cloud
if len(os.environ["KUBERNETES_SERVICE_HOST"]) == 0:
return recordio_local(path)
c = cloud(addr, buf_size)
c.set_dataset(paths)
while True:
r = client.next_record()
if r is None:
break
yield r
c.close()
...@@ -38,7 +38,7 @@ class TestRecordIO(unittest.TestCase): ...@@ -38,7 +38,7 @@ class TestRecordIO(unittest.TestCase):
def test_recordio(self): def test_recordio(self):
path = os.path.join( path = os.path.join(
os.path.dirname(__file__), "test_recordio_creator.dat") os.path.dirname(__file__), "test_recordio_creator.dat")
reader = paddle.v2.reader.creator.recordio(path) reader = paddle.v2.reader.creator.recordio([path])
for idx, r in enumerate(reader()): for idx, r in enumerate(reader()):
self.assertSequenceEqual(r, str(idx)) self.assertSequenceEqual(r, str(idx))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册