提交 2d74b5f9 编写于 作者: L Liu Yiqun

Refine the Python API load/save_inference_model.

上级 b44917d0
...@@ -101,8 +101,8 @@ void TestInference(const std::string& dirname, ...@@ -101,8 +101,8 @@ void TestInference(const std::string& dirname,
if (IsCombined) { if (IsCombined) {
// All parameters are saved in a single file. // All parameters are saved in a single file.
// Hard-coding the file names of program and parameters in unittest. // Hard-coding the file names of program and parameters in unittest.
// Users are free to specify different filename // The file names should be consistent with that used in Python API
// (provided: the filenames are changed in the python api as well: io.py) // `fluid.io.save_inference_model`.
std::string prog_filename = "__model_combined__"; std::string prog_filename = "__model_combined__";
std::string param_filename = "__params_combined__"; std::string param_filename = "__params_combined__";
inference_program = paddle::inference::Load(executor, inference_program = paddle::inference::Load(executor,
......
...@@ -68,7 +68,7 @@ def save_vars(executor, ...@@ -68,7 +68,7 @@ def save_vars(executor,
main_program=None, main_program=None,
vars=None, vars=None,
predicate=None, predicate=None,
save_file_name=None): filename=None):
""" """
Save variables to directory by executor. Save variables to directory by executor.
...@@ -80,8 +80,8 @@ def save_vars(executor, ...@@ -80,8 +80,8 @@ def save_vars(executor,
as a bool. If it returns true, the corresponding input variable will be saved. as a bool. If it returns true, the corresponding input variable will be saved.
:param vars: variables need to be saved. If vars is specified, program & predicate :param vars: variables need to be saved. If vars is specified, program & predicate
will be ignored will be ignored
:param save_file_name: The name of a single file that all vars are saved to. :param filename: The name of a single file that all vars are saved to.
If it is None, save variables to separate files. If it is None, save variables to separate files.
:return: None :return: None
""" """
...@@ -95,7 +95,7 @@ def save_vars(executor, ...@@ -95,7 +95,7 @@ def save_vars(executor,
executor, executor,
dirname=dirname, dirname=dirname,
vars=filter(predicate, main_program.list_vars()), vars=filter(predicate, main_program.list_vars()),
save_file_name=save_file_name) filename=filename)
else: else:
save_program = Program() save_program = Program()
save_block = save_program.global_block() save_block = save_program.global_block()
...@@ -103,7 +103,7 @@ def save_vars(executor, ...@@ -103,7 +103,7 @@ def save_vars(executor,
save_var_map = {} save_var_map = {}
for each_var in vars: for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var) new_var = _clone_var_in_block_(save_block, each_var)
if save_file_name is None: if filename is None:
save_block.append_op( save_block.append_op(
type='save', type='save',
inputs={'X': [new_var]}, inputs={'X': [new_var]},
...@@ -112,7 +112,7 @@ def save_vars(executor, ...@@ -112,7 +112,7 @@ def save_vars(executor,
else: else:
save_var_map[new_var.name] = new_var save_var_map[new_var.name] = new_var
if save_file_name is not None: if filename is not None:
save_var_list = [] save_var_list = []
for name in sorted(save_var_map.keys()): for name in sorted(save_var_map.keys()):
save_var_list.append(save_var_map[name]) save_var_list.append(save_var_map[name])
...@@ -121,12 +121,12 @@ def save_vars(executor, ...@@ -121,12 +121,12 @@ def save_vars(executor,
type='save_combine', type='save_combine',
inputs={'X': save_var_list}, inputs={'X': save_var_list},
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, save_file_name)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(save_program) executor.run(save_program)
def save_params(executor, dirname, main_program=None, save_file_name=None): def save_params(executor, dirname, main_program=None, filename=None):
""" """
Save all parameters to directory with executor. Save all parameters to directory with executor.
""" """
...@@ -136,11 +136,10 @@ def save_params(executor, dirname, main_program=None, save_file_name=None): ...@@ -136,11 +136,10 @@ def save_params(executor, dirname, main_program=None, save_file_name=None):
main_program=main_program, main_program=main_program,
vars=None, vars=None,
predicate=is_parameter, predicate=is_parameter,
save_file_name=save_file_name) filename=filename)
def save_persistables(executor, dirname, main_program=None, def save_persistables(executor, dirname, main_program=None, filename=None):
save_file_name=None):
""" """
Save all persistables to directory with executor. Save all persistables to directory with executor.
""" """
...@@ -150,7 +149,7 @@ def save_persistables(executor, dirname, main_program=None, ...@@ -150,7 +149,7 @@ def save_persistables(executor, dirname, main_program=None,
main_program=main_program, main_program=main_program,
vars=None, vars=None,
predicate=is_persistable, predicate=is_persistable,
save_file_name=save_file_name) filename=filename)
def load_vars(executor, def load_vars(executor,
...@@ -158,7 +157,7 @@ def load_vars(executor, ...@@ -158,7 +157,7 @@ def load_vars(executor,
main_program=None, main_program=None,
vars=None, vars=None,
predicate=None, predicate=None,
load_file_name=None): filename=None):
""" """
Load variables from directory by executor. Load variables from directory by executor.
...@@ -170,8 +169,8 @@ def load_vars(executor, ...@@ -170,8 +169,8 @@ def load_vars(executor,
as a bool. If it returns true, the corresponding input variable will be loaded. as a bool. If it returns true, the corresponding input variable will be loaded.
:param vars: variables need to be loaded. If vars is specified, program & :param vars: variables need to be loaded. If vars is specified, program &
predicate will be ignored predicate will be ignored
:param load_file_name: The name of the single file that all vars are loaded from. :param filename: The name of the single file that all vars are loaded from.
If it is None, load variables from separate files. If it is None, load variables from separate files.
:return: None :return: None
""" """
...@@ -185,7 +184,7 @@ def load_vars(executor, ...@@ -185,7 +184,7 @@ def load_vars(executor,
executor, executor,
dirname=dirname, dirname=dirname,
vars=filter(predicate, main_program.list_vars()), vars=filter(predicate, main_program.list_vars()),
load_file_name=load_file_name) filename=filename)
else: else:
load_prog = Program() load_prog = Program()
load_block = load_prog.global_block() load_block = load_prog.global_block()
...@@ -194,7 +193,7 @@ def load_vars(executor, ...@@ -194,7 +193,7 @@ def load_vars(executor,
for each_var in vars: for each_var in vars:
assert isinstance(each_var, Variable) assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var) new_var = _clone_var_in_block_(load_block, each_var)
if load_file_name is None: if filename is None:
load_block.append_op( load_block.append_op(
type='load', type='load',
inputs={}, inputs={},
...@@ -203,7 +202,7 @@ def load_vars(executor, ...@@ -203,7 +202,7 @@ def load_vars(executor,
else: else:
load_var_map[new_var.name] = new_var load_var_map[new_var.name] = new_var
if load_file_name is not None: if filename is not None:
load_var_list = [] load_var_list = []
for name in sorted(load_var_map.keys()): for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name]) load_var_list.append(load_var_map[name])
...@@ -212,12 +211,12 @@ def load_vars(executor, ...@@ -212,12 +211,12 @@ def load_vars(executor,
type='load_combine', type='load_combine',
inputs={}, inputs={},
outputs={"Out": load_var_list}, outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, load_file_name)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
def load_params(executor, dirname, main_program=None, load_file_name=None): def load_params(executor, dirname, main_program=None, filename=None):
""" """
load all parameters from directory by executor. load all parameters from directory by executor.
""" """
...@@ -226,11 +225,10 @@ def load_params(executor, dirname, main_program=None, load_file_name=None): ...@@ -226,11 +225,10 @@ def load_params(executor, dirname, main_program=None, load_file_name=None):
dirname=dirname, dirname=dirname,
main_program=main_program, main_program=main_program,
predicate=is_parameter, predicate=is_parameter,
load_file_name=load_file_name) filename=filename)
def load_persistables(executor, dirname, main_program=None, def load_persistables(executor, dirname, main_program=None, filename=None):
load_file_name=None):
""" """
load all persistables from directory by executor. load all persistables from directory by executor.
""" """
...@@ -239,7 +237,7 @@ def load_persistables(executor, dirname, main_program=None, ...@@ -239,7 +237,7 @@ def load_persistables(executor, dirname, main_program=None,
dirname=dirname, dirname=dirname,
main_program=main_program, main_program=main_program,
predicate=is_persistable, predicate=is_persistable,
load_file_name=load_file_name) filename=filename)
def get_inference_program(target_vars, main_program=None): def get_inference_program(target_vars, main_program=None):
...@@ -299,7 +297,8 @@ def save_inference_model(dirname, ...@@ -299,7 +297,8 @@ def save_inference_model(dirname,
target_vars, target_vars,
executor, executor,
main_program=None, main_program=None,
save_file_name=None): model_filename=None,
params_filename=None):
""" """
Build a model especially for inference, Build a model especially for inference,
and save it to directory by the executor. and save it to directory by the executor.
...@@ -310,8 +309,11 @@ def save_inference_model(dirname, ...@@ -310,8 +309,11 @@ def save_inference_model(dirname,
:param executor: executor that save inference model :param executor: executor that save inference model
:param main_program: original program, which will be pruned to build the inference model. :param main_program: original program, which will be pruned to build the inference model.
Default default_main_program(). Default default_main_program().
:param save_file_name: The name of a single file that all parameters are saved to. :param model_filename: The name of file to save inference program.
If it is None, save parameters to separate files. If not specified, default filename `__model__` will be used.
:param params_filename: The name of file to save parameters.
It is used for the case that all parameters are saved in a single binary file.
If not specified, parameters are considered saved in separate files.
:return: None :return: None
""" """
...@@ -342,15 +344,19 @@ def save_inference_model(dirname, ...@@ -342,15 +344,19 @@ def save_inference_model(dirname,
prepend_feed_ops(inference_program, feeded_var_names) prepend_feed_ops(inference_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names) append_fetch_ops(inference_program, fetch_var_names)
if save_file_name == None: if model_filename is not None:
model_file_name = dirname + "/__model__" model_filename = os.path.basename(model_filename)
else: else:
model_file_name = dirname + "/__model_combined__" model_filename = "__model__"
model_filename = os.path.join(dirname, model_filename)
with open(model_file_name, "wb") as f: if params_filename is not None:
params_filename = os.path.basename(params_filename)
with open(model_filename, "wb") as f:
f.write(inference_program.desc.serialize_to_string()) f.write(inference_program.desc.serialize_to_string())
save_persistables(executor, dirname, inference_program, save_file_name) save_persistables(executor, dirname, inference_program, params_filename)
def get_feed_targets_names(program): def get_feed_targets_names(program):
...@@ -371,15 +377,21 @@ def get_fetch_targets_names(program): ...@@ -371,15 +377,21 @@ def get_fetch_targets_names(program):
return fetch_targets_names return fetch_targets_names
def load_inference_model(dirname, executor, load_file_name=None): def load_inference_model(dirname,
executor,
model_filename=None,
params_filename=None):
""" """
Load inference model from a directory Load inference model from a directory
:param dirname: directory path :param dirname: directory path
:param executor: executor that load inference model :param executor: executor that load inference model
:param load_file_name: The name of the single file that all parameters are loaded from. :param model_filename: The name of file to load inference program.
If it is None, load parameters from separate files. If not specified, default filename `__model__` will be used.
:param params_filename: The name of file to load parameters.
It is used for the case that all parameters are saved in a single binary file.
If not specified, parameters are considered saved in separate files.
:return: [program, feed_target_names, fetch_targets] :return: [program, feed_target_names, fetch_targets]
program: program especially for inference. program: program especially for inference.
feed_target_names: Names of variables that need to feed data feed_target_names: Names of variables that need to feed data
...@@ -388,16 +400,20 @@ def load_inference_model(dirname, executor, load_file_name=None): ...@@ -388,16 +400,20 @@ def load_inference_model(dirname, executor, load_file_name=None):
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname) raise ValueError("There is no directory named '%s'", dirname)
if load_file_name == None: if model_filename is not None:
model_file_name = dirname + "/__model__" model_filename = os.path.basename(model_filename)
else: else:
model_file_name = dirname + "/__model_combined__" model_filename = "__model__"
model_filename = os.path.join(dirname, model_filename)
if params_filename is not None:
params_filename = os.path.basename(params_filename)
with open(model_file_name, "rb") as f: with open(model_filename, "rb") as f:
program_desc_str = f.read() program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, load_file_name) load_persistables(executor, dirname, program, params_filename)
feed_target_names = get_feed_targets_names(program) feed_target_names = get_feed_targets_names(program)
fetch_target_names = get_fetch_targets_names(program) fetch_target_names = get_fetch_targets_names(program)
......
...@@ -78,7 +78,12 @@ def conv_net(img, label): ...@@ -78,7 +78,12 @@ def conv_net(img, label):
return loss_net(conv_pool_2, label) return loss_net(conv_pool_2, label)
def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): def train(nn_type,
use_cuda,
parallel,
save_dirname=None,
model_filename=None,
params_filename=None):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
return return
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
...@@ -146,7 +151,8 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): ...@@ -146,7 +151,8 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
fluid.io.save_inference_model( fluid.io.save_inference_model(
save_dirname, ["img"], [prediction], save_dirname, ["img"], [prediction],
exe, exe,
save_file_name=save_param_filename) model_filename=model_filename,
params_filename=params_filename)
return return
else: else:
print( print(
...@@ -158,7 +164,10 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename): ...@@ -158,7 +164,10 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
raise AssertionError("Loss of recognize digits is too large") raise AssertionError("Loss of recognize digits is too large")
def infer(use_cuda, save_dirname=None, param_filename=None): def infer(use_cuda,
save_dirname=None,
model_filename=None,
params_filename=None):
if save_dirname is None: if save_dirname is None:
return return
...@@ -171,8 +180,9 @@ def infer(use_cuda, save_dirname=None, param_filename=None): ...@@ -171,8 +180,9 @@ def infer(use_cuda, save_dirname=None, param_filename=None):
# the feed_target_names (the names of variables that will be feeded # the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that # data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators). # we want to obtain data from using fetch operators).
[inference_program, feed_target_names, fetch_targets [inference_program, feed_target_names,
] = fluid.io.load_inference_model(save_dirname, exe, param_filename) fetch_targets] = fluid.io.load_inference_model(
save_dirname, exe, model_filename, params_filename)
# The input's dimension of conv should be 4-D or 5-D. # The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range [-1.0, 1.0]. # Use normilized image pixels as input data, which should be in the range [-1.0, 1.0].
...@@ -189,25 +199,27 @@ def infer(use_cuda, save_dirname=None, param_filename=None): ...@@ -189,25 +199,27 @@ def infer(use_cuda, save_dirname=None, param_filename=None):
def main(use_cuda, parallel, nn_type, combine): def main(use_cuda, parallel, nn_type, combine):
save_dirname = None
model_filename = None
params_filename = None
if not use_cuda and not parallel: if not use_cuda and not parallel:
save_dirname = "recognize_digits_" + nn_type + ".inference.model" save_dirname = "recognize_digits_" + nn_type + ".inference.model"
save_filename = None
if combine == True: if combine == True:
save_filename = "__params_combined__" model_filename = "__model_combined__"
else: params_filename = "__params_combined__"
save_dirname = None
save_filename = None
train( train(
nn_type=nn_type, nn_type=nn_type,
use_cuda=use_cuda, use_cuda=use_cuda,
parallel=parallel, parallel=parallel,
save_dirname=save_dirname, save_dirname=save_dirname,
save_param_filename=save_filename) model_filename=model_filename,
params_filename=params_filename)
infer( infer(
use_cuda=use_cuda, use_cuda=use_cuda,
save_dirname=save_dirname, save_dirname=save_dirname,
param_filename=save_filename) model_filename=model_filename,
params_filename=params_filename)
class TestRecognizeDigits(unittest.TestCase): class TestRecognizeDigits(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册