static_flops.py 7.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from collections import OrderedDict
17
from paddle.static import Program, Variable
18

19 20
__all__ = []

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

class VarWrapper(object):
    def __init__(self, var, graph):
        assert isinstance(var, Variable)
        assert isinstance(graph, GraphWrapper)
        self._var = var
        self._graph = graph

    def name(self):
        """
        Get the name of the variable.
        """
        return self._var.name

    def shape(self):
        """
        Get the shape of the varibale.
        """
        return self._var.shape


class OpWrapper(object):
    def __init__(self, op, graph):
        assert isinstance(graph, GraphWrapper)
        self._op = op
        self._graph = graph

    def type(self):
        """
        Get the type of this operator.
        """
        return self._op.type

    def inputs(self, name):
        """
        Get all the varibales by the input name.
        """
        if name in self._op.input_names:
            return [
                self._graph.var(var_name) for var_name in self._op.input(name)
            ]
        else:
            return []

    def outputs(self, name):
        """
        Get all the varibales by the output name.
        """
        return [self._graph.var(var_name) for var_name in self._op.output(name)]


class GraphWrapper(object):
    """
    It is a wrapper of paddle.fluid.framework.IrGraph with some special functions
    for paddle slim framework.

    Args:
78
        program(framework.Program): A program with
79 80 81 82 83 84 85 86 87
        in_nodes(dict): A dict to indicate the input nodes of the graph.
                        The key is user-defined and human-readable name.
                        The value is the name of Variable.
        out_nodes(dict): A dict to indicate the input nodes of the graph.
                        The key is user-defined and human-readable name.
                        The value is the name of Variable.
    """

    def __init__(self, program=None, in_nodes=[], out_nodes=[]):
88
        """ """
89
        super().__init__()
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
        self.program = Program() if program is None else program
        self.persistables = {}
        self.teacher_persistables = {}
        for var in self.program.list_vars():
            if var.persistable:
                self.persistables[var.name] = var
        self.compiled_graph = None
        in_nodes = [] if in_nodes is None else in_nodes
        out_nodes = [] if out_nodes is None else out_nodes
        self.in_nodes = OrderedDict(in_nodes)
        self.out_nodes = OrderedDict(out_nodes)
        self._attrs = OrderedDict()

    def ops(self):
        """
        Return all operator nodes included in the graph as a set.
        """
        ops = []
        for block in self.program.blocks:
            for op in block.ops:
                ops.append(OpWrapper(op, self))
        return ops

    def var(self, name):
        """
        Get the variable by variable name.
        """
        for block in self.program.blocks:
            if block.has_var(name):
                return VarWrapper(block.var(name), self)
        return None


def count_convNd(op):
    filter_shape = op.inputs("Filter")[0].shape()
    filter_ops = np.product(filter_shape[1:])
    bias_ops = 1 if len(op.inputs("Bias")) > 0 else 0
    output_numel = np.product(op.outputs("Output")[0].shape()[1:])
    total_ops = output_numel * (filter_ops + bias_ops)
Y
yukavio 已提交
129
    total_ops = abs(total_ops)
130 131 132 133 134 135 136 137 138 139 140
    return total_ops


def count_leaky_relu(op):
    total_ops = np.product(op.outputs("Output")[0].shape()[1:])
    return total_ops


def count_bn(op):
    output_numel = np.product(op.outputs("Y")[0].shape()[1:])
    total_ops = 2 * output_numel
Y
yukavio 已提交
141
    total_ops = abs(total_ops)
142 143 144 145 146 147 148
    return total_ops


def count_linear(op):
    total_mul = op.inputs("Y")[0].shape()[0]
    numel = np.product(op.outputs("Out")[0].shape()[1:])
    total_ops = total_mul * numel
Y
yukavio 已提交
149
    total_ops = abs(total_ops)
150 151 152 153 154 155 156 157 158 159 160 161
    return total_ops


def count_pool2d(op):
    input_shape = op.inputs("X")[0].shape()
    output_shape = op.outputs('Out')[0].shape()
    kernel = np.array(input_shape[2:]) // np.array(output_shape[2:])
    total_add = np.product(kernel)
    total_div = 1
    kernel_ops = total_add + total_div
    num_elements = np.product(output_shape[1:])
    total_ops = kernel_ops * num_elements
Y
yukavio 已提交
162
    total_ops = abs(total_ops)
163 164 165 166 167 168
    return total_ops


def count_element_op(op):
    input_shape = op.inputs("X")[0].shape()
    total_ops = np.product(input_shape[1:])
Y
yukavio 已提交
169
    total_ops = abs(total_ops)
170 171 172
    return total_ops


Y
yukavio 已提交
173 174 175
def _graph_flops(graph, detail=False):
    assert isinstance(graph, GraphWrapper)
    flops = 0
176
    op_flops = 0
177
    table = Table(["OP Type", 'Param name', "Flops"])
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    for op in graph.ops():
        param_name = ''
        if op.type() in ['conv2d', 'depthwise_conv2d']:
            op_flops = count_convNd(op)
            flops += op_flops
            param_name = op.inputs("Filter")[0].name()
        elif op.type() == 'pool2d':
            op_flops = count_pool2d(op)
            flops += op_flops

        elif op.type() in ['mul', 'matmul']:
            op_flops = count_linear(op)
            flops += op_flops
            param_name = op.inputs("Y")[0].name()
        elif op.type() == 'batch_norm':
            op_flops = count_bn(op)
            flops += op_flops
        elif op.type().startswith('element'):
            op_flops = count_element_op(op)
            flops += op_flops
        if op_flops != 0:
            table.add_row([op.type(), param_name, op_flops])
        op_flops = 0
    if detail:
202
        table.print_table()
203 204 205 206 207 208
    return flops


def static_flops(program, print_detail=False):
    graph = GraphWrapper(program)
    return _graph_flops(graph, detail=print_detail)
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224


class Table(object):
    def __init__(self, table_heads):
        self.table_heads = table_heads
        self.table_len = []
        self.data = []
        self.col_num = len(table_heads)
        for head in table_heads:
            self.table_len.append(len(head))

    def add_row(self, row_str):
        if not isinstance(row_str, list):
            print('The row_str should be a list')
        if len(row_str) != self.col_num:
            print(
225 226 227 228
                'The length of row data should be equal the length of table heads, but the data: {} is not equal table heads {}'.format(
                    len(row_str), self.col_num
                )
            )
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
        for i in range(self.col_num):
            if len(str(row_str[i])) > self.table_len[i]:
                self.table_len[i] = len(str(row_str[i]))
        self.data.append(row_str)

    def print_row(self, row):
        string = ''
        for i in range(self.col_num):
            string += '|' + str(row[i]).center(self.table_len[i] + 2)
        string += '|'
        print(string)

    def print_shelf(self):
        string = ''
        for length in self.table_len:
            string += '+'
            string += '-' * (length + 2)
        string += '+'
        print(string)

    def print_table(self):
        self.print_shelf()
        self.print_row(self.table_heads)
        self.print_shelf()
        for data in self.data:
            self.print_row(data)
        self.print_shelf()