flops.py 3.2 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from ..core import GraphWrapper

__all__ = ["flops"]


W
whs 已提交
21
def flops(program, only_conv=True, detail=False):
W
wanghaoshuang 已提交
22 23 24 25
    """
    Get FLOPS of target graph.
    Args:
        program(Program): The program used to calculate FLOPS.
W
whs 已提交
26 27 28 29 30 31 32 33
        only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
                         default: True.
        detail(bool): Whether to return detail of each convolution layer.
    
    Return:
        If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`
        FLOPs(int): The FLOPs of target network.
        details(dict): The key is the parameter name of convlution layer and the value is the FLOPs of each convolution layer.
W
wanghaoshuang 已提交
34 35
    """
    graph = GraphWrapper(program)
W
whs 已提交
36
    return _graph_flops(graph, only_conv=only_conv, detail=detail)
W
wanghaoshuang 已提交
37 38


W
wanghaoshuang 已提交
39
def _graph_flops(graph, only_conv=False, detail=False):
W
wanghaoshuang 已提交
40 41
    assert isinstance(graph, GraphWrapper)
    flops = 0
W
wanghaoshuang 已提交
42
    params2flops = {}
W
wanghaoshuang 已提交
43 44 45 46
    for op in graph.ops():
        if op.type() in ['conv2d', 'depthwise_conv2d']:
            filter_shape = op.inputs("Filter")[0].shape()
            output_shape = op.outputs("Output")[0].shape()
W
whs 已提交
47
            c_out, c_in, k_h, k_w = filter_shape
W
wanghaoshuang 已提交
48
            _, _, h_out, w_out = output_shape
W
whs 已提交
49 50
            # c_in is the channel number of filter. It is (input_channel // groups).
            kernel_ops = k_h * k_w * float(c_in)
W
wanghaoshuang 已提交
51 52 53 54
            if len(op.inputs("Bias")) > 0:
                with_bias = 1
            else:
                with_bias = 0
W
whs 已提交
55
            op_flops = h_out * w_out * c_out * (kernel_ops + with_bias)
W
wanghaoshuang 已提交
56 57
            flops += op_flops
            params2flops[op.inputs("Filter")[0].name()] = op_flops
W
wanghaoshuang 已提交
58 59 60 61 62 63
        elif op.type() == 'pool2d' and not only_conv:
            output_shape = op.outputs("Out")[0].shape()
            _, c_out, h_out, w_out = output_shape
            k_size = op.attr("ksize")
            flops += h_out * w_out * c_out * (k_size[0]**2)

W
whs 已提交
64
        elif op.type() == 'mul':
W
wanghaoshuang 已提交
65 66 67 68 69
            x_shape = list(op.inputs("X")[0].shape())
            y_shape = op.inputs("Y")[0].shape()
            if x_shape[0] == -1:
                x_shape[0] = 1

W
whs 已提交
70 71 72 73 74
            op_flops = x_shape[0] * x_shape[1] * y_shape[1]
            flops += op_flops
            params2flops[op.inputs("Y")[0].name()] = op_flops

        elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'] and not only_conv:
W
wanghaoshuang 已提交
75 76 77 78 79
            input_shape = list(op.inputs("X")[0].shape())
            if input_shape[0] == -1:
                input_shape[0] = 1
            flops += np.product(input_shape)

W
wanghaoshuang 已提交
80 81 82 83
    if detail:
        return flops, params2flops
    else:
        return flops