提交 2fb280c9 编写于 作者: K kexinzhao 提交者: Yiqun Liu

Revise python save load api using new load/save op (#7995)

* initial commit

* add get_parameters method

* add get_parameters method

* small fix

* address comments

* address comments

* address comments

* fix
上级 270ecbe4
......@@ -489,7 +489,8 @@ class Operator(object):
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'listen_and_serv', 'parallel_do'
'recv', 'listen_and_serv', 'parallel_do', 'save_combine',
'load_combine'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
......
......@@ -46,6 +46,9 @@ def is_parameter(var):
def is_persistable(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
return False
return var.persistable
......@@ -60,7 +63,12 @@ def _clone_var_in_block_(block, var):
persistable=True)
def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
def save_vars(executor,
dirname,
main_program=None,
vars=None,
predicate=None,
save_file_name=None):
"""
Save variables to directory by executor.
......@@ -69,9 +77,12 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
:param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default default_main_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved.
:param vars: variables need to be saved. If specify vars, program & predicate
as a bool. If it returns true, the corresponding input variable will be saved.
:param vars: variables need to be saved. If vars is specified, program & predicate
will be ignored
:param save_file_name: The name of a single file that all vars are saved to.
If it is None, save variables to separate files.
:return: None
"""
if vars is None:
......@@ -83,21 +94,39 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
save_vars(
executor,
dirname=dirname,
vars=filter(predicate, main_program.list_vars()))
vars=filter(predicate, main_program.list_vars()),
save_file_name=save_file_name)
else:
save_program = Program()
save_block = save_program.global_block()
save_var_map = {}
for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var)
if save_file_name is None:
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)})
else:
save_var_map[new_var.name] = new_var
if save_file_name is not None:
save_var_list = []
for name in sorted(save_var_map.keys()):
save_var_list.append(save_var_map[name])
save_block.append_op(
type='save_combine',
inputs={'X': save_var_list},
outputs={},
attrs={'file_path': os.path.join(dirname, save_file_name)})
executor.run(save_program)
def save_params(executor, dirname, main_program=None):
def save_params(executor, dirname, main_program=None, save_file_name=None):
"""
Save all parameters to directory with executor.
"""
......@@ -106,10 +135,12 @@ def save_params(executor, dirname, main_program=None):
dirname=dirname,
main_program=main_program,
vars=None,
predicate=is_parameter)
predicate=is_parameter,
save_file_name=save_file_name)
def save_persistables(executor, dirname, main_program=None):
def save_persistables(executor, dirname, main_program=None,
save_file_name=None):
"""
Save all persistables to directory with executor.
"""
......@@ -118,21 +149,30 @@ def save_persistables(executor, dirname, main_program=None):
dirname=dirname,
main_program=main_program,
vars=None,
predicate=is_persistable)
predicate=is_persistable,
save_file_name=save_file_name)
def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
def load_vars(executor,
dirname,
main_program=None,
vars=None,
predicate=None,
load_file_name=None):
"""
Load variables from directory by executor.
:param executor: executor that save variable
:param executor: executor that load variable
:param dirname: directory path
:param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default default_main_program().
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded.
:param vars: variables need to be loaded. If specify vars, program &
as a bool. If it returns true, the corresponding input variable will be loaded.
:param vars: variables need to be loaded. If vars is specified, program &
predicate will be ignored
:param load_file_name: The name of the single file that all vars are loaded from.
If it is None, load variables from separate files.
:return: None
"""
if vars is None:
......@@ -144,23 +184,40 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
load_vars(
executor,
dirname=dirname,
vars=filter(predicate, main_program.list_vars()))
vars=filter(predicate, main_program.list_vars()),
load_file_name=load_file_name)
else:
load_prog = Program()
load_block = load_prog.global_block()
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var)
if load_file_name is None:
load_block.append_op(
type='load',
inputs={},
outputs={"Out": [new_var]},
outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
else:
load_var_map[new_var.name] = new_var
if load_file_name is not None:
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])
load_block.append_op(
type='load_combine',
inputs={},
outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, load_file_name)})
executor.run(load_prog)
def load_params(executor, dirname, main_program=None):
def load_params(executor, dirname, main_program=None, load_file_name=None):
"""
load all parameters from directory by executor.
"""
......@@ -168,10 +225,12 @@ def load_params(executor, dirname, main_program=None):
executor,
dirname=dirname,
main_program=main_program,
predicate=is_parameter)
predicate=is_parameter,
load_file_name=load_file_name)
def load_persistables(executor, dirname, main_program=None):
def load_persistables(executor, dirname, main_program=None,
load_file_name=None):
"""
load all persistables from directory by executor.
"""
......@@ -179,7 +238,8 @@ def load_persistables(executor, dirname, main_program=None):
executor,
dirname=dirname,
main_program=main_program,
predicate=is_persistable)
predicate=is_persistable,
load_file_name=load_file_name)
def get_inference_program(target_vars, main_program=None):
......@@ -238,7 +298,8 @@ def save_inference_model(dirname,
feeded_var_names,
target_vars,
executor,
main_program=None):
main_program=None,
save_file_name=None):
"""
Build a model especially for inference,
and save it to directory by the executor.
......@@ -249,6 +310,8 @@ def save_inference_model(dirname,
:param executor: executor that save inference model
:param main_program: original program, which will be pruned to build the inference model.
Default default_main_program().
:param save_file_name: The name of a single file that all parameters are saved to.
If it is None, save parameters to separate files.
:return: None
"""
......@@ -283,25 +346,7 @@ def save_inference_model(dirname,
with open(model_file_name, "wb") as f:
f.write(inference_program.desc.serialize_to_string())
save_params(executor, dirname, main_program)
def load_persistables_if_exist(executor, dirname, main_program=None):
filenames = next(os.walk(dirname))[2]
filenames = set(filenames)
def _is_presistable_and_exist_(var):
if not is_persistable(var):
return False
else:
return var.name in filenames
load_vars(
executor,
dirname,
main_program=main_program,
vars=None,
predicate=_is_presistable_and_exist_)
save_persistables(executor, dirname, inference_program, save_file_name)
def get_feed_targets_names(program):
......@@ -322,12 +367,14 @@ def get_fetch_targets_names(program):
return fetch_targets_names
def load_inference_model(dirname, executor):
def load_inference_model(dirname, executor, load_file_name=None):
"""
Load inference model from a directory
:param dirname: directory path
:param executor: executor that load inference model
:param load_file_name: The name of the single file that all parameters are loaded from.
If it is None, load parameters from separate files.
:return: [program, feed_target_names, fetch_targets]
program: program especially for inference.
......@@ -342,7 +389,7 @@ def load_inference_model(dirname, executor):
program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program)
load_persistables(executor, dirname, program, load_file_name)
feed_target_names = get_feed_targets_names(program)
fetch_target_names = get_fetch_targets_names(program)
......@@ -359,6 +406,7 @@ def get_parameter_value(para, executor):
:param executor: executor for retrieving the value
:param para: the given parameter
:return: the LoDTensor for the parameter
"""
assert is_parameter(para)
......@@ -377,6 +425,7 @@ def get_parameter_value_by_name(name, executor, program=None):
:param name: the name of the parameter
:param program: the program where the variable is found
Default default_main_program().
:return: the LoDTensor for the variable
"""
if program is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册