提交 ec9d4d52 编写于 作者: Y Yancey 提交者: GitHub

Add start_record interface (#3128)

* add start_record interface

* call master client in reader

* update

* add demo code in comments

* update comments

* delete unittest for recordio reader
上级 aaf8401f
......@@ -3,24 +3,11 @@ import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master
import os
import cPickle as pickle
from paddle.v2.reader.creator import cloud_reader
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379"
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
def cloud_reader():
global master_client
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30)
while 1:
r, e = master_client.next_record()
if not r:
if e != -2: # other errors
print "get record error:", e
break
yield pickle.loads(r)
etcd_endpoints = "http://" + etcd_ip + ":2379"
print "etcd endpoints: ", etcd_endpoints
def main():
......@@ -49,7 +36,7 @@ def main():
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec=etcd_endpoint,
pserver_spec=etcd_endpoints,
use_etcd=True)
# event_handler to print training and testing info
......@@ -75,7 +62,11 @@ def main():
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
cloud_reader, buf_size=500), batch_size=2),
cloud_reader(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"],
etcd_endpoints),
buf_size=500),
batch_size=2),
feeding={'x': 0,
'y': 1},
event_handler=event_handler,
......
......@@ -76,3 +76,6 @@ class client(object):
# Memory created from C should be freed.
get_c_lib().mem_free(ret.contents)
return record, 0
def paddle_start_get_records(self, pass_id):
get_c_lib().paddle_start_get_records(self.c, pass_id)
......@@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could
be used in user program.
"""
__all__ = ['np_array', 'text_file', "recordio"]
__all__ = ['np_array', 'text_file', "cloud_reader"]
def np_array(x):
......@@ -81,35 +81,41 @@ def recordio_local(paths, buf_size=100):
return dec.buffered(reader, buf_size)
def recordio(paths, buf_size=100):
pass_num = 0
def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64):
"""
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
Create a data reader that yield a record one by one from
the paths:
:path: path of recordio files.
:etcd_endpoints: the endpoints for etcd cluster
:returns: data reader of recordio files.
.. code-block:: python
from paddle.v2.reader.creator import cloud_reader
etcd_endpoints = "http://127.0.0.1:2379"
trainer.train.(
reader=cloud_reader(["/work/dataset/uci_housing/uci_housing*"], etcd_endpoints),
)
"""
import os
import paddle.v2.master.client as cloud
if "KUBERNETES_SERVICE_HOST" not in os.environ.keys():
return recordio_local(paths)
host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environment variable.')
addr = os.environ(host)
import cPickle as pickle
import paddle.v2.master as master
c = master.client(etcd_endpoints, timeout_sec, buf_size)
c.set_dataset(paths)
def reader():
c = cloud(addr, buf_size)
c.set_dataset(paths)
global pass_num
c.paddle_start_get_records(pass_num)
pass_num += 1
while True:
r, err = client.next_record()
if err < 0:
r, e = c.next_record()
if not r:
if e != -2:
print "get record error: ", e
break
yield r
c.release()
yield pickle.loads(r)
return reader
......@@ -34,14 +34,5 @@ class TestTextFile(unittest.TestCase):
self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1))
class TestRecordIO(unittest.TestCase):
def test_recordio(self):
path = os.path.join(
os.path.dirname(__file__), "test_recordio_creator.dat")
reader = paddle.v2.reader.creator.recordio([path])
for idx, r in enumerate(reader()):
self.assertSequenceEqual(r, str(idx))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册