# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ''' Example: >>from paddle.fluid.contrib.model_stat import summary >>main_program = ... >>summary(main_program) +-----+------------+----------------+----------------+---------+------------+ | No. | TYPE | INPUT | OUTPUT | PARAMs | FLOPs | +-----+------------+----------------+----------------+---------+------------+ | 0 | conv2d | (3, 200, 200) | (64, 100, 100) | 9408 | 188160000 | | 1 | batch_norm | (64, 100, 100) | (64, 100, 100) | 256 | 640000 | | 2 | relu | (64, 100, 100) | (64, 100, 100) | 0 | 640000 | | 3 | pool2d | (64, 100, 100) | (64, 50, 50) | 0 | 1440000 | ... | 176 | conv2d | (512, 7, 7) | (512, 7, 7) | 2359296 | 231211008 | | 177 | relu | (512, 7, 7) | (512, 7, 7) | 0 | 25088 | | 178 | conv2d | (512, 7, 7) | (2048, 7, 7) | 1048576 | 102760448 | | 179 | relu | (2048, 7, 7) | (2048, 7, 7) | 0 | 100352 | | 180 | pool2d | (2048, 7, 7) | (2048, 1, 1) | 0 | 100352 | +-----+------------+----------------+----------------+---------+------------+ Total PARAMs: 48017344(0.0480G) Total FLOPs: 11692747751(11.69G) ''' from collections import OrderedDict def summary(main_prog): ''' It can summary model's PARAMS, FLOPs until now. It support common operator like conv, fc, pool, relu, sigmoid, bn etc. Args: main_prog: main program Returns: print summary on terminal ''' collected_ops_list = [] for one_b in main_prog.blocks: block_vars = one_b.vars for one_op in one_b.ops: op_info = OrderedDict() spf_res = _summary_model(block_vars, one_op) if spf_res is None: continue # TODO: get the operator name op_info['type'] = one_op.type op_info['input_shape'] = spf_res[0][1:] op_info['out_shape'] = spf_res[1][1:] op_info['PARAMs'] = spf_res[2] op_info['FLOPs'] = spf_res[3] collected_ops_list.append(op_info) summary_table, total = _format_summary(collected_ops_list) _print_summary(summary_table, total) def _summary_model(block_vars, one_op): ''' Compute operator's params and flops. Args: block_vars: all vars of one block one_op: one operator to count Returns: in_data_shape: one operator's input data shape out_data_shape: one operator's output data shape params: one operator's PARAMs flops: : one operator's FLOPs ''' if one_op.type in ['conv2d', 'depthwise_conv2d']: k_arg_shape = block_vars[one_op.input("Filter")[0]].shape in_data_shape = block_vars[one_op.input("Input")[0]].shape out_data_shape = block_vars[one_op.output("Output")[0]].shape c_out, c_in, k_h, k_w = k_arg_shape _, c_out_, h_out, w_out = out_data_shape assert c_out == c_out_, 'shape error!' k_groups = one_op.attr("groups") kernel_ops = k_h * k_w * (c_in / k_groups) bias_ops = 0 if one_op.input("Bias") == [] else 1 params = c_out * (kernel_ops + bias_ops) flops = h_out * w_out * c_out * (kernel_ops + bias_ops) # base nvidia paper, include mul and add flops = 2 * flops elif one_op.type == 'pool2d': in_data_shape = block_vars[one_op.input("X")[0]].shape out_data_shape = block_vars[one_op.output("Out")[0]].shape _, c_out, h_out, w_out = out_data_shape k_size = one_op.attr("ksize") params = 0 flops = h_out * w_out * c_out * (k_size[0] * k_size[1]) elif one_op.type == 'mul': k_arg_shape = block_vars[one_op.input("Y")[0]].shape in_data_shape = block_vars[one_op.input("X")[0]].shape out_data_shape = block_vars[one_op.output("Out")[0]].shape # TODO: fc has mul ops # add attr to mul op, tell us whether it belongs to 'fc' # this's not the best way if 'fc' not in one_op.output("Out")[0]: return None k_in, k_out = k_arg_shape # bias in sum op params = k_in * k_out + 1 flops = k_in * k_out elif one_op.type in ['sigmoid', 'tanh', 'relu', 'leaky_relu', 'prelu']: in_data_shape = block_vars[one_op.input("X")[0]].shape out_data_shape = block_vars[one_op.output("Out")[0]].shape params = 0 if one_op.type == 'prelu': params = 1 flops = 1 for one_dim in in_data_shape: flops *= one_dim elif one_op.type == 'batch_norm': in_data_shape = block_vars[one_op.input("X")[0]].shape out_data_shape = block_vars[one_op.output("Y")[0]].shape _, c_in, h_out, w_out = in_data_shape # gamma, beta params = c_in * 2 # compute mean and std flops = h_out * w_out * c_in * 2 else: return None return in_data_shape, out_data_shape, params, flops def _format_summary(collected_ops_list): ''' Format summary report. Args: collected_ops_list: the collected operator with summary Returns: summary_table: summary report format total: sum param and flops ''' _verify_dependent_package() from prettytable import PrettyTable summary_table = PrettyTable( ["No.", "TYPE", "INPUT", "OUTPUT", "PARAMs", "FLOPs"]) summary_table.align = 'r' total = {} total_params = [] total_flops = [] for i, one_op in enumerate(collected_ops_list): # notice the order table_row = [ i, one_op['type'], one_op['input_shape'], one_op['out_shape'], int(one_op['PARAMs']), int(one_op['FLOPs']), ] summary_table.add_row(table_row) total_params.append(int(one_op['PARAMs'])) total_flops.append(int(one_op['FLOPs'])) total['params'] = total_params total['flops'] = total_flops return summary_table, total def _verify_dependent_package(): """ Verify whether `prettytable` is installed. """ try: from prettytable import PrettyTable except ImportError: raise ImportError( "paddle.summary() requires package `prettytable`, place install it firstly using `pip install prettytable`. " ) def _print_summary(summary_table, total): ''' Print all the summary on terminal. Args: summary_table: summary report format total: sum param and flops ''' parmas = total['params'] flops = total['flops'] print(summary_table) print('Total PARAMs: {}({:.4f}M)'.format( sum(parmas), sum(parmas) / (10**6))) print('Total FLOPs: {}({:.2f}G)'.format(sum(flops), sum(flops) / 10**9)) print( "Notice: \n now supported ops include [Conv, DepthwiseConv, FC(mul), BatchNorm, Pool, Activation(sigmoid, tanh, relu, leaky_relu, prelu)]" )