# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from functools import partial import numpy as np import tabulate import megengine as mge import megengine.module as m import megengine.module.qat as qatm import megengine.module.quantized as qm from megengine.core.tensor.dtype import get_dtype_bit from megengine.functional.tensor import zeros try: mge.logger.MegEngineLogFormatter.max_lines = float("inf") except AttributeError as e: raise ValueError("set logger max lines failed") logger = mge.get_logger(__name__) logger.setLevel("INFO") _calc_flops_dict = {} _calc_receptive_field_dict = {} def _receptive_field_fallback(module, inputs, outputs): assert not hasattr(module, "_rf") assert not hasattr(module, "_stride") if len(inputs) == 0: # TODO: support other dimension module._rf = (1, 1) module._stride = (1, 1) return module._rf, module._stride rf, stride = preprocess_receptive_field(module, inputs, outputs) module._rf = rf module._stride = stride return rf, stride # key tuple, impl_dict, fallback _iter_list = [ ("flops_num", _calc_flops_dict, None), ( ("receptive_field", "stride"), _calc_receptive_field_dict, _receptive_field_fallback, ), ] def _register_dict(*modules, dict=None): def callback(impl): for module in modules: dict[module] = impl return impl return callback def register_flops(*modules): return _register_dict(*modules, dict=_calc_flops_dict) def register_receptive_field(*modules): return _register_dict(*modules, dict=_calc_receptive_field_dict) @register_flops( m.Conv1d, m.Conv2d, m.Conv3d, ) def flops_convNd(module: m.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 group = module.groups ic = inputs[0].shape[1] oc = outputs[0].shape[1] goc = oc // group gic = ic // group N = outputs[0].shape[0] HW = np.prod(outputs[0].shape[2:]) # N x Cout x H x W x (Cin x Kw x Kh + bias) return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) @register_flops(m.ConvTranspose2d) def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs): return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size) @register_flops(m.Linear) def flops_linear(module: m.Linear, inputs, outputs): bias = 1 if module.bias is not None else 0 return np.prod(outputs[0].shape) * module.in_features @register_flops(m.BatchMatMulActivation) def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): bias = 1 if module.bias is not None else 0 x = inputs[0] w = module.weight batch_size = x.shape[0] n, p = x.shape[1:] _, m = w.shape[1:] return n * (p + bias) * m * batch_size # does not need import qat and quantized module since they inherit from float module. hook_modules = ( m.conv._ConvNd, m.Linear, m.BatchMatMulActivation, ) def dict2table(list_of_dict, header): table_data = [header] for d in list_of_dict: row = [] for h in header: v = "" if h in d: v = d[h] row.append(v) table_data.append(row) return table_data def sizeof_fmt(num, suffix="B"): for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if abs(num) < 1024.0: return "{:3.3f} {}{}".format(num, unit, suffix) num /= 1024.0 sign_str = "-" if num < 0 else "" return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) def preprocess_receptive_field(module, inputs, outputs): # TODO: support other dimensions pre_rf = ( max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), max(i.owner._rf[1] for i in inputs), ) pre_stride = ( max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), max(i.owner._stride[1] for i in inputs), ) return pre_rf, pre_stride def get_flops_stats(module, inputs, outputs): rst = { "input_shapes": [i.shape for i in inputs], "output_shapes": [o.shape for o in outputs], } valid_flag = False for key, _dict, fallback in _iter_list: for _type in _dict: if isinstance(module, _type): value = _dict[_type](module, inputs, outputs) valid_flag = True break else: if fallback is not None: value = fallback(module, inputs, outputs) continue if isinstance(key, tuple): assert isinstance(value, tuple) for k, v in zip(key, value): rst[k] = v else: rst[key] = value if valid_flag: return rst else: return None return def print_flops_stats(flops, bar_length_max=20): max_flops_num = max([i["flops_num"] for i in flops] + [0]) total_flops_num = 0 for d in flops: total_flops_num += int(d["flops_num"]) d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") for d in flops: ratio = d["ratio"] = d["flops_num"] / total_flops_num d["percentage"] = "{:.2f}%".format(ratio * 100) bar_length = int(d["flops_num"] / max_flops_num * bar_length_max) d["bar"] = "#" * bar_length d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") header = [ "name", "class_name", "input_shapes", "output_shapes", "receptive_field", "stride", "flops", "flops_cum", "percentage", "bar", ] total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") total_var_size = sum( sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops ) flops.append( dict(name="total", flops=total_flops_str, output_shapes=total_var_size) ) logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) return total_flops_num def get_param_stats(param: np.ndarray): nbits = get_dtype_bit(param.dtype.name) shape = param.shape param_dim = np.prod(param.shape) param_size = param_dim * nbits // 8 return { "dtype": param.dtype, "shape": shape, "mean": "{:.3g}".format(param.mean()), "std": "{:.3g}".format(param.std()), "param_dim": param_dim, "nbits": nbits, "size": param_size, } def print_params_stats(params, bar_length_max=20): max_size = max([d["size"] for d in params] + [0]) total_param_dims, total_param_size = 0, 0 for d in params: total_param_dims += int(d["param_dim"]) total_param_size += int(d["size"]) d["size_cum"] = sizeof_fmt(total_param_size) for d in params: ratio = d["size"] / total_param_size d["ratio"] = ratio d["percentage"] = "{:.2f}%".format(ratio * 100) bar_length = int(d["size"] / max_size * bar_length_max) d["size_bar"] = "#" * bar_length d["size"] = sizeof_fmt(d["size"]) param_size = sizeof_fmt(total_param_size) params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) header = [ "name", "dtype", "shape", "mean", "std", "param_dim", "bits", "size", "size_cum", "percentage", "size_bar", ] logger.info( "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) ) return total_param_dims, total_param_size def print_summary(**kwargs): data = [["item", "value"]] data.extend(list(kwargs.items())) logger.info("summary\n" + tabulate.tabulate(data)) def module_stats( model: m.Module, input_size: int, bar_length_max: int = 20, log_params: bool = True, log_flops: bool = True, ): r""" Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. :param model: model that need to get stats info. :param input_size: size of input for running model and calculating stats. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. """ def module_stats_hook(module, inputs, outputs, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] flops_stats = get_flops_stats(module, inputs, outputs) if flops_stats is not None: flops_stats["name"] = name flops_stats["class_name"] = class_name flops.append(flops_stats) if hasattr(module, "weight") and module.weight is not None: w = module.weight param_stats = get_param_stats(w.numpy()) param_stats["name"] = name + "-w" params.append(param_stats) if hasattr(module, "bias") and module.bias is not None: b = module.bias param_stats = get_param_stats(b.numpy()) param_stats["name"] = name + "-b" params.append(param_stats) # multiple inputs to the network if not isinstance(input_size[0], tuple): input_size = [input_size] params = [] flops = [] hooks = [] for (name, module) in model.named_modules(): if isinstance(module, hook_modules): hooks.append( module.register_forward_hook(partial(module_stats_hook, name=name)) ) inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] model.eval() model(*inputs) for h in hooks: h.remove() extra_info = { "#params": len(params), } total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: total_param_dims, total_param_size = print_params_stats(params, bar_length_max) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: total_flops = print_flops_stats(flops, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_params and log_flops: extra_info["flops/param_size"] = "{:3.3f}".format( total_flops / total_param_size ) print_summary(**extra_info) return total_param_size, total_flops