flops.py 4.6 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
whs 已提交
14
import paddle
W
wanghaoshuang 已提交
15
import numpy as np
W
whs 已提交
16
import paddle.jit as jit
17
from ..core import GraphWrapper, dygraph2program
W
wanghaoshuang 已提交
18

W
whs 已提交
19
__all__ = ["flops", "dygraph_flops"]
W
wanghaoshuang 已提交
20 21


W
whs 已提交
22
def flops(program, only_conv=True, detail=False):
W
whs 已提交
23 24
    """Get FLOPs of target graph.

W
wanghaoshuang 已提交
25 26
    Args:
        program(Program): The program used to calculate FLOPS.
W
whs 已提交
27 28 29 30
        only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
                         default: True.
        detail(bool): Whether to return detail of each convolution layer.
    
W
whs 已提交
31 32
    Returns:
        int|tuple: If `detail` is true, then return a tuple in format `(FLOPs, details)`, otherwise it will just return `FlOPs`. The details is a dict whose key is the parameter name of convlution layer and value is the FLOPs of each convolution layer. 
W
wanghaoshuang 已提交
33 34
    """
    graph = GraphWrapper(program)
W
whs 已提交
35
    return _graph_flops(graph, only_conv=only_conv, detail=detail)
W
wanghaoshuang 已提交
36 37


W
whs 已提交
38
def _graph_flops(graph, only_conv=True, detail=False):
W
wanghaoshuang 已提交
39 40
    assert isinstance(graph, GraphWrapper)
    flops = 0
W
wanghaoshuang 已提交
41
    params2flops = {}
W
wanghaoshuang 已提交
42 43 44 45
    for op in graph.ops():
        if op.type() in ['conv2d', 'depthwise_conv2d']:
            filter_shape = op.inputs("Filter")[0].shape()
            output_shape = op.outputs("Output")[0].shape()
W
whs 已提交
46
            c_out, c_in, k_h, k_w = filter_shape
W
wanghaoshuang 已提交
47
            _, _, h_out, w_out = output_shape
W
whs 已提交
48 49
            # c_in is the channel number of filter. It is (input_channel // groups).
            kernel_ops = k_h * k_w * float(c_in)
W
wanghaoshuang 已提交
50 51 52 53
            if len(op.inputs("Bias")) > 0:
                with_bias = 1
            else:
                with_bias = 0
W
whs 已提交
54
            op_flops = h_out * w_out * c_out * (kernel_ops + with_bias)
W
wanghaoshuang 已提交
55 56
            flops += op_flops
            params2flops[op.inputs("Filter")[0].name()] = op_flops
W
wanghaoshuang 已提交
57 58 59 60 61 62
        elif op.type() == 'pool2d' and not only_conv:
            output_shape = op.outputs("Out")[0].shape()
            _, c_out, h_out, w_out = output_shape
            k_size = op.attr("ksize")
            flops += h_out * w_out * c_out * (k_size[0]**2)

W
whs 已提交
63
        elif op.type() == 'mul':
W
wanghaoshuang 已提交
64 65 66 67 68
            x_shape = list(op.inputs("X")[0].shape())
            y_shape = op.inputs("Y")[0].shape()
            if x_shape[0] == -1:
                x_shape[0] = 1

W
whs 已提交
69 70 71 72
            op_flops = x_shape[0] * x_shape[1] * y_shape[1]
            flops += op_flops
            params2flops[op.inputs("Y")[0].name()] = op_flops

W
whs 已提交
73 74
        elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'
                           ] and not only_conv:
W
wanghaoshuang 已提交
75 76 77 78 79
            input_shape = list(op.inputs("X")[0].shape())
            if input_shape[0] == -1:
                input_shape[0] = 1
            flops += np.product(input_shape)

W
wanghaoshuang 已提交
80 81 82 83
    if detail:
        return flops, params2flops
    else:
        return flops
W
whs 已提交
84 85


86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False):
    """
    Compute the FLOPs of nn.Layer.
    Args:
      model(nn.Layer): The target model.
      inputs(list): The dummy inputs used for 'model.forward'. It can be:
                      1. list<int>|tuple<int>: means 'model.forward' accepts
                         only one variable as argument and the shape of
                         variable is 'inputs'.
                      2. list<list<list>>: means 'model.forward' accepts multiple
                         variables as arguments and the shapes of variables is 'inputs'.
                      3. others: 'inputs' will be used as argument list by calling
                         'model.forward(*inputs)'.
      dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means
                      data type of each input. None means all the inputs is 'float32'.
                      Default: None.
      only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
                         default: True.
      detail(bool): Whether to return detail of each convolution layer.
    """
W
whs 已提交
106

107
    program = dygraph2program(model, inputs)
W
whs 已提交
108 109
    graph = GraphWrapper(program)
    return _graph_flops(graph, only_conv=only_conv, detail=detail)