flops.py 3.2 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from ..core import GraphWrapper

__all__ = ["flops"]


W
whs 已提交
21
def flops(program, only_conv=True, detail=False):
W
whs 已提交
22 23
    """Get FLOPs of target graph.

W
wanghaoshuang 已提交
24 25
    Args:
        program(Program): The program used to calculate FLOPS.
W
whs 已提交
26 27 28 29
        only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
                         default: True.
        detail(bool): Whether to return detail of each convolution layer.
    
W
whs 已提交
30 31
    Returns:
        int|tuple: If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`. The details is a dict whose key is the parameter name of convlution layer and value is the FLOPs of each convolution layer. 
W
wanghaoshuang 已提交
32 33
    """
    graph = GraphWrapper(program)
W
whs 已提交
34
    return _graph_flops(graph, only_conv=only_conv, detail=detail)
W
wanghaoshuang 已提交
35 36


W
whs 已提交
37
def _graph_flops(graph, only_conv=True, detail=False):
W
wanghaoshuang 已提交
38 39
    assert isinstance(graph, GraphWrapper)
    flops = 0
W
wanghaoshuang 已提交
40
    params2flops = {}
W
wanghaoshuang 已提交
41 42 43 44
    for op in graph.ops():
        if op.type() in ['conv2d', 'depthwise_conv2d']:
            filter_shape = op.inputs("Filter")[0].shape()
            output_shape = op.outputs("Output")[0].shape()
W
whs 已提交
45
            c_out, c_in, k_h, k_w = filter_shape
W
wanghaoshuang 已提交
46
            _, _, h_out, w_out = output_shape
W
whs 已提交
47 48
            # c_in is the channel number of filter. It is (input_channel // groups).
            kernel_ops = k_h * k_w * float(c_in)
W
wanghaoshuang 已提交
49 50 51 52
            if len(op.inputs("Bias")) > 0:
                with_bias = 1
            else:
                with_bias = 0
W
whs 已提交
53
            op_flops = h_out * w_out * c_out * (kernel_ops + with_bias)
W
wanghaoshuang 已提交
54 55
            flops += op_flops
            params2flops[op.inputs("Filter")[0].name()] = op_flops
W
wanghaoshuang 已提交
56 57 58 59 60 61
        elif op.type() == 'pool2d' and not only_conv:
            output_shape = op.outputs("Out")[0].shape()
            _, c_out, h_out, w_out = output_shape
            k_size = op.attr("ksize")
            flops += h_out * w_out * c_out * (k_size[0]**2)

W
whs 已提交
62
        elif op.type() == 'mul':
W
wanghaoshuang 已提交
63 64 65 66 67
            x_shape = list(op.inputs("X")[0].shape())
            y_shape = op.inputs("Y")[0].shape()
            if x_shape[0] == -1:
                x_shape[0] = 1

W
whs 已提交
68 69 70 71
            op_flops = x_shape[0] * x_shape[1] * y_shape[1]
            flops += op_flops
            params2flops[op.inputs("Y")[0].name()] = op_flops

W
whs 已提交
72 73
        elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'
                           ] and not only_conv:
W
wanghaoshuang 已提交
74 75 76 77 78
            input_shape = list(op.inputs("X")[0].shape())
            if input_shape[0] == -1:
                input_shape[0] = 1
            flops += np.product(input_shape)

W
wanghaoshuang 已提交
79 80 81 82
    if detail:
        return flops, params2flops
    else:
        return flops