From f5a5baa5c71a25df07c651b38ee59215de8ada0f Mon Sep 17 00:00:00 2001 From: likejiao Date: Mon, 21 Sep 2020 17:58:30 +0800 Subject: [PATCH] Save the model without specifying a program --- parl/core/fluid/agent.py | 83 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/parl/core/fluid/agent.py b/parl/core/fluid/agent.py index c587b97..fc29122 100644 --- a/parl/core/fluid/agent.py +++ b/parl/core/fluid/agent.py @@ -21,6 +21,7 @@ from parl.core.fluid import layers from parl.core.agent_base import AgentBase from parl.core.fluid.algorithm import Algorithm from parl.utils import machine_info +from parl.utils import logger __all__ = ['Agent'] @@ -132,7 +133,85 @@ class Agent(AgentBase): """ raise NotImplementedError - def save(self, save_path, program=None): + def save(self, save_path=None): + """Save parameters for every fluid program. + + Args: + save_path(str): a directory where to save all the parameters. + + Example: + + .. code-block:: python + + agent = AtariAgent() + agent.save() + + """ + if save_path is None: + save_path = './program_model' + if not os.path.exists(save_path): + os.makedirs(save_path) + for keyval in self.__dict__.items(): + filename = keyval[0] + program = keyval[1] + if isinstance(program, fluid.framework.Program) or \ + isinstance(program, fluid.compiler.CompiledProgram): + fluid.io.save_params( + executor=self.fluid_executor, + dirname=save_path, + main_program=program, + filename=filename) + + def restore(self, save_path=None): + """Restore previously saved parameters from save_path. + default save_path is ./program_model + + Args: + save_path(str): path where parameters were previously saved. + + Raises: + ValueError: if save_path does not exist or no file in save_path. + + Example: + + .. code-block:: python + + agent = AtariAgent() + agent.save() + agent.restore() + + """ + if save_path is None: + save_path = './program_model' + if not os.path.exists(save_path): + raise Exception( + 'can not restore from {}, directory does not exists'.format( + save_path)) + if os.path.isfile(save_path): + raise Exception( + 'can not restore from {}, it is a file, not directory'.format( + save_path)) + + for keyval in self.__dict__.items(): + filename = keyval[0] + program = keyval[1] + if isinstance(program, fluid.framework.Program) or \ + isinstance(program, fluid.compiler.CompiledProgram): + if not os.path.isfile('{}/{}'.format(save_path, filename)): + raise Exception('{}/{} does not exits'.format( + save_path, filename)) + logger.info(type(program)) + if type(program) is fluid.compiler.CompiledProgram: + program = program._init_program + logger.info(type(program)) + + fluid.io.load_params( + executor=self.fluid_executor, + dirname=save_path, + main_program=program, + filename=filename) + + def save_program(self, save_path, program=None): """Save parameters. Args: @@ -160,7 +239,7 @@ class Agent(AgentBase): main_program=program, filename=filename) - def restore(self, save_path, program=None): + def restore_program(self, save_path, program=None): """Restore previously saved parameters. This method requires a program that describes the network structure. The save_path argument is typically a value previously passed to ``save_params()``. -- GitLab