提交 f5a5baa5 编写于 作者: L likejiao

Save the model without specifying a program

上级 b966fa78
......@@ -21,6 +21,7 @@ from parl.core.fluid import layers
from parl.core.agent_base import AgentBase
from parl.core.fluid.algorithm import Algorithm
from parl.utils import machine_info
from parl.utils import logger
__all__ = ['Agent']
......@@ -132,7 +133,85 @@ class Agent(AgentBase):
"""
raise NotImplementedError
def save(self, save_path, program=None):
def save(self, save_path=None):
"""Save parameters for every fluid program.
Args:
save_path(str): a directory where to save all the parameters.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save()
"""
if save_path is None:
save_path = './program_model'
if not os.path.exists(save_path):
os.makedirs(save_path)
for keyval in self.__dict__.items():
filename = keyval[0]
program = keyval[1]
if isinstance(program, fluid.framework.Program) or \
isinstance(program, fluid.compiler.CompiledProgram):
fluid.io.save_params(
executor=self.fluid_executor,
dirname=save_path,
main_program=program,
filename=filename)
def restore(self, save_path=None):
"""Restore previously saved parameters from save_path.
default save_path is ./program_model
Args:
save_path(str): path where parameters were previously saved.
Raises:
ValueError: if save_path does not exist or no file in save_path.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save()
agent.restore()
"""
if save_path is None:
save_path = './program_model'
if not os.path.exists(save_path):
raise Exception(
'can not restore from {}, directory does not exists'.format(
save_path))
if os.path.isfile(save_path):
raise Exception(
'can not restore from {}, it is a file, not directory'.format(
save_path))
for keyval in self.__dict__.items():
filename = keyval[0]
program = keyval[1]
if isinstance(program, fluid.framework.Program) or \
isinstance(program, fluid.compiler.CompiledProgram):
if not os.path.isfile('{}/{}'.format(save_path, filename)):
raise Exception('{}/{} does not exits'.format(
save_path, filename))
logger.info(type(program))
if type(program) is fluid.compiler.CompiledProgram:
program = program._init_program
logger.info(type(program))
fluid.io.load_params(
executor=self.fluid_executor,
dirname=save_path,
main_program=program,
filename=filename)
def save_program(self, save_path, program=None):
"""Save parameters.
Args:
......@@ -160,7 +239,7 @@ class Agent(AgentBase):
main_program=program,
filename=filename)
def restore(self, save_path, program=None):
def restore_program(self, save_path, program=None):
"""Restore previously saved parameters.
This method requires a program that describes the network structure.
The save_path argument is typically a value previously passed to ``save_params()``.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册