From b28cc734d00ba5db8c59e8bd013751fe117383ef Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Thu, 2 Sep 2021 13:54:19 +0800 Subject: [PATCH] fix static error in summary (#35303) --- python/paddle/hapi/model_summary.py | 2 ++ python/paddle/tests/test_model.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 45db83f9141..8d581f38e9b 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -147,6 +147,8 @@ def summary(net, input_size=None, dtypes=None, input=None): input_size = [] for key in input.keys(): input_size.append(tuple(input[key].shape)) + elif isinstance(input, paddle.fluid.framework.Variable): + input_size = tuple(input.shape) else: raise ValueError( "Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size." diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index f90ff0c99af..037601cd083 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -662,6 +662,12 @@ class TestModelFunction(unittest.TestCase): np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) def test_summary_input(self): + paddle.enable_static() + mymodel = MyModel() + input_data = paddle.rand([1, 20]) + paddle.summary(mymodel, input=input_data) + paddle.disable_static() + rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional') input_data = paddle.rand([4, 23, 16]) paddle.summary(rnn, input=input_data) -- GitLab