未验证 提交 8c5f1581 编写于 作者: Y yukavio 提交者: GitHub

remove PrettyTable dependence from paddle.flops (#30675)

上级 fb7fbc7a
......@@ -16,7 +16,7 @@ import paddle
import warnings
import paddle.nn as nn
import numpy as np
from .static_flops import static_flops
from .static_flops import static_flops, Table
__all__ = ['flops']
......@@ -265,13 +265,7 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
for handler in handler_collection:
handler.remove()
try:
from prettytable import PrettyTable
except ImportError:
raise ImportError(
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
)
table = PrettyTable(
table = Table(
["Layer Name", "Input Shape", "Output Shape", "Params", "Flops"])
for n, m in model.named_sublayers():
......@@ -288,8 +282,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
m._buffers.pop("total_params")
m._buffers.pop('input_shape')
m._buffers.pop('output_shape')
if (print_detail):
print(table)
if print_detail:
table.print_table()
print('Total Flops: {} Total Params: {}'.format(
int(total_ops), int(total_params)))
return int(total_ops)
......@@ -169,13 +169,7 @@ def count_element_op(op):
def _graph_flops(graph, detail=False):
assert isinstance(graph, GraphWrapper)
flops = 0
try:
from prettytable import PrettyTable
except ImportError:
raise ImportError(
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
)
table = PrettyTable(["OP Type", 'Param name', "Flops"])
table = Table(["OP Type", 'Param name', "Flops"])
for op in graph.ops():
param_name = ''
if op.type() in ['conv2d', 'depthwise_conv2d']:
......@@ -200,10 +194,55 @@ def _graph_flops(graph, detail=False):
table.add_row([op.type(), param_name, op_flops])
op_flops = 0
if detail:
print(table)
table.print_table()
return flops
def static_flops(program, print_detail=False):
graph = GraphWrapper(program)
return _graph_flops(graph, detail=print_detail)
class Table(object):
def __init__(self, table_heads):
self.table_heads = table_heads
self.table_len = []
self.data = []
self.col_num = len(table_heads)
for head in table_heads:
self.table_len.append(len(head))
def add_row(self, row_str):
if not isinstance(row_str, list):
print('The row_str should be a list')
if len(row_str) != self.col_num:
print(
'The length of row data should be equal the length of table heads, but the data: {} is not equal table heads {}'.
format(len(row_str), self.col_num))
for i in range(self.col_num):
if len(str(row_str[i])) > self.table_len[i]:
self.table_len[i] = len(str(row_str[i]))
self.data.append(row_str)
def print_row(self, row):
string = ''
for i in range(self.col_num):
string += '|' + str(row[i]).center(self.table_len[i] + 2)
string += '|'
print(string)
def print_shelf(self):
string = ''
for length in self.table_len:
string += '+'
string += '-' * (length + 2)
string += '+'
print(string)
def print_table(self):
self.print_shelf()
self.print_row(self.table_heads)
self.print_shelf()
for data in self.data:
self.print_row(data)
self.print_shelf()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部