未验证 提交 61b3a597 编写于 作者: Y yuyang18

Refine Python Reader

上级 c204f0cc
......@@ -33,6 +33,8 @@ class PyReader : public framework::FileReader {
if (!success) out->clear();
}
~PyReader() { queue_->Close(); }
void Shutdown() override { queue_->Close(); }
void Start() override { queue_->ReOpen(); }
......
......@@ -558,6 +558,7 @@ def py_reader(capacity,
current_reset_method = reader.reset
reader.thread = None
reader.tensor_provider = None
reader.exited = False
def start_provide_thread(func):
def __provider_thread__():
......@@ -571,17 +572,20 @@ def py_reader(capacity,
array.append(item)
if reader.exited:
break
feed_queue.push(array)
if reader.exited:
break
feed_queue.close()
reader.thread = threading.Thread(target=__provider_thread__)
reader.thread.start()
def __set_tensor_provider__(func):
reader._tensor_provider = func
start_provide_thread(reader._tensor_provider)
reader.tensor_provider = func
def __set_paddle_reader__(reader):
def __set_paddle_reader__(paddle_reader):
with program_guard(Program(), Program()):
feed_list = []
counter = 0
......@@ -596,25 +600,29 @@ def py_reader(capacity,
counter += 1
feeder = DataFeeder(feed_list=feed_list, place=core.CPUPlace())
reader = feeder.decorate_reader(reader, multi_devices=False)
paddle_reader = feeder.decorate_reader(
paddle_reader, multi_devices=False)
def __tensor_provider__():
for data in reader():
yield [data[str(idx)] for idx in xrange(counter)]
for slots in paddle_reader():
yield [slots[str(idx)] for idx in xrange(counter)]
__set_tensor_provider__(__tensor_provider__)
def __reset__():
current_reset_method()
if reader.thread is not None and reader.tensor_provider is not None:
reader.exited = True
reader.thread.join()
# restart provider thread.
start_provide_thread(reader.tensor_provider)
reader.exited = False
def __start__():
start_provide_thread(reader.tensor_provider)
reader.reset = __reset__
reader.decorate_tensor_provider = __set_tensor_provider__
reader.decorate_paddle_reader = __set_paddle_reader__
reader.start = __start__
return reader
......
......@@ -67,11 +67,12 @@ def main():
train_reader.decorate_paddle_reader(
paddle.v2.reader.shuffle(
paddle.batch(mnist.train(), 256), buf_size=8192))
paddle.batch(mnist.train(), 512), buf_size=8192))
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 256))
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512))
for epoch_id in xrange(10):
train_reader.start()
try:
while True:
print 'train_loss', numpy.array(
......@@ -80,6 +81,7 @@ def main():
print 'End of epoch', epoch_id
train_reader.reset()
test_reader.start()
try:
while True:
print 'test loss', numpy.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册