未验证 提交 f3959e9d 编写于 作者: A Aurelius84 提交者: GitHub

[save/load] Fix bug with input_spec=dict[InputSpec] in jit.save (#31517)

* fix bug with jit.save

* refine code
上级 83a2fb1f
...@@ -25,7 +25,7 @@ import paddle ...@@ -25,7 +25,7 @@ import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten, pack_sequence_as
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import ConversionOptions, CONVERSION_OPTIONS from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import ConversionOptions, CONVERSION_OPTIONS
...@@ -681,6 +681,11 @@ def save(layer, path, input_spec=None, **configs): ...@@ -681,6 +681,11 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec) inner_input_spec)
elif 'forward' == attr_func: elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error # transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same sturcture
# as original input_spec here.
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
inner_input_spec)
static_forward = declarative( static_forward = declarative(
inner_layer.forward, input_spec=inner_input_spec) inner_layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program concrete_program = static_forward.concrete_program
......
...@@ -222,6 +222,16 @@ class LinearNetWithDictInput(paddle.nn.Layer): ...@@ -222,6 +222,16 @@ class LinearNetWithDictInput(paddle.nn.Layer):
return out return out
class LinearNetWithDictInputNoPrune(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(LinearNetWithDictInputNoPrune, self).__init__()
self._linear = Linear(in_size, out_size)
def forward(self, img):
out = self._linear(img['img'] + img['img2'])
return out
class EmptyLayer(paddle.nn.Layer): class EmptyLayer(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(EmptyLayer, self).__init__() super(EmptyLayer, self).__init__()
...@@ -443,6 +453,30 @@ class TestSaveLoadWithDictInput(unittest.TestCase): ...@@ -443,6 +453,30 @@ class TestSaveLoadWithDictInput(unittest.TestCase):
self.assertEqual(len(loaded_net._input_spec()), 1) self.assertEqual(len(loaded_net._input_spec()), 1)
class TestSaveLoadWithDictInputNoPrune(unittest.TestCase):
def test_dict_input(self):
net = LinearNetWithDictInputNoPrune(8, 8)
path = "test_jit_save_load_with_dict_input_no_prune/model"
# prune inputs
paddle.jit.save(
layer=net,
path=path,
input_spec=[{
'img': InputSpec(
shape=[None, 8], dtype='float32', name='img'),
'img2': InputSpec(
shape=[None, 8], dtype='float32', name='img2')
}])
img = paddle.randn(shape=[4, 8], dtype='float32')
img2 = paddle.randn(shape=[4, 8], dtype='float32')
loaded_net = paddle.jit.load(path)
loaded_out = loaded_net(img, img2)
self.assertEqual(len(loaded_net._input_spec()), 2)
class TestSaveLoadWithInputSpec(unittest.TestCase): class TestSaveLoadWithInputSpec(unittest.TestCase):
def setUp(self): def setUp(self):
# enable dygraph mode # enable dygraph mode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册