module_stats.py 16.6 KB
Newer Older
1 2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
4 5 6 7
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8
from collections import Iterable, namedtuple
9
from functools import partial
10
from typing import Iterable
11 12 13 14 15 16 17 18

import numpy as np
import tabulate

import megengine as mge
import megengine.module as m
import megengine.module.qat as qatm
import megengine.module.quantized as qm
19 20
from megengine import Tensor
from megengine import functional as F
21
from megengine.core.tensor.dtype import get_dtype_bit
22
from megengine.functional.tensor import zeros
23
from megengine.tensor import Tensor
24

25 26
from .module_utils import set_module_mode_safe

27 28 29 30 31 32 33 34 35
try:
    mge.logger.MegEngineLogFormatter.max_lines = float("inf")
except AttributeError as e:
    raise ValueError("set logger max lines failed")

logger = mge.get_logger(__name__)
logger.setLevel("INFO")


36 37 38 39 40
_calc_flops_dict = {}
_calc_receptive_field_dict = {}


def _receptive_field_fallback(module, inputs, outputs):
41 42
    if not _receptive_field_enabled:
        return
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    assert not hasattr(module, "_rf")
    assert not hasattr(module, "_stride")
    if len(inputs) == 0:
        # TODO: support other dimension
        module._rf = (1, 1)
        module._stride = (1, 1)
        return module._rf, module._stride
    rf, stride = preprocess_receptive_field(module, inputs, outputs)
    module._rf = rf
    module._stride = stride
    return rf, stride


# key tuple, impl_dict, fallback
_iter_list = [
    ("flops_num", _calc_flops_dict, None),
    (
        ("receptive_field", "stride"),
        _calc_receptive_field_dict,
        _receptive_field_fallback,
    ),
]

66 67
_receptive_field_enabled = False

68 69

def _register_dict(*modules, dict=None):
70 71
    def callback(impl):
        for module in modules:
72
            dict[module] = impl
73 74 75 76 77
        return impl

    return callback


78 79 80 81 82 83 84 85
def register_flops(*modules):
    return _register_dict(*modules, dict=_calc_flops_dict)


def register_receptive_field(*modules):
    return _register_dict(*modules, dict=_calc_receptive_field_dict)


86 87 88 89 90 91 92 93 94 95
def enable_receptive_field():
    global _receptive_field_enabled
    _receptive_field_enabled = True


def disable_receptive_field():
    global _receptive_field_enabled
    _receptive_field_enabled = False


96
@register_flops(
97
    m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d
98
)
99
def flops_convNd(module: m.Conv2d, inputs, outputs):
100 101
    bias = 1 if module.bias is not None else 0
    # N x Cout x H x W x  (Cin x Kw x Kh + bias)
102 103 104
    return np.prod(outputs[0].shape) * (
        module.in_channels // module.groups * np.prod(module.kernel_size) + bias
    )
105

106

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
@register_flops(
    m.batchnorm._BatchNorm, m.SyncBatchNorm, m.GroupNorm, m.LayerNorm, m.InstanceNorm,
)
def flops_norm(module: m.Linear, inputs, outputs):
    return np.prod(inputs[0].shape) * 7


@register_flops(m.AvgPool2d, m.MaxPool2d)
def flops_pool(module: m.AvgPool2d, inputs, outputs):
    return np.prod(outputs[0].shape) * (module.kernel_size ** 2)


@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d)
def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs):
    stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1))
    kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h
    stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1))
    kernel_w = inputs[0].shape[3] - (inputs[0].shape[3] - 1) * stride_w
    return np.prod(outputs[0].shape) * kernel_h * kernel_w


128 129
@register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs):
130 131
    bias = module.out_features if module.bias is not None else 0
    return np.prod(outputs[0].shape) * module.in_features + bias
132

133 134 135 136 137 138 139 140 141 142

@register_flops(m.BatchMatMulActivation)
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs):
    bias = 1 if module.bias is not None else 0
    x = inputs[0]
    w = module.weight
    batch_size = x.shape[0]
    n, p = x.shape[1:]
    _, m = w.shape[1:]
    return n * (p + bias) * m * batch_size
