提交 6e3e3f13 编写于 作者: C Chen Weihang

add get inference program api

上级 827ac36f
......@@ -19,9 +19,11 @@ import pickle
import warnings
import functools
from collections import OrderedDict
import six
import paddle
# deprecated module import
from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
......@@ -644,6 +646,18 @@ class SaveLoadConfig(object):
self._keep_name_table = value
# NOTE(chenweihang): change jit.save/load argument `configs` to `config`
def deprecate_save_load_configs(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if 'configs' in kwargs:
kwargs['config'] = kwargs['configs']
kwargs.pop('configs')
return func(*args, **kwargs)
return wrapper
def _get_input_var_names(inputs, input_spec):
name_none_error = "The %s's name is None. " \
"When using jit.save, please set InputSepc's name in " \
......@@ -696,9 +710,9 @@ def _get_output_vars(outputs, output_spec):
if isinstance(var, Variable):
output_vars_dict[var.name] = var
if output_spec is None:
result_list = output_vars_dict.values()
result_list = list(output_vars_dict.values())
elif output_spec is not None and len(output_spec) == len(output_vars_dict):
result_list = output_vars_dict.values()
result_list = list(output_vars_dict.values())
for var in output_spec:
if var.name not in output_vars_dict:
warnings.warn(name_no_exists_error % var.name)
......@@ -711,16 +725,95 @@ def _get_output_vars(outputs, output_spec):
return result_list
# NOTE(chenweihang): change jit.save/load argument `configs` to `config`
def deprecate_save_load_configs(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if 'configs' in kwargs:
kwargs['config'] = kwargs['configs']
kwargs.pop('configs')
return func(*args, **kwargs)
def _infer_input_check(layer, input_spec):
prog_translator = ProgramTranslator()
if not prog_translator.enable_to_static:
raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
)
if not isinstance(layer, Layer):
raise TypeError(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
% type(layer))
return wrapper
# avoid change user given input_spec
inner_input_spec = None
if input_spec is not None:
if not isinstance(input_spec, list):
raise TypeError(
"The input input_spec should be 'list', but received input_spec's type is %s."
% type(input_spec))
inner_input_spec = []
for var in input_spec:
if isinstance(var, paddle.static.InputSpec):
inner_input_spec.append(var)
elif isinstance(var, (core.VarBase, Variable)):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var))
else:
raise TypeError(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
% type(var))
return inner_input_spec
def _get_concrete_program_from_layer(layer, inner_input_spec):
# TODO(chenweihang): add support for other method, not only forward
if isinstance(layer.forward, StaticLayer):
concrete_program = layer.forward.concrete_program
else:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward = declarative(layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
inner_input_spec = None
return concrete_program
def _build_input_and_output(concrete_program, inner_input_spec, config):
# NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(concrete_program.inputs,
inner_input_spec)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs, config.output_spec)
return input_var_names, output_vars
# NOTE: This function is not exposed to users, only used for paddle2onnx now
@switch_to_static_graph
def get_inference_program(layer, input_spec=None, config=None):
# 1. input check
inner_input_spec = _infer_input_check(layer, input_spec)
if config is None:
config = SaveLoadConfig()
# 2. get program from Layer
concrete_program = _get_concrete_program_from_layer(layer, inner_input_spec)
# 3. build input & output of save_infernece_model
input_var_names, output_vars = _build_input_and_output(
concrete_program, inner_input_spec, config)
# 4. only get inference program
inference_program = paddle.fluid.io.get_inference_program(
input_var_names, output_vars, concrete_program.main_program.clone())
return inference_program
@deprecate_save_load_configs
......@@ -830,72 +923,18 @@ def save(layer, model_path, input_spec=None, config=None):
model_path = "linear.example.model"
paddle.jit.save(layer, model_path)
"""
# 1. input check
prog_translator = ProgramTranslator()
if not prog_translator.enable_to_static:
raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
)
if not isinstance(layer, Layer):
raise TypeError(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
% type(layer))
configs = config
if configs is None:
configs = SaveLoadConfig()
inner_input_spec = _infer_input_check(layer, input_spec)
# avoid change user given input_spec
inner_input_spec = None
if input_spec is not None:
if not isinstance(input_spec, list):
raise TypeError(
"The input input_spec should be 'list', but received input_spec's type is %s."
% type(input_spec))
inner_input_spec = []
for var in input_spec:
if isinstance(var, paddle.static.InputSpec):
inner_input_spec.append(var)
elif isinstance(var, (core.VarBase, Variable)):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var))
else:
raise TypeError(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
% type(var))
if config is None:
config = SaveLoadConfig()
# 2. get program from Layer
# TODO(chenweihang): add support for other method, not only forward
if isinstance(layer.forward, StaticLayer):
concrete_program = layer.forward.concrete_program
else:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward = declarative(layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
inner_input_spec = None
concrete_program = _get_concrete_program_from_layer(layer, inner_input_spec)
# 3. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(concrete_program.inputs,
inner_input_spec)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec)
input_var_names, output_vars = _build_input_and_output(
concrete_program, inner_input_spec, config)
# NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
......@@ -927,8 +966,8 @@ def save(layer, model_path, input_spec=None, config=None):
from paddle.fluid.io import save_inference_model
# VARIABLE_FILENAME keep nameing style consistent with '__model__'
if configs.params_filename is None:
configs.params_filename = VARIABLE_FILENAME
if config.params_filename is None:
config.params_filename = VARIABLE_FILENAME
with scope_guard(scope):
save_inference_model(
......@@ -937,11 +976,11 @@ def save(layer, model_path, input_spec=None, config=None):
target_vars=output_vars,
executor=Executor(_current_expected_place()),
main_program=concrete_program.main_program.clone(),
model_filename=configs.model_filename,
model_filename=config.model_filename,
params_filename=None
if configs.separate_params else configs.params_filename,
export_for_deployment=configs._export_for_deployment,
program_only=configs._program_only)
if config.separate_params else config.params_filename,
export_for_deployment=config._export_for_deployment,
program_only=config._program_only)
# NOTE(chenweihang): [ Save extra variable info ]
# save_inference_model will lose some important variable information, including:
......
......@@ -22,10 +22,11 @@ import logging
import pickle
import contextlib
from functools import reduce
import numpy as np
import paddle
# ddeprecated module import
from paddle.fluid import layers
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.evaluator import Evaluator
......@@ -220,6 +221,113 @@ def _get_valid_program(main_program):
return main_program
def _feed_fetch_check(feeded_var_names, target_vars,
export_for_deployment=True):
if isinstance(feeded_var_names, six.string_types):
feeded_var_names = [feeded_var_names]
elif export_for_deployment:
if len(feeded_var_names) > 0:
# TODO(paddle-dev): polish these code blocks
if not (bool(feeded_var_names) and all(
isinstance(name, six.string_types)
for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
target_vars = [target_vars]
elif export_for_deployment:
if not (bool(target_vars) and
all(isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.")
def _auc_states_check_and_remind(main_program):
all_ops = main_program.global_block().ops
for op in all_ops:
# clear device of Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
op._set_attr(device_attr_name, "")
if op.type == 'auc':
warnings.warn(
"please ensure that you have set the auc states to zeros before saving inference model"
)
break
def _update_target_vars(target_vars, main_program):
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with program_guard(main_program):
uniq_target_vars = []
for i, var in enumerate(target_vars):
if isinstance(var, Variable):
var = layers.scale(
var, 1., name="save_infer_model/scale_{}".format(i))
uniq_target_vars.append(var)
target_vars = uniq_target_vars
return target_vars
def _get_train_program(feeded_var_names, target_vars, main_program):
# 1. feed & fetch check
_feed_fetch_check(feeded_var_names, target_vars, False)
# 2. remind user to set auc_states to zeros if the program contains auc op
_auc_states_check_and_remind(main_program)
# 3. update input target_vars to fix bug
target_vars = _update_target_vars(target_vars, main_program)
return main_program
def _serialization(main_program, model_basename):
with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string())
# NOTE: This function is not exposed to users, only used for paddle2onnx now
@dygraph_not_support
def get_inference_program(feeded_var_names, target_vars, main_program):
# 1. feed & fetch check
_feed_fetch_check(feeded_var_names, target_vars)
# 2. remind user to set auc_states to zeros if the program contains auc op
_auc_states_check_and_remind(main_program)
# 3. update input target_vars to fix bug
target_vars = _update_target_vars(target_vars, main_program)
# 4. build inference program
main_program = main_program.clone()
global_block = main_program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
main_program.desc.flush()
main_program = main_program._prune_with_input(
feeded_var_names=feeded_var_names, targets=target_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc)
return main_program
@dygraph_not_support
def save_vars(executor,
dirname,
......@@ -1257,50 +1365,16 @@ def save_inference_model(dirname,
# "./infer_model".
"""
if isinstance(feeded_var_names, six.string_types):
feeded_var_names = [feeded_var_names]
elif export_for_deployment:
if len(feeded_var_names) > 0:
# TODO(paddle-dev): polish these code blocks
if not (bool(feeded_var_names) and all(
isinstance(name, six.string_types)
for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
target_vars = [target_vars]
elif export_for_deployment:
if not (bool(target_vars) and
all(isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.")
# 1. get main program
main_program = _get_valid_program(main_program)
# remind user to set auc_states to zeros if the program contains auc op
all_ops = main_program.global_block().ops
for op in all_ops:
# clear device of Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
op._set_attr(device_attr_name, "")
if op.type == 'auc':
warnings.warn(
"please ensure that you have set the auc states to zeros before saving inference model"
)
break
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with program_guard(main_program):
uniq_target_vars = []
for i, var in enumerate(target_vars):
if isinstance(var, Variable):
var = layers.scale(
var, 1., name="save_infer_model/scale_{}".format(i))
uniq_target_vars.append(var)
target_vars = uniq_target_vars
target_var_name_list = [var.name for var in target_vars]
# When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be
# more flexible.
origin_program = main_program.clone()
# 2. dirname check & create
# when a pserver and a trainer running on the same machine, mkdir may conflict
save_dirname = dirname
try:
......@@ -1310,57 +1384,34 @@ def save_inference_model(dirname,
if e.errno != errno.EEXIST:
raise
# 3. model_filename check & create
if model_filename is not None:
model_basename = os.path.basename(model_filename)
else:
model_basename = "__model__"
model_basename = os.path.join(save_dirname, model_basename)
# When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be
# more flexible.
origin_program = main_program.clone()
# 4. get & serialize program
if export_for_deployment:
main_program = main_program.clone()
global_block = main_program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
main_program.desc.flush()
main_program = main_program._prune_with_input(
feeded_var_names=feeded_var_names, targets=target_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc)
with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string())
main_program = get_inference_program(feeded_var_names, target_vars,
main_program)
_serialization(main_program, model_basename)
else:
# TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing.
with open(model_basename + ".main_program", "wb") as f:
f.write(main_program.desc.serialize_to_string())
main_program = _get_train_program(feeded_var_names, target_vars,
main_program)
_serialization(main_program, model_basename + ".main_program")
# 5. get target var_name list & judge whether serialize program only
target_var_name_list = [var.name for var in target_vars]
if program_only:
warnings.warn(
"save_inference_model specified the param `program_only` to True, It will not save params of Program."
)
return target_var_name_list
# 6. save persistables
main_program._copy_dist_param_info_from(origin_program)
if params_filename is not None:
......
......@@ -23,6 +23,7 @@ import paddle.fluid.core as core
import paddle.fluid as fluid
import warnings
import paddle
import paddle.fluid.executor as executor
import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer
......@@ -201,4 +202,5 @@ class TestLoadInferenceModelError(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -755,5 +755,34 @@ class TestJitSaveLoadNoParamLayer(unittest.TestCase):
self.assertTrue(np.array_equal(out, load_out))
class TestJitGetInferenceProgram(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
def test_get_inference_program(self):
layer = LinearNet(784, 1)
train(layer)
model_path = "model.jit_get_inference_program"
paddle.jit.save(layer, model_path)
infer_program = paddle.jit.get_inference_program(layer)
# the program of jit.load is different with original inference program
model_file_path = os.path.join(model_path, "__model__")
load_program_desc = fluid.dygraph.io._load_program_desc(model_file_path)
load_program = fluid.dygraph.io._build_program_by_desc(
load_program_desc)
self.assertEqual(infer_program.num_blocks, load_program.num_blocks)
self.assertEqual(
len(infer_program.global_block().ops),
len(load_program.global_block().ops))
self.assertEqual(
len(infer_program.global_block().vars),
len(load_program.global_block().vars))
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,9 @@ from ..fluid.dygraph.jit import declarative as to_static #DEFINE_ALIAS
from ..fluid.dygraph import ProgramTranslator #DEFINE_ALIAS
from ..fluid.dygraph.io import TranslatedLayer #DEFINE_ALIAS
# NOTE: This function is not exposed to users, only used for paddle2onnx now
from ..fluid.dygraph.jit import get_inference_program #DEFINE_ALIAS
__all__ = [
'save', 'load', 'TracedLayer', 'to_static', 'ProgramTranslator',
'TranslatedLayer', 'set_code_level', 'set_verbosity'
......
......@@ -43,3 +43,6 @@ from ..fluid.parallel_executor import ParallelExecutor #DEFINE_ALIAS
from ..fluid.param_attr import WeightNormParamAttr #DEFINE_ALIAS
from ..tensor.io import save #DEFINE_ALIAS
from ..tensor.io import load #DEFINE_ALIAS
# NOTE: This function is not exposed to users, only used for paddle2onnx now
from ..fluid.io import get_inference_program
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册