diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index c0868ed39ee2ab02fde481c720f449055b9430fc..e3ef5e3ed2b60f1f742a929f0a95d69519b77e89 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse +import json import logging import numpy as np @@ -14,6 +15,7 @@ import numpy as np from megengine.core.tensor.dtype import is_quantize from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.utils.module_stats import ( + get_flops_stats, get_param_stats, print_flops_stats, print_params_stats, @@ -89,6 +91,7 @@ def visualize( inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: + # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue( list=AttrValue.ListValue( @@ -101,24 +104,20 @@ def visualize( ] ) ), + "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), + "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } - if hasattr(node, "calc_flops"): - flops_num = node.calc_flops() + flops_stats = get_flops_stats(node, node.inputs, node.outputs) + if flops_stats is not None: # add op flops attr - if log_path: + if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( - s=sizeof_fmt(flops_num).encode(encoding="utf-8") - ) - flops_list.append( - dict( - name=node.name, - class_name=node.type, - input_shapes=[i.shape for i in node.inputs], - output_shapes=[o.shape for o in node.outputs], - flops_num=flops_num, - flops_cum=0, + s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") ) - ) + flops_stats["name"] = node.name + flops_stats["class_name"] = node.type + flops_list.append(flops_stats) + if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr @@ -132,6 +131,7 @@ def visualize( # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug if not len(node.name.split(".")) > 2 and not node in graph.input_vars: continue + if log_path: node_list.append( NodeDef( @@ -141,14 +141,26 @@ def visualize( attr=attr, ) ) + # summary + extra_info = { + "#ops": len(graph.all_oprs), + "#params": len(params_list), + } - total_flops, total_params = None, None + total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: total_param_dims, total_param_size = print_params_stats( params_list, bar_length_max ) + extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) + extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: total_flops = print_flops_stats(flops_list, bar_length_max) + extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") + if log_params and log_flops: + extra_info["flops/param_size"] = "{:3.3f}".format( + total_flops / total_param_size + ) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) @@ -160,21 +172,12 @@ def visualize( writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) - # summary - extra_info = { - "#ops": len(graph.all_oprs), - "#params": len(params_list), - "total_param_dims": sizeof_fmt(total_param_dims), - "total_param_size": sizeof_fmt(total_param_size), - "total_flops": sizeof_fmt(total_flops, suffix="OPs"), - "flops/param_size": "{:3.3f}".format(total_flops / total_param_size), - } print_summary(**extra_info) # FIXME: remove this after resolving "span dist too large" warning _imperative_rt_logger.set_log_level(old_level) - return total_params, total_flops + return total_param_size, total_flops def main(): diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index d063b983fd14c317564bfd7deb17b025b1a822ae..5e94308a0b26e899c1db4252d3b7d7fac2ba78e9 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -26,61 +26,95 @@ logger = mge.get_logger(__name__) logger.setLevel("INFO") -CALC_FLOPS = {} - - -def _register_modules(*modules): +_calc_flops_dict = {} +_calc_receptive_field_dict = {} + + +def _receptive_field_fallback(module, inputs, outputs): + assert not hasattr(module, "_rf") + assert not hasattr(module, "_stride") + if len(inputs) == 0: + # TODO: support other dimension + module._rf = (1, 1) + module._stride = (1, 1) + return module._rf, module._stride + rf, stride = preprocess_receptive_field(module, inputs, outputs) + module._rf = rf + module._stride = stride + return rf, stride + + +# key tuple, impl_dict, fallback +_iter_list = [ + ("flops_num", _calc_flops_dict, None), + ( + ("receptive_field", "stride"), + _calc_receptive_field_dict, + _receptive_field_fallback, + ), +] + + +def _register_dict(*modules, dict=None): def callback(impl): for module in modules: - CALC_FLOPS[module] = impl + dict[module] = impl return impl return callback -@_register_modules( - m.Conv2d, - m.ConvTranspose2d, - m.LocalConv2d, - qm.Conv2d, - qm.ConvRelu2d, - qm.ConvBn2d, - qm.ConvBnRelu2d, - qatm.Conv2d, - qatm.ConvRelu2d, - qatm.ConvBn2d, - qatm.ConvBnRelu2d, +def register_flops(*modules): + return _register_dict(*modules, dict=_calc_flops_dict) + + +def register_receptive_field(*modules): + return _register_dict(*modules, dict=_calc_receptive_field_dict) + + +@register_flops( + m.Conv1d, m.Conv2d, m.Conv3d, ) -def count_convNd(module, input, output): +def flops_convNd(module: m.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 group = module.groups - ic = input[0].shape[1] - oc = output[0].shape[1] + ic = inputs[0].shape[1] + oc = outputs[0].shape[1] goc = oc // group gic = ic // group - N = output[0].shape[0] - HW = np.prod(output[0].shape[2:]) + N = outputs[0].shape[0] + HW = np.prod(outputs[0].shape[2:]) # N x Cout x H x W x (Cin x Kw x Kh + bias) return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) -@_register_modules(m.ConvTranspose2d) -def count_deconvNd(module, input, output): - return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size) +@register_flops(m.ConvTranspose2d) +def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs): + return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size) + +@register_flops(m.Linear) +def flops_linear(module: m.Linear, inputs, outputs): + bias = 1 if module.bias is not None else 0 + return np.prod(outputs[0].shape) * module.in_features -@_register_modules(m.Linear, qatm.Linear, qm.Linear) -def count_linear(module, input, output): - return np.prod(output[0].shape) * module.in_features + +@register_flops(m.BatchMatMulActivation) +def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): + bias = 1 if module.bias is not None else 0 + x = inputs[0] + w = module.weight + batch_size = x.shape[0] + n, p = x.shape[1:] + _, m = w.shape[1:] + return n * (p + bias) * m * batch_size # does not need import qat and quantized module since they inherit from float module. hook_modules = ( - m.Conv2d, - m.ConvTranspose2d, - m.LocalConv2d, - m.BatchNorm2d, + m.conv._ConvNd, m.Linear, + m.BatchMatMulActivation, ) @@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"): return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) +def preprocess_receptive_field(module, inputs, outputs): + # TODO: support other dimensions + pre_rf = ( + max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), + max(i.owner._rf[1] for i in inputs), + ) + pre_stride = ( + max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), + max(i.owner._stride[1] for i in inputs), + ) + return pre_rf, pre_stride + + +def get_flops_stats(module, inputs, outputs): + rst = { + "input_shapes": [i.shape for i in inputs], + "output_shapes": [o.shape for o in outputs], + } + valid_flag = False + for key, _dict, fallback in _iter_list: + for _type in _dict: + if isinstance(module, _type): + value = _dict[_type](module, inputs, outputs) + valid_flag = True + break + else: + if fallback is not None: + value = fallback(module, inputs, outputs) + continue + + if isinstance(key, tuple): + assert isinstance(value, tuple) + for k, v in zip(key, value): + rst[k] = v + else: + rst[key] = value + + if valid_flag: + return rst + else: + return None + return + + def print_flops_stats(flops, bar_length_max=20): - flops_list = [i["flops_num"] for i in flops] - max_flops_num = max(flops_list + [0]) - # calc total flops and set flops_cum + max_flops_num = max([i["flops_num"] for i in flops] + [0]) total_flops_num = 0 for d in flops: total_flops_num += int(d["flops_num"]) d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") for d in flops: - f = d["flops_num"] - d["flops"] = sizeof_fmt(f, suffix="OPs") - r = d["ratio"] = f / total_flops_num - d["percentage"] = "{:.2f}%".format(r * 100) - bar_length = int(f / max_flops_num * bar_length_max) + ratio = d["ratio"] = d["flops_num"] / total_flops_num + d["percentage"] = "{:.2f}%".format(ratio * 100) + bar_length = int(d["flops_num"] / max_flops_num * bar_length_max) d["bar"] = "#" * bar_length + d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") header = [ "name", "class_name", "input_shapes", "output_shapes", + "receptive_field", + "stride", "flops", "flops_cum", "percentage", @@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray): param_size = param_dim * nbits // 8 return { "shape": shape, - "mean": param.mean(), - "std": param.std(), + "mean": "{:.3g}".format(param.mean()), + "std": "{:.3g}".format(param.std()), "param_dim": param_dim, "nbits": nbits, "size": param_size, @@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray): def print_params_stats(params, bar_length_max=20): + max_size = max([d["size"] for d in params] + [0]) total_param_dims, total_param_size = 0, 0 for d in params: total_param_dims += int(d["param_dim"]) total_param_size += int(d["size"]) - ratio = d["size"] / total_param_size - d["size"] = sizeof_fmt(d["size"]) d["size_cum"] = sizeof_fmt(total_param_size) - d["ratio"] = ratio - d["percentage"] = "{:.2f}%".format(ratio * 100) - # construct bar - max_ratio = max([d["ratio"] for d in params]) for d in params: - bar_length = int(d["ratio"] / max_ratio * bar_length_max) + ratio = d["size"] / total_param_size + d["ratio"] = ratio + d["percentage"] = "{:.2f}%".format(ratio * 100) + bar_length = int(d["size"] / max_size * bar_length_max) d["size_bar"] = "#" * bar_length + d["size"] = sizeof_fmt(d["size"]) param_size = sizeof_fmt(total_param_size) params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) @@ -225,26 +301,14 @@ def module_stats( :param log_flops: whether print and record op flops. """ - def module_stats_hook(module, input, output, name=""): + def module_stats_hook(module, inputs, outputs, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] - flops_fun = CALC_FLOPS.get(type(module)) - if callable(flops_fun): - flops_num = flops_fun(module, input, output) - - if not isinstance(output, (list, tuple)): - output = [output] - - flops.append( - dict( - name=name, - class_name=class_name, - input_shapes=[i.shape for i in input], - output_shapes=[o.shape for o in output], - flops_num=flops_num, - flops_cum=0, - ) - ) + flops_stats = get_flops_stats(module, inputs, outputs) + if flops_stats is not None: + flops_stats["name"] = name + flops_stats["class_name"] = class_name + flops.append(flops_stats) if hasattr(module, "weight") and module.weight is not None: w = module.weight @@ -278,19 +342,22 @@ def module_stats( for h in hooks: h.remove() - total_flops, total_params = 0, 0 + extra_info = { + "#params": len(params), + } + total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: total_param_dims, total_param_size = print_params_stats(params, bar_length_max) + extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) + extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: total_flops = print_flops_stats(flops, bar_length_max) + extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") + if log_params and log_flops: + extra_info["flops/param_size"] = "{:3.3f}".format( + total_flops / total_param_size + ) - extra_info = { - "#params": len(params), - "total_param_dims": sizeof_fmt(total_param_dims), - "total_param_size": sizeof_fmt(total_param_size), - "total_flops": sizeof_fmt(total_flops, suffix="OPs"), - "flops/param_size": "{:3.3f}".format(total_flops / total_param_size), - } print_summary(**extra_info) - return total_params, total_flops + return total_param_size, total_flops diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index 58a5982eaff895663bec85a56080b7d010e5af7e..5e1f1af339729d377eb139e029bf94c612b6a3be 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -18,6 +18,11 @@ from ..core.ops import builtin from ..core.tensor.megbrain_graph import InputNode from ..tensor import Tensor from .comp_graph_tools import replace_vars +from .module_stats import ( + preprocess_receptive_field, + register_flops, + register_receptive_field, +) class NetworkNode: @@ -225,8 +230,21 @@ class Elemwise(OpNode): type = "Elemwise" opdef = builtin.Elemwise - def calc_flops(self): - return np.prod(self.outputs[0].shape) + +class ElemwiseMultiType(OpNode): + type = "ElemwiseMultiType" + opdef = builtin.ElemwiseMultiType + + @classmethod + def load(cls, opr): + obj = super(ElemwiseMultiType, cls).load(opr) + obj.params["dtype"] = opr.outputs[0].dtype + return obj + + +@register_flops(Elemwise, ElemwiseMultiType) +def flops_elemwise(opnode: Elemwise, inputs, outputs): + return np.prod(outputs[0].shape) class Reduce(OpNode): @@ -255,20 +273,24 @@ class MatrixMul(OpNode): type = "MatrixMul" opdef = builtin.MatrixMul - def calc_flops(self): - assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 - mid_shape = self.inputs[0].shape[1] - return np.prod(self.outputs[0].shape) * mid_shape + +@register_flops(MatrixMul) +def flops_matmul(opnode: MatrixMul, inputs, outputs): + assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2 + mid_shape = inputs[0].shape[1] + return np.prod(outputs[0].shape) * mid_shape class BatchedMatrixMul(OpNode): type = "BatchedMatmul" opdef = builtin.BatchedMatrixMul - def calc_flops(self): - assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 - mid_shape = self.inputs[0].shape[2] - return np.prod(self.outputs[0].shape) * mid_shape + +@register_flops(BatchedMatrixMul) +def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs): + assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3 + mid_shape = inputs[0].shape[2] + return np.prod(outputs[0].shape) * mid_shape class Dot(OpNode): @@ -285,18 +307,6 @@ class ConvolutionForward(OpNode): type = "Convolution" opdef = builtin.Convolution - def calc_flops(self): - param_W_shape = self.inputs[1].shape - kh = param_W_shape[-2] - kw = param_W_shape[-1] - if len(param_W_shape) == 5: - num_input = param_W_shape[2] - else: - num_input = param_W_shape[1] - NCHW = np.prod(self.outputs[0].shape) - # N x Cout x H x W x (Cin x Kw x Kh) - return NCHW * (num_input * kw * kh) - class ConvolutionBackwardData(OpNode): type = "ConvTranspose" @@ -343,17 +353,41 @@ class ConvBiasForward(OpNode): obj.params["dtype"] = opr.outputs[0].dtype return obj - def calc_flops(self): - param_W_shape = self.inputs[1].shape - kh = param_W_shape[-2] - kw = param_W_shape[-1] - if len(param_W_shape) == 5: - num_input = param_W_shape[2] - else: - num_input = param_W_shape[1] - NCHW = np.prod(self.outputs[0].shape) - # N x Cout x H x W x (Cin x Kw x Kh + bias) - return NCHW * (num_input * kw * kh + 1) + +@register_flops( + ConvolutionForward, ConvBiasForward, +) +def flops_conv(opnode: ConvolutionForward, inputs, outputs): + param_W_shape = inputs[1].shape + kh = param_W_shape[-2] + kw = param_W_shape[-1] + if len(param_W_shape) == 5: + num_input = param_W_shape[2] + else: + num_input = param_W_shape[1] + NCHW = np.prod(outputs[0].shape) + bias = 1 if isinstance(opnode, ConvBiasForward) else 0 + # N x Cout x H x W x (Cin x Kw x Kh) + return NCHW * (num_input * kw * kh + bias) + + +@register_receptive_field(ConvolutionForward, ConvBiasForward) +def receptive_field(opnode: ConvolutionForward, inputs, outputs): + pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs) + param_W_shape = inputs[1].shape + kh = param_W_shape[-2] + kw = param_W_shape[-1] + rf = ( + kh * pre_stride[0] + pre_rf[0] - pre_stride[0], + kw * pre_stride[1] + pre_rf[1] - pre_stride[1], + ) + stride = ( + opnode.params["stride_h"] * pre_stride[0], + opnode.params["stride_w"] * pre_stride[1], + ) + opnode._rf = rf + opnode._stride = stride + return rf, stride class BatchConvBiasForward(OpNode): @@ -652,20 +686,6 @@ class AssertEqual(OpNode): opdef = builtin.AssertEqual -class ElemwiseMultiType(OpNode): - type = "ElemwiseMultiType" - opdef = builtin.ElemwiseMultiType - - @classmethod - def load(cls, opr): - obj = super(ElemwiseMultiType, cls).load(opr) - obj.params["dtype"] = opr.outputs[0].dtype - return obj - - def calc_flops(self): - return np.prod(self.outputs[0].shape) - - class CvtColorForward(OpNode): type = "CvtColor" opdef = builtin.CvtColor