未验证 提交 3c5f2cac 编写于 作者: C Chen Weihang 提交者: GitHub

fix save parse error for dict input (#28712)

上级 9ab335bb
......@@ -356,7 +356,9 @@ def _get_input_var_names(inputs, input_spec):
"in input_spec is the same as the name of InputSpec in " \
"`to_static` decorated on the Layer.forward method."
result_list = []
input_var_names = [var.name for var in inputs if isinstance(var, Variable)]
input_var_names = [
var.name for var in flatten(inputs) if isinstance(var, Variable)
]
if input_spec is None:
# no prune
result_list = input_var_names
......@@ -606,7 +608,7 @@ def save(layer, path, input_spec=None, **configs):
"The input input_spec should be 'list', but received input_spec's type is %s."
% type(input_spec))
inner_input_spec = []
for var in input_spec:
for var in flatten(input_spec):
if isinstance(var, paddle.static.InputSpec):
inner_input_spec.append(var)
elif isinstance(var, (core.VarBase, Variable)):
......
......@@ -169,6 +169,25 @@ class LinearNetWithNestOut(fluid.dygraph.Layer):
return y, [(z, loss), out]
class LinearNetWithDictInput(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(LinearNetWithDictInput, self).__init__()
self._linear = Linear(in_size, out_size)
@paddle.jit.to_static(input_spec=[{
'img': InputSpec(
shape=[None, 8], dtype='float32', name='img')
}, {
'label': InputSpec(
shape=[None, 1], dtype='int64', name='label')
}])
def forward(self, img, label):
out = self._linear(img['img'])
# not return loss to avoid prune output
loss = paddle.nn.functional.cross_entropy(out, label['label'])
return out
class EmptyLayer(paddle.nn.Layer):
def __init__(self):
super(EmptyLayer, self).__init__()
......@@ -359,6 +378,37 @@ class TestSaveLoadWithNestOut(unittest.TestCase):
self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy()))
class TestSaveLoadWithDictInput(unittest.TestCase):
def test_dict_input(self):
# NOTE: This net cannot be executed, it is just
# a special case for exporting models in model validation
# We DO NOT recommend this writing way of Layer
net = LinearNetWithDictInput(8, 8)
# net.forward.concrete_program.inputs:
# (<__main__.LinearNetWithDictInput object at 0x7f2655298a98>,
# {'img': var img : fluid.VarType.LOD_TENSOR.shape(-1, 8).astype(VarType.FP32)},
# {'label': var label : fluid.VarType.LOD_TENSOR.shape(-1, 1).astype(VarType.INT64)})
self.assertEqual(len(net.forward.concrete_program.inputs), 3)
path = "test_jit_save_load_with_dict_input/model"
# prune inputs
paddle.jit.save(
layer=net,
path=path,
input_spec=[{
'img': InputSpec(
shape=[None, 8], dtype='float32', name='img')
}])
img = paddle.randn(shape=[4, 8], dtype='float32')
loaded_net = paddle.jit.load(path)
loaded_out = loaded_net(img)
# loaded_net._input_spec():
# [InputSpec(shape=(-1, 8), dtype=VarType.FP32, name=img)]
self.assertEqual(len(loaded_net._input_spec()), 1)
class TestSaveLoadWithInputSpec(unittest.TestCase):
def setUp(self):
# enable dygraph mode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册