提交 8e8f3601 编写于 作者: G gongweibao 提交者: GitHub

Merge pull request #2665 from gongweibao/cloudandlocal

Recordio cloud and local interface
...@@ -104,11 +104,22 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -104,11 +104,22 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK return C.PADDLE_MASTER_OK
} }
// return value:
// 0:ok
// -1:error
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
r := c.NextRecord() r, err := c.NextRecord()
if err != nil {
// Error
// TODO: return the type of error?
*record = (*C.uchar)(nullPtr)
return -1
}
if len(r) == 0 { if len(r) == 0 {
// Empty record
*record = (*C.uchar)(nullPtr) *record = (*C.uchar)(nullPtr)
return 0 return 0
} }
......
...@@ -11,7 +11,12 @@ import ( ...@@ -11,7 +11,12 @@ import (
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan []byte ch chan record
}
type record struct {
r []byte
err error
} }
// NewClient creates a new Client. // NewClient creates a new Client.
...@@ -21,7 +26,7 @@ type Client struct { ...@@ -21,7 +26,7 @@ type Client struct {
func NewClient(addrCh <-chan string, bufSize int) *Client { func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan []byte, bufSize) c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh) go c.monitorMaster(addrCh)
go c.getRecords() go c.getRecords()
return c return c
...@@ -46,10 +51,11 @@ func (c *Client) getRecords() { ...@@ -46,10 +51,11 @@ func (c *Client) getRecords() {
s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1) s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() { for s.Scan() {
c.ch <- s.Record() c.ch <- record{s.Record(), nil}
} }
if s.Err() != nil { if s.Err() != nil {
c.ch <- record{nil, s.Err()}
log.Errorln(err, chunk.Path) log.Errorln(err, chunk.Path)
} }
...@@ -116,6 +122,7 @@ func (c *Client) taskFinished(taskID int) error { ...@@ -116,6 +122,7 @@ func (c *Client) taskFinished(taskID int) error {
// //
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
// thread-safe. // thread-safe.
func (c *Client) NextRecord() []byte { func (c *Client) NextRecord() ([]byte, error) {
return <-c.ch r := <-c.ch
return r.r, r.err
} }
...@@ -68,12 +68,17 @@ func TestNextRecord(t *testing.T) { ...@@ -68,12 +68,17 @@ func TestNextRecord(t *testing.T) {
for pass := 0; pass < 50; pass++ { for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool) received := make(map[byte]bool)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
r := c.NextRecord() r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
}
if len(r) != 1 { if len(r) != 1 {
t.Fatal("Length should be 1.", r) t.Fatal(pass, i, "Length should be 1.", r)
} }
if received[r[0]] { if received[r[0]] {
t.Fatal("Received duplicate.", received, r) t.Fatal(pass, i, "Received duplicate.", received, r)
} }
received[r[0]] = true received[r[0]] = true
} }
......
...@@ -26,14 +26,22 @@ class client(object): ...@@ -26,14 +26,22 @@ class client(object):
holder[idx] = c_ptr holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths)) lib.paddle_set_dataset(self.c, holder, len(paths))
# return format: (record, errno)
# errno = 0: ok
# < 0: error
def next_record(self): def next_record(self):
p = ctypes.c_char_p() p = ctypes.c_char_p()
ret = ctypes.pointer(p) ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret) size = lib.paddle_next_record(self.c, ret)
if size < 0:
# Error
return None, size
if size == 0: if size == 0:
# Empty record # Empty record
return "" return "", 0
record = ret.contents.value[:size] record = ret.contents.value[:size]
# Memory created from C should be freed. # Memory created from C should be freed.
lib.mem_free(ret.contents) lib.mem_free(ret.contents)
return record return record, 0
...@@ -57,17 +57,20 @@ def text_file(path): ...@@ -57,17 +57,20 @@ def text_file(path):
return reader return reader
def recordio(path): def recordio_local(paths, buf_size=100):
""" """
Creates a data reader that outputs record one one by one from given recordio file Creates a data reader from given RecordIO file paths separated by ",",
:path: path of recordio file glob pattern is supported.
:returns: data reader of recordio file :path: path of recordio files.
:returns: data reader of recordio files.
""" """
import recordio as rec import recordio as rec
import paddle.v2.reader.decorator as dec
def reader(): def reader():
f = rec.reader(path) a = ','.join(paths)
f = rec.reader(a)
while True: while True:
r = f.read() r = f.read()
if r is None: if r is None:
...@@ -75,4 +78,38 @@ def recordio(path): ...@@ -75,4 +78,38 @@ def recordio(path):
yield r yield r
f.close() f.close()
return dec.buffered(reader, buf_size)
def recordio(paths, buf_size=100):
"""
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
:path: path of recordio files.
:returns: data reader of recordio files.
"""
import os
import paddle.v2.master.client as cloud
if "KUBERNETES_SERVICE_HOST" not in os.environ.keys():
return recordio_local(paths)
host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environ.')
addr = os.environ(host)
def reader():
c = cloud(addr, buf_size)
c.set_dataset(paths)
while True:
r, err = client.next_record()
if err < 0:
break
yield r
c.close()
return reader return reader
...@@ -38,7 +38,7 @@ class TestRecordIO(unittest.TestCase): ...@@ -38,7 +38,7 @@ class TestRecordIO(unittest.TestCase):
def test_recordio(self): def test_recordio(self):
path = os.path.join( path = os.path.join(
os.path.dirname(__file__), "test_recordio_creator.dat") os.path.dirname(__file__), "test_recordio_creator.dat")
reader = paddle.v2.reader.creator.recordio(path) reader = paddle.v2.reader.creator.recordio([path])
for idx, r in enumerate(reader()): for idx, r in enumerate(reader()):
self.assertSequenceEqual(r, str(idx)) self.assertSequenceEqual(r, str(idx))
......
...@@ -29,7 +29,7 @@ setup(name='paddle', ...@@ -29,7 +29,7 @@ setup(name='paddle',
description='Parallel Distributed Deep Learning', description='Parallel Distributed Deep Learning',
install_requires=setup_requires, install_requires=setup_requires,
packages=packages, packages=packages,
package_data={'paddle.v2.master': ['${paddle_master_LIB_NAME}'], }, package_data={'paddle.v2.master': ['libpaddle_master.so'], },
package_dir={ package_dir={
'': '${CMAKE_CURRENT_SOURCE_DIR}', '': '${CMAKE_CURRENT_SOURCE_DIR}',
# The paddle.v2.framework.proto will be generated while compiling. # The paddle.v2.framework.proto will be generated while compiling.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册