提交 fca9e884 编写于 作者: F fengjiayi

Update Readers Python API

1. Combine 'open_files', 'multi_pass_reader' and 'threaded_reader'
together to make the new 'open_files' interface.
2. Add some docstring.
3. Simplify interface names of 'create_XXX_reader', e.g, rename
'create_double_buffer_reader' to 'double_buffer'.
上级 6be51f10
...@@ -22,7 +22,7 @@ from ..executor import global_scope ...@@ -22,7 +22,7 @@ from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'create_shuffle_reader', 'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader', 'create_multi_pass_reader' 'create_double_buffer_reader'
] ]
...@@ -283,7 +283,43 @@ def _copy_reader_create_op_(block, op): ...@@ -283,7 +283,43 @@ def _copy_reader_create_op_(block, op):
return new_op return new_op
def open_recordio_file(filename, shapes, lod_levels, dtypes): def open_recordio_file(filename,
shapes,
lod_levels,
dtypes,
pass_num=1,
for_parallel=False):
"""
Open a RecordIO file
This layer takes a RecordIO file to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from the given RecordIO file.
Args:
filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run. After completing the
given number of passes, 'has_next()' will return False.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get RecordIO file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_recordio_file(
filename='./data.recordio',
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = [] shape_concat = []
ranks = [] ranks = []
...@@ -310,6 +346,13 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): ...@@ -310,6 +346,13 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var.persistable = True startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = for_parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
...@@ -318,11 +361,15 @@ def open_files(filenames, ...@@ -318,11 +361,15 @@ def open_files(filenames,
lod_levels, lod_levels,
dtypes, dtypes,
thread_num, thread_num,
buffer_size=None): buffer_size=None,
pass_num=1,
for_parallel=False):
""" """
Open files Open files
This layer takes a list of files to read from and returns a Reader Variable. Via the Reader Variable, we can get data from given files. This layer takes a list of files to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from given files. All files must
have name suffixs to indicate their formats, e.g., '*.recordio'.
Args: Args:
filenames(list): The list of file names. filenames(list): The list of file names.
...@@ -331,6 +378,10 @@ def open_files(filenames, ...@@ -331,6 +378,10 @@ def open_files(filenames,
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number. thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer. buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run. After completing the
given number of passes, 'has_next()' will return False.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns: Returns:
Variable: A Reader Variable via which we can get file data. Variable: A Reader Variable via which we can get file data.
...@@ -338,16 +389,16 @@ def open_files(filenames, ...@@ -338,16 +389,16 @@ def open_files(filenames,
Examples: Examples:
.. code-block:: python .. code-block:: python
reader = fluid.layers.open_files(filenames=['./data1.recordio', reader = fluid.layers.io.open_files(filenames=['./data1.recordio',
'./data2.recordio'], './data2.recordio'],
shapes=[(3,224,224), (1)], shapes=[(3,224,224), (1)],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'],
thread_num=2, thread_num=2,
buffer_size=2) buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data: # Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader) image, label = fluid.layers.io.read_file(reader)
""" """
if buffer_size is None: if buffer_size is None:
buffer_size = thread_num buffer_size = thread_num
...@@ -361,13 +412,12 @@ def open_files(filenames, ...@@ -361,13 +412,12 @@ def open_files(filenames,
shape_concat.extend(shape) shape_concat.extend(shape)
ranks.append(len(shape)) ranks.append(len(shape))
var_name = unique_name('multiple_reader') multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_reader = startup_blk.create_var(name=multi_file_reader_name)
startup_blk.append_op( startup_blk.append_op(
type='open_files', type='open_files',
outputs={'Out': [startup_var]}, outputs={'Out': [startup_reader]},
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
...@@ -377,14 +427,21 @@ def open_files(filenames, ...@@ -377,14 +427,21 @@ def open_files(filenames,
'buffer_size': buffer_size 'buffer_size': buffer_size
}) })
startup_var.desc.set_dtypes(dtypes) startup_reader.desc.set_dtypes(dtypes)
startup_var.persistable = True startup_reader.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_reader = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_reader)
return monkey_patch_reader_methods(main_prog_var) if pass_num > 1:
main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = for_parallel(reader=main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs): return monkey_patch_reader_methods(main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs={}):
var_name = unique_name(op_type) var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_var = startup_blk.create_var(name=var_name)
...@@ -400,12 +457,12 @@ def __create_decorated_reader__(op_type, reader, attrs): ...@@ -400,12 +457,12 @@ def __create_decorated_reader__(op_type, reader, attrs):
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
def create_shuffle_reader(reader, buffer_size): def shuffle(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader, return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)}) {'buffer_size': int(buffer_size)})
def create_double_buffer_reader(reader, place=None): def double_buffer(reader, place=None):
attrs = dict() attrs = dict()
if place is not None: if place is not None:
attrs['place'] = str(place).upper() attrs['place'] = str(place).upper()
...@@ -413,11 +470,15 @@ def create_double_buffer_reader(reader, place=None): ...@@ -413,11 +470,15 @@ def create_double_buffer_reader(reader, place=None):
attrs) attrs)
def create_multi_pass_reader(reader, pass_num): def multi_pass(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader, return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)}) {'pass_num': int(pass_num)})
def for_parallel(reader):
return __create_decorated_reader__('create_threaded_reader', reader)
def read_file(file_obj): def read_file(file_obj):
helper = LayerHelper('read_file') helper = LayerHelper('read_file')
out = [ out = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册