提交 6e3e3f13 编写于 作者: C Chen Weihang

add get inference program api

上级 827ac36f
...@@ -19,9 +19,11 @@ import pickle ...@@ -19,9 +19,11 @@ import pickle
import warnings import warnings
import functools import functools
from collections import OrderedDict from collections import OrderedDict
import six import six
import paddle import paddle
# deprecated module import
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
...@@ -644,6 +646,18 @@ class SaveLoadConfig(object): ...@@ -644,6 +646,18 @@ class SaveLoadConfig(object):
self._keep_name_table = value self._keep_name_table = value
# NOTE(chenweihang): change jit.save/load argument `configs` to `config`
def deprecate_save_load_configs(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if 'configs' in kwargs:
kwargs['config'] = kwargs['configs']
kwargs.pop('configs')
return func(*args, **kwargs)
return wrapper
def _get_input_var_names(inputs, input_spec): def _get_input_var_names(inputs, input_spec):
name_none_error = "The %s's name is None. " \ name_none_error = "The %s's name is None. " \
"When using jit.save, please set InputSepc's name in " \ "When using jit.save, please set InputSepc's name in " \
...@@ -696,9 +710,9 @@ def _get_output_vars(outputs, output_spec): ...@@ -696,9 +710,9 @@ def _get_output_vars(outputs, output_spec):
if isinstance(var, Variable): if isinstance(var, Variable):
output_vars_dict[var.name] = var output_vars_dict[var.name] = var
if output_spec is None: if output_spec is None:
result_list = output_vars_dict.values() result_list = list(output_vars_dict.values())
elif output_spec is not None and len(output_spec) == len(output_vars_dict): elif output_spec is not None and len(output_spec) == len(output_vars_dict):
result_list = output_vars_dict.values() result_list = list(output_vars_dict.values())
for var in output_spec: for var in output_spec:
if var.name not in output_vars_dict: if var.name not in output_vars_dict:
warnings.warn(name_no_exists_error % var.name) warnings.warn(name_no_exists_error % var.name)
...@@ -711,16 +725,95 @@ def _get_output_vars(outputs, output_spec): ...@@ -711,16 +725,95 @@ def _get_output_vars(outputs, output_spec):
return result_list return result_list
# NOTE(chenweihang): change jit.save/load argument `configs` to `config` def _infer_input_check(layer, input_spec):
def deprecate_save_load_configs(func): prog_translator = ProgramTranslator()
@functools.wraps(func) if not prog_translator.enable_to_static:
def wrapper(*args, **kwargs): raise RuntimeError(
if 'configs' in kwargs: "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
kwargs['config'] = kwargs['configs'] )
kwargs.pop('configs') if not isinstance(layer, Layer):
return func(*args, **kwargs) raise TypeError(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
% type(layer))
return wrapper # avoid change user given input_spec
inner_input_spec = None
if input_spec is not None:
if not isinstance(input_spec, list):
raise TypeError(
"The input input_spec should be 'list', but received input_spec's type is %s."
% type(input_spec))
inner_input_spec = []
for var in input_spec:
if isinstance(var, paddle.static.InputSpec):
inner_input_spec.append(var)
elif isinstance(var, (core.VarBase, Variable)):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var))
else:
raise TypeError(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
% type(var))
return inner_input_spec
def _get_concrete_program_from_layer(layer, inner_input_spec):
# TODO(chenweihang): add support for other method, not only forward
if isinstance(layer.forward, StaticLayer):
concrete_program = layer.forward.concrete_program
else:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward = declarative(layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
inner_input_spec = None
return concrete_program
def _build_input_and_output(concrete_program, inner_input_spec, config):
# NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(concrete_program.inputs,
inner_input_spec)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs, config.output_spec)
return input_var_names, output_vars
# NOTE: This function is not exposed to users, only used for paddle2onnx now
@switch_to_static_graph
def get_inference_program(layer, input_spec=None, config=None):
# 1. input check
inner_input_spec = _infer_input_check(layer, input_spec)
if config is None:
config = SaveLoadConfig()
# 2. get program from Layer
concrete_program = _get_concrete_program_from_layer(layer, inner_input_spec)
# 3. build input & output of save_infernece_model
input_var_names, output_vars = _build_input_and_output(
concrete_program, inner_input_spec, config)
# 4. only get inference program
inference_program = paddle.fluid.io.get_inference_program(
input_var_names, output_vars, concrete_program.main_program.clone())
return inference_program
@deprecate_save_load_configs @deprecate_save_load_configs
...@@ -830,72 +923,18 @@ def save(layer, model_path, input_spec=None, config=None): ...@@ -830,72 +923,18 @@ def save(layer, model_path, input_spec=None, config=None):
model_path = "linear.example.model" model_path = "linear.example.model"
paddle.jit.save(layer, model_path) paddle.jit.save(layer, model_path)
""" """
# 1. input check # 1. input check
prog_translator = ProgramTranslator() inner_input_spec = _infer_input_check(layer, input_spec)
if not prog_translator.enable_to_static:
raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
)
if not isinstance(layer, Layer):
raise TypeError(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
% type(layer))
configs = config
if configs is None:
configs = SaveLoadConfig()
# avoid change user given input_spec if config is None:
inner_input_spec = None config = SaveLoadConfig()
if input_spec is not None:
if not isinstance(input_spec, list):
raise TypeError(
"The input input_spec should be 'list', but received input_spec's type is %s."
% type(input_spec))
inner_input_spec = []
for var in input_spec:
if isinstance(var, paddle.static.InputSpec):
inner_input_spec.append(var)
elif isinstance(var, (core.VarBase, Variable)):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var))
else:
raise TypeError(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
% type(var))
# 2. get program from Layer # 2. get program from Layer
# TODO(chenweihang): add support for other method, not only forward concrete_program = _get_concrete_program_from_layer(layer, inner_input_spec)
if isinstance(layer.forward, StaticLayer):
concrete_program = layer.forward.concrete_program
else:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward = declarative(layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
inner_input_spec = None
# 3. build input & output of save_infernece_model # 3. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ] input_var_names, output_vars = _build_input_and_output(
# There are two cases, whether to prune the inputs or not concrete_program, inner_input_spec, config)
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(concrete_program.inputs,
inner_input_spec)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec)
# NOTE(chenweihang): we maintain the mapping of variable name to # NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable) # structured name, the buffer variable (non-persistable)
...@@ -927,8 +966,8 @@ def save(layer, model_path, input_spec=None, config=None): ...@@ -927,8 +966,8 @@ def save(layer, model_path, input_spec=None, config=None):
from paddle.fluid.io import save_inference_model from paddle.fluid.io import save_inference_model
# VARIABLE_FILENAME keep nameing style consistent with '__model__' # VARIABLE_FILENAME keep nameing style consistent with '__model__'
if configs.params_filename is None: if config.params_filename is None:
configs.params_filename = VARIABLE_FILENAME config.params_filename = VARIABLE_FILENAME
with scope_guard(scope): with scope_guard(scope):
save_inference_model( save_inference_model(
...@@ -937,11 +976,11 @@ def save(layer, model_path, input_spec=None, config=None): ...@@ -937,11 +976,11 @@ def save(layer, model_path, input_spec=None, config=None):
target_vars=output_vars, target_vars=output_vars,
executor=Executor(_current_expected_place()), executor=Executor(_current_expected_place()),
main_program=concrete_program.main_program.clone(), main_program=concrete_program.main_program.clone(),
model_filename=configs.model_filename, model_filename=config.model_filename,
params_filename=None params_filename=None
if configs.separate_params else configs.params_filename, if config.separate_params else config.params_filename,
export_for_deployment=configs._export_for_deployment, export_for_deployment=config._export_for_deployment,
program_only=configs._program_only) program_only=config._program_only)
# NOTE(chenweihang): [ Save extra variable info ] # NOTE(chenweihang): [ Save extra variable info ]
# save_inference_model will lose some important variable information, including: # save_inference_model will lose some important variable information, including:
......
...@@ -22,10 +22,11 @@ import logging ...@@ -22,10 +22,11 @@ import logging
import pickle import pickle
import contextlib import contextlib
from functools import reduce from functools import reduce
import numpy as np import numpy as np
import paddle import paddle
# ddeprecated module import
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
...@@ -220,6 +221,113 @@ def _get_valid_program(main_program): ...@@ -220,6 +221,113 @@ def _get_valid_program(main_program):
return main_program return main_program
def _feed_fetch_check(feeded_var_names, target_vars,
export_for_deployment=True):
if isinstance(feeded_var_names, six.string_types):
feeded_var_names = [feeded_var_names]
elif export_for_deployment:
if len(feeded_var_names) > 0:
# TODO(paddle-dev): polish these code blocks
if not (bool(feeded_var_names) and all(
isinstance(name, six.string_types)
for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
target_vars = [target_vars]
elif export_for_deployment:
if not (bool(target_vars) and
all(isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.")
def _auc_states_check_and_remind(main_program):
all_ops = main_program.global_block().ops
for op in all_ops:
# clear device of Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
op._set_attr(device_attr_name, "")
if op.type == 'auc':
warnings.warn(
"please ensure that you have set the auc states to zeros before saving inference model"
)
break
def _update_target_vars(target_vars, main_program):
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with program_guard(main_program):
uniq_target_vars = []
for i, var in enumerate(target_vars):
if isinstance(var, Variable):
var = layers.scale(
var, 1., name="save_infer_model/scale_{}".format(i))
uniq_target_vars.append(var)
target_vars = uniq_target_vars
return target_vars
def _get_train_program(feeded_var_names, target_vars, main_program):
# 1. feed & fetch check
_feed_fetch_check(feeded_var_names, target_vars, False)
# 2. remind user to set auc_states to zeros if the program contains auc op
_auc_states_check_and_remind(main_program)
# 3. update input target_vars to fix bug
target_vars = _update_target_vars(target_vars, main_program)
return main_program
def _serialization(main_program, model_basename):
with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string())
# NOTE: This function is not exposed to users, only used for paddle2onnx now
@dygraph_not_support
def get_inference_program(feeded_var_names, target_vars, main_program):
# 1. feed & fetch check
_feed_fetch_check(feeded_var_names, target_vars)
# 2. remind user to set auc_states to zeros if the program contains auc op
_auc_states_check_and_remind(main_program)
# 3. update input target_vars to fix bug
target_vars = _update_target_vars(target_vars, main_program)
# 4. build inference program
main_program = main_program.clone()
global_block = main_program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
main_program.desc.flush()
main_program = main_program._prune_with_input(
feeded_var_names=feeded_var_names, targets=target_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc)
return main_program
@dygraph_not_support @dygraph_not_support
def save_vars(executor, def save_vars(executor,
dirname, dirname,
...@@ -1257,50 +1365,16 @@ def save_inference_model(dirname, ...@@ -1257,50 +1365,16 @@ def save_inference_model(dirname,
# "./infer_model". # "./infer_model".
""" """
if isinstance(feeded_var_names, six.string_types): # 1. get main program
feeded_var_names = [feeded_var_names]
elif export_for_deployment:
if len(feeded_var_names) > 0:
# TODO(paddle-dev): polish these code blocks
if not (bool(feeded_var_names) and all(
isinstance(name, six.string_types)
for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
target_vars = [target_vars]
elif export_for_deployment:
if not (bool(target_vars) and
all(isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.")
main_program = _get_valid_program(main_program) main_program = _get_valid_program(main_program)
# remind user to set auc_states to zeros if the program contains auc op # When export_for_deployment is true, we modify the program online so that
all_ops = main_program.global_block().ops # it can only be loaded for inference directly. If it's false, the whole
for op in all_ops: # original program and related meta are saved so that future usage can be
# clear device of Op # more flexible.
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() origin_program = main_program.clone()
op._set_attr(device_attr_name, "")
if op.type == 'auc':
warnings.warn(
"please ensure that you have set the auc states to zeros before saving inference model"
)
break
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with program_guard(main_program):
uniq_target_vars = []
for i, var in enumerate(target_vars):
if isinstance(var, Variable):
var = layers.scale(
var, 1., name="save_infer_model/scale_{}".format(i))
uniq_target_vars.append(var)
target_vars = uniq_target_vars
target_var_name_list = [var.name for var in target_vars]
# 2. dirname check & create
# when a pserver and a trainer running on the same machine, mkdir may conflict # when a pserver and a trainer running on the same machine, mkdir may conflict
save_dirname = dirname save_dirname = dirname
try: try:
...@@ -1310,57 +1384,34 @@ def save_inference_model(dirname, ...@@ -1310,57 +1384,34 @@ def save_inference_model(dirname,
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
# 3. model_filename check & create
if model_filename is not None: if model_filename is not None:
model_basename = os.path.basename(model_filename) model_basename = os.path.basename(model_filename)
else: else:
model_basename = "__model__" model_basename = "__model__"
model_basename = os.path.join(save_dirname, model_basename) model_basename = os.path.join(save_dirname, model_basename)
# When export_for_deployment is true, we modify the program online so that # 4. get & serialize program
# it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be
# more flexible.
origin_program = main_program.clone()
if export_for_deployment: if export_for_deployment:
main_program = main_program.clone() main_program = get_inference_program(feeded_var_names, target_vars,
global_block = main_program.global_block() main_program)
need_to_remove_op_index = [] _serialization(main_program, model_basename)
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
main_program.desc.flush()
main_program = main_program._prune_with_input(
feeded_var_names=feeded_var_names, targets=target_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc)
with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string())
else: else:
# TODO(panyx0718): Save more information so that it can also be used # TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing. # for training and more flexible post-processing.
with open(model_basename + ".main_program", "wb") as f: main_program = _get_train_program(feeded_var_names, target_vars,
f.write(main_program.desc.serialize_to_string()) main_program)
_serialization(main_program, model_basename + ".main_program")
# 5. get target var_name list & judge whether serialize program only
target_var_name_list = [var.name for var in target_vars]
if program_only: if program_only:
warnings.warn( warnings.warn(
"save_inference_model specified the param `program_only` to True, It will not save params of Program." "save_inference_model specified the param `program_only` to True, It will not save params of Program."
) )
return target_var_name_list return target_var_name_list
# 6. save persistables
main_program._copy_dist_param_info_from(origin_program) main_program._copy_dist_param_info_from(origin_program)
if params_filename is not None: if params_filename is not None:
......
...@@ -23,6 +23,7 @@ import paddle.fluid.core as core ...@@ -23,6 +23,7 @@ import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
import warnings import warnings
import paddle
import paddle.fluid.executor as executor import paddle.fluid.executor as executor
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer import paddle.fluid.optimizer as optimizer
...@@ -201,4 +202,5 @@ class TestLoadInferenceModelError(unittest.TestCase): ...@@ -201,4 +202,5 @@ class TestLoadInferenceModelError(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -755,5 +755,34 @@ class TestJitSaveLoadNoParamLayer(unittest.TestCase): ...@@ -755,5 +755,34 @@ class TestJitSaveLoadNoParamLayer(unittest.TestCase):
self.assertTrue(np.array_equal(out, load_out)) self.assertTrue(np.array_equal(out, load_out))
class TestJitGetInferenceProgram(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
def test_get_inference_program(self):
layer = LinearNet(784, 1)
train(layer)
model_path = "model.jit_get_inference_program"
paddle.jit.save(layer, model_path)
infer_program = paddle.jit.get_inference_program(layer)
# the program of jit.load is different with original inference program
model_file_path = os.path.join(model_path, "__model__")
load_program_desc = fluid.dygraph.io._load_program_desc(model_file_path)
load_program = fluid.dygraph.io._build_program_by_desc(
load_program_desc)
self.assertEqual(infer_program.num_blocks, load_program.num_blocks)
self.assertEqual(
len(infer_program.global_block().ops),
len(load_program.global_block().ops))
self.assertEqual(
len(infer_program.global_block().vars),
len(load_program.global_block().vars))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,6 +21,9 @@ from ..fluid.dygraph.jit import declarative as to_static #DEFINE_ALIAS ...@@ -21,6 +21,9 @@ from ..fluid.dygraph.jit import declarative as to_static #DEFINE_ALIAS
from ..fluid.dygraph import ProgramTranslator #DEFINE_ALIAS from ..fluid.dygraph import ProgramTranslator #DEFINE_ALIAS
from ..fluid.dygraph.io import TranslatedLayer #DEFINE_ALIAS from ..fluid.dygraph.io import TranslatedLayer #DEFINE_ALIAS
# NOTE: This function is not exposed to users, only used for paddle2onnx now
from ..fluid.dygraph.jit import get_inference_program #DEFINE_ALIAS
__all__ = [ __all__ = [
'save', 'load', 'TracedLayer', 'to_static', 'ProgramTranslator', 'save', 'load', 'TracedLayer', 'to_static', 'ProgramTranslator',
'TranslatedLayer', 'set_code_level', 'set_verbosity' 'TranslatedLayer', 'set_code_level', 'set_verbosity'
......
...@@ -43,3 +43,6 @@ from ..fluid.parallel_executor import ParallelExecutor #DEFINE_ALIAS ...@@ -43,3 +43,6 @@ from ..fluid.parallel_executor import ParallelExecutor #DEFINE_ALIAS
from ..fluid.param_attr import WeightNormParamAttr #DEFINE_ALIAS from ..fluid.param_attr import WeightNormParamAttr #DEFINE_ALIAS
from ..tensor.io import save #DEFINE_ALIAS from ..tensor.io import save #DEFINE_ALIAS
from ..tensor.io import load #DEFINE_ALIAS from ..tensor.io import load #DEFINE_ALIAS
# NOTE: This function is not exposed to users, only used for paddle2onnx now
from ..fluid.io import get_inference_program
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册