From 2cdc36f4be222f5f2212e66aa11acc3645d967e4 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 14 Jan 2021 00:56:31 -0600 Subject: [PATCH] [Cherry-pick] Fix prune input bug of jit.save #30425 [Cherry-pick] Fix prune input bug of jit.save cheryy-pick of #30384 --- .../dygraph_to_static/program_translator.py | 15 ++-- .../fluid/dygraph/dygraph_to_static/utils.py | 48 ++++++------ .../tests/unittests/test_jit_save_load.py | 78 ++++++++++++++++++- 3 files changed, 109 insertions(+), 32 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 7c039efeb1d..770a72fbaf0 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -470,19 +470,22 @@ class StaticFunction(object): cached_program_len = len(self._program_cache) # If specific `input_spec`, apply convertion from dygraph layers into static Program. if cached_program_len == 0: - if input_spec is None: - input_spec = self._function_spec.input_spec - elif self._function_spec.input_spec is not None: - if not input_specs_compatible( + desired_input_spec = input_spec + if self._function_spec.input_spec is not None: + if input_spec is not None and not input_specs_compatible( flatten(input_spec), flatten(self._function_spec.input_spec)): raise ValueError( "The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`". format(input_spec, self._function_spec.input_spec)) + # NOTE(chenweihang): we should always translated program based on the `input_spec` + # decorated on forward if it is valid + desired_input_spec = self._function_spec.input_spec - has_input_spec = (input_spec is not None) + has_input_spec = (desired_input_spec is not None) if has_input_spec: - concrete_program, _ = self.get_concrete_program(*input_spec) + concrete_program, _ = self.get_concrete_program( + *desired_input_spec) return concrete_program else: raise ValueError( diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 2fac616673d..3676958f15d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -1222,37 +1222,41 @@ def unwrap(func): return unwrapped_f -def input_specs_compatible(src_input_specs, other_input_specs): +def input_specs_compatible(src_input_specs, desired_input_specs): """ Returns True if the two input specs are compatible, otherwise False. args: src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of paddle.static.InputSpec - other_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of + desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of paddle.static.InputSpec """ len_specs = len(src_input_specs) - if len_specs != len(other_input_specs): - return False - - for i in range(len_specs): - src_shape = src_input_specs[i].shape - other_shape = other_input_specs[i].shape - len_shape = len(src_shape) - if len_shape != len(other_shape): - return False - for j in range(len_shape): - if src_shape[j] is None or src_shape[j] < 0: - continue - if other_shape[j] is None or other_shape[j] < 0: - continue - if src_shape[j] != other_shape[j]: + if len_specs != len(desired_input_specs): + # NOTE(chenweihang): if the input_spec of jit.save is a subset of + # input_spec of to_static, also compatible + for spec in src_input_specs: + if spec not in desired_input_specs: + return False + else: + for i in range(len_specs): + src_shape = src_input_specs[i].shape + other_shape = desired_input_specs[i].shape + len_shape = len(src_shape) + if len_shape != len(other_shape): + return False + for j in range(len_shape): + if src_shape[j] is None or src_shape[j] < 0: + continue + if other_shape[j] is None or other_shape[j] < 0: + continue + if src_shape[j] != other_shape[j]: + return False + + src_dtype = convert_dtype(src_input_specs[i].dtype) + other_dtype = convert_dtype(desired_input_specs[i].dtype) + if src_dtype != other_dtype: return False - - src_dtype = convert_dtype(src_input_specs[i].dtype) - other_dtype = convert_dtype(other_input_specs[i].dtype) - if src_dtype != other_dtype: - return False return True diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index b2704085fd4..a43918765d4 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -95,6 +95,38 @@ class LinerNetWithLabel(paddle.nn.Layer): return out, avg_loss +class LinerNetWithPruneInput(paddle.nn.Layer): + def __init__(self, in_size, out_size): + super(LinerNetWithPruneInput, self).__init__() + self._linear = Linear(in_size, out_size) + + @declarative(input_spec=[ + InputSpec( + shape=[None, 784], dtype='float32', name="image"), InputSpec( + shape=[None, 1], dtype='int64', name="label") + ]) + def forward(self, x, label): + out = self._linear(x) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.mean(loss) + return out + + +class LinerNetWithUselessInput(paddle.nn.Layer): + def __init__(self, in_size, out_size): + super(LinerNetWithUselessInput, self).__init__() + self._linear = Linear(in_size, out_size) + + @declarative(input_spec=[ + InputSpec( + shape=[None, 784], dtype='float32', name="image"), InputSpec( + shape=[None, 1], dtype='int64', name="label") + ]) + def forward(self, x, label): + out = self._linear(x) + return out + + class LinearNetReturnLoss(fluid.dygraph.Layer): def __init__(self, in_size, out_size): super(LinearNetReturnLoss, self).__init__() @@ -627,16 +659,24 @@ class TestJitSaveMultiCases(unittest.TestCase): paddle.seed(SEED) paddle.framework.random._manual_program_seed(SEED) - def verify_inference_correctness(self, layer, model_path, with_label=False): + def verify_inference_correctness(self, + layer, + model_path, + with_label_and_loss=False, + with_label=False): layer.eval() loaded_layer = paddle.jit.load(model_path) loaded_layer.eval() # inference & compare x = paddle.to_tensor(np.random.random((1, 784)).astype('float32')) - if with_label: + if with_label_and_loss: y = paddle.to_tensor(np.random.random((1, 1)).astype('int64')) pred, _ = layer(x, y) pred = pred.numpy() + elif with_label: + y = paddle.to_tensor(np.random.random((1, 1)).astype('int64')) + pred = layer(x, y) + pred = pred.numpy() else: pred = layer(x).numpy() loaded_pred = loaded_layer(x).numpy() @@ -714,7 +754,8 @@ class TestJitSaveMultiCases(unittest.TestCase): ], output_spec=[out]) - self.verify_inference_correctness(layer, model_path, True) + self.verify_inference_correctness( + layer, model_path, with_label_and_loss=True) def test_prune_to_static_no_train(self): layer = LinerNetWithLabel(784, 1) @@ -732,7 +773,36 @@ class TestJitSaveMultiCases(unittest.TestCase): ], output_spec=output_spec) - self.verify_inference_correctness(layer, model_path, True) + self.verify_inference_correctness( + layer, model_path, with_label_and_loss=True) + + def test_prune_input_to_static_no_train(self): + layer = LinerNetWithPruneInput(784, 1) + + model_path = "test_prune_input_to_static_no_train/model" + paddle.jit.save( + layer, + model_path, + input_spec=[ + InputSpec( + shape=[None, 784], dtype='float32', name="image") + ]) + + self.verify_inference_correctness(layer, model_path, with_label=True) + + def test_prune_useless_input_to_static_no_train(self): + layer = LinerNetWithUselessInput(784, 1) + + model_path = "test_prune_useless_input_to_static_no_train/model" + paddle.jit.save( + layer, + model_path, + input_spec=[ + InputSpec( + shape=[None, 784], dtype='float32', name="image") + ]) + + self.verify_inference_correctness(layer, model_path, with_label=True) def test_no_prune_input_spec_name_warning(self): layer = LinearNetWithInputSpec(784, 1) -- GitLab