diff --git a/python/paddle/hapi/dynamic_flops.py b/python/paddle/hapi/dynamic_flops.py index bfbb483ac31ea0537f7f5d2cccdbdd384b3ae444..63de7f971afe8a4c30922e1a84bfe4726858f175 100644 --- a/python/paddle/hapi/dynamic_flops.py +++ b/python/paddle/hapi/dynamic_flops.py @@ -16,7 +16,7 @@ import paddle import warnings import paddle.nn as nn import numpy as np -from .static_flops import static_flops +from .static_flops import static_flops, Table __all__ = ['flops'] @@ -265,13 +265,7 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): for handler in handler_collection: handler.remove() - try: - from prettytable import PrettyTable - except ImportError: - raise ImportError( - "paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. " - ) - table = PrettyTable( + table = Table( ["Layer Name", "Input Shape", "Output Shape", "Params", "Flops"]) for n, m in model.named_sublayers(): @@ -288,8 +282,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False): m._buffers.pop("total_params") m._buffers.pop('input_shape') m._buffers.pop('output_shape') - if (print_detail): - print(table) + if print_detail: + table.print_table() print('Total Flops: {} Total Params: {}'.format( int(total_ops), int(total_params))) return int(total_ops) diff --git a/python/paddle/hapi/static_flops.py b/python/paddle/hapi/static_flops.py index 9815d4cfff54bab7672484a8d4501e8c59827f3c..4314633603130e44f9ad2a34181639e28448488f 100644 --- a/python/paddle/hapi/static_flops.py +++ b/python/paddle/hapi/static_flops.py @@ -169,13 +169,7 @@ def count_element_op(op): def _graph_flops(graph, detail=False): assert isinstance(graph, GraphWrapper) flops = 0 - try: - from prettytable import PrettyTable - except ImportError: - raise ImportError( - "paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. " - ) - table = PrettyTable(["OP Type", 'Param name', "Flops"]) + table = Table(["OP Type", 'Param name', "Flops"]) for op in graph.ops(): param_name = '' if op.type() in ['conv2d', 'depthwise_conv2d']: @@ -200,10 +194,55 @@ def _graph_flops(graph, detail=False): table.add_row([op.type(), param_name, op_flops]) op_flops = 0 if detail: - print(table) + table.print_table() return flops def static_flops(program, print_detail=False): graph = GraphWrapper(program) return _graph_flops(graph, detail=print_detail) + + +class Table(object): + def __init__(self, table_heads): + self.table_heads = table_heads + self.table_len = [] + self.data = [] + self.col_num = len(table_heads) + for head in table_heads: + self.table_len.append(len(head)) + + def add_row(self, row_str): + if not isinstance(row_str, list): + print('The row_str should be a list') + if len(row_str) != self.col_num: + print( + 'The length of row data should be equal the length of table heads, but the data: {} is not equal table heads {}'. + format(len(row_str), self.col_num)) + for i in range(self.col_num): + if len(str(row_str[i])) > self.table_len[i]: + self.table_len[i] = len(str(row_str[i])) + self.data.append(row_str) + + def print_row(self, row): + string = '' + for i in range(self.col_num): + string += '|' + str(row[i]).center(self.table_len[i] + 2) + string += '|' + print(string) + + def print_shelf(self): + string = '' + for length in self.table_len: + string += '+' + string += '-' * (length + 2) + string += '+' + print(string) + + def print_table(self): + self.print_shelf() + self.print_row(self.table_heads) + self.print_shelf() + for data in self.data: + self.print_row(data) + self.print_shelf()