From f3959e9ddc4397af6bc73b587e51c99e3808003e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 11 Mar 2021 10:03:26 +0800 Subject: [PATCH] [save/load] Fix bug with input_spec=dict[InputSpec] in jit.save (#31517) * fix bug with jit.save * refine code --- python/paddle/fluid/dygraph/jit.py | 7 +++- .../tests/unittests/test_jit_save_load.py | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 90b0085fe33..4b35d778459 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -25,7 +25,7 @@ import paddle from paddle.fluid import core from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy from paddle.fluid.data_feeder import check_type -from paddle.fluid.layers.utils import flatten +from paddle.fluid.layers.utils import flatten, pack_sequence_as from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import ConversionOptions, CONVERSION_OPTIONS @@ -681,6 +681,11 @@ def save(layer, path, input_spec=None, **configs): inner_input_spec) elif 'forward' == attr_func: # transform in jit.save, if input_spec is incomplete, declarative will throw error + # inner_input_spec is list[InputSpec], it should be packed with same sturcture + # as original input_spec here. + if inner_input_spec: + inner_input_spec = pack_sequence_as(input_spec, + inner_input_spec) static_forward = declarative( inner_layer.forward, input_spec=inner_input_spec) concrete_program = static_forward.concrete_program diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index a43918765d4..bf9912c89cb 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -222,6 +222,16 @@ class LinearNetWithDictInput(paddle.nn.Layer): return out +class LinearNetWithDictInputNoPrune(paddle.nn.Layer): + def __init__(self, in_size, out_size): + super(LinearNetWithDictInputNoPrune, self).__init__() + self._linear = Linear(in_size, out_size) + + def forward(self, img): + out = self._linear(img['img'] + img['img2']) + return out + + class EmptyLayer(paddle.nn.Layer): def __init__(self): super(EmptyLayer, self).__init__() @@ -443,6 +453,30 @@ class TestSaveLoadWithDictInput(unittest.TestCase): self.assertEqual(len(loaded_net._input_spec()), 1) +class TestSaveLoadWithDictInputNoPrune(unittest.TestCase): + def test_dict_input(self): + net = LinearNetWithDictInputNoPrune(8, 8) + + path = "test_jit_save_load_with_dict_input_no_prune/model" + # prune inputs + paddle.jit.save( + layer=net, + path=path, + input_spec=[{ + 'img': InputSpec( + shape=[None, 8], dtype='float32', name='img'), + 'img2': InputSpec( + shape=[None, 8], dtype='float32', name='img2') + }]) + + img = paddle.randn(shape=[4, 8], dtype='float32') + img2 = paddle.randn(shape=[4, 8], dtype='float32') + loaded_net = paddle.jit.load(path) + loaded_out = loaded_net(img, img2) + + self.assertEqual(len(loaded_net._input_spec()), 2) + + class TestSaveLoadWithInputSpec(unittest.TestCase): def setUp(self): # enable dygraph mode -- GitLab