未验证 提交 ae1e71b3 编写于 作者: Y yingyibiao 提交者: GitHub

Fix paddle.flops AttributeError (#38850)

* Fix AttributeError when output y is a tuple which has no attribute 'shape'

* Add unit test for dynamic_flops with multiple outputs

* Add unit test for dynamic_flops with multiple outputs
上级 01222f52
......@@ -181,6 +181,9 @@ def count_parameters(m, x, y):
def count_io_info(m, x, y):
m.register_buffer('input_shape', paddle.to_tensor(x[0].shape))
if isinstance(y, (list, tuple)):
m.register_buffer('output_shape', paddle.to_tensor(y[0].shape))
else:
m.register_buffer('output_shape', paddle.to_tensor(y.shape))
......@@ -258,8 +261,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
for m in model.sublayers():
if len(list(m.children())) > 0:
continue
if set(['total_ops', 'total_params', 'input_shape',
'output_shape']).issubset(set(list(m._buffers.keys()))):
if {'total_ops', 'total_params', 'input_shape',
'output_shape'}.issubset(set(list(m._buffers.keys()))):
total_ops += m.total_ops
total_params += m.total_params
......@@ -274,8 +277,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
for n, m in model.named_sublayers():
if len(list(m.children())) > 0:
continue
if set(['total_ops', 'total_params', 'input_shape',
'output_shape']).issubset(set(list(m._buffers.keys()))):
if {'total_ops', 'total_params', 'input_shape',
'output_shape'}.issubset(set(list(m._buffers.keys()))):
table.add_row([
m.full_name(), list(m.input_shape.numpy()),
list(m.output_shape.numpy()), int(m.total_params),
......
......@@ -732,6 +732,18 @@ class TestModelFunction(unittest.TestCase):
custom_ops={paddle.nn.Dropout: customize_dropout},
print_detail=True)
def test_dynamic_flops_with_multiple_outputs(self):
net = paddle.nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, return_mask=True)
def customize_dropout(m, x, y):
m.total_ops += 0
paddle.flops(
net, [1, 2, 32, 32],
custom_ops={paddle.nn.Dropout: customize_dropout},
print_detail=True)
def test_export_deploy_model(self):
self.set_seed()
np.random.seed(201)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册