未验证 提交 b28cc734 编写于 作者: W wangna11BD 提交者: GitHub

fix static error in summary (#35303)

上级 25871e0e
...@@ -147,6 +147,8 @@ def summary(net, input_size=None, dtypes=None, input=None): ...@@ -147,6 +147,8 @@ def summary(net, input_size=None, dtypes=None, input=None):
input_size = [] input_size = []
for key in input.keys(): for key in input.keys():
input_size.append(tuple(input[key].shape)) input_size.append(tuple(input[key].shape))
elif isinstance(input, paddle.fluid.framework.Variable):
input_size = tuple(input.shape)
else: else:
raise ValueError( raise ValueError(
"Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size." "Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size."
......
...@@ -662,6 +662,12 @@ class TestModelFunction(unittest.TestCase): ...@@ -662,6 +662,12 @@ class TestModelFunction(unittest.TestCase):
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
def test_summary_input(self): def test_summary_input(self):
paddle.enable_static()
mymodel = MyModel()
input_data = paddle.rand([1, 20])
paddle.summary(mymodel, input=input_data)
paddle.disable_static()
rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional') rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional')
input_data = paddle.rand([4, 23, 16]) input_data = paddle.rand([4, 23, 16])
paddle.summary(rnn, input=input_data) paddle.summary(rnn, input=input_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册