提交 2fb280c9 编写于 作者: K kexinzhao 提交者: Yiqun Liu

Revise python save load api using new load/save op (#7995)

* initial commit

* add get_parameters method

* add get_parameters method

* small fix

* address comments

* address comments

* address comments

* fix
上级 270ecbe4
...@@ -489,7 +489,8 @@ class Operator(object): ...@@ -489,7 +489,8 @@ class Operator(object):
no_kernel_op_set = { no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'listen_and_serv', 'parallel_do' 'recv', 'listen_and_serv', 'parallel_do', 'save_combine',
'load_combine'
} }
if type not in no_kernel_op_set: if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
......
...@@ -46,6 +46,9 @@ def is_parameter(var): ...@@ -46,6 +46,9 @@ def is_parameter(var):
def is_persistable(var): def is_persistable(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
return False
return var.persistable return var.persistable
...@@ -60,7 +63,12 @@ def _clone_var_in_block_(block, var): ...@@ -60,7 +63,12 @@ def _clone_var_in_block_(block, var):
persistable=True) persistable=True)
def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): def save_vars(executor,
dirname,
main_program=None,
vars=None,
predicate=None,
save_file_name=None):
""" """
Save variables to directory by executor. Save variables to directory by executor.
...@@ -69,9 +77,12 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): ...@@ -69,9 +77,12 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
:param main_program: program. If vars is None, then filter all variables in this :param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default default_main_program. program which fit `predicate`. Default default_main_program.
:param predicate: The Predicate describes a callable that returns a variable :param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved. as a bool. If it returns true, the corresponding input variable will be saved.
:param vars: variables need to be saved. If specify vars, program & predicate :param vars: variables need to be saved. If vars is specified, program & predicate
will be ignored will be ignored
:param save_file_name: The name of a single file that all vars are saved to.
If it is None, save variables to separate files.
:return: None :return: None
""" """
if vars is None: if vars is None:
...@@ -83,21 +94,39 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): ...@@ -83,21 +94,39 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
save_vars( save_vars(
executor, executor,
dirname=dirname, dirname=dirname,
vars=filter(predicate, main_program.list_vars())) vars=filter(predicate, main_program.list_vars()),
save_file_name=save_file_name)
else: else:
save_program = Program() save_program = Program()
save_block = save_program.global_block() save_block = save_program.global_block()
save_var_map = {}
for each_var in vars: for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var) new_var = _clone_var_in_block_(save_block, each_var)
if save_file_name is None:
save_block.append_op( save_block.append_op(
type='save', type='save',
inputs={'X': [new_var]}, inputs={'X': [new_var]},
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)}) attrs={'file_path': os.path.join(dirname, new_var.name)})
else:
save_var_map[new_var.name] = new_var
if save_file_name is not None:
save_var_list = []
for name in sorted(save_var_map.keys()):
save_var_list.append(save_var_map[name])
save_block.append_op(
type='save_combine',
inputs={'X': save_var_list},
outputs={},
attrs={'file_path': os.path.join(dirname, save_file_name)})
executor.run(save_program) executor.run(save_program)
def save_params(executor, dirname, main_program=None): def save_params(executor, dirname, main_program=None, save_file_name=None):
""" """
Save all parameters to directory with executor. Save all parameters to directory with executor.
""" """
...@@ -106,10 +135,12 @@ def save_params(executor, dirname, main_program=None): ...@@ -106,10 +135,12 @@ def save_params(executor, dirname, main_program=None):
dirname=dirname, dirname=dirname,
main_program=main_program, main_program=main_program,
vars=None, vars=None,
predicate=is_parameter) predicate=is_parameter,
save_file_name=save_file_name)
def save_persistables(executor, dirname, main_program=None): def save_persistables(executor, dirname, main_program=None,
save_file_name=None):
""" """
Save all persistables to directory with executor. Save all persistables to directory with executor.
""" """
...@@ -118,21 +149,30 @@ def save_persistables(executor, dirname, main_program=None): ...@@ -118,21 +149,30 @@ def save_persistables(executor, dirname, main_program=None):
dirname=dirname, dirname=dirname,
main_program=main_program, main_program=main_program,
vars=None, vars=None,
predicate=is_persistable) predicate=is_persistable,
save_file_name=save_file_name)
def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): def load_vars(executor,
dirname,
main_program=None,
vars=None,
predicate=None,
load_file_name=None):
""" """
Load variables from directory by executor. Load variables from directory by executor.
:param executor: executor that save variable :param executor: executor that load variable
:param dirname: directory path :param dirname: directory path
:param main_program: program. If vars is None, then filter all variables in this :param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default default_main_program(). program which fit `predicate`. Default default_main_program().
:param predicate: The Predicate describes a callable that returns a variable :param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded. as a bool. If it returns true, the corresponding input variable will be loaded.
:param vars: variables need to be loaded. If specify vars, program & :param vars: variables need to be loaded. If vars is specified, program &
predicate will be ignored predicate will be ignored
:param load_file_name: The name of the single file that all vars are loaded from.
If it is None, load variables from separate files.
:return: None :return: None
""" """
if vars is None: if vars is None:
...@@ -144,23 +184,40 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): ...@@ -144,23 +184,40 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
load_vars( load_vars(
executor, executor,
dirname=dirname, dirname=dirname,
vars=filter(predicate, main_program.list_vars())) vars=filter(predicate, main_program.list_vars()),
load_file_name=load_file_name)
else: else:
load_prog = Program() load_prog = Program()
load_block = load_prog.global_block() load_block = load_prog.global_block()
load_var_map = {}
for each_var in vars: for each_var in vars:
assert isinstance(each_var, Variable) assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var) new_var = _clone_var_in_block_(load_block, each_var)
if load_file_name is None:
load_block.append_op( load_block.append_op(
type='load', type='load',
inputs={}, inputs={},
outputs={"Out": [new_var]}, outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)}) attrs={'file_path': os.path.join(dirname, new_var.name)})
else:
load_var_map[new_var.name] = new_var
if load_file_name is not None:
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])
load_block.append_op(
type='load_combine',
inputs={},
outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, load_file_name)})
executor.run(load_prog) executor.run(load_prog)
def load_params(executor, dirname, main_program=None): def load_params(executor, dirname, main_program=None, load_file_name=None):
""" """
load all parameters from directory by executor. load all parameters from directory by executor.
""" """
...@@ -168,10 +225,12 @@ def load_params(executor, dirname, main_program=None): ...@@ -168,10 +225,12 @@ def load_params(executor, dirname, main_program=None):
executor, executor,
dirname=dirname, dirname=dirname,
main_program=main_program, main_program=main_program,
predicate=is_parameter) predicate=is_parameter,
load_file_name=load_file_name)
def load_persistables(executor, dirname, main_program=None): def load_persistables(executor, dirname, main_program=None,
load_file_name=None):
""" """
load all persistables from directory by executor. load all persistables from directory by executor.
""" """
...@@ -179,7 +238,8 @@ def load_persistables(executor, dirname, main_program=None): ...@@ -179,7 +238,8 @@ def load_persistables(executor, dirname, main_program=None):
executor, executor,
dirname=dirname, dirname=dirname,
main_program=main_program, main_program=main_program,
predicate=is_persistable) predicate=is_persistable,
load_file_name=load_file_name)
def get_inference_program(target_vars, main_program=None): def get_inference_program(target_vars, main_program=None):
...@@ -238,7 +298,8 @@ def save_inference_model(dirname, ...@@ -238,7 +298,8 @@ def save_inference_model(dirname,
feeded_var_names, feeded_var_names,
target_vars, target_vars,
executor, executor,
main_program=None): main_program=None,
save_file_name=None):
""" """
Build a model especially for inference, Build a model especially for inference,
and save it to directory by the executor. and save it to directory by the executor.
...@@ -249,6 +310,8 @@ def save_inference_model(dirname, ...@@ -249,6 +310,8 @@ def save_inference_model(dirname,
:param executor: executor that save inference model :param executor: executor that save inference model
:param main_program: original program, which will be pruned to build the inference model. :param main_program: original program, which will be pruned to build the inference model.
Default default_main_program(). Default default_main_program().
:param save_file_name: The name of a single file that all parameters are saved to.
If it is None, save parameters to separate files.
:return: None :return: None
""" """
...@@ -283,25 +346,7 @@ def save_inference_model(dirname, ...@@ -283,25 +346,7 @@ def save_inference_model(dirname,
with open(model_file_name, "wb") as f: with open(model_file_name, "wb") as f:
f.write(inference_program.desc.serialize_to_string()) f.write(inference_program.desc.serialize_to_string())
save_params(executor, dirname, main_program) save_persistables(executor, dirname, inference_program, save_file_name)
def load_persistables_if_exist(executor, dirname, main_program=None):
filenames = next(os.walk(dirname))[2]
filenames = set(filenames)
def _is_presistable_and_exist_(var):
if not is_persistable(var):
return False
else:
return var.name in filenames
load_vars(
executor,
dirname,
main_program=main_program,
vars=None,
predicate=_is_presistable_and_exist_)
def get_feed_targets_names(program): def get_feed_targets_names(program):
...@@ -322,12 +367,14 @@ def get_fetch_targets_names(program): ...@@ -322,12 +367,14 @@ def get_fetch_targets_names(program):
return fetch_targets_names return fetch_targets_names
def load_inference_model(dirname, executor): def load_inference_model(dirname, executor, load_file_name=None):
""" """
Load inference model from a directory Load inference model from a directory
:param dirname: directory path :param dirname: directory path
:param executor: executor that load inference model :param executor: executor that load inference model
:param load_file_name: The name of the single file that all parameters are loaded from.
If it is None, load parameters from separate files.
:return: [program, feed_target_names, fetch_targets] :return: [program, feed_target_names, fetch_targets]
program: program especially for inference. program: program especially for inference.
...@@ -342,7 +389,7 @@ def load_inference_model(dirname, executor): ...@@ -342,7 +389,7 @@ def load_inference_model(dirname, executor):
program_desc_str = f.read() program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program) load_persistables(executor, dirname, program, load_file_name)
feed_target_names = get_feed_targets_names(program) feed_target_names = get_feed_targets_names(program)
fetch_target_names = get_fetch_targets_names(program) fetch_target_names = get_fetch_targets_names(program)
...@@ -359,6 +406,7 @@ def get_parameter_value(para, executor): ...@@ -359,6 +406,7 @@ def get_parameter_value(para, executor):
:param executor: executor for retrieving the value :param executor: executor for retrieving the value
:param para: the given parameter :param para: the given parameter
:return: the LoDTensor for the parameter :return: the LoDTensor for the parameter
""" """
assert is_parameter(para) assert is_parameter(para)
...@@ -377,6 +425,7 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -377,6 +425,7 @@ def get_parameter_value_by_name(name, executor, program=None):
:param name: the name of the parameter :param name: the name of the parameter
:param program: the program where the variable is found :param program: the program where the variable is found
Default default_main_program(). Default default_main_program().
:return: the LoDTensor for the variable :return: the LoDTensor for the variable
""" """
if program is None: if program is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册