未验证 提交 b7fac0f9 编写于 作者: H HydrogenSulfate 提交者: GitHub

fix paddle.summary's bug when outputs contains non-tensor (#34160)

* fix paddle.summary's bug when output contains non-tensor
上级 02cc3c5e
...@@ -262,8 +262,10 @@ def summary_string(model, input_size=None, dtypes=None, input=None): ...@@ -262,8 +262,10 @@ def summary_string(model, input_size=None, dtypes=None, input=None):
def _get_output_shape(output): def _get_output_shape(output):
if isinstance(output, (list, tuple)): if isinstance(output, (list, tuple)):
output_shape = [_get_output_shape(o) for o in output] output_shape = [_get_output_shape(o) for o in output]
else: elif hasattr(output, 'shape'):
output_shape = list(output.shape) output_shape = list(output.shape)
else:
output_shape = []
return output_shape return output_shape
def register_hook(layer): def register_hook(layer):
......
...@@ -68,6 +68,28 @@ class LeNetDygraph(paddle.nn.Layer): ...@@ -68,6 +68,28 @@ class LeNetDygraph(paddle.nn.Layer):
return x return x
class ModelInner(paddle.nn.Layer):
def __init__(self):
super(ModelInner, self).__init__()
self.fc = paddle.nn.Linear(3, 4)
def forward(self, x):
y = self.fc(x)
return y, 0
class ModelOutter(paddle.nn.Layer):
def __init__(self):
super(ModelOutter, self).__init__()
self.module1 = ModelInner()
self.module2 = paddle.nn.Linear(4, 5)
def forward(self, x):
y, dummpy = self.module1(x)
y = self.module2(y)
return y, 3
class LeNetListInput(LeNetDygraph): class LeNetListInput(LeNetDygraph):
def forward(self, inputs): def forward(self, inputs):
x = inputs[0] x = inputs[0]
...@@ -607,6 +629,9 @@ class TestModelFunction(unittest.TestCase): ...@@ -607,6 +629,9 @@ class TestModelFunction(unittest.TestCase):
model.summary(input_size=[(20)]) model.summary(input_size=[(20)])
model.summary(input_size=(20), dtype='float32') model.summary(input_size=(20), dtype='float32')
def test_summary_non_tensor(self):
paddle.summary(ModelOutter(), input_size=(-1, 3))
def test_summary_nlp(self): def test_summary_nlp(self):
def _get_param_from_state_dict(state_dict): def _get_param_from_state_dict(state_dict):
params = 0 params = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册