From 7ca852803142a679a3d231dee0401dace8c7f63b Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Fri, 13 Jul 2018 16:36:13 +0800 Subject: [PATCH] Refine pyreader demo --- python/paddle/fluid/layers/io.py | 15 +++++++++++++-- python/paddle/fluid/tests/demo/pyreader.py | 8 +++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 149b33334a8..a3c287e2839 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -447,7 +447,12 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): return monkey_patch_reader_methods(main_prog_var) -def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): +def py_reader(capacity, + shapes, + dtypes, + lod_levels=None, + name=None, + use_double_buffer=True): """ Create a reader and blocking queue for data feeding in Python @@ -460,6 +465,7 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): using `close()` method when unused. Args: + use_double_buffer(bool): Whether use double buffer or not. capacity(int): The maximum capacity of the BlockingQueue. shapes(list|tuple): List of tuples which declaring data shapes. dtypes(list|tuple): List of strs which declaring data type. @@ -509,9 +515,11 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): if name is None: queue_name = unique_name('lod_tensor_blocking_queue') reader_name = unique_name('create_py_reader') + double_buffer_name = unique_name('double_buffer') else: queue_name = "_".join([name, "queue"]) reader_name = "_".join([name, "reader"]) + double_buffer_name = "_".join([name, "double_buffer"]) var = global_scope().var(queue_name) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) @@ -534,7 +542,10 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): main_prog_var = _copy_reader_var_(default_main_program().current_block(), startup_var) - return monkey_patch_reader_methods(main_prog_var), feed_queue + reader = monkey_patch_reader_methods(main_prog_var) + if use_double_buffer: + reader = double_buffer(reader, name=double_buffer_name) + return reader, feed_queue def open_files(filenames, diff --git a/python/paddle/fluid/tests/demo/pyreader.py b/python/paddle/fluid/tests/demo/pyreader.py index cc459a8f761..e4df9f749c6 100644 --- a/python/paddle/fluid/tests/demo/pyreader.py +++ b/python/paddle/fluid/tests/demo/pyreader.py @@ -26,7 +26,7 @@ def network(is_train): shapes=((-1, 784), (-1, 1)), dtypes=('float32', 'int64'), name="train_reader" if is_train else "test_reader") - img, label = fluid.layers.read_file(fluid.layers.double_buffer(reader)) + img, label = fluid.layers.read_file(reader) hidden = img @@ -100,7 +100,7 @@ def main(): trainer.run(fetch_list=[loss.name])) except fluid.core.EOFException: print 'End of epoch', epoch_id - train_reader.reset() + # train_reader.reset() train_data_thread.join() test_data_thread = pipe_reader_to_queue( @@ -111,10 +111,12 @@ def main(): tester.run(fetch_list=[test_loss.name])) except fluid.core.EOFException: print 'End of testing' - test_reader.reset() + # test_reader.reset() test_data_thread.join() break + del trainer + del tester if __name__ == '__main__': -- GitLab