From 291cf8211227cb2837dc929866a8e41904d7c2b9 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 30 Dec 2022 20:15:00 +0800 Subject: [PATCH] [ bugfix ] fix bugs in Indexable and support LayerDict (#49409) * bugfix: fix bugs in Indexable and support LayerDict * fix bugs. --- .../unittests/dygraph_to_static/test_loop.py | 34 +++++++++++++++++++ .../paddle/jit/dy2static/convert_operators.py | 8 +++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index fabfa8edc3..5c84da8e62 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -441,5 +441,39 @@ class TestErrorInForLoop(TestTransformForLoop): self.dyfunc = for_loop_dyfunc_not_support +class Net(paddle.nn.Layer): + def __init__(self): + super().__init__() + + self.layer_dict = paddle.nn.LayerDict( + { + "conv1": paddle.nn.Conv2D(3, 3, 1), + "conv2": paddle.nn.Conv2D(3, 3, 1), + "conv3": paddle.nn.Conv2D(3, 3, 1), + } + ) + + def forward(self, x): + out = 0 + for layer_name in self.layer_dict: + out += self.layer_dict[layer_name](x) + return out + + +class TestForLoopMeetDict(unittest.TestCase): + def test_start(self): + + net = Net() + model = paddle.jit.to_static( + net, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 3, 224, 224], dtype='float32' + ) + ], + ) + paddle.jit.save(model, "./inference/inference") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 3ec3dba88d..328b879c5a 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -42,10 +42,12 @@ def convert_attr(x, attr): def indexable(x, code=None): if isinstance(x, Variable): return x - if hasattr(x, '__len__') and hasattr(x, '__getitem__'): - return x - if hasattr(x, '__iter__'): + elif hasattr(x, '__iter__'): return [i for i in x] + elif hasattr(x, '__len__') and hasattr( + x, '__getitem__' + ): # used for customed type and non-iterable type. + return x else: raise RuntimeError("X can't be convert into indexable.") -- GitLab