From 96934b74307242b3f1f16ad85207a84802a2e9cd Mon Sep 17 00:00:00 2001 From: yukavio <67678385+yukavio@users.noreply.github.com> Date: Mon, 21 Dec 2020 15:35:13 +0800 Subject: [PATCH] fix flops (#29758) * fix flops * fix flops --- python/paddle/hapi/dynamic_flops.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/paddle/hapi/dynamic_flops.py b/python/paddle/hapi/dynamic_flops.py index 9e2f78b559..8f6697872c 100644 --- a/python/paddle/hapi/dynamic_flops.py +++ b/python/paddle/hapi/dynamic_flops.py @@ -221,7 +221,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): if m_type in custom_ops: flops_fn = custom_ops[m_type] if m_type not in types_collection: - print("Customize Function has been appied to {}".format(m_type)) + print("Customize Function has been applied to {}".format( + m_type)) elif m_type in register_hooks: flops_fn = register_hooks[m_type] if m_type not in types_collection: @@ -254,11 +255,9 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): for m in model.sublayers(): if len(list(m.children())) > 0: continue - total_ops += m.total_ops - total_params += m.total_params - if hasattr(m, 'total_ops') and hasattr(m, 'total_params'): - total_ops = int(total_ops) - total_params = int(total_params) + if hasattr(m, 'total_ops') and hasattr(m, 'total_params'): + total_ops += m.total_ops + total_params += m.total_params if training: model.train() @@ -277,7 +276,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): for n, m in model.named_sublayers(): if len(list(m.children())) > 0: continue - if "total_ops" in m._buffers: + if set(['total_ops', 'total_params', 'input_shape', + 'output_shape']).issubset(set(list(m._buffers.keys()))): table.add_row([ m.full_name(), list(m.input_shape.numpy()), list(m.output_shape.numpy()), int(m.total_params), @@ -289,6 +289,6 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): m._buffers.pop('output_shape') if (print_detail): print(table) - print('Total Flops: {} Total Params: {}'.format(total_ops, - total_params)) - return total_ops + print('Total Flops: {} Total Params: {}'.format( + int(total_ops), int(total_params))) + return int(total_ops) -- GitLab