“9605fcd124ae6a3cdad171d2d61107e9cabe4c2f”上不存在“paddle/fluid/framework/ir/shuffle_channel_detect_pass.h”
提交 d310c549 编写于 作者: L likejiao

compatible with `program` argument

上级 ba58d597
...@@ -132,11 +132,15 @@ class Agent(AgentBase): ...@@ -132,11 +132,15 @@ class Agent(AgentBase):
""" """
raise NotImplementedError raise NotImplementedError
def save(self, save_path=None): def save(self, save_path=None, program=None):
"""Save parameters for every fluid program. """Save parameters.
Args: Args:
save_path(str): a directory where to save all the parameters. save_path(str): a directory where to save the parameters.
program(fluid.Program): program that describes the neural network structure. If None, will all program.
Raises:
Error: if program does not exist
Example: Example:
...@@ -144,32 +148,55 @@ class Agent(AgentBase): ...@@ -144,32 +148,55 @@ class Agent(AgentBase):
agent = AtariAgent() agent = AtariAgent()
agent.save() agent.save()
agent.save('./program_model')
agent.save('./program_model', program=agent.learn_program)
""" """
if save_path is None: if save_path is None:
save_path = './program_model' save_path = './program_model'
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
for keyval in self.__dict__.items(): all_programs = [
filename = keyval[0] (kv[0], kv[1]) for kv in self.__dict__.items()
program = keyval[1] if (isinstance(kv[1], fluid.framework.Program)
if isinstance(program, fluid.framework.Program) or \ or isinstance(kv[1], fluid.compiler.CompiledProgram))
isinstance(program, fluid.compiler.CompiledProgram): ]
fluid.io.save_params(
executor=self.fluid_executor, if program:
dirname=save_path, filename = None
main_program=program, for keyval in all_programs:
filename=filename) if program == keyval[1]:
filename = keyval[0]
def restore(self, save_path=None): break
if filename is None:
raise Exception('can not find program {}.'.format(program))
fluid.io.save_params(
executor=self.fluid_executor,
dirname=save_path,
main_program=program,
filename=filename)
else:
for keyval in all_programs:
filename = keyval[0]
program = keyval[1]
if isinstance(program, fluid.framework.Program) or \
isinstance(program, fluid.compiler.CompiledProgram):
fluid.io.save_params(
executor=self.fluid_executor,
dirname=save_path,
main_program=program,
filename=filename)
def restore(self, save_path=None, program=None):
"""Restore previously saved parameters from save_path. """Restore previously saved parameters from save_path.
default save_path is ./program_model default save_path is ./program_model
Args: Args:
save_path(str): path where parameters were previously saved. save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will restore all program.
Raises: Raises:
ValueError: if save_path does not exist or no file in save_path. Error: if save_path does not exist or can not find the specific program file in save_path.
Example: Example:
...@@ -190,12 +217,42 @@ class Agent(AgentBase): ...@@ -190,12 +217,42 @@ class Agent(AgentBase):
raise Exception( raise Exception(
'can not restore from {}, it is a file, not directory'.format( 'can not restore from {}, it is a file, not directory'.format(
save_path)) save_path))
all_programs = [
for keyval in self.__dict__.items(): (kv[0], kv[1]) for kv in self.__dict__.items()
filename = keyval[0] if (isinstance(kv[1], fluid.framework.Program)
program = keyval[1] or isinstance(kv[1], fluid.compiler.CompiledProgram))
if isinstance(program, fluid.framework.Program) or \ ]
isinstance(program, fluid.compiler.CompiledProgram):
if program:
filename = None
for keyval in all_programs:
if program == keyval[1]:
filename = keyval[0]
break
if filename is None:
raise Exception('can not find the program to restore.')
if not os.path.isfile('{}/{}'.format(save_path, filename)):
raise Exception('{}/{} does not exits'.format(
save_path, filename))
if type(program) is fluid.compiler.CompiledProgram:
program = program._init_program
fluid.io.load_params(
executor=self.fluid_executor,
dirname=save_path,
main_program=program,
filename=filename)
else:
programs_list = [kv[0] for kv in all_programs]
exist_files = os.listdir(save_path)
if len(programs_list) != len(exist_files):
raise Exception(
'expected to restore {} model file under directory {}: {}, but {} files are found: {}.'
.format(
len(programs_list), save_path, programs_list,
len(exist_files), exist_files))
for keyval in all_programs:
filename = keyval[0]
program = keyval[1]
if not os.path.isfile('{}/{}'.format(save_path, filename)): if not os.path.isfile('{}/{}'.format(save_path, filename)):
raise Exception('{}/{} does not exits'.format( raise Exception('{}/{} does not exits'.format(
save_path, filename)) save_path, filename))
...@@ -207,65 +264,3 @@ class Agent(AgentBase): ...@@ -207,65 +264,3 @@ class Agent(AgentBase):
dirname=save_path, dirname=save_path,
main_program=program, main_program=program,
filename=filename) filename=filename)
def save_program(self, save_path, program=None):
"""Save parameters.
Args:
save_path(str): where to save the parameters.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
"""
if program is None:
program = self.learn_program
dirname = os.sep.join(save_path.split(os.sep)[:-1])
filename = save_path.split(os.sep)[-1]
fluid.io.save_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)
def restore_program(self, save_path, program=None):
"""Restore previously saved parameters.
This method requires a program that describes the network structure.
The save_path argument is typically a value previously passed to ``save_params()``.
Args:
save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
agent.restore('./model.ckpt')
"""
if program is None:
program = self.learn_program
if type(program) is fluid.compiler.CompiledProgram:
program = program._init_program
dirname = os.sep.join(save_path.split(os.sep)[:-1])
filename = save_path.split(os.sep)[-1]
fluid.io.load_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册