未验证 提交 291cf821 编写于 作者: X xiongkun 提交者: GitHub

[ bugfix ] fix bugs in Indexable and support LayerDict (#49409)

* bugfix: fix bugs in Indexable and support LayerDict

* fix bugs.
上级 5c4adfae
...@@ -441,5 +441,39 @@ class TestErrorInForLoop(TestTransformForLoop): ...@@ -441,5 +441,39 @@ class TestErrorInForLoop(TestTransformForLoop):
self.dyfunc = for_loop_dyfunc_not_support self.dyfunc = for_loop_dyfunc_not_support
class Net(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.layer_dict = paddle.nn.LayerDict(
{
"conv1": paddle.nn.Conv2D(3, 3, 1),
"conv2": paddle.nn.Conv2D(3, 3, 1),
"conv3": paddle.nn.Conv2D(3, 3, 1),
}
)
def forward(self, x):
out = 0
for layer_name in self.layer_dict:
out += self.layer_dict[layer_name](x)
return out
class TestForLoopMeetDict(unittest.TestCase):
def test_start(self):
net = Net()
model = paddle.jit.to_static(
net,
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32'
)
],
)
paddle.jit.save(model, "./inference/inference")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -42,10 +42,12 @@ def convert_attr(x, attr): ...@@ -42,10 +42,12 @@ def convert_attr(x, attr):
def indexable(x, code=None): def indexable(x, code=None):
if isinstance(x, Variable): if isinstance(x, Variable):
return x return x
if hasattr(x, '__len__') and hasattr(x, '__getitem__'): elif hasattr(x, '__iter__'):
return x
if hasattr(x, '__iter__'):
return [i for i in x] return [i for i in x]
elif hasattr(x, '__len__') and hasattr(
x, '__getitem__'
): # used for customed type and non-iterable type.
return x
else: else:
raise RuntimeError("X can't be convert into indexable.") raise RuntimeError("X can't be convert into indexable.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册