diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index ae98e299a4838db8b7089e09683c33c940cfde3d..7f5187d29984482d91b5709bf4514d013028767a 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -489,7 +489,8 @@ class Operator(object): no_kernel_op_set = { 'feed', 'fetch', 'save', 'load', 'recurrent', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', - 'recv', 'listen_and_serv', 'parallel_do' + 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', + 'load_combine' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index d56ec45c538b580f5520bc060b4b339bb1be0539..613dc20b6ea5533d126a73b7ec47796b3f812db5 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -46,6 +46,9 @@ def is_parameter(var): def is_persistable(var): + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST: + return False return var.persistable @@ -60,7 +63,12 @@ def _clone_var_in_block_(block, var): persistable=True) -def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): +def save_vars(executor, + dirname, + main_program=None, + vars=None, + predicate=None, + save_file_name=None): """ Save variables to directory by executor. @@ -69,9 +77,12 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): :param main_program: program. If vars is None, then filter all variables in this program which fit `predicate`. Default default_main_program. :param predicate: The Predicate describes a callable that returns a variable - as a bool. If it returns true, the variables will be saved. - :param vars: variables need to be saved. If specify vars, program & predicate + as a bool. If it returns true, the corresponding input variable will be saved. + :param vars: variables need to be saved. If vars is specified, program & predicate will be ignored + :param save_file_name: The name of a single file that all vars are saved to. + If it is None, save variables to separate files. + :return: None """ if vars is None: @@ -83,21 +94,39 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): save_vars( executor, dirname=dirname, - vars=filter(predicate, main_program.list_vars())) + vars=filter(predicate, main_program.list_vars()), + save_file_name=save_file_name) else: save_program = Program() save_block = save_program.global_block() + + save_var_map = {} for each_var in vars: new_var = _clone_var_in_block_(save_block, each_var) + if save_file_name is None: + save_block.append_op( + type='save', + inputs={'X': [new_var]}, + outputs={}, + attrs={'file_path': os.path.join(dirname, new_var.name)}) + else: + save_var_map[new_var.name] = new_var + + if save_file_name is not None: + save_var_list = [] + for name in sorted(save_var_map.keys()): + save_var_list.append(save_var_map[name]) + save_block.append_op( - type='save', - inputs={'X': [new_var]}, + type='save_combine', + inputs={'X': save_var_list}, outputs={}, - attrs={'file_path': os.path.join(dirname, new_var.name)}) + attrs={'file_path': os.path.join(dirname, save_file_name)}) + executor.run(save_program) -def save_params(executor, dirname, main_program=None): +def save_params(executor, dirname, main_program=None, save_file_name=None): """ Save all parameters to directory with executor. """ @@ -106,10 +135,12 @@ def save_params(executor, dirname, main_program=None): dirname=dirname, main_program=main_program, vars=None, - predicate=is_parameter) + predicate=is_parameter, + save_file_name=save_file_name) -def save_persistables(executor, dirname, main_program=None): +def save_persistables(executor, dirname, main_program=None, + save_file_name=None): """ Save all persistables to directory with executor. """ @@ -118,21 +149,30 @@ def save_persistables(executor, dirname, main_program=None): dirname=dirname, main_program=main_program, vars=None, - predicate=is_persistable) + predicate=is_persistable, + save_file_name=save_file_name) -def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): +def load_vars(executor, + dirname, + main_program=None, + vars=None, + predicate=None, + load_file_name=None): """ Load variables from directory by executor. - :param executor: executor that save variable + :param executor: executor that load variable :param dirname: directory path :param main_program: program. If vars is None, then filter all variables in this program which fit `predicate`. Default default_main_program(). :param predicate: The Predicate describes a callable that returns a variable - as a bool. If it returns true, the variables will be loaded. - :param vars: variables need to be loaded. If specify vars, program & + as a bool. If it returns true, the corresponding input variable will be loaded. + :param vars: variables need to be loaded. If vars is specified, program & predicate will be ignored + :param load_file_name: The name of the single file that all vars are loaded from. + If it is None, load variables from separate files. + :return: None """ if vars is None: @@ -144,23 +184,40 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): load_vars( executor, dirname=dirname, - vars=filter(predicate, main_program.list_vars())) + vars=filter(predicate, main_program.list_vars()), + load_file_name=load_file_name) else: load_prog = Program() load_block = load_prog.global_block() + + load_var_map = {} for each_var in vars: assert isinstance(each_var, Variable) new_var = _clone_var_in_block_(load_block, each_var) + if load_file_name is None: + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [new_var]}, + attrs={'file_path': os.path.join(dirname, new_var.name)}) + else: + load_var_map[new_var.name] = new_var + + if load_file_name is not None: + load_var_list = [] + for name in sorted(load_var_map.keys()): + load_var_list.append(load_var_map[name]) + load_block.append_op( - type='load', + type='load_combine', inputs={}, - outputs={"Out": [new_var]}, - attrs={'file_path': os.path.join(dirname, new_var.name)}) + outputs={"Out": load_var_list}, + attrs={'file_path': os.path.join(dirname, load_file_name)}) executor.run(load_prog) -def load_params(executor, dirname, main_program=None): +def load_params(executor, dirname, main_program=None, load_file_name=None): """ load all parameters from directory by executor. """ @@ -168,10 +225,12 @@ def load_params(executor, dirname, main_program=None): executor, dirname=dirname, main_program=main_program, - predicate=is_parameter) + predicate=is_parameter, + load_file_name=load_file_name) -def load_persistables(executor, dirname, main_program=None): +def load_persistables(executor, dirname, main_program=None, + load_file_name=None): """ load all persistables from directory by executor. """ @@ -179,7 +238,8 @@ def load_persistables(executor, dirname, main_program=None): executor, dirname=dirname, main_program=main_program, - predicate=is_persistable) + predicate=is_persistable, + load_file_name=load_file_name) def get_inference_program(target_vars, main_program=None): @@ -238,7 +298,8 @@ def save_inference_model(dirname, feeded_var_names, target_vars, executor, - main_program=None): + main_program=None, + save_file_name=None): """ Build a model especially for inference, and save it to directory by the executor. @@ -249,6 +310,8 @@ def save_inference_model(dirname, :param executor: executor that save inference model :param main_program: original program, which will be pruned to build the inference model. Default default_main_program(). + :param save_file_name: The name of a single file that all parameters are saved to. + If it is None, save parameters to separate files. :return: None """ @@ -283,25 +346,7 @@ def save_inference_model(dirname, with open(model_file_name, "wb") as f: f.write(inference_program.desc.serialize_to_string()) - save_params(executor, dirname, main_program) - - -def load_persistables_if_exist(executor, dirname, main_program=None): - filenames = next(os.walk(dirname))[2] - filenames = set(filenames) - - def _is_presistable_and_exist_(var): - if not is_persistable(var): - return False - else: - return var.name in filenames - - load_vars( - executor, - dirname, - main_program=main_program, - vars=None, - predicate=_is_presistable_and_exist_) + save_persistables(executor, dirname, inference_program, save_file_name) def get_feed_targets_names(program): @@ -322,13 +367,15 @@ def get_fetch_targets_names(program): return fetch_targets_names -def load_inference_model(dirname, executor): +def load_inference_model(dirname, executor, load_file_name=None): """ Load inference model from a directory :param dirname: directory path :param executor: executor that load inference model - + :param load_file_name: The name of the single file that all parameters are loaded from. + If it is None, load parameters from separate files. + :return: [program, feed_target_names, fetch_targets] program: program especially for inference. feed_target_names: Names of variables that need to feed data @@ -342,7 +389,7 @@ def load_inference_model(dirname, executor): program_desc_str = f.read() program = Program.parse_from_string(program_desc_str) - load_persistables_if_exist(executor, dirname, program) + load_persistables(executor, dirname, program, load_file_name) feed_target_names = get_feed_targets_names(program) fetch_target_names = get_fetch_targets_names(program) @@ -359,6 +406,7 @@ def get_parameter_value(para, executor): :param executor: executor for retrieving the value :param para: the given parameter + :return: the LoDTensor for the parameter """ assert is_parameter(para) @@ -377,6 +425,7 @@ def get_parameter_value_by_name(name, executor, program=None): :param name: the name of the parameter :param program: the program where the variable is found Default default_main_program(). + :return: the LoDTensor for the variable """ if program is None: