未验证 提交 96934b74 编写于 作者: Y yukavio 提交者: GitHub

fix flops (#29758)

* fix flops

* fix flops
上级 41a7b071
...@@ -221,7 +221,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): ...@@ -221,7 +221,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
if m_type in custom_ops: if m_type in custom_ops:
flops_fn = custom_ops[m_type] flops_fn = custom_ops[m_type]
if m_type not in types_collection: if m_type not in types_collection:
print("Customize Function has been appied to {}".format(m_type)) print("Customize Function has been applied to {}".format(
m_type))
elif m_type in register_hooks: elif m_type in register_hooks:
flops_fn = register_hooks[m_type] flops_fn = register_hooks[m_type]
if m_type not in types_collection: if m_type not in types_collection:
...@@ -254,11 +255,9 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): ...@@ -254,11 +255,9 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
for m in model.sublayers(): for m in model.sublayers():
if len(list(m.children())) > 0: if len(list(m.children())) > 0:
continue continue
total_ops += m.total_ops if hasattr(m, 'total_ops') and hasattr(m, 'total_params'):
total_params += m.total_params total_ops += m.total_ops
if hasattr(m, 'total_ops') and hasattr(m, 'total_params'): total_params += m.total_params
total_ops = int(total_ops)
total_params = int(total_params)
if training: if training:
model.train() model.train()
...@@ -277,7 +276,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): ...@@ -277,7 +276,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
for n, m in model.named_sublayers(): for n, m in model.named_sublayers():
if len(list(m.children())) > 0: if len(list(m.children())) > 0:
continue continue
if "total_ops" in m._buffers: if set(['total_ops', 'total_params', 'input_shape',
'output_shape']).issubset(set(list(m._buffers.keys()))):
table.add_row([ table.add_row([
m.full_name(), list(m.input_shape.numpy()), m.full_name(), list(m.input_shape.numpy()),
list(m.output_shape.numpy()), int(m.total_params), list(m.output_shape.numpy()), int(m.total_params),
...@@ -289,6 +289,6 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): ...@@ -289,6 +289,6 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
m._buffers.pop('output_shape') m._buffers.pop('output_shape')
if (print_detail): if (print_detail):
print(table) print(table)
print('Total Flops: {} Total Params: {}'.format(total_ops, print('Total Flops: {} Total Params: {}'.format(
total_params)) int(total_ops), int(total_params)))
return total_ops return int(total_ops)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册