diff --git a/parl/core/fluid/agent.py b/parl/core/fluid/agent.py index 4e4f6b755847492af5501ab54ed2b863dd80b8e2..aef49b95ea8481dc41bf2a604d6ba04e6794d58e 100644 --- a/parl/core/fluid/agent.py +++ b/parl/core/fluid/agent.py @@ -132,11 +132,15 @@ class Agent(AgentBase): """ raise NotImplementedError - def save(self, save_path=None): - """Save parameters for every fluid program. + def save(self, save_path=None, program=None): + """Save parameters. Args: - save_path(str): a directory where to save all the parameters. + save_path(str): a directory where to save the parameters. + program(fluid.Program): program that describes the neural network structure. If None, will all program. + + Raises: + Error: if program does not exist Example: @@ -144,32 +148,55 @@ class Agent(AgentBase): agent = AtariAgent() agent.save() + agent.save('./program_model') + agent.save('./program_model', program=agent.learn_program) """ if save_path is None: save_path = './program_model' if not os.path.exists(save_path): os.makedirs(save_path) - for keyval in self.__dict__.items(): - filename = keyval[0] - program = keyval[1] - if isinstance(program, fluid.framework.Program) or \ - isinstance(program, fluid.compiler.CompiledProgram): - fluid.io.save_params( - executor=self.fluid_executor, - dirname=save_path, - main_program=program, - filename=filename) - - def restore(self, save_path=None): + all_programs = [ + (kv[0], kv[1]) for kv in self.__dict__.items() + if (isinstance(kv[1], fluid.framework.Program) + or isinstance(kv[1], fluid.compiler.CompiledProgram)) + ] + + if program: + filename = None + for keyval in all_programs: + if program == keyval[1]: + filename = keyval[0] + break + if filename is None: + raise Exception('can not find program {}.'.format(program)) + fluid.io.save_params( + executor=self.fluid_executor, + dirname=save_path, + main_program=program, + filename=filename) + else: + for keyval in all_programs: + filename = keyval[0] + program = keyval[1] + if isinstance(program, fluid.framework.Program) or \ + isinstance(program, fluid.compiler.CompiledProgram): + fluid.io.save_params( + executor=self.fluid_executor, + dirname=save_path, + main_program=program, + filename=filename) + + def restore(self, save_path=None, program=None): """Restore previously saved parameters from save_path. default save_path is ./program_model Args: save_path(str): path where parameters were previously saved. + program(fluid.Program): program that describes the neural network structure. If None, will restore all program. Raises: - ValueError: if save_path does not exist or no file in save_path. + Error: if save_path does not exist or can not find the specific program file in save_path. Example: @@ -190,12 +217,42 @@ class Agent(AgentBase): raise Exception( 'can not restore from {}, it is a file, not directory'.format( save_path)) - - for keyval in self.__dict__.items(): - filename = keyval[0] - program = keyval[1] - if isinstance(program, fluid.framework.Program) or \ - isinstance(program, fluid.compiler.CompiledProgram): + all_programs = [ + (kv[0], kv[1]) for kv in self.__dict__.items() + if (isinstance(kv[1], fluid.framework.Program) + or isinstance(kv[1], fluid.compiler.CompiledProgram)) + ] + + if program: + filename = None + for keyval in all_programs: + if program == keyval[1]: + filename = keyval[0] + break + if filename is None: + raise Exception('can not find the program to restore.') + if not os.path.isfile('{}/{}'.format(save_path, filename)): + raise Exception('{}/{} does not exits'.format( + save_path, filename)) + if type(program) is fluid.compiler.CompiledProgram: + program = program._init_program + fluid.io.load_params( + executor=self.fluid_executor, + dirname=save_path, + main_program=program, + filename=filename) + else: + programs_list = [kv[0] for kv in all_programs] + exist_files = os.listdir(save_path) + if len(programs_list) != len(exist_files): + raise Exception( + 'expected to restore {} model file under directory {}: {}, but {} files are found: {}.' + .format( + len(programs_list), save_path, programs_list, + len(exist_files), exist_files)) + for keyval in all_programs: + filename = keyval[0] + program = keyval[1] if not os.path.isfile('{}/{}'.format(save_path, filename)): raise Exception('{}/{} does not exits'.format( save_path, filename)) @@ -207,65 +264,3 @@ class Agent(AgentBase): dirname=save_path, main_program=program, filename=filename) - - def save_program(self, save_path, program=None): - """Save parameters. - - Args: - save_path(str): where to save the parameters. - program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program. - - Raises: - ValueError: if program is None and self.learn_program does not exist. - - Example: - - .. code-block:: python - - agent = AtariAgent() - agent.save('./model.ckpt') - - """ - if program is None: - program = self.learn_program - dirname = os.sep.join(save_path.split(os.sep)[:-1]) - filename = save_path.split(os.sep)[-1] - fluid.io.save_params( - executor=self.fluid_executor, - dirname=dirname, - main_program=program, - filename=filename) - - def restore_program(self, save_path, program=None): - """Restore previously saved parameters. - This method requires a program that describes the network structure. - The save_path argument is typically a value previously passed to ``save_params()``. - - Args: - save_path(str): path where parameters were previously saved. - program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program. - - Raises: - ValueError: if program is None and self.learn_program does not exist. - - Example: - - .. code-block:: python - - agent = AtariAgent() - agent.save('./model.ckpt') - agent.restore('./model.ckpt') - - """ - - if program is None: - program = self.learn_program - if type(program) is fluid.compiler.CompiledProgram: - program = program._init_program - dirname = os.sep.join(save_path.split(os.sep)[:-1]) - filename = save_path.split(os.sep)[-1] - fluid.io.load_params( - executor=self.fluid_executor, - dirname=dirname, - main_program=program, - filename=filename)