未验证 提交 5067e3a8 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Static]Enhance check of TracedLayers out vars (#30576)

上级 d1b25ed9
...@@ -53,21 +53,21 @@ def create_program_from_desc(program_desc): ...@@ -53,21 +53,21 @@ def create_program_from_desc(program_desc):
return program return program
def _extract_vars(inputs, result_list): def _extract_vars(inputs, result_list, err_tag='inputs'):
if isinstance(inputs, Variable): if isinstance(inputs, Variable):
result_list.append(inputs) result_list.append(inputs)
elif isinstance(inputs, (list, tuple)): elif isinstance(inputs, (list, tuple)):
for var in inputs: for var in inputs:
_extract_vars(var, result_list) _extract_vars(var, result_list, err_tag)
else: else:
raise TypeError( raise TypeError(
"The type of 'each element of inputs' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.". "The type of 'each element of {}' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.".
format(type(inputs))) format(err_tag, type(inputs)))
def extract_vars(inputs): def extract_vars(inputs, err_tag='inputs'):
result_list = [] result_list = []
_extract_vars(inputs, result_list) _extract_vars(inputs, result_list, err_tag)
return result_list return result_list
...@@ -278,8 +278,8 @@ class _SaveLoadConfig(object): ...@@ -278,8 +278,8 @@ class _SaveLoadConfig(object):
# NOTE: Users rarely use following configs, so these configs are not open to users, # NOTE: Users rarely use following configs, so these configs are not open to users,
# reducing user learning costs, but we retain the configuration capabilities # reducing user learning costs, but we retain the configuration capabilities
# If True, programs are modified to only support direct inference deployment. # If True, programs are modified to only support direct inference deployment.
# Otherwise,more information will be stored for flexible optimization and re-training. # Otherwise,more information will be stored for flexible optimization and re-training.
# Currently, only True is supported # Currently, only True is supported
self._export_for_deployment = True self._export_for_deployment = True
...@@ -406,7 +406,7 @@ def _get_input_var_names(inputs, input_spec): ...@@ -406,7 +406,7 @@ def _get_input_var_names(inputs, input_spec):
elif input_spec is not None and len(input_spec) == len(input_var_names): elif input_spec is not None and len(input_spec) == len(input_var_names):
# no prune # no prune
result_list = input_var_names result_list = input_var_names
# if input spec name not in input_var_names, only raise warning # if input spec name not in input_var_names, only raise warning
for spec in input_spec: for spec in input_spec:
if spec.name is None: if spec.name is None:
warnings.warn(name_none_error % spec) warnings.warn(name_none_error % spec)
...@@ -624,7 +624,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -624,7 +624,7 @@ def save(layer, path, input_spec=None, **configs):
# NOTE(chenweihang): If the input layer be wrapped by DataParallel, # NOTE(chenweihang): If the input layer be wrapped by DataParallel,
# the args and kwargs of forward method will can't be parsed by # the args and kwargs of forward method will can't be parsed by
# function_spec, so here we save DataParallel._layers instead # function_spec, so here we save DataParallel._layers instead
# DataParallel it self # DataParallel it self
# NOTE(chenweihang): using inner_layer, do not change input layer # NOTE(chenweihang): using inner_layer, do not change input layer
if isinstance(layer, paddle.DataParallel): if isinstance(layer, paddle.DataParallel):
...@@ -684,7 +684,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -684,7 +684,7 @@ def save(layer, path, input_spec=None, **configs):
static_forward = declarative( static_forward = declarative(
inner_layer.forward, input_spec=inner_input_spec) inner_layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program concrete_program = static_forward.concrete_program
# the input_spec has been used in declarative, which is equal to # the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec, # @declarative with input_spec and jit.save without input_spec,
# avoid needless warning # avoid needless warning
inner_input_spec = None inner_input_spec = None
...@@ -704,21 +704,21 @@ def save(layer, path, input_spec=None, **configs): ...@@ -704,21 +704,21 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec) inner_input_spec)
# NOTE(chenweihang): [ Get output variables ] # NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var, # the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the # we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec # var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs, output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec) configs.output_spec)
# NOTE(chenweihang): we maintain the mapping of variable name to # NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable) # structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer, # saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name # we only record the state_dict variable's structured name
state_names_dict = dict() state_names_dict = dict()
for structured_name, var in six.iteritems(inner_layer.state_dict()): for structured_name, var in six.iteritems(inner_layer.state_dict()):
state_names_dict[var.name] = structured_name state_names_dict[var.name] = structured_name
# 4. share parameters from Layer to scope & record var info # 4. share parameters from Layer to scope & record var info
for param_or_buffer in concrete_program.parameters: for param_or_buffer in concrete_program.parameters:
# share to scope # share to scope
param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor( param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor(
...@@ -742,7 +742,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -742,7 +742,7 @@ def save(layer, path, input_spec=None, **configs):
# construct new save_inference_model arguments # construct new save_inference_model arguments
model_path = dirname model_path = dirname
# NOTE(chenweihang): because prefix contains model and params filename, # NOTE(chenweihang): because prefix contains model and params filename,
# so we don't support set model_filename & params_filename # so we don't support set model_filename & params_filename
if 'forward' == attr_func: if 'forward' == attr_func:
model_filename = file_prefix + INFER_MODEL_SUFFIX model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX params_filename = file_prefix + INFER_PARAMS_SUFFIX
...@@ -769,12 +769,12 @@ def save(layer, path, input_spec=None, **configs): ...@@ -769,12 +769,12 @@ def save(layer, path, input_spec=None, **configs):
# - Which persistent variable are parameter and which are not # - Which persistent variable are parameter and which are not
# - Parameter.trainable information # - Parameter.trainable information
# #
# The lost information cannot be recovered when it is loaded again, # The lost information cannot be recovered when it is loaded again,
# so if we want to perform fine-tune after loading, we may need to # so if we want to perform fine-tune after loading, we may need to
# configure redundant information to proceed. # configure redundant information to proceed.
# #
# Due to compatibility issues, we cannot change the original storage structure, # Due to compatibility issues, we cannot change the original storage structure,
# but we can save these information in `jit.save` without changing the original # but we can save these information in `jit.save` without changing the original
# storage to improve user experience. So we save extra information into # storage to improve user experience. So we save extra information into
# file `***.pdiparams.info` # file `***.pdiparams.info`
with scope_guard(scope): with scope_guard(scope):
...@@ -1032,7 +1032,7 @@ def _trace(layer, ...@@ -1032,7 +1032,7 @@ def _trace(layer,
outputs = [original_outputs] outputs = [original_outputs]
else: else:
outputs = original_outputs outputs = original_outputs
out_vars = [var for var in outputs] out_vars = extract_vars(outputs, err_tag='outputs')
program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc( program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix) var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
......
...@@ -13,16 +13,18 @@ ...@@ -13,16 +13,18 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import six import six
import unittest import unittest
import paddle.nn as nn
class SimpleFCLayer(fluid.dygraph.Layer): class SimpleFCLayer(nn.Layer):
def __init__(self, feature_size, batch_size, fc_size): def __init__(self, feature_size, batch_size, fc_size):
super(SimpleFCLayer, self).__init__() super(SimpleFCLayer, self).__init__()
self._linear = fluid.dygraph.Linear(feature_size, fc_size) self._linear = nn.Linear(feature_size, fc_size)
self._offset = fluid.dygraph.to_variable( self._offset = paddle.to_tensor(
np.random.random((batch_size, fc_size)).astype('float32')) np.random.random((batch_size, fc_size)).astype('float32'))
def forward(self, x): def forward(self, x):
...@@ -30,6 +32,17 @@ class SimpleFCLayer(fluid.dygraph.Layer): ...@@ -30,6 +32,17 @@ class SimpleFCLayer(fluid.dygraph.Layer):
return fc + self._offset return fc + self._offset
class LinearNetWithNone(nn.Layer):
def __init__(self, feature_size, fc_size):
super(LinearNetWithNone, self).__init__()
self._linear = nn.Linear(feature_size, fc_size)
def forward(self, x):
fc = self._linear(x)
return [fc, [None, 2]]
class TestTracedLayerErrMsg(unittest.TestCase): class TestTracedLayerErrMsg(unittest.TestCase):
def setUp(self): def setUp(self):
self.batch_size = 4 self.batch_size = 4
...@@ -152,5 +165,14 @@ class TestTracedLayerErrMsg(unittest.TestCase): ...@@ -152,5 +165,14 @@ class TestTracedLayerErrMsg(unittest.TestCase):
return layer return layer
class TestOutVarWithNoneErrMsg(unittest.TestCase):
def test_linear_net_with_none(self):
model = LinearNetWithNone(100, 16)
in_x = paddle.to_tensor(np.random.random((4, 100)).astype('float32'))
with self.assertRaises(TypeError):
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(model,
[in_x])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册