143 144 145 146


# does not need import qat and quantized module since they inherit from float module.
hook_modules = (
147
    m.conv._ConvNd,
148
    m.Linear,
149
    m.BatchMatMulActivation,
150 151 152 153 154 155
    m.batchnorm._BatchNorm,
    m.LayerNorm,
    m.GroupNorm,
    m.InstanceNorm,
    m.pooling._PoolNd,
    m.adaptive_pooling._AdaptivePoolNd,
156 157 158
)


159 160 161 162 163 164 165 166 167 168
def _mean(inp):
    inp = mge.tensor(inp)
    return F.mean(inp).numpy()


def _std(inp):
    inp = mge.tensor(inp)
    return F.std(inp).numpy()


169 170 171 172 173 174 175 176 177 178 179 180 181 182
def dict2table(list_of_dict, header):
    table_data = [header]
    for d in list_of_dict:
        row = []
        for h in header:
            v = ""
            if h in d:
                v = d[h]
            row.append(v)
        table_data.append(row)
    return table_data


def sizeof_fmt(num, suffix="B"):
183 184 185 186 187 188 189 190
    if suffix == "B":
        scale = 1024.0
        units = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"]
    else:
        scale = 1000.0
        units = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]
    for unit in units:
        if abs(num) < scale or unit == units[-1]:
191
            return "{:3.3f} {}{}".format(num, unit, suffix)
192
        num /= scale
193 194


195 196 197 198
def preprocess_receptive_field(module, inputs, outputs):
    # TODO: support other dimensions
    pre_rf = (
        max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs),
199
        max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs),
200 201 202
    )
    pre_stride = (
        max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs),
203
        max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs),
204 205 206 207
    )
    return pre_rf, pre_stride


208
def get_op_stats(module, inputs, outputs):
209 210
    if not isinstance(outputs, tuple) and not isinstance(outputs, list):
        outputs = (outputs,)
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
    rst = {
        "input_shapes": [i.shape for i in inputs],
        "output_shapes": [o.shape for o in outputs],
    }
    valid_flag = False
    for key, _dict, fallback in _iter_list:
        for _type in _dict:
            if isinstance(module, _type):
                value = _dict[_type](module, inputs, outputs)
                valid_flag = True
                break
        else:
            if fallback is not None:
                value = fallback(module, inputs, outputs)
            continue

        if isinstance(key, tuple):
            assert isinstance(value, tuple)
            for k, v in zip(key, value):
                rst[k] = v
        else:
            rst[key] = value

    if valid_flag:
        return rst
    else:
        return None
    return


241
def sum_op_stats(flops, bar_length_max=20):
242
    max_flops_num = max([i["flops_num"] for i in flops] + [0])
243 244 245 246 247
    total_flops_num = 0
    for d in flops:
        total_flops_num += int(d["flops_num"])
        d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")

248
    for d in flops:
249 250 251
        ratio = d["ratio"] = d["flops_num"] / total_flops_num
        d["percentage"] = "{:.2f}%".format(ratio * 100)
        bar_length = int(d["flops_num"] / max_flops_num * bar_length_max)
252
        d["bar"] = "#" * bar_length
253
        d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs")
254

255 256 257 258 259 260 261 262 263 264 265 266
    total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
    total_var_size = sum(
        sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops
    )
    flops.append(
        dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
    )

    return total_flops_num, flops


def print_op_stats(flops):
267 268 269 270 271 272 273 274 275 276
    header = [
        "name",
        "class_name",
        "input_shapes",
        "output_shapes",
        "flops",
        "flops_cum",
        "percentage",
        "bar",
    ]
277 278 279
    if _receptive_field_enabled:
        header.insert(4, "receptive_field")
        header.insert(5, "stride")
280 281 282
    logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)))


283 284
def get_param_stats(param: Tensor):
    nbits = get_dtype_bit(np.dtype(param.dtype).name)
