提交 fca9e884 编写于 作者: F fengjiayi

Update Readers Python API

1. Combine 'open_files', 'multi_pass_reader' and 'threaded_reader'
together to make the new 'open_files' interface.
2. Add some docstring.
3. Simplify interface names of 'create_XXX_reader', e.g, rename
'create_double_buffer_reader' to 'double_buffer'.
上级 6be51f10
......@@ -22,7 +22,7 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader', 'create_multi_pass_reader'
'create_double_buffer_reader'
]
......@@ -283,7 +283,43 @@ def _copy_reader_create_op_(block, op):
return new_op
def open_recordio_file(filename, shapes, lod_levels, dtypes):
def open_recordio_file(filename,
shapes,
lod_levels,
dtypes,
pass_num=1,
for_parallel=False):
"""
Open a RecordIO file
This layer takes a RecordIO file to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from the given RecordIO file.
Args:
filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run. After completing the
given number of passes, 'has_next()' will return False.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get RecordIO file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_recordio_file(
filename='./data.recordio',
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
......@@ -310,6 +346,13 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = for_parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var)
......@@ -318,11 +361,15 @@ def open_files(filenames,
lod_levels,
dtypes,
thread_num,
buffer_size=None):
buffer_size=None,
pass_num=1,
for_parallel=False):
"""
Open files
This layer takes a list of files to read from and returns a Reader Variable. Via the Reader Variable, we can get data from given files.
This layer takes a list of files to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from given files. All files must
have name suffixs to indicate their formats, e.g., '*.recordio'.
Args:
filenames(list): The list of file names.
......@@ -331,6 +378,10 @@ def open_files(filenames,
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run. After completing the
given number of passes, 'has_next()' will return False.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get file data.
......@@ -338,16 +389,16 @@ def open_files(filenames,
Examples:
.. code-block:: python
reader = fluid.layers.open_files(filenames=['./data1.recordio',
reader = fluid.layers.io.open_files(filenames=['./data1.recordio',
'./data2.recordio'],
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
image, label = fluid.layers.io.read_file(reader)
"""
if buffer_size is None:
buffer_size = thread_num
......@@ -361,13 +412,12 @@ def open_files(filenames,
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('multiple_reader')
multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_reader = startup_blk.create_var(name=multi_file_reader_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_var]},
outputs={'Out': [startup_reader]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
......@@ -377,14 +427,21 @@ def open_files(filenames,
'buffer_size': buffer_size
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
return monkey_patch_reader_methods(main_prog_var)
startup_reader.desc.set_dtypes(dtypes)
startup_reader.persistable = True
main_prog_reader = _copy_reader_var_(default_main_program().current_block(),
startup_reader)
if pass_num > 1:
main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = for_parallel(reader=main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs):
return monkey_patch_reader_methods(main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs={}):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
......@@ -400,12 +457,12 @@ def __create_decorated_reader__(op_type, reader, attrs):
return monkey_patch_reader_methods(main_prog_var)
def create_shuffle_reader(reader, buffer_size):
def shuffle(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)})
def create_double_buffer_reader(reader, place=None):
def double_buffer(reader, place=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
......@@ -413,11 +470,15 @@ def create_double_buffer_reader(reader, place=None):
attrs)
def create_multi_pass_reader(reader, pass_num):
def multi_pass(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)})
def for_parallel(reader):
return __create_decorated_reader__('create_threaded_reader', reader)
def read_file(file_obj):
helper = LayerHelper('read_file')
out = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册