From a7ff580e5441a26e2ce8f407286db8075851380d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Aug 2020 18:07:18 +0800 Subject: [PATCH] feat(mge/utils): add net stats to calculate parameters and flops GitOrigin-RevId: a77f89e24bf10c7d3a0f79659f2a78382b38ce5a --- python_module/megengine/utils/net_stats.py | 279 +++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 python_module/megengine/utils/net_stats.py diff --git a/python_module/megengine/utils/net_stats.py b/python_module/megengine/utils/net_stats.py new file mode 100644 index 000000000..fa35d114b --- /dev/null +++ b/python_module/megengine/utils/net_stats.py @@ -0,0 +1,279 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from functools import partial + +import numpy as np +import tabulate + +import megengine as mge +import megengine._internal as mgb +import megengine.module as m +import megengine.module.qat as qatm +import megengine.module.quantized as qm + +try: + mge.logger.MegEngineLogFormatter.max_lines = float("inf") +except AttributeError as e: + raise ValueError("set logger max lines failed") + +logger = mge.get_logger(__name__) + + +CALC_FLOPS = {} + + +def _register_modules(*modules): + def callback(impl): + for module in modules: + CALC_FLOPS[module] = impl + return impl + + return callback + + +@_register_modules( + m.Conv2d, + m.ConvTranspose2d, + m.LocalConv2d, + qm.Conv2d, + qm.ConvRelu2d, + qm.ConvBn2d, + qm.ConvBnRelu2d, + qatm.Conv2d, + qatm.ConvRelu2d, + qatm.ConvBn2d, + qatm.ConvBnRelu2d, +) +def count_convNd(module, input, output): + bias = 1 if module.bias is not None else 0 + group = module.groups + ic = input[0].shape[1] + oc = output[0].shape[1] + goc = oc // group + gic = ic // group + N = output[0].shape[0] + HW = np.prod(output[0].shape[2:]) + # N x Cout x H x W x (Cin x Kw x Kh + bias) + return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) + + +@_register_modules(m.ConvTranspose2d) +def count_deconvNd(module, input, output): + return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size) + + +@_register_modules(m.Linear, qatm.Linear, qm.Linear) +def count_linear(module, input, output): + return np.prod(output[0].shape) * module.in_features + + +# does not need import qat and quantized module since they inherit from float module. +hook_modules = ( + m.Conv2d, + m.ConvTranspose2d, + m.LocalConv2d, + m.BatchNorm2d, + m.Linear, +) + + +def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=True): + def dict2table(list_of_dict, header): + table_data = [header] + for d in list_of_dict: + row = [] + for h in header: + v = "" + if h in d: + v = d[h] + row.append(v) + table_data.append(row) + return table_data + + def sizeof_fmt(num, suffix="B"): + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return "{:3.3f} {}{}".format(num, unit, suffix) + num /= 1024.0 + sign_str = "-" if num < 0 else "" + return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) + + def get_byteswidth(tensor): + dtype = tensor.dtype + if mgb.dtype.is_quantize(dtype): + return 1 + elif mgb.dtype.is_bfloat16(dtype): + return 2 + else: + return 4 + + def print_flops_stats(flops): + flops_list = [i["flops_num"] for i in flops] + max_flops_num = max(flops_list + [0]) + # calc total flops and set flops_cum + total_flops_num = 0 + for d in flops: + total_flops_num += int(d["flops_num"]) + d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") + + for i in flops: + f = i["flops_num"] + i["flops"] = sizeof_fmt(f, suffix="OPs") + r = i["ratio"] = f / total_flops_num + i["percentage"] = "{:.2f}%".format(r * 100) + bar_length = int(f / max_flops_num * bar_length_max) + i["bar"] = "#" * bar_length + + header = [ + "name", + "class_name", + "input_shapes", + "output_shapes", + "flops", + "flops_cum", + "percentage", + "bar", + ] + + total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") + total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) + flops.append( + dict(name="total", flops=total_flops_str, output_shapes=total_var_size) + ) + + logger.info( + "flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)) + ) + + return total_flops_num + + def print_params_stats(params): + total_param_dims, total_param_size = 0, 0 + for d in params: + total_param_dims += int(d["param_dim"]) + total_param_size += int(d["size"]) + d["size"] = sizeof_fmt(d["size"]) + d["size_cum"] = sizeof_fmt(total_param_size) + + for d in params: + ratio = d["param_dim"] / total_param_dims + d["ratio"] = ratio + d["percentage"] = "{:.2f}%".format(ratio * 100) + + # construct bar + max_ratio = max([d["ratio"] for d in params]) + for d in params: + bar_length = int(d["ratio"] / max_ratio * bar_length_max) + d["size_bar"] = "#" * bar_length + + param_size = sizeof_fmt(total_param_size) + params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) + + header = [ + "name", + "shape", + "mean", + "std", + "param_dim", + "bits", + "size", + "size_cum", + "percentage", + "size_bar", + ] + + logger.info( + "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) + ) + + return total_param_size + + def net_stats_hook(module, input, output, name=""): + class_name = str(module.__class__).split(".")[-1].split("'")[0] + + flops_fun = CALC_FLOPS.get(type(module)) + if callable(flops_fun): + flops_num = flops_fun(module, input, output) + + if not isinstance(output, (list, tuple)): + output = [output] + + flops.append( + dict( + name=name, + class_name=class_name, + input_shapes=[i.shape for i in input], + output_shapes=[o.shape for o in output], + flops_num=flops_num, + flops_cum=0, + ) + ) + + if hasattr(module, "weight") and module.weight is not None: + w = module.weight + value = w.numpy() + param_dim = np.prod(w.shape) + param_bytes = get_byteswidth(w) + params.append( + dict( + name=name + "-w", + shape=w.shape, + param_dim=param_dim, + bits=param_bytes * 8, + size=param_dim * param_bytes, + size_cum=0, + mean="{:.2g}".format(value.mean()), + std="{:.2g}".format(value.std()), + ) + ) + + if hasattr(module, "bias") and module.bias is not None: + b = module.bias + value = b.numpy() + param_dim = np.prod(b.shape) + param_bytes = get_byteswidth(b) + params.append( + dict( + name=name + "-b", + shape=b.shape, + param_dim=param_dim, + bits=param_bytes * 8, + size=param_dim * param_bytes, + size_cum=0, + mean="{:.2g}".format(value.mean()), + std="{:.2g}".format(value.std()), + ) + ) + + # multiple inputs to the network + if not isinstance(input_size[0], tuple): + input_size = [input_size] + + params = [] + flops = [] + hooks = [] + + for (name, module) in model.named_modules(): + if isinstance(module, hook_modules): + hooks.append( + module.register_forward_hook(partial(net_stats_hook, name=name)) + ) + + inputs = [mge.zeros(in_size, dtype=np.float32) for in_size in input_size] + model.eval() + model(*inputs) + for h in hooks: + h.remove() + + total_flops, total_params = 0, 0 + if log_params: + total_params = print_params_stats(params) + if log_flops: + total_flops = print_flops_stats(flops) + + return total_params, total_flops -- GitLab