From 61b3a5977f6f061d7d1b237e6d4832b7763b8c63 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Sat, 14 Jul 2018 21:31:50 +0800 Subject: [PATCH] Refine Python Reader --- .../operators/reader/create_py_reader_op.cc | 2 ++ python/paddle/fluid/layers/io.py | 26 ++++++++++++------- python/paddle/fluid/tests/demo/pyreader.py | 6 +++-- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 833776f56e..0f31ca1a94 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -33,6 +33,8 @@ class PyReader : public framework::FileReader { if (!success) out->clear(); } + ~PyReader() { queue_->Close(); } + void Shutdown() override { queue_->Close(); } void Start() override { queue_->ReOpen(); } diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 8a82f8e05f..0bf9f46cf7 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -558,6 +558,7 @@ def py_reader(capacity, current_reset_method = reader.reset reader.thread = None reader.tensor_provider = None + reader.exited = False def start_provide_thread(func): def __provider_thread__(): @@ -571,17 +572,20 @@ def py_reader(capacity, array.append(item) + if reader.exited: + break feed_queue.push(array) + if reader.exited: + break feed_queue.close() reader.thread = threading.Thread(target=__provider_thread__) reader.thread.start() def __set_tensor_provider__(func): - reader._tensor_provider = func - start_provide_thread(reader._tensor_provider) + reader.tensor_provider = func - def __set_paddle_reader__(reader): + def __set_paddle_reader__(paddle_reader): with program_guard(Program(), Program()): feed_list = [] counter = 0 @@ -596,25 +600,29 @@ def py_reader(capacity, counter += 1 feeder = DataFeeder(feed_list=feed_list, place=core.CPUPlace()) - - reader = feeder.decorate_reader(reader, multi_devices=False) + paddle_reader = feeder.decorate_reader( + paddle_reader, multi_devices=False) def __tensor_provider__(): - for data in reader(): - yield [data[str(idx)] for idx in xrange(counter)] + for slots in paddle_reader(): + yield [slots[str(idx)] for idx in xrange(counter)] __set_tensor_provider__(__tensor_provider__) def __reset__(): current_reset_method() if reader.thread is not None and reader.tensor_provider is not None: + reader.exited = True reader.thread.join() - # restart provider thread. - start_provide_thread(reader.tensor_provider) + reader.exited = False + + def __start__(): + start_provide_thread(reader.tensor_provider) reader.reset = __reset__ reader.decorate_tensor_provider = __set_tensor_provider__ reader.decorate_paddle_reader = __set_paddle_reader__ + reader.start = __start__ return reader diff --git a/python/paddle/fluid/tests/demo/pyreader.py b/python/paddle/fluid/tests/demo/pyreader.py index 3185df07db..3a7dbf8106 100644 --- a/python/paddle/fluid/tests/demo/pyreader.py +++ b/python/paddle/fluid/tests/demo/pyreader.py @@ -67,11 +67,12 @@ def main(): train_reader.decorate_paddle_reader( paddle.v2.reader.shuffle( - paddle.batch(mnist.train(), 256), buf_size=8192)) + paddle.batch(mnist.train(), 512), buf_size=8192)) - test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 256)) + test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512)) for epoch_id in xrange(10): + train_reader.start() try: while True: print 'train_loss', numpy.array( @@ -80,6 +81,7 @@ def main(): print 'End of epoch', epoch_id train_reader.reset() + test_reader.start() try: while True: print 'test loss', numpy.array( -- GitLab