提交 c3a1ac3d 编写于 作者: M Megvii Engine Team

feat(mge/tools): module_status-add-functions

BREAKING CHANGE:

GitOrigin-RevId: ced3da3a129713c652d93b73756b93273bf1cc9b
上级 05e4c826
......@@ -9,6 +9,7 @@
import argparse
import logging
import re
from collections import namedtuple
import numpy as np
......@@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import (
enable_receptive_field,
get_activation_stats,
get_op_stats,
get_param_stats,
print_activations_stats,
print_op_stats,
print_param_stats,
print_summary,
sizeof_fmt,
sum_activations_stats,
sum_op_stats,
sum_param_stats,
)
from megengine.utils.network import Network
......@@ -34,6 +40,7 @@ def visualize(
bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
log_activations: bool = True,
):
r"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
......@@ -44,6 +51,7 @@ def visualize(
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
"""
if log_path:
try:
......@@ -83,6 +91,10 @@ def visualize(
node_list = []
flops_list = []
params_list = []
activations_list = []
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
for node in graph.all_oprs:
if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx]
......@@ -124,6 +136,11 @@ def visualize(
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
acts = get_activation_stats(node_oup.numpy())
acts["name"] = node.name
acts["class_name"] = node.type
activations_list.append(acts)
if node.type == "ImmutableTensor":
param_stats = get_param_stats(node.numpy())
# add tensor size attr
......@@ -149,20 +166,36 @@ def visualize(
"#params": len(params_list),
}
total_flops, total_param_dims, total_param_size = 0, 0, 0
(
total_flops,
total_param_dims,
total_param_size,
total_act_dims,
total_param_size,
) = (0, 0, 0, 0, 0)
total_param_dims, total_param_size, params = sum_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_params:
total_param_dims, total_param_size = print_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
print_param_stats(params)
total_flops, flops = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_flops:
total_flops = print_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
print_op_stats(flops)
total_act_dims, total_act_size, activations = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if log_activations:
print_activations_stats(activations)
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size)
if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
......@@ -179,7 +212,12 @@ def visualize(
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)
return total_param_size, total_flops
return (
total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size,
),
stats_details(params=params, flops=flops, activations=activations),
)
def main():
......
......@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
from collections import namedtuple
from functools import partial
import numpy as np
......@@ -18,6 +18,8 @@ import megengine.module.quantized as qm
from megengine.core.tensor.dtype import get_dtype_bit
from megengine.functional.tensor import zeros
from .module_utils import set_module_mode_safe
try:
mge.logger.MegEngineLogFormatter.max_lines = float("inf")
except AttributeError as e:
......@@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs):
)
@register_flops(
m.batchnorm._BatchNorm, m.SyncBatchNorm, m.GroupNorm, m.LayerNorm, m.InstanceNorm,
)
def flops_norm(module: m.Linear, inputs, outputs):
return np.prod(inputs[0].shape) * 7
@register_flops(m.AvgPool2d, m.MaxPool2d)
def flops_pool(module: m.AvgPool2d, inputs, outputs):
return np.prod(outputs[0].shape) * (module.kernel_size ** 2)
@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d)
def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs):
stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1))
kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h
stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1))
kernel_w = inputs[0].shape[3] - (inputs[0].shape[3] - 1) * stride_w
return np.prod(outputs[0].shape) * kernel_h * kernel_w
@register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs):
bias = module.out_features if module.bias is not None else 0
......@@ -120,6 +143,12 @@ hook_modules = (
m.conv._ConvNd,
m.Linear,
m.BatchMatMulActivation,
m.batchnorm._BatchNorm,
m.LayerNorm,
m.GroupNorm,
m.InstanceNorm,
m.pooling._PoolNd,
m.adaptive_pooling._AdaptivePoolNd,
)
......@@ -137,12 +166,16 @@ def dict2table(list_of_dict, header):
def sizeof_fmt(num, suffix="B"):
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
if abs(num) < 1024.0:
if suffix == "B":
scale = 1024.0
units = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"]
else:
scale = 1000.0
units = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]
for unit in units:
if abs(num) < scale or unit == units[-1]:
return "{:3.3f} {}{}".format(num, unit, suffix)
num /= 1024.0
sign_str = "-" if num < 0 else ""
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)
num /= scale
def preprocess_receptive_field(module, inputs, outputs):
......@@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs):
def get_op_stats(module, inputs, outputs):
if not isinstance(outputs, tuple) and not isinstance(outputs, list):
outputs = (outputs,)
rst = {
"input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs],
......@@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs):
return
def print_op_stats(flops, bar_length_max=20):
def sum_op_stats(flops, bar_length_max=20):
max_flops_num = max([i["flops_num"] for i in flops] + [0])
total_flops_num = 0
for d in flops:
......@@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20):
d["bar"] = "#" * bar_length
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs")
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops
)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)
return total_flops_num, flops
def print_op_stats(flops):
header = [
"name",
"class_name",
......@@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20):
if _receptive_field_enabled:
header.insert(4, "receptive_field")
header.insert(5, "stride")
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops
)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)
logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)))
return total_flops_num
def get_param_stats(param: np.ndarray):
nbits = get_dtype_bit(param.dtype.name)
......@@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray):
}
def print_param_stats(params, bar_length_max=20):
def sum_param_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0
for d in params:
......@@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20):
param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))
return total_param_dims, total_param_size, params
def print_param_stats(params):
header = [
"name",
"dtype",
......@@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20):
"mean",
"std",
"param_dim",
"bits",
"nbits",
"size",
"size_cum",
"percentage",
"size_bar",
]
logger.info(
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
)
return total_param_dims, total_param_size
def get_activation_stats(output: np.ndarray):
out_shape = output.shape
activations_dtype = output.dtype
nbits = get_dtype_bit(activations_dtype.name)
act_dim = np.prod(out_shape)
act_size = act_dim * nbits // 8
return {
"dtype": activations_dtype,
"shape": out_shape,
"act_dim": act_dim,
"mean": "{:.3g}".format(output.mean()),
"std": "{:.3g}".format(output.std()),
"nbits": nbits,
"size": act_size,
}
def sum_activations_stats(activations, bar_length_max=20):
max_act_size = max([i["size"] for i in activations] + [0])
total_act_dims, total_act_size = 0, 0
for d in activations:
total_act_size += int(d["size"])
total_act_dims += int(d["act_dim"])
d["size_cum"] = sizeof_fmt(total_act_size)
for d in activations:
ratio = d["ratio"] = d["size"] / total_act_size
d["percentage"] = "{:.2f}%".format(ratio * 100)
bar_length = int(d["size"] / max_act_size * bar_length_max)
d["size_bar"] = "#" * bar_length
d["size"] = sizeof_fmt(d["size"])
act_size = sizeof_fmt(total_act_size)
activations.append(dict(name="total", act_dim=total_act_dims, size=act_size,))
return total_act_dims, total_act_size, activations
def print_activations_stats(activations):
header = [
"name",
"class_name",
"dtype",
"shape",
"mean",
"std",
"nbits",
"act_dim",
"size",
"size_cum",
"percentage",
"size_bar",
]
logger.info(
"activations stats: \n"
+ tabulate.tabulate(dict2table(activations, header=header))
)
def print_summary(**kwargs):
......@@ -294,25 +390,26 @@ def print_summary(**kwargs):
def module_stats(
model: m.Module,
input_size: int,
input_shapes: list,
bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
log_activations: bool = True,
):
r"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
:param model: model that need to get stats info.
:param input_size: size of input for running model and calculating stats.
:param input_shapes: shapes of inputs for running model and calculating stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
"""
disable_receptive_field()
def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_stats = get_op_stats(module, inputs, outputs)
if flops_stats is not None:
flops_stats["name"] = name
......@@ -331,38 +428,25 @@ def module_stats(
param_stats["name"] = name + "-b"
params.append(param_stats)
@contextlib.contextmanager
def adjust_stats(module, training=False):
"""Adjust module to training/eval mode temporarily.
Args:
module (M.Module): used module.
training (bool): training mode. True for train mode, False fro eval mode.
"""
def recursive_backup_stats(module, mode):
for m in module.modules():
# save prev status to _prev_training
m._prev_training = m.training
m.train(mode, recursive=False)
def recursive_recover_stats(module):
for m in module.modules():
# recover prev status and delete attribute
m.training = m._prev_training
delattr(m, "_prev_training")
recursive_backup_stats(module, mode=training)
yield module
recursive_recover_stats(module)
if not isinstance(outputs, tuple) or not isinstance(outputs, list):
output = outputs.numpy()
else:
output = outputs[0].numpy()
activation_stats = get_activation_stats(output)
activation_stats["name"] = name
activation_stats["class_name"] = class_name
activations.append(activation_stats)
# multiple inputs to the network
if not isinstance(input_size[0], tuple):
input_size = [input_size]
if not isinstance(input_shapes[0], tuple):
input_shapes = [input_shapes]
params = []
flops = []
hooks = []
activations = []
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
for (name, module) in model.named_modules():
if isinstance(module, hook_modules):
......@@ -370,8 +454,8 @@ def module_stats(
module.register_forward_hook(partial(module_stats_hook, name=name))
)
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]
with adjust_stats(model, training=False) as model:
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes]
with set_module_mode_safe(model, training=False) as model:
model(*inputs)
for h in hooks:
......@@ -380,19 +464,40 @@ def module_stats(
extra_info = {
"#params": len(params),
}
total_flops, total_param_dims, total_param_size = 0, 0, 0
(
total_flops,
total_param_dims,
total_param_size,
total_act_dims,
total_param_size,
) = (0, 0, 0, 0, 0)
total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_params:
total_param_dims, total_param_size = print_param_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
print_param_stats(params)
total_flops, flops = sum_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_flops:
total_flops = print_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
print_op_stats(flops)
total_act_dims, total_act_size, activations = sum_activations_stats(
activations, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if log_activations:
print_activations_stats(activations)
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size)
print_summary(**extra_info)
return total_param_size, total_flops
return (
total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size,
),
stats_details(params=params, flops=flops, activations=activations),
)
......@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
from collections import Iterable
from ..module import Sequential
......@@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value):
parent[key] = value
_access_structure(obj, key, callback=f)
@contextlib.contextmanager
def set_module_mode_safe(
module: Module, training: bool = False,
):
"""Adjust module to training/eval mode temporarily.
:param module: used module.
:param training: training (bool): training mode. True for train mode, False fro eval mode.
"""
backup_stats = {}
def recursive_backup_stats(module, mode):
for m in module.modules():
backup_stats[m] = m.training
m.train(mode, recursive=False)
def recursive_recover_stats(module):
for m in module.modules():
m.training = backup_stats.pop(m)
recursive_backup_stats(module, mode=training)
yield module
recursive_recover_stats(module)
import math
from copy import deepcopy
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
from megengine.core._trace_option import use_symbolic_shape
from megengine.utils.module_stats import module_stats
@pytest.mark.skipif(
use_symbolic_shape(), reason="This test do not support symbolic shape.",
)
def test_module_stats():
net = ResNet(BasicBlock, [2, 2, 2, 2])
input_shape = (1, 3, 224, 224)
total_stats, stats_details = module_stats(net, input_shape)
x1 = mge.tensor(np.zeros((1, 3, 224, 224)))
gt_flops, gt_acts = net.get_stats(x1)
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
gt_flops,
gt_acts,
)
class BasicBlock(M.Module):
expansion = 1
def __init__(
self,
in_channels,
channels,
stride=1,
groups=1,
base_width=64,
dilation=1,
norm=M.BatchNorm2d,
):
super().__init__()
self.tmp_in_channels = in_channels
self.tmp_channels = channels
self.stride = stride
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = M.Conv2d(
in_channels, channels, 3, stride, padding=dilation, bias=False
)
self.bn1 = norm(channels)
self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
self.bn2 = norm(channels)
self.downsample_id = M.Identity()
self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
self.downsample_norm = norm(channels)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
identity = self.downsample_id(identity)
else:
identity = self.downsample_conv(identity)
identity = self.downsample_norm(identity)
x += identity
x = F.relu(x)
return x
def get_stats(self, x):
activations, flops = 0, 0
identity = x
in_x = deepcopy(x)
x = self.conv1(x)
tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
activations += tmp_acts
flops += tmp_flops
in_x = deepcopy(x)
x = self.bn1(x)
tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
activations += tmp_acts
flops += tmp_flops
x = F.relu(x)
in_x = deepcopy(x)
x = self.conv2(x)
tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
activations += tmp_acts
flops += tmp_flops
in_x = deepcopy(x)
x = self.bn2(x)
tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
activations += tmp_acts
flops += tmp_flops
if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
identity = self.downsample_id(identity)
else:
in_x = deepcopy(identity)
identity = self.downsample_conv(identity)
tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
activations += tmp_acts
flops += tmp_flops
in_x = deepcopy(identity)
identity = self.downsample_norm(identity)
tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
activations += tmp_acts
flops += tmp_flops
x += identity
x = F.relu(x)
return x, flops, activations
class ResNet(M.Module):
def __init__(
self,
block,
layers=[2, 2, 2, 2],
num_classes=1000,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm=M.BatchNorm2d,
):
super().__init__()
self.in_channels = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = M.Conv2d(
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = norm(self.in_channels)
self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1_0 = BasicBlock(
self.in_channels,
64,
stride=1,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm=M.BatchNorm2d,
)
self.layer1_1 = BasicBlock(
self.in_channels,
64,
stride=1,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm=M.BatchNorm2d,
)
self.layer2_0 = BasicBlock(64, 128, stride=2)
self.layer2_1 = BasicBlock(128, 128)
self.layer3_0 = BasicBlock(128, 256, stride=2)
self.layer3_1 = BasicBlock(256, 256)
self.layer4_0 = BasicBlock(256, 512, stride=2)
self.layer4_1 = BasicBlock(512, 512)
self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
self.layer2 = self._make_layer(
block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
)
self.layer3 = self._make_layer(
block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
)
self.layer4 = self._make_layer(
block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
)
self.fc = M.Linear(512, num_classes)
for m in self.modules():
if isinstance(m, M.Conv2d):
M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
M.init.uniform_(m.bias, -bound, bound)
elif isinstance(m, M.BatchNorm2d):
M.init.ones_(m.weight)
M.init.zeros_(m.bias)
elif isinstance(m, M.Linear):
M.init.msra_uniform_(m.weight, a=math.sqrt(5))
if m.bias is not None:
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
M.init.uniform_(m.bias, -bound, bound)
if zero_init_residual:
for m in self.modules():
M.init.zeros_(m.bn2.weight)
def _make_layer(
self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
):
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
layers = []
layers.append(
block(
self.in_channels,
channels,
stride,
groups=self.groups,
base_width=self.base_width,
dilation=previous_dilation,
norm=norm,
)
)
self.in_channels = channels * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.in_channels,
channels,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm=norm,
)
)
return M.Sequential(*layers)
def extract_features(self, x):
outputs = {}
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.maxpool(x)
outputs["stem"] = x
x = self.layer1(x)
outputs["res2"] = x
x = self.layer2(x)
outputs["res3"] = x
x = self.layer3(x)
outputs["res4"] = x
x = self.layer4(x)
outputs["res5"] = x
return outputs
def forward(self, x):
x = self.extract_features(x)["res5"]
x = F.avg_pool2d(x, 7)
x = F.flatten(x, 1)
x = self.fc(x)
return x
def get_stats(self, x):
flops, activations = 0, 0
in_x = deepcopy(x)
x = self.conv1(x)
tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
activations += tmp_acts
flops += tmp_flops
in_x = deepcopy(x)
x = self.bn1(x)
tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
activations += tmp_acts
flops += tmp_flops
x = F.relu(x)
in_x = deepcopy(x)
x = self.maxpool(x)
tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x = F.avg_pool2d(x, 7)
x = F.flatten(x, 1)
in_x = deepcopy(x)
x = self.fc(x)
tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
activations += tmp_acts
flops += tmp_flops
return flops, activations
def cal_conv_stats(module, input, output):
bias = 1 if module.bias is not None else 0
flops = np.prod(output[0].shape) * (
module.in_channels // module.groups * np.prod(module.kernel_size) + bias
)
acts = np.prod(output[0].shape)
return flops, acts
def cal_norm_stats(module, input, output):
return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
def cal_linear_stats(module, inputs, outputs):
bias = module.out_features if module.bias is not None else 0
return (
np.prod(outputs[0].shape) * module.in_features + bias,
np.prod(outputs[0].shape),
)
def cal_pool_stats(module, inputs, outputs):
return (
np.prod(outputs[0].shape) * (module.kernel_size ** 2),
np.prod(outputs[0].shape),
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册