提交 d310c549 编写于 作者: L likejiao

compatible with `program` argument

上级 ba58d597
...@@ -132,11 +132,15 @@ class Agent(AgentBase): ...@@ -132,11 +132,15 @@ class Agent(AgentBase):
""" """
raise NotImplementedError raise NotImplementedError
def save(self, save_path=None): def save(self, save_path=None, program=None):
"""Save parameters for every fluid program. """Save parameters.
Args: Args:
save_path(str): a directory where to save all the parameters. save_path(str): a directory where to save the parameters.
program(fluid.Program): program that describes the neural network structure. If None, will all program.
Raises:
Error: if program does not exist
Example: Example:
...@@ -144,13 +148,35 @@ class Agent(AgentBase): ...@@ -144,13 +148,35 @@ class Agent(AgentBase):
agent = AtariAgent() agent = AtariAgent()
agent.save() agent.save()
agent.save('./program_model')
agent.save('./program_model', program=agent.learn_program)
""" """
if save_path is None: if save_path is None:
save_path = './program_model' save_path = './program_model'
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
for keyval in self.__dict__.items(): all_programs = [
(kv[0], kv[1]) for kv in self.__dict__.items()
if (isinstance(kv[1], fluid.framework.Program)
or isinstance(kv[1], fluid.compiler.CompiledProgram))
]
if program:
filename = None
for keyval in all_programs:
if program == keyval[1]:
filename = keyval[0]
break
if filename is None:
raise Exception('can not find program {}.'.format(program))
fluid.io.save_params(
executor=self.fluid_executor,
dirname=save_path,
main_program=program,
filename=filename)
else:
for keyval in all_programs:
filename = keyval[0] filename = keyval[0]
program = keyval[1] program = keyval[1]
if isinstance(program, fluid.framework.Program) or \ if isinstance(program, fluid.framework.Program) or \
...@@ -161,15 +187,16 @@ class Agent(AgentBase): ...@@ -161,15 +187,16 @@ class Agent(AgentBase):
main_program=program, main_program=program,
filename=filename) filename=filename)
def restore(self, save_path=None): def restore(self, save_path=None, program=None):
"""Restore previously saved parameters from save_path. """Restore previously saved parameters from save_path.
default save_path is ./program_model default save_path is ./program_model
Args: Args:
save_path(str): path where parameters were previously saved. save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will restore all program.
Raises: Raises:
ValueError: if save_path does not exist or no file in save_path. Error: if save_path does not exist or can not find the specific program file in save_path.
Example: Example:
...@@ -190,82 +217,50 @@ class Agent(AgentBase): ...@@ -190,82 +217,50 @@ class Agent(AgentBase):
raise Exception( raise Exception(
'can not restore from {}, it is a file, not directory'.format( 'can not restore from {}, it is a file, not directory'.format(
save_path)) save_path))
all_programs = [
for keyval in self.__dict__.items(): (kv[0], kv[1]) for kv in self.__dict__.items()
if (isinstance(kv[1], fluid.framework.Program)
or isinstance(kv[1], fluid.compiler.CompiledProgram))
]
if program:
filename = None
for keyval in all_programs:
if program == keyval[1]:
filename = keyval[0] filename = keyval[0]
program = keyval[1] break
if isinstance(program, fluid.framework.Program) or \ if filename is None:
isinstance(program, fluid.compiler.CompiledProgram): raise Exception('can not find the program to restore.')
if not os.path.isfile('{}/{}'.format(save_path, filename)): if not os.path.isfile('{}/{}'.format(save_path, filename)):
raise Exception('{}/{} does not exits'.format( raise Exception('{}/{} does not exits'.format(
save_path, filename)) save_path, filename))
if type(program) is fluid.compiler.CompiledProgram: if type(program) is fluid.compiler.CompiledProgram:
program = program._init_program program = program._init_program
fluid.io.load_params( fluid.io.load_params(
executor=self.fluid_executor, executor=self.fluid_executor,
dirname=save_path, dirname=save_path,
main_program=program, main_program=program,
filename=filename) filename=filename)
else:
def save_program(self, save_path, program=None): programs_list = [kv[0] for kv in all_programs]
"""Save parameters. exist_files = os.listdir(save_path)
if len(programs_list) != len(exist_files):
Args: raise Exception(
save_path(str): where to save the parameters. 'expected to restore {} model file under directory {}: {}, but {} files are found: {}.'
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program. .format(
len(programs_list), save_path, programs_list,
Raises: len(exist_files), exist_files))
ValueError: if program is None and self.learn_program does not exist. for keyval in all_programs:
filename = keyval[0]
Example: program = keyval[1]
if not os.path.isfile('{}/{}'.format(save_path, filename)):
.. code-block:: python raise Exception('{}/{} does not exits'.format(
save_path, filename))
agent = AtariAgent()
agent.save('./model.ckpt')
"""
if program is None:
program = self.learn_program
dirname = os.sep.join(save_path.split(os.sep)[:-1])
filename = save_path.split(os.sep)[-1]
fluid.io.save_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)
def restore_program(self, save_path, program=None):
"""Restore previously saved parameters.
This method requires a program that describes the network structure.
The save_path argument is typically a value previously passed to ``save_params()``.
Args:
save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
agent.restore('./model.ckpt')
"""
if program is None:
program = self.learn_program
if type(program) is fluid.compiler.CompiledProgram: if type(program) is fluid.compiler.CompiledProgram:
program = program._init_program program = program._init_program
dirname = os.sep.join(save_path.split(os.sep)[:-1])
filename = save_path.split(os.sep)[-1]
fluid.io.load_params( fluid.io.load_params(
executor=self.fluid_executor, executor=self.fluid_executor,
dirname=dirname, dirname=save_path,
main_program=program, main_program=program,
filename=filename) filename=filename)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册