diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 45db83f9141a72a9f44457bc8f19e105e1d8ce73..8d581f38e9b01f552eec85417566dc13c3fb8072 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -147,6 +147,8 @@ def summary(net, input_size=None, dtypes=None, input=None): input_size = [] for key in input.keys(): input_size.append(tuple(input[key].shape)) + elif isinstance(input, paddle.fluid.framework.Variable): + input_size = tuple(input.shape) else: raise ValueError( "Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size." diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index f90ff0c99af959ab7cf4c15f5856dde0b67503dd..037601cd083c2e2c5ed50a82285a670504efb322 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -662,6 +662,12 @@ class TestModelFunction(unittest.TestCase): np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) def test_summary_input(self): + paddle.enable_static() + mymodel = MyModel() + input_data = paddle.rand([1, 20]) + paddle.summary(mymodel, input=input_data) + paddle.disable_static() + rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional') input_data = paddle.rand([4, 23, 16]) paddle.summary(rnn, input=input_data)