From 5067e3a8d2ac7d78c1c8913f432f16dfd2d29992 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 20 Jan 2021 09:56:38 +0800 Subject: [PATCH] [Dy2Static]Enhance check of TracedLayers out vars (#30576) --- python/paddle/fluid/dygraph/jit.py | 42 +++++++++---------- .../unittests/test_traced_layer_err_msg.py | 28 +++++++++++-- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 5bafbe7f41c..90b0085fe33 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -53,21 +53,21 @@ def create_program_from_desc(program_desc): return program -def _extract_vars(inputs, result_list): +def _extract_vars(inputs, result_list, err_tag='inputs'): if isinstance(inputs, Variable): result_list.append(inputs) elif isinstance(inputs, (list, tuple)): for var in inputs: - _extract_vars(var, result_list) + _extract_vars(var, result_list, err_tag) else: raise TypeError( - "The type of 'each element of inputs' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.". - format(type(inputs))) + "The type of 'each element of {}' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}.". + format(err_tag, type(inputs))) -def extract_vars(inputs): +def extract_vars(inputs, err_tag='inputs'): result_list = [] - _extract_vars(inputs, result_list) + _extract_vars(inputs, result_list, err_tag) return result_list @@ -278,8 +278,8 @@ class _SaveLoadConfig(object): # NOTE: Users rarely use following configs, so these configs are not open to users, # reducing user learning costs, but we retain the configuration capabilities - # If True, programs are modified to only support direct inference deployment. - # Otherwise,more information will be stored for flexible optimization and re-training. + # If True, programs are modified to only support direct inference deployment. + # Otherwise,more information will be stored for flexible optimization and re-training. # Currently, only True is supported self._export_for_deployment = True @@ -406,7 +406,7 @@ def _get_input_var_names(inputs, input_spec): elif input_spec is not None and len(input_spec) == len(input_var_names): # no prune result_list = input_var_names - # if input spec name not in input_var_names, only raise warning + # if input spec name not in input_var_names, only raise warning for spec in input_spec: if spec.name is None: warnings.warn(name_none_error % spec) @@ -624,7 +624,7 @@ def save(layer, path, input_spec=None, **configs): # NOTE(chenweihang): If the input layer be wrapped by DataParallel, # the args and kwargs of forward method will can't be parsed by - # function_spec, so here we save DataParallel._layers instead + # function_spec, so here we save DataParallel._layers instead # DataParallel it self # NOTE(chenweihang): using inner_layer, do not change input layer if isinstance(layer, paddle.DataParallel): @@ -684,7 +684,7 @@ def save(layer, path, input_spec=None, **configs): static_forward = declarative( inner_layer.forward, input_spec=inner_input_spec) concrete_program = static_forward.concrete_program - # the input_spec has been used in declarative, which is equal to + # the input_spec has been used in declarative, which is equal to # @declarative with input_spec and jit.save without input_spec, # avoid needless warning inner_input_spec = None @@ -704,21 +704,21 @@ def save(layer, path, input_spec=None, **configs): inner_input_spec) # NOTE(chenweihang): [ Get output variables ] - # the rule is like [ Get input variables name ]. For output var, - # we only support VarBase spec, and actually, we only need the + # the rule is like [ Get input variables name ]. For output var, + # we only support VarBase spec, and actually, we only need the # var name of output, and we don't recommended to use output_spec output_vars = _get_output_vars(concrete_program.outputs, configs.output_spec) # NOTE(chenweihang): we maintain the mapping of variable name to # structured name, the buffer variable (non-persistable) - # saved to inference program may not need by dygraph Layer, + # saved to inference program may not need by dygraph Layer, # we only record the state_dict variable's structured name state_names_dict = dict() for structured_name, var in six.iteritems(inner_layer.state_dict()): state_names_dict[var.name] = structured_name - # 4. share parameters from Layer to scope & record var info + # 4. share parameters from Layer to scope & record var info for param_or_buffer in concrete_program.parameters: # share to scope param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor( @@ -742,7 +742,7 @@ def save(layer, path, input_spec=None, **configs): # construct new save_inference_model arguments model_path = dirname # NOTE(chenweihang): because prefix contains model and params filename, - # so we don't support set model_filename & params_filename + # so we don't support set model_filename & params_filename if 'forward' == attr_func: model_filename = file_prefix + INFER_MODEL_SUFFIX params_filename = file_prefix + INFER_PARAMS_SUFFIX @@ -769,12 +769,12 @@ def save(layer, path, input_spec=None, **configs): # - Which persistent variable are parameter and which are not # - Parameter.trainable information # - # The lost information cannot be recovered when it is loaded again, - # so if we want to perform fine-tune after loading, we may need to + # The lost information cannot be recovered when it is loaded again, + # so if we want to perform fine-tune after loading, we may need to # configure redundant information to proceed. # - # Due to compatibility issues, we cannot change the original storage structure, - # but we can save these information in `jit.save` without changing the original + # Due to compatibility issues, we cannot change the original storage structure, + # but we can save these information in `jit.save` without changing the original # storage to improve user experience. So we save extra information into # file `***.pdiparams.info` with scope_guard(scope): @@ -1032,7 +1032,7 @@ def _trace(layer, outputs = [original_outputs] else: outputs = original_outputs - out_vars = [var for var in outputs] + out_vars = extract_vars(outputs, err_tag='outputs') program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc( var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix) diff --git a/python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py b/python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py index 48251d17d0a..38543fecac8 100644 --- a/python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py +++ b/python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py @@ -13,16 +13,18 @@ # limitations under the License. import numpy as np +import paddle import paddle.fluid as fluid import six import unittest +import paddle.nn as nn -class SimpleFCLayer(fluid.dygraph.Layer): +class SimpleFCLayer(nn.Layer): def __init__(self, feature_size, batch_size, fc_size): super(SimpleFCLayer, self).__init__() - self._linear = fluid.dygraph.Linear(feature_size, fc_size) - self._offset = fluid.dygraph.to_variable( + self._linear = nn.Linear(feature_size, fc_size) + self._offset = paddle.to_tensor( np.random.random((batch_size, fc_size)).astype('float32')) def forward(self, x): @@ -30,6 +32,17 @@ class SimpleFCLayer(fluid.dygraph.Layer): return fc + self._offset +class LinearNetWithNone(nn.Layer): + def __init__(self, feature_size, fc_size): + super(LinearNetWithNone, self).__init__() + self._linear = nn.Linear(feature_size, fc_size) + + def forward(self, x): + fc = self._linear(x) + + return [fc, [None, 2]] + + class TestTracedLayerErrMsg(unittest.TestCase): def setUp(self): self.batch_size = 4 @@ -152,5 +165,14 @@ class TestTracedLayerErrMsg(unittest.TestCase): return layer +class TestOutVarWithNoneErrMsg(unittest.TestCase): + def test_linear_net_with_none(self): + model = LinearNetWithNone(100, 16) + in_x = paddle.to_tensor(np.random.random((4, 100)).astype('float32')) + with self.assertRaises(TypeError): + dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(model, + [in_x]) + + if __name__ == '__main__': unittest.main() -- GitLab