285 286 287 288
    shape = param.shape
    param_dim = np.prod(param.shape)
    param_size = param_dim * nbits // 8
    return {
289
        "dtype": np.dtype(param.dtype),
290
        "shape": shape,
291 292
        "mean": "{:.3g}".format(_mean(param)),
        "std": "{:.3g}".format(_std(param)),
293 294 295 296 297 298
        "param_dim": param_dim,
        "nbits": nbits,
        "size": param_size,
    }


299
def sum_param_stats(params, bar_length_max=20):
300
    max_size = max([d["size"] for d in params] + [0])
301 302 303 304 305 306 307
    total_param_dims, total_param_size = 0, 0
    for d in params:
        total_param_dims += int(d["param_dim"])
        total_param_size += int(d["size"])
        d["size_cum"] = sizeof_fmt(total_param_size)

    for d in params:
308 309 310 311
        ratio = d["size"] / total_param_size
        d["ratio"] = ratio
        d["percentage"] = "{:.2f}%".format(ratio * 100)
        bar_length = int(d["size"] / max_size * bar_length_max)
312
        d["size_bar"] = "#" * bar_length
313
        d["size"] = sizeof_fmt(d["size"])
314 315 316 317

    param_size = sizeof_fmt(total_param_size)
    params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))

318 319 320 321
    return total_param_dims, total_param_size, params


def print_param_stats(params):
322 323
    header = [
        "name",
324
        "dtype",
325 326 327 328
        "shape",
        "mean",
        "std",
        "param_dim",
329
        "nbits",
330 331 332 333 334 335 336 337 338
        "size",
        "size_cum",
        "percentage",
        "size_bar",
    ]
    logger.info(
        "param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
    )

339

340
def get_activation_stats(output: np.ndarray, has_input=False):
341
    out_shape = output.shape
342
    activations_dtype = np.dtype(output.dtype)
343 344 345
    nbits = get_dtype_bit(activations_dtype.name)
    act_dim = np.prod(out_shape)
    act_size = act_dim * nbits // 8
346
    activation_stats = {
347 348 349 350 351 352
        "dtype": activations_dtype,
        "shape": out_shape,
        "act_dim": act_dim,
        "nbits": nbits,
        "size": act_size,
    }
353 354 355 356
    if has_input:
        activation_stats["mean"] = "{:.3g}".format(output.mean())
        activation_stats["std"] = "{:.3g}".format(output.std())
    return activation_stats
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379


def sum_activations_stats(activations, bar_length_max=20):
    max_act_size = max([i["size"] for i in activations] + [0])
    total_act_dims, total_act_size = 0, 0
    for d in activations:
        total_act_size += int(d["size"])
        total_act_dims += int(d["act_dim"])
        d["size_cum"] = sizeof_fmt(total_act_size)

    for d in activations:
        ratio = d["ratio"] = d["size"] / total_act_size
        d["percentage"] = "{:.2f}%".format(ratio * 100)
        bar_length = int(d["size"] / max_act_size * bar_length_max)
        d["size_bar"] = "#" * bar_length
        d["size"] = sizeof_fmt(d["size"])

    act_size = sizeof_fmt(total_act_size)
    activations.append(dict(name="total", act_dim=total_act_dims, size=act_size,))

    return total_act_dims, total_act_size, activations


380
def print_activations_stats(activations, has_input=False):
381 382 383 384 385 386 387 388 389 390 391 392
    header = [
        "name",
        "class_name",
        "dtype",
        "shape",
        "nbits",
        "act_dim",
        "size",
        "size_cum",
        "percentage",
        "size_bar",
    ]
393 394 395
    if has_input:
        header.insert(4, "mean")
        header.insert(5, "std")
396 397 398 399
    logger.info(
        "activations stats: \n"
        + tabulate.tabulate(dict2table(activations, header=header))
    )
400 401 402 403 404 405


def print_summary(**kwargs):
    data = [["item", "value"]]
    data.extend(list(kwargs.items()))
    logger.info("summary\n" + tabulate.tabulate(data))
406 407


