From b7fac0f980e2b05f511640850e832b64c2d1a032 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 29 Jul 2021 19:36:10 +0800 Subject: [PATCH] fix paddle.summary's bug when outputs contains non-tensor (#34160) * fix paddle.summary's bug when output contains non-tensor --- python/paddle/hapi/model_summary.py | 4 +++- python/paddle/tests/test_model.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 7e435fdc27..45db83f914 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -262,8 +262,10 @@ def summary_string(model, input_size=None, dtypes=None, input=None): def _get_output_shape(output): if isinstance(output, (list, tuple)): output_shape = [_get_output_shape(o) for o in output] - else: + elif hasattr(output, 'shape'): output_shape = list(output.shape) + else: + output_shape = [] return output_shape def register_hook(layer): diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 36478289cc..abeb833917 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -68,6 +68,28 @@ class LeNetDygraph(paddle.nn.Layer): return x +class ModelInner(paddle.nn.Layer): + def __init__(self): + super(ModelInner, self).__init__() + self.fc = paddle.nn.Linear(3, 4) + + def forward(self, x): + y = self.fc(x) + return y, 0 + + +class ModelOutter(paddle.nn.Layer): + def __init__(self): + super(ModelOutter, self).__init__() + self.module1 = ModelInner() + self.module2 = paddle.nn.Linear(4, 5) + + def forward(self, x): + y, dummpy = self.module1(x) + y = self.module2(y) + return y, 3 + + class LeNetListInput(LeNetDygraph): def forward(self, inputs): x = inputs[0] @@ -607,6 +629,9 @@ class TestModelFunction(unittest.TestCase): model.summary(input_size=[(20)]) model.summary(input_size=(20), dtype='float32') + def test_summary_non_tensor(self): + paddle.summary(ModelOutter(), input_size=(-1, 3)) + def test_summary_nlp(self): def _get_param_from_state_dict(state_dict): params = 0 -- GitLab