diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 5d82ca17474dd02b65c09ed3d681beaeefe80eb6..d618874ad9866ef7ae90f249e396cfe9086989c1 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -356,7 +356,9 @@ def _get_input_var_names(inputs, input_spec): "in input_spec is the same as the name of InputSpec in " \ "`to_static` decorated on the Layer.forward method." result_list = [] - input_var_names = [var.name for var in inputs if isinstance(var, Variable)] + input_var_names = [ + var.name for var in flatten(inputs) if isinstance(var, Variable) + ] if input_spec is None: # no prune result_list = input_var_names @@ -606,7 +608,7 @@ def save(layer, path, input_spec=None, **configs): "The input input_spec should be 'list', but received input_spec's type is %s." % type(input_spec)) inner_input_spec = [] - for var in input_spec: + for var in flatten(input_spec): if isinstance(var, paddle.static.InputSpec): inner_input_spec.append(var) elif isinstance(var, (core.VarBase, Variable)): diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 5973199125716b23a6982717619e8fec8cc8cf6c..62d1d175d10a0a51e5b6b0b33e1b3401e3402cef 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -169,6 +169,25 @@ class LinearNetWithNestOut(fluid.dygraph.Layer): return y, [(z, loss), out] +class LinearNetWithDictInput(paddle.nn.Layer): + def __init__(self, in_size, out_size): + super(LinearNetWithDictInput, self).__init__() + self._linear = Linear(in_size, out_size) + + @paddle.jit.to_static(input_spec=[{ + 'img': InputSpec( + shape=[None, 8], dtype='float32', name='img') + }, { + 'label': InputSpec( + shape=[None, 1], dtype='int64', name='label') + }]) + def forward(self, img, label): + out = self._linear(img['img']) + # not return loss to avoid prune output + loss = paddle.nn.functional.cross_entropy(out, label['label']) + return out + + class EmptyLayer(paddle.nn.Layer): def __init__(self): super(EmptyLayer, self).__init__() @@ -359,6 +378,37 @@ class TestSaveLoadWithNestOut(unittest.TestCase): self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy())) +class TestSaveLoadWithDictInput(unittest.TestCase): + def test_dict_input(self): + # NOTE: This net cannot be executed, it is just + # a special case for exporting models in model validation + # We DO NOT recommend this writing way of Layer + net = LinearNetWithDictInput(8, 8) + # net.forward.concrete_program.inputs: + # (<__main__.LinearNetWithDictInput object at 0x7f2655298a98>, + # {'img': var img : fluid.VarType.LOD_TENSOR.shape(-1, 8).astype(VarType.FP32)}, + # {'label': var label : fluid.VarType.LOD_TENSOR.shape(-1, 1).astype(VarType.INT64)}) + self.assertEqual(len(net.forward.concrete_program.inputs), 3) + + path = "test_jit_save_load_with_dict_input/model" + # prune inputs + paddle.jit.save( + layer=net, + path=path, + input_spec=[{ + 'img': InputSpec( + shape=[None, 8], dtype='float32', name='img') + }]) + + img = paddle.randn(shape=[4, 8], dtype='float32') + loaded_net = paddle.jit.load(path) + loaded_out = loaded_net(img) + + # loaded_net._input_spec(): + # [InputSpec(shape=(-1, 8), dtype=VarType.FP32, name=img)] + self.assertEqual(len(loaded_net._input_spec()), 1) + + class TestSaveLoadWithInputSpec(unittest.TestCase): def setUp(self): # enable dygraph mode