From fca9e8847d5017601251ee8813e7af513b2603ed Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 8 Apr 2018 16:10:14 +0800 Subject: [PATCH] Update Readers Python API 1. Combine 'open_files', 'multi_pass_reader' and 'threaded_reader' together to make the new 'open_files' interface. 2. Add some docstring. 3. Simplify interface names of 'create_XXX_reader', e.g, rename 'create_double_buffer_reader' to 'double_buffer'. --- python/paddle/fluid/layers/io.py | 109 ++++++++++++++++++++++++------- 1 file changed, 85 insertions(+), 24 deletions(-) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 7413e69234..fc8809ce15 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -22,7 +22,7 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'open_files', 'read_file', 'create_shuffle_reader', - 'create_double_buffer_reader', 'create_multi_pass_reader' + 'create_double_buffer_reader' ] @@ -283,7 +283,43 @@ def _copy_reader_create_op_(block, op): return new_op -def open_recordio_file(filename, shapes, lod_levels, dtypes): +def open_recordio_file(filename, + shapes, + lod_levels, + dtypes, + pass_num=1, + for_parallel=False): + """ + Open a RecordIO file + + This layer takes a RecordIO file to read from and returns a Reader Variable. + Via the Reader Variable, we can get data from the given RecordIO file. + + Args: + filename(str): The RecordIO file's name. + shapes(list): List of tuples which declaring data shapes. + lod_levels(list): List of ints which declaring data lod_level. + dtypes(list): List of strs which declaring data type. + pass_num(int): Number of passes to run. After completing the + given number of passes, 'has_next()' will return False. + for_parallel(Bool): Set it as True if you are going to run + subsequent operators in parallel. + + Returns: + Variable: A Reader Variable via which we can get RecordIO file data. + + Examples: + .. code-block:: python + + reader = fluid.layers.io.open_recordio_file( + filename='./data.recordio', + shapes=[(3,224,224), (1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + + # Via the reader, we can use 'read_file' layer to get data: + image, label = fluid.layers.read_file(reader) + """ dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] shape_concat = [] ranks = [] @@ -310,6 +346,13 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): startup_var.persistable = True main_prog_var = _copy_reader_var_(default_main_program().current_block(), startup_var) + + if pass_num > 1: + main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) + + if for_parallel: + main_prog_var = for_parallel(reader=main_prog_var) + return monkey_patch_reader_methods(main_prog_var) @@ -318,11 +361,15 @@ def open_files(filenames, lod_levels, dtypes, thread_num, - buffer_size=None): + buffer_size=None, + pass_num=1, + for_parallel=False): """ Open files - This layer takes a list of files to read from and returns a Reader Variable. Via the Reader Variable, we can get data from given files. + This layer takes a list of files to read from and returns a Reader Variable. + Via the Reader Variable, we can get data from given files. All files must + have name suffixs to indicate their formats, e.g., '*.recordio'. Args: filenames(list): The list of file names. @@ -331,6 +378,10 @@ def open_files(filenames, dtypes(list): List of strs which declaring data type. thread_num(int): The maximal concurrent prefetch thread number. buffer_size(int): The size of prefetch buffer. + pass_num(int): Number of passes to run. After completing the + given number of passes, 'has_next()' will return False. + for_parallel(Bool): Set it as True if you are going to run + subsequent operators in parallel. Returns: Variable: A Reader Variable via which we can get file data. @@ -338,16 +389,16 @@ def open_files(filenames, Examples: .. code-block:: python - reader = fluid.layers.open_files(filenames=['./data1.recordio', + reader = fluid.layers.io.open_files(filenames=['./data1.recordio', './data2.recordio'], - shapes=[(3,224,224), (1)], - lod_levels=[0, 0], - dtypes=['float32', 'int64'], - thread_num=2, - buffer_size=2) + shapes=[(3,224,224), (1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64'], + thread_num=2, + buffer_size=2) # Via the reader, we can use 'read_file' layer to get data: - image, label = fluid.layers.read_file(reader) + image, label = fluid.layers.io.read_file(reader) """ if buffer_size is None: buffer_size = thread_num @@ -361,13 +412,12 @@ def open_files(filenames, shape_concat.extend(shape) ranks.append(len(shape)) - var_name = unique_name('multiple_reader') - + multi_file_reader_name = unique_name('multi_file_reader') startup_blk = default_startup_program().current_block() - startup_var = startup_blk.create_var(name=var_name) + startup_reader = startup_blk.create_var(name=multi_file_reader_name) startup_blk.append_op( type='open_files', - outputs={'Out': [startup_var]}, + outputs={'Out': [startup_reader]}, attrs={ 'shape_concat': shape_concat, 'lod_levels': lod_levels, @@ -377,14 +427,21 @@ def open_files(filenames, 'buffer_size': buffer_size }) - startup_var.desc.set_dtypes(dtypes) - startup_var.persistable = True - main_prog_var = _copy_reader_var_(default_main_program().current_block(), - startup_var) - return monkey_patch_reader_methods(main_prog_var) + startup_reader.desc.set_dtypes(dtypes) + startup_reader.persistable = True + main_prog_reader = _copy_reader_var_(default_main_program().current_block(), + startup_reader) + if pass_num > 1: + main_prog_reader = multi_pass( + reader=main_prog_reader, pass_num=pass_num) + if for_parallel: + main_prog_reader = for_parallel(reader=main_prog_reader) -def __create_decorated_reader__(op_type, reader, attrs): + return monkey_patch_reader_methods(main_prog_reader) + + +def __create_decorated_reader__(op_type, reader, attrs={}): var_name = unique_name(op_type) startup_blk = default_startup_program().current_block() startup_var = startup_blk.create_var(name=var_name) @@ -400,12 +457,12 @@ def __create_decorated_reader__(op_type, reader, attrs): return monkey_patch_reader_methods(main_prog_var) -def create_shuffle_reader(reader, buffer_size): +def shuffle(reader, buffer_size): return __create_decorated_reader__('create_shuffle_reader', reader, {'buffer_size': int(buffer_size)}) -def create_double_buffer_reader(reader, place=None): +def double_buffer(reader, place=None): attrs = dict() if place is not None: attrs['place'] = str(place).upper() @@ -413,11 +470,15 @@ def create_double_buffer_reader(reader, place=None): attrs) -def create_multi_pass_reader(reader, pass_num): +def multi_pass(reader, pass_num): return __create_decorated_reader__('create_multi_pass_reader', reader, {'pass_num': int(pass_num)}) +def for_parallel(reader): + return __create_decorated_reader__('create_threaded_reader', reader) + + def read_file(file_obj): helper = LayerHelper('read_file') out = [ -- GitLab