diff --git a/python/paddle/fluid/contrib/model_stat.py b/python/paddle/fluid/contrib/model_stat.py new file mode 100644 index 0000000000000000000000000000000000000000..0d974c8d9685840c79de17f297fcba00b01a6c35 --- /dev/null +++ b/python/paddle/fluid/contrib/model_stat.py @@ -0,0 +1,194 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Example: + >>from paddle.fluid.contrib.model_stat import summary + >>main_program = ... + >>summary(main_program) + +-----+------------+----------------+----------------+---------+------------+ + | No. | TYPE | INPUT | OUTPUT | PARAMs | FLOPs | + +-----+------------+----------------+----------------+---------+------------+ + | 0 | conv2d | (3, 200, 200) | (64, 100, 100) | 9408 | 188160000 | + | 1 | batch_norm | (64, 100, 100) | (64, 100, 100) | 256 | 640000 | + | 2 | relu | (64, 100, 100) | (64, 100, 100) | 0 | 640000 | + | 3 | pool2d | (64, 100, 100) | (64, 50, 50) | 0 | 1440000 | + ... + | 176 | conv2d | (512, 7, 7) | (512, 7, 7) | 2359296 | 231211008 | + | 177 | relu | (512, 7, 7) | (512, 7, 7) | 0 | 25088 | + | 178 | conv2d | (512, 7, 7) | (2048, 7, 7) | 1048576 | 102760448 | + | 179 | relu | (2048, 7, 7) | (2048, 7, 7) | 0 | 100352 | + | 180 | pool2d | (2048, 7, 7) | (2048, 1, 1) | 0 | 100352 | + +-----+------------+----------------+----------------+---------+------------+ + Total PARAMs: 48017344(0.0480G) + Total FLOPs: 11692747751(11.69G) +''' +from collections import OrderedDict +from prettytable import PrettyTable + + +def summary(main_prog): + ''' + It can summary model's PARAMS, FLOPs until now. + It support common operator like conv, fc, pool, relu, sigmoid, bn etc. + Args: + main_prog: main program + Returns: + print summary on terminal + ''' + collected_ops_list = [] + for one_b in main_prog.blocks: + block_vars = one_b.vars + for one_op in one_b.ops: + op_info = OrderedDict() + spf_res = _summary_model(block_vars, one_op) + if spf_res is None: + continue + # TODO: get the operator name + op_info['type'] = one_op.type + op_info['input_shape'] = spf_res[0][1:] + op_info['out_shape'] = spf_res[1][1:] + op_info['PARAMs'] = spf_res[2] + op_info['FLOPs'] = spf_res[3] + collected_ops_list.append(op_info) + + summary_table, total = _format_summary(collected_ops_list) + _print_summary(summary_table, total) + + +def _summary_model(block_vars, one_op): + ''' + Compute operator's params and flops. + Args: + block_vars: all vars of one block + one_op: one operator to count + Returns: + in_data_shape: one operator's input data shape + out_data_shape: one operator's output data shape + params: one operator's PARAMs + flops: : one operator's FLOPs + ''' + if one_op.type in ['conv2d', 'depthwise_conv2d']: + k_arg_shape = block_vars[one_op.input("Filter")[0]].shape + in_data_shape = block_vars[one_op.input("Input")[0]].shape + out_data_shape = block_vars[one_op.output("Output")[0]].shape + c_out, c_in, k_h, k_w = k_arg_shape + _, c_out_, h_out, w_out = out_data_shape + assert c_out == c_out_, 'shape error!' + k_groups = one_op.attr("groups") + kernel_ops = k_h * k_w * (c_in / k_groups) + bias_ops = 0 if one_op.input("Bias") == [] else 1 + params = c_out * (kernel_ops + bias_ops) + flops = h_out * w_out * c_out * (kernel_ops + bias_ops) + # base nvidia paper, include mul and add + flops = 2 * flops + + elif one_op.type == 'pool2d': + in_data_shape = block_vars[one_op.input("X")[0]].shape + out_data_shape = block_vars[one_op.output("Out")[0]].shape + _, c_out, h_out, w_out = out_data_shape + k_size = one_op.attr("ksize") + params = 0 + flops = h_out * w_out * c_out * (k_size[0] * k_size[1]) + + elif one_op.type == 'mul': + k_arg_shape = block_vars[one_op.input("Y")[0]].shape + in_data_shape = block_vars[one_op.input("X")[0]].shape + out_data_shape = block_vars[one_op.output("Out")[0]].shape + # TODO: fc has mul ops + # add attr to mul op, tell us whether it belongs to 'fc' + # this's not the best way + if 'fc' not in one_op.output("Out")[0]: + return None + k_in, k_out = k_arg_shape + # bias in sum op + params = k_in * k_out + 1 + flops = k_in * k_out + + elif one_op.type in ['sigmoid', 'tanh', 'relu', 'leaky_relu', 'prelu']: + in_data_shape = block_vars[one_op.input("X")[0]].shape + out_data_shape = block_vars[one_op.output("Out")[0]].shape + params = 0 + if one_op.type == 'prelu': + params = 1 + flops = 1 + for one_dim in in_data_shape: + flops *= one_dim + + elif one_op.type == 'batch_norm': + in_data_shape = block_vars[one_op.input("X")[0]].shape + out_data_shape = block_vars[one_op.output("Y")[0]].shape + _, c_in, h_out, w_out = in_data_shape + # gamma, beta + params = c_in * 2 + # compute mean and std + flops = h_out * w_out * c_in * 2 + + else: + return None + + return in_data_shape, out_data_shape, params, flops + + +def _format_summary(collected_ops_list): + ''' + Format summary report. + Args: + collected_ops_list: the collected operator with summary + Returns: + summary_table: summary report format + total: sum param and flops + ''' + summary_table = PrettyTable( + ["No.", "TYPE", "INPUT", "OUTPUT", "PARAMs", "FLOPs"]) + summary_table.align = 'r' + + total = {} + total_params = [] + total_flops = [] + for i, one_op in enumerate(collected_ops_list): + # notice the order + table_row = [ + i, + one_op['type'], + one_op['input_shape'], + one_op['out_shape'], + int(one_op['PARAMs']), + int(one_op['FLOPs']), + ] + summary_table.add_row(table_row) + total_params.append(int(one_op['PARAMs'])) + total_flops.append(int(one_op['FLOPs'])) + + total['params'] = total_params + total['flops'] = total_flops + + return summary_table, total + + +def _print_summary(summary_table, total): + ''' + Print all the summary on terminal. + Args: + summary_table: summary report format + total: sum param and flops + ''' + parmas = total['params'] + flops = total['flops'] + print(summary_table) + print('Total PARAMs: {}({:.4f}M)'.format( + sum(parmas), sum(parmas) / (10**6))) + print('Total FLOPs: {}({:.2f}G)'.format(sum(flops), sum(flops) / 10**9)) + print( + "Notice: \n now supported ops include [Conv, DepthwiseConv, FC(mul), BatchNorm, Pool, Activation(sigmoid, tanh, relu, leaky_relu, prelu)]" + )