提交 a7ff580e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/utils): add net stats to calculate parameters and flops

GitOrigin-RevId: a77f89e24bf10c7d3a0f79659f2a78382b38ce5a
上级 96ec586d
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
import numpy as np
import tabulate
import megengine as mge
import megengine._internal as mgb
import megengine.module as m
import megengine.module.qat as qatm
import megengine.module.quantized as qm
try:
mge.logger.MegEngineLogFormatter.max_lines = float("inf")
except AttributeError as e:
raise ValueError("set logger max lines failed")
logger = mge.get_logger(__name__)
CALC_FLOPS = {}
def _register_modules(*modules):
def callback(impl):
for module in modules:
CALC_FLOPS[module] = impl
return impl
return callback
@_register_modules(
m.Conv2d,
m.ConvTranspose2d,
m.LocalConv2d,
qm.Conv2d,
qm.ConvRelu2d,
qm.ConvBn2d,
qm.ConvBnRelu2d,
qatm.Conv2d,
qatm.ConvRelu2d,
qatm.ConvBn2d,
qatm.ConvBnRelu2d,
)
def count_convNd(module, input, output):
bias = 1 if module.bias is not None else 0
group = module.groups
ic = input[0].shape[1]
oc = output[0].shape[1]
goc = oc // group
gic = ic // group
N = output[0].shape[0]
HW = np.prod(output[0].shape[2:])
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias)
@_register_modules(m.ConvTranspose2d)
def count_deconvNd(module, input, output):
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size)
@_register_modules(m.Linear, qatm.Linear, qm.Linear)
def count_linear(module, input, output):
return np.prod(output[0].shape) * module.in_features
# does not need import qat and quantized module since they inherit from float module.
hook_modules = (
m.Conv2d,
m.ConvTranspose2d,
m.LocalConv2d,
m.BatchNorm2d,
m.Linear,
)
def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=True):
def dict2table(list_of_dict, header):
table_data = [header]
for d in list_of_dict:
row = []
for h in header:
v = ""
if h in d:
v = d[h]
row.append(v)
table_data.append(row)
return table_data
def sizeof_fmt(num, suffix="B"):
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
if abs(num) < 1024.0:
return "{:3.3f} {}{}".format(num, unit, suffix)
num /= 1024.0
sign_str = "-" if num < 0 else ""
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)
def get_byteswidth(tensor):
dtype = tensor.dtype
if mgb.dtype.is_quantize(dtype):
return 1
elif mgb.dtype.is_bfloat16(dtype):
return 2
else:
return 4
def print_flops_stats(flops):
flops_list = [i["flops_num"] for i in flops]
max_flops_num = max(flops_list + [0])
# calc total flops and set flops_cum
total_flops_num = 0
for d in flops:
total_flops_num += int(d["flops_num"])
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")
for i in flops:
f = i["flops_num"]
i["flops"] = sizeof_fmt(f, suffix="OPs")
r = i["ratio"] = f / total_flops_num
i["percentage"] = "{:.2f}%".format(r * 100)
bar_length = int(f / max_flops_num * bar_length_max)
i["bar"] = "#" * bar_length
header = [
"name",
"class_name",
"input_shapes",
"output_shapes",
"flops",
"flops_cum",
"percentage",
"bar",
]
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)
logger.info(
"flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))
)
return total_flops_num
def print_params_stats(params):
total_param_dims, total_param_size = 0, 0
for d in params:
total_param_dims += int(d["param_dim"])
total_param_size += int(d["size"])
d["size"] = sizeof_fmt(d["size"])
d["size_cum"] = sizeof_fmt(total_param_size)
for d in params:
ratio = d["param_dim"] / total_param_dims
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)
# construct bar
max_ratio = max([d["ratio"] for d in params])
for d in params:
bar_length = int(d["ratio"] / max_ratio * bar_length_max)
d["size_bar"] = "#" * bar_length
param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))
header = [
"name",
"shape",
"mean",
"std",
"param_dim",
"bits",
"size",
"size_cum",
"percentage",
"size_bar",
]
logger.info(
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
)
return total_param_size
def net_stats_hook(module, input, output, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_fun = CALC_FLOPS.get(type(module))
if callable(flops_fun):
flops_num = flops_fun(module, input, output)
if not isinstance(output, (list, tuple)):
output = [output]
flops.append(
dict(
name=name,
class_name=class_name,
input_shapes=[i.shape for i in input],
output_shapes=[o.shape for o in output],
flops_num=flops_num,
flops_cum=0,
)
)
if hasattr(module, "weight") and module.weight is not None:
w = module.weight
value = w.numpy()
param_dim = np.prod(w.shape)
param_bytes = get_byteswidth(w)
params.append(
dict(
name=name + "-w",
shape=w.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(value.mean()),
std="{:.2g}".format(value.std()),
)
)
if hasattr(module, "bias") and module.bias is not None:
b = module.bias
value = b.numpy()
param_dim = np.prod(b.shape)
param_bytes = get_byteswidth(b)
params.append(
dict(
name=name + "-b",
shape=b.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(value.mean()),
std="{:.2g}".format(value.std()),
)
)
# multiple inputs to the network
if not isinstance(input_size[0], tuple):
input_size = [input_size]
params = []
flops = []
hooks = []
for (name, module) in model.named_modules():
if isinstance(module, hook_modules):
hooks.append(
module.register_forward_hook(partial(net_stats_hook, name=name))
)
inputs = [mge.zeros(in_size, dtype=np.float32) for in_size in input_size]
model.eval()
model(*inputs)
for h in hooks:
h.remove()
total_flops, total_params = 0, 0
if log_params:
total_params = print_params_stats(params)
if log_flops:
total_flops = print_flops_stats(flops)
return total_params, total_flops
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册