408
def module_stats(
409
    model: m.Module,
410 411 412 413 414 415
    inputs: Iterable[np.ndarray] = None,
    input_shapes: list = None,
    cal_params: bool = True,
    cal_flops: bool = True,
    cal_activations: bool = True,
    logging_to_stdout: bool = True,
416 417 418 419 420 421
    bar_length_max: int = 20,
):
    r"""
    Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.

    :param model: model that need to get stats info.
422 423 424 425 426 427
    :param inputs: user defined input data for running model and calculating stats, alternative with input_shapes.
    :param input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs.
    :param cal_params: whether calculate and record params size.
    :param cal_flops: whether calculate and record op flops.
    :param cal_activations: whether calculate and record op activations.
    :param logging_to_stdout: whether print all calculated statistic details.
428
    :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
429
    
430
    """
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
    has_inputs = False
    if inputs is not None:
        has_inputs = True
        if not isinstance(inputs, (tuple, list)):
            inputs = [inputs]
        inputs = [Tensor(input, dtype=np.float32) for input in inputs]
    else:
        if input_shapes:
            if not isinstance(input_shapes[0], tuple):
                input_shapes = [input_shapes]
            inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes]
        else:
            logger.error(
                "Inputs or input_shapes is required for running model and calculating stats.",
                exc_info=True,
            )
            return
    if not cal_activations:
        log_activations = False

451
    disable_receptive_field()
452

453
    def module_stats_hook(module, inputs, outputs, name=""):
454
        class_name = str(module.__class__).split(".")[-1].split("'")[0]
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
        if cal_flops:
            flops_stats = get_op_stats(module, inputs, outputs)
            if flops_stats is not None:
                flops_stats["name"] = name
                flops_stats["class_name"] = class_name
                flops.append(flops_stats)

        if cal_params:
            if hasattr(module, "weight") and module.weight is not None:
                w = module.weight
                param_stats = get_param_stats(w.numpy())
                param_stats["name"] = name + "-w"
                params.append(param_stats)

            if hasattr(module, "bias") and module.bias is not None:
                b = module.bias
                param_stats = get_param_stats(b.numpy())
                param_stats["name"] = name + "-b"
                params.append(param_stats)

        if cal_activations:
            if not isinstance(outputs, (tuple, list)):
                output = outputs.numpy()
            else:
                output = outputs[0].numpy()
            activation_stats = get_activation_stats(output, has_inputs)
            activation_stats["name"] = name
            activation_stats["class_name"] = class_name
            activations.append(activation_stats)
484 485 486 487

    params = []
    flops = []
    hooks = []
488 489 490
    activations = []
    total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
    stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
491 492 493 494

    for (name, module) in model.named_modules():
        if isinstance(module, hook_modules):
            hooks.append(
495
                module.register_forward_hook(partial(module_stats_hook, name=name))
496 497
            )

498
    with set_module_mode_safe(model, training=False) as model:
499 500
        model(*inputs)

501 502 503
    for h in hooks:
        h.remove()

504 505 506
    extra_info = {
        "#params": len(params),
    }
507 508 509 510 511
    (
        total_flops,
        total_param_dims,
        total_param_size,
        total_act_dims,
512
        total_act_size,
513 514
    ) = (0, 0, 0, 0, 0)

515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
    if cal_params:
        total_param_dims, total_param_size, params = sum_param_stats(
            params, bar_length_max
        )
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
        if logging_to_stdout:
            print_param_stats(params)

    if cal_flops:
        total_flops, flops = sum_op_stats(flops, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
        if logging_to_stdout:
            print_op_stats(flops)

    if cal_activations:
        total_act_dims, total_act_size, activations = sum_activations_stats(
            activations, bar_length_max
        )
        extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
        extra_info["total_act_size"] = sizeof_fmt(total_act_size)
        if logging_to_stdout:
            print_activations_stats(activations, has_inputs)

    if cal_flops and cal_params:
        extra_info["flops/param_size"] = "{:3.3f}".format(
            total_flops / total_param_size
        )
543

544 545
    print_summary(**extra_info)

546 547 548 549 550 551
    return (
        total_stats(
            param_size=total_param_size, flops=total_flops, act_size=total_act_size,
        ),
        stats_details(params=params, flops=flops, activations=activations),
    )