未验证 提交 7ca85280 编写于 作者: Y yuyang18

Refine pyreader demo

上级 e8eb81ca
...@@ -447,7 +447,12 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): ...@@ -447,7 +447,12 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): def py_reader(capacity,
shapes,
dtypes,
lod_levels=None,
name=None,
use_double_buffer=True):
""" """
Create a reader and blocking queue for data feeding in Python Create a reader and blocking queue for data feeding in Python
...@@ -460,6 +465,7 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): ...@@ -460,6 +465,7 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
using `close()` method when unused. using `close()` method when unused.
Args: Args:
use_double_buffer(bool): Whether use double buffer or not.
capacity(int): The maximum capacity of the BlockingQueue. capacity(int): The maximum capacity of the BlockingQueue.
shapes(list|tuple): List of tuples which declaring data shapes. shapes(list|tuple): List of tuples which declaring data shapes.
dtypes(list|tuple): List of strs which declaring data type. dtypes(list|tuple): List of strs which declaring data type.
...@@ -509,9 +515,11 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): ...@@ -509,9 +515,11 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
if name is None: if name is None:
queue_name = unique_name('lod_tensor_blocking_queue') queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_py_reader') reader_name = unique_name('create_py_reader')
double_buffer_name = unique_name('double_buffer')
else: else:
queue_name = "_".join([name, "queue"]) queue_name = "_".join([name, "queue"])
reader_name = "_".join([name, "reader"]) reader_name = "_".join([name, "reader"])
double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
...@@ -534,7 +542,10 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): ...@@ -534,7 +542,10 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
return monkey_patch_reader_methods(main_prog_var), feed_queue reader = monkey_patch_reader_methods(main_prog_var)
if use_double_buffer:
reader = double_buffer(reader, name=double_buffer_name)
return reader, feed_queue
def open_files(filenames, def open_files(filenames,
......
...@@ -26,7 +26,7 @@ def network(is_train): ...@@ -26,7 +26,7 @@ def network(is_train):
shapes=((-1, 784), (-1, 1)), shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'), dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader") name="train_reader" if is_train else "test_reader")
img, label = fluid.layers.read_file(fluid.layers.double_buffer(reader)) img, label = fluid.layers.read_file(reader)
hidden = img hidden = img
...@@ -100,7 +100,7 @@ def main(): ...@@ -100,7 +100,7 @@ def main():
trainer.run(fetch_list=[loss.name])) trainer.run(fetch_list=[loss.name]))
except fluid.core.EOFException: except fluid.core.EOFException:
print 'End of epoch', epoch_id print 'End of epoch', epoch_id
train_reader.reset() # train_reader.reset()
train_data_thread.join() train_data_thread.join()
test_data_thread = pipe_reader_to_queue( test_data_thread = pipe_reader_to_queue(
...@@ -111,10 +111,12 @@ def main(): ...@@ -111,10 +111,12 @@ def main():
tester.run(fetch_list=[test_loss.name])) tester.run(fetch_list=[test_loss.name]))
except fluid.core.EOFException: except fluid.core.EOFException:
print 'End of testing' print 'End of testing'
test_reader.reset() # test_reader.reset()
test_data_thread.join() test_data_thread.join()
break break
del trainer
del tester
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册