diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index fc67d7530f2321f8e33d081c23a492ef6c0f98fa..3b24584d1ce3dd19ae27b85a02cbf83f6a9eba18 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -9,6 +9,7 @@ import argparse import logging import re +from collections import namedtuple import numpy as np @@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.utils.module_stats import ( enable_receptive_field, + get_activation_stats, get_op_stats, get_param_stats, + print_activations_stats, print_op_stats, print_param_stats, print_summary, sizeof_fmt, + sum_activations_stats, + sum_op_stats, + sum_param_stats, ) from megengine.utils.network import Network @@ -34,6 +40,7 @@ def visualize( bar_length_max: int = 20, log_params: bool = True, log_flops: bool = True, + log_activations: bool = True, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. @@ -44,6 +51,7 @@ def visualize( :param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. + :param log_activations: whether print and record op activations. """ if log_path: try: @@ -83,6 +91,10 @@ def visualize( node_list = [] flops_list = [] params_list = [] + activations_list = [] + total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) + stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) + for node in graph.all_oprs: if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] @@ -124,6 +136,11 @@ def visualize( flops_stats["class_name"] = node.type flops_list.append(flops_stats) + acts = get_activation_stats(node_oup.numpy()) + acts["name"] = node.name + acts["class_name"] = node.type + activations_list.append(acts) + if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr @@ -149,20 +166,36 @@ def visualize( "#params": len(params_list), } - total_flops, total_param_dims, total_param_size = 0, 0, 0 + ( + total_flops, + total_param_dims, + total_param_size, + total_act_dims, + total_param_size, + ) = (0, 0, 0, 0, 0) + + total_param_dims, total_param_size, params = sum_param_stats( + params_list, bar_length_max + ) + extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") + extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_params: - total_param_dims, total_param_size = print_param_stats( - params_list, bar_length_max - ) - extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) - extra_info["total_param_size"] = sizeof_fmt(total_param_size) + print_param_stats(params) + + total_flops, flops = sum_op_stats(flops_list, bar_length_max) + extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_flops: - total_flops = print_op_stats(flops_list, bar_length_max) - extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") - if log_params and log_flops: - extra_info["flops/param_size"] = "{:3.3f}".format( - total_flops / total_param_size - ) + print_op_stats(flops) + + total_act_dims, total_act_size, activations = sum_activations_stats( + activations_list, bar_length_max + ) + extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") + extra_info["total_act_size"] = sizeof_fmt(total_act_size) + if log_activations: + print_activations_stats(activations) + + extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) @@ -179,7 +212,12 @@ def visualize( # FIXME: remove this after resolving "span dist too large" warning _imperative_rt_logger.set_log_level(old_level) - return total_param_size, total_flops + return ( + total_stats( + param_size=total_param_size, flops=total_flops, act_size=total_act_size, + ), + stats_details(params=params, flops=flops, activations=activations), + ) def main(): diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 690cda07ef9854bdb9594342b0ba834b9c55016a..dc3a03ea26143bbe25f461c2e65ed8597e1060d8 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -5,7 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import contextlib +from collections import namedtuple from functools import partial import numpy as np @@ -18,6 +18,8 @@ import megengine.module.quantized as qm from megengine.core.tensor.dtype import get_dtype_bit from megengine.functional.tensor import zeros +from .module_utils import set_module_mode_safe + try: mge.logger.MegEngineLogFormatter.max_lines = float("inf") except AttributeError as e: @@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs): ) +@register_flops( + m.batchnorm._BatchNorm, m.SyncBatchNorm, m.GroupNorm, m.LayerNorm, m.InstanceNorm, +) +def flops_norm(module: m.Linear, inputs, outputs): + return np.prod(inputs[0].shape) * 7 + + +@register_flops(m.AvgPool2d, m.MaxPool2d) +def flops_pool(module: m.AvgPool2d, inputs, outputs): + return np.prod(outputs[0].shape) * (module.kernel_size ** 2) + + +@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) +def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs): + stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1)) + kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h + stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1)) + kernel_w = inputs[0].shape[3] - (inputs[0].shape[3] - 1) * stride_w + return np.prod(outputs[0].shape) * kernel_h * kernel_w + + @register_flops(m.Linear) def flops_linear(module: m.Linear, inputs, outputs): bias = module.out_features if module.bias is not None else 0 @@ -120,6 +143,12 @@ hook_modules = ( m.conv._ConvNd, m.Linear, m.BatchMatMulActivation, + m.batchnorm._BatchNorm, + m.LayerNorm, + m.GroupNorm, + m.InstanceNorm, + m.pooling._PoolNd, + m.adaptive_pooling._AdaptivePoolNd, ) @@ -137,12 +166,16 @@ def dict2table(list_of_dict, header): def sizeof_fmt(num, suffix="B"): - for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: - if abs(num) < 1024.0: + if suffix == "B": + scale = 1024.0 + units = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"] + else: + scale = 1000.0 + units = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"] + for unit in units: + if abs(num) < scale or unit == units[-1]: return "{:3.3f} {}{}".format(num, unit, suffix) - num /= 1024.0 - sign_str = "-" if num < 0 else "" - return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) + num /= scale def preprocess_receptive_field(module, inputs, outputs): @@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs): def get_op_stats(module, inputs, outputs): + if not isinstance(outputs, tuple) and not isinstance(outputs, list): + outputs = (outputs,) rst = { "input_shapes": [i.shape for i in inputs], "output_shapes": [o.shape for o in outputs], @@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs): return -def print_op_stats(flops, bar_length_max=20): +def sum_op_stats(flops, bar_length_max=20): max_flops_num = max([i["flops_num"] for i in flops] + [0]) total_flops_num = 0 for d in flops: @@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20): d["bar"] = "#" * bar_length d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") + total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") + total_var_size = sum( + sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops + ) + flops.append( + dict(name="total", flops=total_flops_str, output_shapes=total_var_size) + ) + + return total_flops_num, flops + + +def print_op_stats(flops): header = [ "name", "class_name", @@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20): if _receptive_field_enabled: header.insert(4, "receptive_field") header.insert(5, "stride") - - total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") - total_var_size = sum( - sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops - ) - flops.append( - dict(name="total", flops=total_flops_str, output_shapes=total_var_size) - ) - logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) - return total_flops_num - def get_param_stats(param: np.ndarray): nbits = get_dtype_bit(param.dtype.name) @@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray): } -def print_param_stats(params, bar_length_max=20): +def sum_param_stats(params, bar_length_max=20): max_size = max([d["size"] for d in params] + [0]) total_param_dims, total_param_size = 0, 0 for d in params: @@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20): param_size = sizeof_fmt(total_param_size) params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) + return total_param_dims, total_param_size, params + + +def print_param_stats(params): header = [ "name", "dtype", @@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20): "mean", "std", "param_dim", - "bits", + "nbits", "size", "size_cum", "percentage", "size_bar", ] - logger.info( "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) ) - return total_param_dims, total_param_size + +def get_activation_stats(output: np.ndarray): + out_shape = output.shape + activations_dtype = output.dtype + nbits = get_dtype_bit(activations_dtype.name) + act_dim = np.prod(out_shape) + act_size = act_dim * nbits // 8 + return { + "dtype": activations_dtype, + "shape": out_shape, + "act_dim": act_dim, + "mean": "{:.3g}".format(output.mean()), + "std": "{:.3g}".format(output.std()), + "nbits": nbits, + "size": act_size, + } + + +def sum_activations_stats(activations, bar_length_max=20): + max_act_size = max([i["size"] for i in activations] + [0]) + total_act_dims, total_act_size = 0, 0 + for d in activations: + total_act_size += int(d["size"]) + total_act_dims += int(d["act_dim"]) + d["size_cum"] = sizeof_fmt(total_act_size) + + for d in activations: + ratio = d["ratio"] = d["size"] / total_act_size + d["percentage"] = "{:.2f}%".format(ratio * 100) + bar_length = int(d["size"] / max_act_size * bar_length_max) + d["size_bar"] = "#" * bar_length + d["size"] = sizeof_fmt(d["size"]) + + act_size = sizeof_fmt(total_act_size) + activations.append(dict(name="total", act_dim=total_act_dims, size=act_size,)) + + return total_act_dims, total_act_size, activations + + +def print_activations_stats(activations): + header = [ + "name", + "class_name", + "dtype", + "shape", + "mean", + "std", + "nbits", + "act_dim", + "size", + "size_cum", + "percentage", + "size_bar", + ] + logger.info( + "activations stats: \n" + + tabulate.tabulate(dict2table(activations, header=header)) + ) def print_summary(**kwargs): @@ -294,25 +390,26 @@ def print_summary(**kwargs): def module_stats( model: m.Module, - input_size: int, + input_shapes: list, bar_length_max: int = 20, log_params: bool = True, log_flops: bool = True, + log_activations: bool = True, ): r""" Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. :param model: model that need to get stats info. - :param input_size: size of input for running model and calculating stats. + :param input_shapes: shapes of inputs for running model and calculating stats. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. + :param log_activations: whether print and record op activations. """ disable_receptive_field() def module_stats_hook(module, inputs, outputs, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] - flops_stats = get_op_stats(module, inputs, outputs) if flops_stats is not None: flops_stats["name"] = name @@ -331,38 +428,25 @@ def module_stats( param_stats["name"] = name + "-b" params.append(param_stats) - @contextlib.contextmanager - def adjust_stats(module, training=False): - """Adjust module to training/eval mode temporarily. - - Args: - module (M.Module): used module. - training (bool): training mode. True for train mode, False fro eval mode. - """ - - def recursive_backup_stats(module, mode): - for m in module.modules(): - # save prev status to _prev_training - m._prev_training = m.training - m.train(mode, recursive=False) - - def recursive_recover_stats(module): - for m in module.modules(): - # recover prev status and delete attribute - m.training = m._prev_training - delattr(m, "_prev_training") - - recursive_backup_stats(module, mode=training) - yield module - recursive_recover_stats(module) + if not isinstance(outputs, tuple) or not isinstance(outputs, list): + output = outputs.numpy() + else: + output = outputs[0].numpy() + activation_stats = get_activation_stats(output) + activation_stats["name"] = name + activation_stats["class_name"] = class_name + activations.append(activation_stats) # multiple inputs to the network - if not isinstance(input_size[0], tuple): - input_size = [input_size] + if not isinstance(input_shapes[0], tuple): + input_shapes = [input_shapes] params = [] flops = [] hooks = [] + activations = [] + total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) + stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) for (name, module) in model.named_modules(): if isinstance(module, hook_modules): @@ -370,8 +454,8 @@ def module_stats( module.register_forward_hook(partial(module_stats_hook, name=name)) ) - inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] - with adjust_stats(model, training=False) as model: + inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] + with set_module_mode_safe(model, training=False) as model: model(*inputs) for h in hooks: @@ -380,19 +464,40 @@ def module_stats( extra_info = { "#params": len(params), } - total_flops, total_param_dims, total_param_size = 0, 0, 0 + ( + total_flops, + total_param_dims, + total_param_size, + total_act_dims, + total_param_size, + ) = (0, 0, 0, 0, 0) + + total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) + extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") + extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_params: - total_param_dims, total_param_size = print_param_stats(params, bar_length_max) - extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) - extra_info["total_param_size"] = sizeof_fmt(total_param_size) + print_param_stats(params) + + total_flops, flops = sum_op_stats(flops, bar_length_max) + extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_flops: - total_flops = print_op_stats(flops, bar_length_max) - extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") - if log_params and log_flops: - extra_info["flops/param_size"] = "{:3.3f}".format( - total_flops / total_param_size - ) + print_op_stats(flops) + + total_act_dims, total_act_size, activations = sum_activations_stats( + activations, bar_length_max + ) + extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") + extra_info["total_act_size"] = sizeof_fmt(total_act_size) + if log_activations: + print_activations_stats(activations) + + extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) print_summary(**extra_info) - return total_param_size, total_flops + return ( + total_stats( + param_size=total_param_size, flops=total_flops, act_size=total_act_size, + ), + stats_details(params=params, flops=flops, activations=activations), + ) diff --git a/imperative/python/megengine/utils/module_utils.py b/imperative/python/megengine/utils/module_utils.py index c66eb6060c61edfc59ac0cddd60676c757168ee7..2ee8e79b780f98bc5e2dd750b325c93d8a6d80e9 100644 --- a/imperative/python/megengine/utils/module_utils.py +++ b/imperative/python/megengine/utils/module_utils.py @@ -5,6 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import contextlib from collections import Iterable from ..module import Sequential @@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value): parent[key] = value _access_structure(obj, key, callback=f) + + +@contextlib.contextmanager +def set_module_mode_safe( + module: Module, training: bool = False, +): + """Adjust module to training/eval mode temporarily. + + :param module: used module. + :param training: training (bool): training mode. True for train mode, False fro eval mode. + """ + backup_stats = {} + + def recursive_backup_stats(module, mode): + for m in module.modules(): + backup_stats[m] = m.training + m.train(mode, recursive=False) + + def recursive_recover_stats(module): + for m in module.modules(): + m.training = backup_stats.pop(m) + + recursive_backup_stats(module, mode=training) + yield module + recursive_recover_stats(module) diff --git a/imperative/python/test/unit/utils/test_module_stats.py b/imperative/python/test/unit/utils/test_module_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..2a570689017fc5f96bbf5988a1c7c5ab507b873c --- /dev/null +++ b/imperative/python/test/unit/utils/test_module_stats.py @@ -0,0 +1,377 @@ +import math +from copy import deepcopy + +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +import megengine.hub as hub +import megengine.module as M +from megengine.core._trace_option import use_symbolic_shape +from megengine.utils.module_stats import module_stats + + +@pytest.mark.skipif( + use_symbolic_shape(), reason="This test do not support symbolic shape.", +) +def test_module_stats(): + net = ResNet(BasicBlock, [2, 2, 2, 2]) + input_shape = (1, 3, 224, 224) + total_stats, stats_details = module_stats(net, input_shape) + + x1 = mge.tensor(np.zeros((1, 3, 224, 224))) + gt_flops, gt_acts = net.get_stats(x1) + + assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( + gt_flops, + gt_acts, + ) + + +class BasicBlock(M.Module): + expansion = 1 + + def __init__( + self, + in_channels, + channels, + stride=1, + groups=1, + base_width=64, + dilation=1, + norm=M.BatchNorm2d, + ): + super().__init__() + + self.tmp_in_channels = in_channels + self.tmp_channels = channels + self.stride = stride + + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.conv1 = M.Conv2d( + in_channels, channels, 3, stride, padding=dilation, bias=False + ) + self.bn1 = norm(channels) + self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False) + self.bn2 = norm(channels) + + self.downsample_id = M.Identity() + self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False) + self.downsample_norm = norm(channels) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.bn2(x) + if self.tmp_in_channels == self.tmp_channels and self.stride == 1: + identity = self.downsample_id(identity) + else: + identity = self.downsample_conv(identity) + identity = self.downsample_norm(identity) + x += identity + x = F.relu(x) + return x + + def get_stats(self, x): + activations, flops = 0, 0 + + identity = x + + in_x = deepcopy(x) + x = self.conv1(x) + tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x) + activations += tmp_acts + flops += tmp_flops + + in_x = deepcopy(x) + x = self.bn1(x) + tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x) + activations += tmp_acts + flops += tmp_flops + + x = F.relu(x) + + in_x = deepcopy(x) + x = self.conv2(x) + tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x) + activations += tmp_acts + flops += tmp_flops + + in_x = deepcopy(x) + x = self.bn2(x) + tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x) + activations += tmp_acts + flops += tmp_flops + + if self.tmp_in_channels == self.tmp_channels and self.stride == 1: + identity = self.downsample_id(identity) + else: + in_x = deepcopy(identity) + identity = self.downsample_conv(identity) + tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity) + activations += tmp_acts + flops += tmp_flops + + in_x = deepcopy(identity) + identity = self.downsample_norm(identity) + tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity) + activations += tmp_acts + flops += tmp_flops + + x += identity + x = F.relu(x) + + return x, flops, activations + + +class ResNet(M.Module): + def __init__( + self, + block, + layers=[2, 2, 2, 2], + num_classes=1000, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm=M.BatchNorm2d, + ): + super().__init__() + self.in_channels = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = M.Conv2d( + 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm(self.in_channels) + self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1_0 = BasicBlock( + self.in_channels, + 64, + stride=1, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm=M.BatchNorm2d, + ) + self.layer1_1 = BasicBlock( + self.in_channels, + 64, + stride=1, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm=M.BatchNorm2d, + ) + self.layer2_0 = BasicBlock(64, 128, stride=2) + self.layer2_1 = BasicBlock(128, 128) + self.layer3_0 = BasicBlock(128, 256, stride=2) + self.layer3_1 = BasicBlock(256, 256) + self.layer4_0 = BasicBlock(256, 512, stride=2) + self.layer4_1 = BasicBlock(512, 512) + + self.layer1 = self._make_layer(block, 64, layers[0], norm=norm) + self.layer2 = self._make_layer( + block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm + ) + self.layer3 = self._make_layer( + block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm + ) + self.layer4 = self._make_layer( + block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm + ) + self.fc = M.Linear(512, num_classes) + + for m in self.modules(): + if isinstance(m, M.Conv2d): + M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + M.init.uniform_(m.bias, -bound, bound) + elif isinstance(m, M.BatchNorm2d): + M.init.ones_(m.weight) + M.init.zeros_(m.bias) + elif isinstance(m, M.Linear): + M.init.msra_uniform_(m.weight, a=math.sqrt(5)) + if m.bias is not None: + fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + M.init.uniform_(m.bias, -bound, bound) + if zero_init_residual: + for m in self.modules(): + M.init.zeros_(m.bn2.weight) + + def _make_layer( + self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d + ): + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + + layers = [] + layers.append( + block( + self.in_channels, + channels, + stride, + groups=self.groups, + base_width=self.base_width, + dilation=previous_dilation, + norm=norm, + ) + ) + self.in_channels = channels * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.in_channels, + channels, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm=norm, + ) + ) + + return M.Sequential(*layers) + + def extract_features(self, x): + outputs = {} + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.maxpool(x) + outputs["stem"] = x + + x = self.layer1(x) + outputs["res2"] = x + x = self.layer2(x) + outputs["res3"] = x + x = self.layer3(x) + outputs["res4"] = x + x = self.layer4(x) + outputs["res5"] = x + return outputs + + def forward(self, x): + x = self.extract_features(x)["res5"] + + x = F.avg_pool2d(x, 7) + x = F.flatten(x, 1) + x = self.fc(x) + + return x + + def get_stats(self, x): + flops, activations = 0, 0 + in_x = deepcopy(x) + x = self.conv1(x) + tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x) + activations += tmp_acts + flops += tmp_flops + + in_x = deepcopy(x) + x = self.bn1(x) + tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x) + activations += tmp_acts + flops += tmp_flops + + x = F.relu(x) + + in_x = deepcopy(x) + x = self.maxpool(x) + tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x) + activations += tmp_acts + flops += tmp_flops + + x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x) + activations += tmp_acts + flops += tmp_flops + + x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x) + activations += tmp_acts + flops += tmp_flops + + x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x) + activations += tmp_acts + flops += tmp_flops + + x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x) + activations += tmp_acts + flops += tmp_flops + + x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x) + activations += tmp_acts + flops += tmp_flops + + x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x) + activations += tmp_acts + flops += tmp_flops + + x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x) + activations += tmp_acts + flops += tmp_flops + + x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x) + activations += tmp_acts + flops += tmp_flops + + x = F.avg_pool2d(x, 7) + + x = F.flatten(x, 1) + + in_x = deepcopy(x) + x = self.fc(x) + tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x) + activations += tmp_acts + flops += tmp_flops + + return flops, activations + + +def cal_conv_stats(module, input, output): + bias = 1 if module.bias is not None else 0 + flops = np.prod(output[0].shape) * ( + module.in_channels // module.groups * np.prod(module.kernel_size) + bias + ) + acts = np.prod(output[0].shape) + return flops, acts + + +def cal_norm_stats(module, input, output): + return np.prod(input[0].shape) * 7, np.prod(output[0].shape) + + +def cal_linear_stats(module, inputs, outputs): + bias = module.out_features if module.bias is not None else 0 + return ( + np.prod(outputs[0].shape) * module.in_features + bias, + np.prod(outputs[0].shape), + ) + + +def cal_pool_stats(module, inputs, outputs): + return ( + np.prod(outputs[0].shape) * (module.kernel_size ** 2), + np.prod(outputs[0].shape), + )