diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index b9f19926a4ce43fcdccf3afd2e00f8d0bbf31d4d..273d7b8ff44f9eafdfe2632a383cc3dff389d8b8 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -35,13 +35,11 @@ def _graph_flops(graph, only_conv=False, detail=False): for op in graph.ops(): if op.type() in ['conv2d', 'depthwise_conv2d']: filter_shape = op.inputs("Filter")[0].shape() - input_shape = op.inputs("Input")[0].shape() output_shape = op.outputs("Output")[0].shape() - _, c_in, _, _ = input_shape - c_out, _, k_h, k_w = filter_shape + c_out, c_in, k_h, k_w = filter_shape _, _, h_out, w_out = output_shape - groups = op.attr("groups") - kernel_ops = k_h * k_w * (float(c_in) / groups) + # c_in is the channel number of filter. It is (input_channel // groups). + kernel_ops = k_h * k_w * float(c_in) if len(op.inputs("Bias")) > 0: with_bias = 1 else: @@ -50,7 +48,6 @@ def _graph_flops(graph, only_conv=False, detail=False): flops += op_flops params2flops[op.inputs("Filter")[0].name()] = op_flops elif op.type() == 'pool2d' and not only_conv: - input_shape = op.inputs("X")[0].shape() output_shape = op.outputs("Out")[0].shape() _, c_out, h_out, w_out = output_shape k_size = op.attr("ksize")