提交 c3a1ac3d 编写于 作者: M Megvii Engine Team

feat(mge/tools): module_status-add-functions

BREAKING CHANGE:

GitOrigin-RevId: ced3da3a129713c652d93b73756b93273bf1cc9b
上级 05e4c826
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import argparse import argparse
import logging import logging
import re import re
from collections import namedtuple
import numpy as np import numpy as np
...@@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize ...@@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import ( from megengine.utils.module_stats import (
enable_receptive_field, enable_receptive_field,
get_activation_stats,
get_op_stats, get_op_stats,
get_param_stats, get_param_stats,
print_activations_stats,
print_op_stats, print_op_stats,
print_param_stats, print_param_stats,
print_summary, print_summary,
sizeof_fmt, sizeof_fmt,
sum_activations_stats,
sum_op_stats,
sum_param_stats,
) )
from megengine.utils.network import Network from megengine.utils.network import Network
...@@ -34,6 +40,7 @@ def visualize( ...@@ -34,6 +40,7 @@ def visualize(
bar_length_max: int = 20, bar_length_max: int = 20,
log_params: bool = True, log_params: bool = True,
log_flops: bool = True, log_flops: bool = True,
log_activations: bool = True,
): ):
r""" r"""
Load megengine dumped model and visualize graph structure with tensorboard log files. Load megengine dumped model and visualize graph structure with tensorboard log files.
...@@ -44,6 +51,7 @@ def visualize( ...@@ -44,6 +51,7 @@ def visualize(
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size. :param log_params: whether print and record params size.
:param log_flops: whether print and record op flops. :param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
""" """
if log_path: if log_path:
try: try:
...@@ -83,6 +91,10 @@ def visualize( ...@@ -83,6 +91,10 @@ def visualize(
node_list = [] node_list = []
flops_list = [] flops_list = []
params_list = [] params_list = []
activations_list = []
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
for node in graph.all_oprs: for node in graph.all_oprs:
if hasattr(node, "output_idx"): if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx] node_oup = node.outputs[node.output_idx]
...@@ -124,6 +136,11 @@ def visualize( ...@@ -124,6 +136,11 @@ def visualize(
flops_stats["class_name"] = node.type flops_stats["class_name"] = node.type
flops_list.append(flops_stats) flops_list.append(flops_stats)
acts = get_activation_stats(node_oup.numpy())
acts["name"] = node.name
acts["class_name"] = node.type
activations_list.append(acts)
if node.type == "ImmutableTensor": if node.type == "ImmutableTensor":
param_stats = get_param_stats(node.numpy()) param_stats = get_param_stats(node.numpy())
# add tensor size attr # add tensor size attr
...@@ -149,20 +166,36 @@ def visualize( ...@@ -149,20 +166,36 @@ def visualize(
"#params": len(params_list), "#params": len(params_list),
} }
total_flops, total_param_dims, total_param_size = 0, 0, 0 (
total_flops,
total_param_dims,
total_param_size,
total_act_dims,
total_param_size,
) = (0, 0, 0, 0, 0)
total_param_dims, total_param_size, params = sum_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_params: if log_params:
total_param_dims, total_param_size = print_param_stats( print_param_stats(params)
params_list, bar_length_max
) total_flops, flops = sum_op_stats(flops_list, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops: if log_flops:
total_flops = print_op_stats(flops_list, bar_length_max) print_op_stats(flops)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops: total_act_dims, total_act_size, activations = sum_activations_stats(
extra_info["flops/param_size"] = "{:3.3f}".format( activations_list, bar_length_max
total_flops / total_param_size )
) extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if log_activations:
print_activations_stats(activations)
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size)
if log_path: if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
...@@ -179,7 +212,12 @@ def visualize( ...@@ -179,7 +212,12 @@ def visualize(
# FIXME: remove this after resolving "span dist too large" warning # FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level) _imperative_rt_logger.set_log_level(old_level)
return total_param_size, total_flops return (
total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size,
),
stats_details(params=params, flops=flops, activations=activations),
)
def main(): def main():
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib from collections import namedtuple
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -18,6 +18,8 @@ import megengine.module.quantized as qm ...@@ -18,6 +18,8 @@ import megengine.module.quantized as qm
from megengine.core.tensor.dtype import get_dtype_bit from megengine.core.tensor.dtype import get_dtype_bit
from megengine.functional.tensor import zeros from megengine.functional.tensor import zeros
from .module_utils import set_module_mode_safe
try: try:
mge.logger.MegEngineLogFormatter.max_lines = float("inf") mge.logger.MegEngineLogFormatter.max_lines = float("inf")
except AttributeError as e: except AttributeError as e:
...@@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs): ...@@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs):
) )
@register_flops(
m.batchnorm._BatchNorm, m.SyncBatchNorm, m.GroupNorm, m.LayerNorm, m.InstanceNorm,
)
def flops_norm(module: m.Linear, inputs, outputs):
return np.prod(inputs[0].shape) * 7
@register_flops(m.AvgPool2d, m.MaxPool2d)
def flops_pool(module: m.AvgPool2d, inputs, outputs):
return np.prod(outputs[0].shape) * (module.kernel_size ** 2)
@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d)
def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs):
stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1))
kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h
stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1))
kernel_w = inputs[0].shape[3] - (inputs[0].shape[3] - 1) * stride_w
return np.prod(outputs[0].shape) * kernel_h * kernel_w
@register_flops(m.Linear) @register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs): def flops_linear(module: m.Linear, inputs, outputs):
bias = module.out_features if module.bias is not None else 0 bias = module.out_features if module.bias is not None else 0
...@@ -120,6 +143,12 @@ hook_modules = ( ...@@ -120,6 +143,12 @@ hook_modules = (
m.conv._ConvNd, m.conv._ConvNd,
m.Linear, m.Linear,
m.BatchMatMulActivation, m.BatchMatMulActivation,
m.batchnorm._BatchNorm,
m.LayerNorm,
m.GroupNorm,
m.InstanceNorm,
m.pooling._PoolNd,
m.adaptive_pooling._AdaptivePoolNd,
) )
...@@ -137,12 +166,16 @@ def dict2table(list_of_dict, header): ...@@ -137,12 +166,16 @@ def dict2table(list_of_dict, header):
def sizeof_fmt(num, suffix="B"): def sizeof_fmt(num, suffix="B"):
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if suffix == "B":
if abs(num) < 1024.0: scale = 1024.0
units = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"]
else:
scale = 1000.0
units = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]
for unit in units:
if abs(num) < scale or unit == units[-1]:
return "{:3.3f} {}{}".format(num, unit, suffix) return "{:3.3f} {}{}".format(num, unit, suffix)
num /= 1024.0 num /= scale
sign_str = "-" if num < 0 else ""
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)
def preprocess_receptive_field(module, inputs, outputs): def preprocess_receptive_field(module, inputs, outputs):
...@@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs): ...@@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs):
def get_op_stats(module, inputs, outputs): def get_op_stats(module, inputs, outputs):
if not isinstance(outputs, tuple) and not isinstance(outputs, list):
outputs = (outputs,)
rst = { rst = {
"input_shapes": [i.shape for i in inputs], "input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs], "output_shapes": [o.shape for o in outputs],
...@@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs): ...@@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs):
return return
def print_op_stats(flops, bar_length_max=20): def sum_op_stats(flops, bar_length_max=20):
max_flops_num = max([i["flops_num"] for i in flops] + [0]) max_flops_num = max([i["flops_num"] for i in flops] + [0])
total_flops_num = 0 total_flops_num = 0
for d in flops: for d in flops:
...@@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20): ...@@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20):
d["bar"] = "#" * bar_length d["bar"] = "#" * bar_length
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs")
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops
)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)
return total_flops_num, flops
def print_op_stats(flops):
header = [ header = [
"name", "name",
"class_name", "class_name",
...@@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20): ...@@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20):
if _receptive_field_enabled: if _receptive_field_enabled:
header.insert(4, "receptive_field") header.insert(4, "receptive_field")
header.insert(5, "stride") header.insert(5, "stride")
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops
)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)
logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)))
return total_flops_num
def get_param_stats(param: np.ndarray): def get_param_stats(param: np.ndarray):
nbits = get_dtype_bit(param.dtype.name) nbits = get_dtype_bit(param.dtype.name)
...@@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray): ...@@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray):
} }
def print_param_stats(params, bar_length_max=20): def sum_param_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0]) max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0 total_param_dims, total_param_size = 0, 0
for d in params: for d in params:
...@@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20): ...@@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20):
param_size = sizeof_fmt(total_param_size) param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))
return total_param_dims, total_param_size, params
def print_param_stats(params):
header = [ header = [
"name", "name",
"dtype", "dtype",
...@@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20): ...@@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20):
"mean", "mean",
"std", "std",
"param_dim", "param_dim",
"bits", "nbits",
"size", "size",
"size_cum", "size_cum",
"percentage", "percentage",
"size_bar", "size_bar",
] ]
logger.info( logger.info(
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) "param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
) )
return total_param_dims, total_param_size
def get_activation_stats(output: np.ndarray):
out_shape = output.shape
activations_dtype = output.dtype
nbits = get_dtype_bit(activations_dtype.name)
act_dim = np.prod(out_shape)
act_size = act_dim * nbits // 8
return {
"dtype": activations_dtype,
"shape": out_shape,
"act_dim": act_dim,
"mean": "{:.3g}".format(output.mean()),
"std": "{:.3g}".format(output.std()),
"nbits": nbits,
"size": act_size,
}
def sum_activations_stats(activations, bar_length_max=20):
max_act_size = max([i["size"] for i in activations] + [0])
total_act_dims, total_act_size = 0, 0
for d in activations:
total_act_size += int(d["size"])
total_act_dims += int(d["act_dim"])
d["size_cum"] = sizeof_fmt(total_act_size)
for d in activations:
ratio = d["ratio"] = d["size"] / total_act_size
d["percentage"] = "{:.2f}%".format(ratio * 100)
bar_length = int(d["size"] / max_act_size * bar_length_max)
d["size_bar"] = "#" * bar_length
d["size"] = sizeof_fmt(d["size"])
act_size = sizeof_fmt(total_act_size)
activations.append(dict(name="total", act_dim=total_act_dims, size=act_size,))
return total_act_dims, total_act_size, activations
def print_activations_stats(activations):
header = [
"name",
"class_name",
"dtype",
"shape",
"mean",
"std",
"nbits",
"act_dim",
"size",
"size_cum",
"percentage",
"size_bar",
]
logger.info(
"activations stats: \n"
+ tabulate.tabulate(dict2table(activations, header=header))
)
def print_summary(**kwargs): def print_summary(**kwargs):
...@@ -294,25 +390,26 @@ def print_summary(**kwargs): ...@@ -294,25 +390,26 @@ def print_summary(**kwargs):
def module_stats( def module_stats(
model: m.Module, model: m.Module,
input_size: int, input_shapes: list,
bar_length_max: int = 20, bar_length_max: int = 20,
log_params: bool = True, log_params: bool = True,
log_flops: bool = True, log_flops: bool = True,
log_activations: bool = True,
): ):
r""" r"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
:param model: model that need to get stats info. :param model: model that need to get stats info.
:param input_size: size of input for running model and calculating stats. :param input_shapes: shapes of inputs for running model and calculating stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size. :param log_params: whether print and record params size.
:param log_flops: whether print and record op flops. :param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
""" """
disable_receptive_field() disable_receptive_field()
def module_stats_hook(module, inputs, outputs, name=""): def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_stats = get_op_stats(module, inputs, outputs) flops_stats = get_op_stats(module, inputs, outputs)
if flops_stats is not None: if flops_stats is not None:
flops_stats["name"] = name flops_stats["name"] = name
...@@ -331,38 +428,25 @@ def module_stats( ...@@ -331,38 +428,25 @@ def module_stats(
param_stats["name"] = name + "-b" param_stats["name"] = name + "-b"
params.append(param_stats) params.append(param_stats)
@contextlib.contextmanager if not isinstance(outputs, tuple) or not isinstance(outputs, list):
def adjust_stats(module, training=False): output = outputs.numpy()
"""Adjust module to training/eval mode temporarily. else:
output = outputs[0].numpy()
Args: activation_stats = get_activation_stats(output)
module (M.Module): used module. activation_stats["name"] = name
training (bool): training mode. True for train mode, False fro eval mode. activation_stats["class_name"] = class_name
""" activations.append(activation_stats)
def recursive_backup_stats(module, mode):
for m in module.modules():
# save prev status to _prev_training
m._prev_training = m.training
m.train(mode, recursive=False)
def recursive_recover_stats(module):
for m in module.modules():
# recover prev status and delete attribute
m.training = m._prev_training
delattr(m, "_prev_training")
recursive_backup_stats(module, mode=training)
yield module
recursive_recover_stats(module)
# multiple inputs to the network # multiple inputs to the network
if not isinstance(input_size[0], tuple): if not isinstance(input_shapes[0], tuple):
input_size = [input_size] input_shapes = [input_shapes]
params = [] params = []
flops = [] flops = []
hooks = [] hooks = []
activations = []
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
for (name, module) in model.named_modules(): for (name, module) in model.named_modules():
if isinstance(module, hook_modules): if isinstance(module, hook_modules):
...@@ -370,8 +454,8 @@ def module_stats( ...@@ -370,8 +454,8 @@ def module_stats(
module.register_forward_hook(partial(module_stats_hook, name=name)) module.register_forward_hook(partial(module_stats_hook, name=name))
) )
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes]
with adjust_stats(model, training=False) as model: with set_module_mode_safe(model, training=False) as model:
model(*inputs) model(*inputs)
for h in hooks: for h in hooks:
...@@ -380,19 +464,40 @@ def module_stats( ...@@ -380,19 +464,40 @@ def module_stats(
extra_info = { extra_info = {
"#params": len(params), "#params": len(params),
} }
total_flops, total_param_dims, total_param_size = 0, 0, 0 (
total_flops,
total_param_dims,
total_param_size,
total_act_dims,
total_param_size,
) = (0, 0, 0, 0, 0)
total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_params: if log_params:
total_param_dims, total_param_size = print_param_stats(params, bar_length_max) print_param_stats(params)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size) total_flops, flops = sum_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_flops: if log_flops:
total_flops = print_op_stats(flops, bar_length_max) print_op_stats(flops)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops: total_act_dims, total_act_size, activations = sum_activations_stats(
extra_info["flops/param_size"] = "{:3.3f}".format( activations, bar_length_max
total_flops / total_param_size )
) extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if log_activations:
print_activations_stats(activations)
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size)
print_summary(**extra_info) print_summary(**extra_info)
return total_param_size, total_flops return (
total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size,
),
stats_details(params=params, flops=flops, activations=activations),
)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
from collections import Iterable from collections import Iterable
from ..module import Sequential from ..module import Sequential
...@@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value): ...@@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value):
parent[key] = value parent[key] = value
_access_structure(obj, key, callback=f) _access_structure(obj, key, callback=f)
@contextlib.contextmanager
def set_module_mode_safe(
module: Module, training: bool = False,
):
"""Adjust module to training/eval mode temporarily.
:param module: used module.
:param training: training (bool): training mode. True for train mode, False fro eval mode.
"""
backup_stats = {}
def recursive_backup_stats(module, mode):
for m in module.modules():
backup_stats[m] = m.training
m.train(mode, recursive=False)
def recursive_recover_stats(module):
for m in module.modules():
m.training = backup_stats.pop(m)
recursive_backup_stats(module, mode=training)
yield module
recursive_recover_stats(module)
import math
from copy import deepcopy
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
from megengine.core._trace_option import use_symbolic_shape
from megengine.utils.module_stats import module_stats
@pytest.mark.skipif(
use_symbolic_shape(), reason="This test do not support symbolic shape.",
)
def test_module_stats():
net = ResNet(BasicBlock, [2, 2, 2, 2])
input_shape = (1, 3, 224, 224)
total_stats, stats_details = module_stats(net, input_shape)
x1 = mge.tensor(np.zeros((1, 3, 224, 224)))
gt_flops, gt_acts = net.get_stats(x1)
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
gt_flops,
gt_acts,
)
class BasicBlock(M.Module):
expansion = 1
def __init__(
self,
in_channels,
channels,
stride=1,
groups=1,
base_width=64,
dilation=1,
norm=M.BatchNorm2d,
):
super().__init__()
self.tmp_in_channels = in_channels
self.tmp_channels = channels
self.stride = stride
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = M.Conv2d(
in_channels, channels, 3, stride, padding=dilation, bias=False
)
self.bn1 = norm(channels)
self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
self.bn2 = norm(channels)
self.downsample_id = M.Identity()
self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
self.downsample_norm = norm(channels)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
identity = self.downsample_id(identity)
else:
identity = self.downsample_conv(identity)
identity = self.downsample_norm(identity)
x += identity
x = F.relu(x)
return x
def get_stats(self, x):
activations, flops = 0, 0
identity = x
in_x = deepcopy(x)
x = self.conv1(x)
tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
activations += tmp_acts
flops += tmp_flops
in_x = deepcopy(x)
x = self.bn1(x)
tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
activations += tmp_acts
flops += tmp_flops
x = F.relu(x)
in_x = deepcopy(x)
x = self.conv2(x)
tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
activations += tmp_acts
flops += tmp_flops
in_x = deepcopy(x)
x = self.bn2(x)
tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
activations += tmp_acts
flops += tmp_flops
if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
identity = self.downsample_id(identity)
else:
in_x = deepcopy(identity)
identity = self.downsample_conv(identity)
tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
activations += tmp_acts
flops += tmp_flops
in_x = deepcopy(identity)
identity = self.downsample_norm(identity)
tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
activations += tmp_acts
flops += tmp_flops
x += identity
x = F.relu(x)
return x, flops, activations
class ResNet(M.Module):
def __init__(
self,
block,
layers=[2, 2, 2, 2],
num_classes=1000,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm=M.BatchNorm2d,
):
super().__init__()
self.in_channels = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = M.Conv2d(
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = norm(self.in_channels)
self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1_0 = BasicBlock(
self.in_channels,
64,
stride=1,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm=M.BatchNorm2d,
)
self.layer1_1 = BasicBlock(
self.in_channels,
64,
stride=1,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm=M.BatchNorm2d,
)
self.layer2_0 = BasicBlock(64, 128, stride=2)
self.layer2_1 = BasicBlock(128, 128)
self.layer3_0 = BasicBlock(128, 256, stride=2)
self.layer3_1 = BasicBlock(256, 256)
self.layer4_0 = BasicBlock(256, 512, stride=2)
self.layer4_1 = BasicBlock(512, 512)
self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
self.layer2 = self._make_layer(
block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
)
self.layer3 = self._make_layer(
block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
)
self.layer4 = self._make_layer(
block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
)
self.fc = M.Linear(512, num_classes)
for m in self.modules():
if isinstance(m, M.Conv2d):
M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
M.init.uniform_(m.bias, -bound, bound)
elif isinstance(m, M.BatchNorm2d):
M.init.ones_(m.weight)
M.init.zeros_(m.bias)
elif isinstance(m, M.Linear):
M.init.msra_uniform_(m.weight, a=math.sqrt(5))
if m.bias is not None:
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
M.init.uniform_(m.bias, -bound, bound)
if zero_init_residual:
for m in self.modules():
M.init.zeros_(m.bn2.weight)
def _make_layer(
self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
):
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
layers = []
layers.append(
block(
self.in_channels,
channels,
stride,
groups=self.groups,
base_width=self.base_width,
dilation=previous_dilation,
norm=norm,
)
)
self.in_channels = channels * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.in_channels,
channels,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm=norm,
)
)
return M.Sequential(*layers)
def extract_features(self, x):
outputs = {}
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.maxpool(x)
outputs["stem"] = x
x = self.layer1(x)
outputs["res2"] = x
x = self.layer2(x)
outputs["res3"] = x
x = self.layer3(x)
outputs["res4"] = x
x = self.layer4(x)
outputs["res5"] = x
return outputs
def forward(self, x):
x = self.extract_features(x)["res5"]
x = F.avg_pool2d(x, 7)
x = F.flatten(x, 1)
x = self.fc(x)
return x
def get_stats(self, x):
flops, activations = 0, 0
in_x = deepcopy(x)
x = self.conv1(x)
tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
activations += tmp_acts
flops += tmp_flops
in_x = deepcopy(x)
x = self.bn1(x)
tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
activations += tmp_acts
flops += tmp_flops
x = F.relu(x)
in_x = deepcopy(x)
x = self.maxpool(x)
tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
activations += tmp_acts
flops += tmp_flops
x = F.avg_pool2d(x, 7)
x = F.flatten(x, 1)
in_x = deepcopy(x)
x = self.fc(x)
tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
activations += tmp_acts
flops += tmp_flops
return flops, activations
def cal_conv_stats(module, input, output):
bias = 1 if module.bias is not None else 0
flops = np.prod(output[0].shape) * (
module.in_channels // module.groups * np.prod(module.kernel_size) + bias
)
acts = np.prod(output[0].shape)
return flops, acts
def cal_norm_stats(module, input, output):
return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
def cal_linear_stats(module, inputs, outputs):
bias = module.out_features if module.bias is not None else 0
return (
np.prod(outputs[0].shape) * module.in_features + bias,
np.prod(outputs[0].shape),
)
def cal_pool_stats(module, inputs, outputs):
return (
np.prod(outputs[0].shape) * (module.kernel_size ** 2),
np.prod(outputs[0].shape),
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册