提交 88f6eef4 编写于 作者: S sneaxiy

Merge branch 'complete_py_reader_python' of https://github.com/sneaxiy/Paddle...

Merge branch 'complete_py_reader_python' of https://github.com/sneaxiy/Paddle into complete_py_reader_python
......@@ -446,7 +446,7 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
return monkey_patch_reader_methods(main_prog_var)
def py_reader(capacity, shapes, lod_levels, dtypes):
def py_reader(capacity, shapes, dtypes, lod_levels=None):
"""
Create a reader and blocking queue for data feeding in Python
......@@ -461,8 +461,8 @@ def py_reader(capacity, shapes, lod_levels, dtypes):
Args:
capacity(int): The maximum capacity of the BlockingQueue.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
lod_levels(list): List of ints which declaring data lod_level.
Returns:
tuple(Variable, BlockingQueue):
......@@ -477,7 +477,6 @@ def py_reader(capacity, shapes, lod_levels, dtypes):
reader, queue = fluid.layers.py_reader(
capacity=10,
shapes=[[-1,3,224,224], [-1,1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
......@@ -501,6 +500,9 @@ def py_reader(capacity, shapes, lod_levels, dtypes):
shape_concat.extend(shape)
ranks.append(len(shape))
if lod_levels is None:
lod_levels = [0] * len(shapes)
queue_name = unique_name('lod_tensor_blocking_queue')
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册