From 3ea128654423b53bfb87d60e47aa6a9ccd23c916 Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 24 Dec 2019 10:29:31 +0800 Subject: [PATCH] Fix FLOPs API (#11) --- paddleslim/analysis/flops.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index 273d7b8f..fed377db 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -18,14 +18,22 @@ from ..core import GraphWrapper __all__ = ["flops"] -def flops(program, detail=False): +def flops(program, only_conv=True, detail=False): """ Get FLOPS of target graph. Args: program(Program): The program used to calculate FLOPS. + only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true. + default: True. + detail(bool): Whether to return detail of each convolution layer. + + Return: + If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs` + FLOPs(int): The FLOPs of target network. + details(dict): The key is the parameter name of convlution layer and the value is the FLOPs of each convolution layer. """ graph = GraphWrapper(program) - return _graph_flops(graph, detail=detail) + return _graph_flops(graph, only_conv=only_conv, detail=detail) def _graph_flops(graph, only_conv=False, detail=False): @@ -44,7 +52,7 @@ def _graph_flops(graph, only_conv=False, detail=False): with_bias = 1 else: with_bias = 0 - op_flops = 2 * h_out * w_out * c_out * (kernel_ops + with_bias) + op_flops = h_out * w_out * c_out * (kernel_ops + with_bias) flops += op_flops params2flops[op.inputs("Filter")[0].name()] = op_flops elif op.type() == 'pool2d' and not only_conv: @@ -53,14 +61,17 @@ def _graph_flops(graph, only_conv=False, detail=False): k_size = op.attr("ksize") flops += h_out * w_out * c_out * (k_size[0]**2) - elif op.type() == 'mul' and not only_conv: + elif op.type() == 'mul': x_shape = list(op.inputs("X")[0].shape()) y_shape = op.inputs("Y")[0].shape() if x_shape[0] == -1: x_shape[0] = 1 - flops += 2 * x_shape[0] * x_shape[1] * y_shape[1] - elif op.type() in ['relu', 'sigmoid', 'batch_norm'] and not only_conv: + op_flops = x_shape[0] * x_shape[1] * y_shape[1] + flops += op_flops + params2flops[op.inputs("Y")[0].name()] = op_flops + + elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'] and not only_conv: input_shape = list(op.inputs("X")[0].shape()) if input_shape[0] == -1: input_shape[0] = 1 -- GitLab