flops.py 6.3 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
whs 已提交
14
import paddle
W
wanghaoshuang 已提交
15
import numpy as np
W
whs 已提交
16
from ..core import GraphWrapper, dygraph2program
W
wanghaoshuang 已提交
17

W
whs 已提交
18
__all__ = ["flops", "dygraph_flops"]
W
wanghaoshuang 已提交
19 20


21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
def flops(model, inputs=None, dtypes=None, only_conv=True, detail=False):
    """
    Compute the FLOPs of nn.Layer of paddle.Program.
    Args:
      model(paddle.nn.Layer|paddle.static.Program): The target model.
      inputs(list): It is only used when model is instance of 'paddle.nn.Layer'. The dummy inputs used for 'model.forward'. It can be:
                      1. list<int>|tuple<int>: means 'model.forward' accepts
                         only one variable as argument and the shape of
                         variable is 'inputs'.
                      2. list<list<list>>: means 'model.forward' accepts multiple
                         variables as arguments and the shapes of variables is 'inputs'.
                      3. others: 'inputs' will be used as argument list by calling
                         'model.forward(*inputs)'.
      dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means
                      data type of each input. None means all the inputs is 'float32'.
                      Default: None.
      only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
                         default: True.
      detail(bool): Whether to return detail of each convolution layer.
    """
    if isinstance(model, paddle.static.Program):
        return _static_flops(model, only_conv=only_conv, detail=detail)
    elif isinstance(model, paddle.nn.Layer):
        return dygraph_flops(
            model, inputs, dtypes=dtypes, only_conv=only_conv, detail=detail)


def _static_flops(program, only_conv=True, detail=False):
W
whs 已提交
49 50
    """Get FLOPs of target graph.

W
wanghaoshuang 已提交
51 52
    Args:
        program(Program): The program used to calculate FLOPS.
W
whs 已提交
53 54 55 56
        only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
                         default: True.
        detail(bool): Whether to return detail of each convolution layer.
    
W
whs 已提交
57 58
    Returns:
        int|tuple: If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`. The details is a dict whose key is the parameter name of convlution layer and value is the FLOPs of each convolution layer. 
W
wanghaoshuang 已提交
59 60
    """
    graph = GraphWrapper(program)
W
whs 已提交
61
    return _graph_flops(graph, only_conv=only_conv, detail=detail)
W
wanghaoshuang 已提交
62 63


W
whs 已提交
64
def _graph_flops(graph, only_conv=True, detail=False):
W
wanghaoshuang 已提交
65 66
    assert isinstance(graph, GraphWrapper)
    flops = 0
W
wanghaoshuang 已提交
67
    params2flops = {}
W
wanghaoshuang 已提交
68 69 70 71
    for op in graph.ops():
        if op.type() in ['conv2d', 'depthwise_conv2d']:
            filter_shape = op.inputs("Filter")[0].shape()
            output_shape = op.outputs("Output")[0].shape()
W
whs 已提交
72
            c_out, c_in, k_h, k_w = filter_shape
W
wanghaoshuang 已提交
73
            _, _, h_out, w_out = output_shape
W
whs 已提交
74 75
            # c_in is the channel number of filter. It is (input_channel // groups).
            kernel_ops = k_h * k_w * float(c_in)
W
wanghaoshuang 已提交
76 77 78 79
            if len(op.inputs("Bias")) > 0:
                with_bias = 1
            else:
                with_bias = 0
W
whs 已提交
80
            op_flops = h_out * w_out * c_out * (kernel_ops + with_bias)
W
wanghaoshuang 已提交
81 82
            flops += op_flops
            params2flops[op.inputs("Filter")[0].name()] = op_flops
W
wanghaoshuang 已提交
83 84 85 86
        elif op.type() == 'pool2d' and not only_conv:
            output_shape = op.outputs("Out")[0].shape()
            _, c_out, h_out, w_out = output_shape
            k_size = op.attr("ksize")
87 88
            if op.attr('pooling_type') == 'avg':
                flops += (h_out * w_out * c_out * (k_size[0]**2) * 2)
W
wanghaoshuang 已提交
89

G
Guanghua Yu 已提交
90
        elif op.type() in ['mul', 'matmul', 'matmul_v2']:
W
wanghaoshuang 已提交
91 92 93 94 95
            x_shape = list(op.inputs("X")[0].shape())
            y_shape = op.inputs("Y")[0].shape()
            if x_shape[0] == -1:
                x_shape[0] = 1

W
whs 已提交
96 97 98 99
            op_flops = x_shape[0] * x_shape[1] * y_shape[1]
            flops += op_flops
            params2flops[op.inputs("Y")[0].name()] = op_flops

W
whs 已提交
100 101
        elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'
                           ] and not only_conv:
W
wanghaoshuang 已提交
102 103 104
            input_shape = list(op.inputs("X")[0].shape())
            if input_shape[0] == -1:
                input_shape[0] = 1
105 106 107 108 109
            if op.type() == 'batch_norm':
                op_flops = np.product(input_shape) * 2
            else:
                op_flops = np.product(input_shape)
            flops += op_flops
W
wanghaoshuang 已提交
110

W
wanghaoshuang 已提交
111 112 113 114
    if detail:
        return flops, params2flops
    else:
        return flops
W
whs 已提交
115 116


W
whs 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False):
    """
    Compute the FLOPs of nn.Layer.
    Args:
      model(nn.Layer): The target model.
      inputs(list): The dummy inputs used for 'model.forward'. It can be:
                      1. list<int>|tuple<int>: means 'model.forward' accepts
                         only one variable as argument and the shape of
                         variable is 'inputs'.
                      2. list<list<list>>: means 'model.forward' accepts multiple
                         variables as arguments and the shapes of variables is 'inputs'.
                      3. others: 'inputs' will be used as argument list by calling
                         'model.forward(*inputs)'.
      dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means
                      data type of each input. None means all the inputs is 'float32'.
                      Default: None.
      only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
                         default: True.
      detail(bool): Whether to return detail of each convolution layer.
    """
W
whs 已提交
137

M
minghaoBD 已提交
138
    program = dygraph2program(model, inputs, dtypes=dtypes)
W
whs 已提交
139 140
    graph = GraphWrapper(program)
    return _graph_flops(graph, only_conv=only_conv, detail=detail)