提交 ec9d4d52 编写于 作者: Y Yancey 提交者: GitHub

Add start_record interface (#3128)

* add start_record interface

* call master client in reader

* update

* add demo code in comments

* update comments

* delete unittest for recordio reader
上级 aaf8401f
...@@ -3,24 +3,11 @@ import paddle.v2.dataset.uci_housing as uci_housing ...@@ -3,24 +3,11 @@ import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master import paddle.v2.master as master
import os import os
import cPickle as pickle import cPickle as pickle
from paddle.v2.reader.creator import cloud_reader
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379" etcd_endpoints = "http://" + etcd_ip + ":2379"
print "connecting to master, etcd endpoints: ", etcd_endpoint print "etcd endpoints: ", etcd_endpoints
master_client = master.client(etcd_endpoint, 5, 64)
def cloud_reader():
global master_client
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30)
while 1:
r, e = master_client.next_record()
if not r:
if e != -2: # other errors
print "get record error:", e
break
yield pickle.loads(r)
def main(): def main():
...@@ -49,7 +36,7 @@ def main(): ...@@ -49,7 +36,7 @@ def main():
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
is_local=False, is_local=False,
pserver_spec=etcd_endpoint, pserver_spec=etcd_endpoints,
use_etcd=True) use_etcd=True)
# event_handler to print training and testing info # event_handler to print training and testing info
...@@ -75,7 +62,11 @@ def main(): ...@@ -75,7 +62,11 @@ def main():
trainer.train( trainer.train(
reader=paddle.batch( reader=paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
cloud_reader, buf_size=500), batch_size=2), cloud_reader(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"],
etcd_endpoints),
buf_size=500),
batch_size=2),
feeding={'x': 0, feeding={'x': 0,
'y': 1}, 'y': 1},
event_handler=event_handler, event_handler=event_handler,
......
...@@ -76,3 +76,6 @@ class client(object): ...@@ -76,3 +76,6 @@ class client(object):
# Memory created from C should be freed. # Memory created from C should be freed.
get_c_lib().mem_free(ret.contents) get_c_lib().mem_free(ret.contents)
return record, 0 return record, 0
def paddle_start_get_records(self, pass_id):
get_c_lib().paddle_start_get_records(self.c, pass_id)
...@@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could ...@@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could
be used in user program. be used in user program.
""" """
__all__ = ['np_array', 'text_file', "recordio"] __all__ = ['np_array', 'text_file', "cloud_reader"]
def np_array(x): def np_array(x):
...@@ -81,35 +81,41 @@ def recordio_local(paths, buf_size=100): ...@@ -81,35 +81,41 @@ def recordio_local(paths, buf_size=100):
return dec.buffered(reader, buf_size) return dec.buffered(reader, buf_size)
def recordio(paths, buf_size=100): pass_num = 0
def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64):
""" """
Creates a data reader that outputs record one one by one Create a data reader that yield a record one by one from
from given local or cloud recordio path. the paths:
:path: path of recordio files. :path: path of recordio files.
:etcd_endpoints: the endpoints for etcd cluster
:returns: data reader of recordio files. :returns: data reader of recordio files.
.. code-block:: python
from paddle.v2.reader.creator import cloud_reader
etcd_endpoints = "http://127.0.0.1:2379"
trainer.train.(
reader=cloud_reader(["/work/dataset/uci_housing/uci_housing*"], etcd_endpoints),
)
""" """
import os import os
import paddle.v2.master.client as cloud import cPickle as pickle
import paddle.v2.master as master
if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): c = master.client(etcd_endpoints, timeout_sec, buf_size)
return recordio_local(paths) c.set_dataset(paths)
host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environment variable.')
addr = os.environ(host)
def reader(): def reader():
c = cloud(addr, buf_size) global pass_num
c.set_dataset(paths) c.paddle_start_get_records(pass_num)
pass_num += 1
while True: while True:
r, err = client.next_record() r, e = c.next_record()
if err < 0: if not r:
if e != -2:
print "get record error: ", e
break break
yield r yield pickle.loads(r)
c.release()
return reader return reader
...@@ -34,14 +34,5 @@ class TestTextFile(unittest.TestCase): ...@@ -34,14 +34,5 @@ class TestTextFile(unittest.TestCase):
self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1))
class TestRecordIO(unittest.TestCase):
def test_recordio(self):
path = os.path.join(
os.path.dirname(__file__), "test_recordio_creator.dat")
reader = paddle.v2.reader.creator.recordio([path])
for idx, r in enumerate(reader()):
self.assertSequenceEqual(r, str(idx))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册