未验证 提交 7ca85280 编写于 作者: Y yuyang18

Refine pyreader demo

上级 e8eb81ca
......@@ -447,7 +447,12 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
return monkey_patch_reader_methods(main_prog_var)
def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
def py_reader(capacity,
shapes,
dtypes,
lod_levels=None,
name=None,
use_double_buffer=True):
"""
Create a reader and blocking queue for data feeding in Python
......@@ -460,6 +465,7 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
using `close()` method when unused.
Args:
use_double_buffer(bool): Whether use double buffer or not.
capacity(int): The maximum capacity of the BlockingQueue.
shapes(list|tuple): List of tuples which declaring data shapes.
dtypes(list|tuple): List of strs which declaring data type.
......@@ -509,9 +515,11 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
if name is None:
queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_py_reader')
double_buffer_name = unique_name('double_buffer')
else:
queue_name = "_".join([name, "queue"])
reader_name = "_".join([name, "reader"])
double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
......@@ -534,7 +542,10 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
return monkey_patch_reader_methods(main_prog_var), feed_queue
reader = monkey_patch_reader_methods(main_prog_var)
if use_double_buffer:
reader = double_buffer(reader, name=double_buffer_name)
return reader, feed_queue
def open_files(filenames,
......
......@@ -26,7 +26,7 @@ def network(is_train):
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader")
img, label = fluid.layers.read_file(fluid.layers.double_buffer(reader))
img, label = fluid.layers.read_file(reader)
hidden = img
......@@ -100,7 +100,7 @@ def main():
trainer.run(fetch_list=[loss.name]))
except fluid.core.EOFException:
print 'End of epoch', epoch_id
train_reader.reset()
# train_reader.reset()
train_data_thread.join()
test_data_thread = pipe_reader_to_queue(
......@@ -111,10 +111,12 @@ def main():
tester.run(fetch_list=[test_loss.name]))
except fluid.core.EOFException:
print 'End of testing'
test_reader.reset()
# test_reader.reset()
test_data_thread.join()
break
del trainer
del tester
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册