From 46d96478a16410a06be3383463cd417eb8390a48 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 19 May 2021 19:59:50 +0800 Subject: [PATCH] feat(mge/tools): optimize statistical tools BREAKING CHANGE: GitOrigin-RevId: cd2a1acd1128e482f2cb0963e4e7e1fe0cc47b14 --- .../megengine/tools/network_visualize.py | 165 +++++++++++------ .../python/megengine/utils/module_stats.py | 172 +++++++++++------- .../test/unit/utils/test_module_stats.py | 12 +- 3 files changed, 222 insertions(+), 127 deletions(-) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index 86c471413..f752925a8 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -12,6 +12,7 @@ import re from collections import namedtuple import numpy as np +from tqdm import tqdm from megengine.core.tensor.dtype import is_quantize from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level @@ -37,10 +38,13 @@ logger = get_logger(__name__) def visualize( model_path: str, log_path: str, + input: np.ndarray = None, + inp_dict: dict = None, + cal_params: bool = True, + cal_flops: bool = True, + cal_activations: bool = True, + logging_to_stdout: bool = True, bar_length_max: int = 20, - log_params: bool = True, - log_flops: bool = True, - log_activations: bool = True, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. @@ -48,10 +52,14 @@ def visualize( :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. + :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. + :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. + :param cal_params: whether calculate and record params size. + :param cal_flops: whether calculate and record op flops. + :param cal_activations: whether calculate and record op activations. + :param logging_to_stdout: whether print all calculated statistic details. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. - :param log_params: whether print and record params size. - :param log_flops: whether print and record op flops. - :param log_activations: whether print and record op activations. + """ if log_path: try: @@ -78,6 +86,27 @@ def visualize( enable_receptive_field() graph = Network.load(model_path) + graph.reset_batch_size(1) + + has_input = False + if input is not None or inp_dict is not None: + has_input = True + repl_dict = {} + inp_vars = graph.input_vars + if inp_dict is not None: + assert len(inp_dict) == len( + inp_vars + ), "Inputs are not sufficient for calculation." + for v in inp_vars: + new_input = graph.make_const(inp_dict[v.name], name=v.name) + repl_dict[v] = new_input + else: + assert len(inp_vars) == 1, "The graph needs more than one input." + inp_var = inp_vars[0] + repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) + graph.replace_vars(repl_dict=repl_dict) + + graph._compile() def process_name(name): # nodes that start with point or contain float const will lead to display bug @@ -93,7 +122,7 @@ def visualize( total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) - for node in graph.all_oprs: + for node in tqdm(graph.all_oprs): if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: @@ -123,31 +152,35 @@ def visualize( "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } - flops_stats = get_op_stats(node, node.inputs, node.outputs) - if flops_stats is not None: - # add op flops attr - if log_path and hasattr(flops_stats, "flops_num"): - attr["flops"] = AttrValue( - s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") - ) - flops_stats["name"] = node.name - flops_stats["class_name"] = node.type - flops_list.append(flops_stats) - acts = get_activation_stats(node_oup) + if cal_flops: + flops_stats = get_op_stats(node, node.inputs, node.outputs) + if flops_stats is not None: + # add op flops attr + if log_path and hasattr(flops_stats, "flops_num"): + attr["flops"] = AttrValue( + s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") + ) + flops_stats["name"] = node.name + flops_stats["class_name"] = node.type + flops_list.append(flops_stats) + + if cal_activations: + acts = get_activation_stats(node_oup.numpy(), has_input=has_input) acts["name"] = node.name acts["class_name"] = node.type activations_list.append(acts) - if node.type == "ImmutableTensor": - param_stats = get_param_stats(node_oup) - # add tensor size attr - if log_path: - attr["size"] = AttrValue( - s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") - ) - param_stats["name"] = node.name - params_list.append(param_stats) + if cal_params: + if node.type == "ImmutableTensor": + param_stats = get_param_stats(node.numpy()) + # add tensor size attr + if log_path: + attr["size"] = AttrValue( + s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") + ) + param_stats["name"] = node.name + params_list.append(param_stats) if log_path: node_list.append( @@ -169,31 +202,37 @@ def visualize( total_param_dims, total_param_size, total_act_dims, - total_param_size, + total_act_size, ) = (0, 0, 0, 0, 0) - total_param_dims, total_param_size, params = sum_param_stats( - params_list, bar_length_max - ) - extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") - extra_info["total_param_size"] = sizeof_fmt(total_param_size) - if log_params: - print_param_stats(params) - - total_flops, flops = sum_op_stats(flops_list, bar_length_max) - extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") - if log_flops: - print_op_stats(flops) - - total_act_dims, total_act_size, activations = sum_activations_stats( - activations_list, bar_length_max - ) - extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") - extra_info["total_act_size"] = sizeof_fmt(total_act_size) - if log_activations: - print_activations_stats(activations) + if cal_params: + total_param_dims, total_param_size, params_list = sum_param_stats( + params_list, bar_length_max + ) + extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") + extra_info["total_param_size"] = sizeof_fmt(total_param_size) + if logging_to_stdout: + print_param_stats(params_list) - extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) + if cal_flops: + total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) + extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") + if logging_to_stdout: + print_op_stats(flops_list) + + if cal_activations: + total_act_dims, total_act_size, activations_list = sum_activations_stats( + activations_list, bar_length_max + ) + extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") + extra_info["total_act_size"] = sizeof_fmt(total_act_size) + if logging_to_stdout: + print_activations_stats(activations_list, has_input=has_input) + + if cal_flops and cal_params: + extra_info["flops/param_size"] = "{:3.3f}".format( + total_flops / total_param_size + ) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) @@ -211,7 +250,9 @@ def visualize( total_stats( param_size=total_param_size, flops=total_flops, act_size=total_act_size, ), - stats_details(params=params, flops=flops, activations=activations), + stats_details( + params=params_list, flops=flops_list, activations=activations_list + ), ) @@ -229,12 +270,24 @@ def main(): help="size of bar indicating max flops or parameter size in net stats.", ) parser.add_argument( - "--log_params", + "--cal_params", + action="store_true", + help="whether calculate and record params size.", + ) + parser.add_argument( + "--cal_flops", + action="store_true", + help="whether calculate and record op flops.", + ) + parser.add_argument( + "--cal_activations", action="store_true", - help="whether print and record params size.", + help="whether calculate and record op activations.", ) parser.add_argument( - "--log_flops", action="store_true", help="whether print and record op flops.", + "--logging_to_stdout", + action="store_true", + help="whether print all calculated statistic details.", ) parser.add_argument( "--all", @@ -243,8 +296,10 @@ def main(): ) args = parser.parse_args() if args.all: - args.log_params = True - args.log_flops = True + args.cal_params = True + args.cal_flops = True + args.cal_activations = True + args.logging_to_stdout = True if not args.log_path: args.log_path = "./log" kwargs = vars(args) diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index d91f2cbf4..3e5cce499 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -5,8 +5,9 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from collections import namedtuple +from collections import Iterable, namedtuple from functools import partial +from typing import Iterable import numpy as np import tabulate @@ -19,6 +20,7 @@ from megengine import Tensor from megengine import functional as F from megengine.core.tensor.dtype import get_dtype_bit from megengine.functional.tensor import zeros +from megengine.tensor import Tensor from .module_utils import set_module_mode_safe @@ -335,21 +337,23 @@ def print_param_stats(params): ) -def get_activation_stats(output: Tensor): +def get_activation_stats(output: np.ndarray, has_input=False): out_shape = output.shape activations_dtype = np.dtype(output.dtype) nbits = get_dtype_bit(activations_dtype.name) act_dim = np.prod(out_shape) act_size = act_dim * nbits // 8 - return { + activation_stats = { "dtype": activations_dtype, "shape": out_shape, "act_dim": act_dim, - "mean": "{:.3g}".format(_mean(output)), - "std": "{:.3g}".format(_std(output)), "nbits": nbits, "size": act_size, } + if has_input: + activation_stats["mean"] = "{:.3g}".format(output.mean()) + activation_stats["std"] = "{:.3g}".format(output.std()) + return activation_stats def sum_activations_stats(activations, bar_length_max=20): @@ -373,14 +377,12 @@ def sum_activations_stats(activations, bar_length_max=20): return total_act_dims, total_act_size, activations -def print_activations_stats(activations): +def print_activations_stats(activations, has_input=False): header = [ "name", "class_name", "dtype", "shape", - "mean", - "std", "nbits", "act_dim", "size", @@ -388,6 +390,9 @@ def print_activations_stats(activations): "percentage", "size_bar", ] + if has_input: + header.insert(4, "mean") + header.insert(5, "std") logger.info( "activations stats: \n" + tabulate.tabulate(dict2table(activations, header=header)) @@ -402,56 +407,80 @@ def print_summary(**kwargs): def module_stats( model: m.Module, - input_shapes: list, + inputs: Iterable[np.ndarray] = None, + input_shapes: list = None, + cal_params: bool = True, + cal_flops: bool = True, + cal_activations: bool = True, + logging_to_stdout: bool = True, bar_length_max: int = 20, - log_params: bool = True, - log_flops: bool = True, - log_activations: bool = True, ): r""" Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. :param model: model that need to get stats info. - :param input_shapes: shapes of inputs for running model and calculating stats. + :param inputs: user defined input data for running model and calculating stats, alternative with input_shapes. + :param input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs. + :param cal_params: whether calculate and record params size. + :param cal_flops: whether calculate and record op flops. + :param cal_activations: whether calculate and record op activations. + :param logging_to_stdout: whether print all calculated statistic details. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. - :param log_params: whether print and record params size. - :param log_flops: whether print and record op flops. - :param log_activations: whether print and record op activations. + """ + has_inputs = False + if inputs is not None: + has_inputs = True + if not isinstance(inputs, (tuple, list)): + inputs = [inputs] + inputs = [Tensor(input, dtype=np.float32) for input in inputs] + else: + if input_shapes: + if not isinstance(input_shapes[0], tuple): + input_shapes = [input_shapes] + inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] + else: + logger.error( + "Inputs or input_shapes is required for running model and calculating stats.", + exc_info=True, + ) + return + if not cal_activations: + log_activations = False + disable_receptive_field() def module_stats_hook(module, inputs, outputs, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] - flops_stats = get_op_stats(module, inputs, outputs) - if flops_stats is not None: - flops_stats["name"] = name - flops_stats["class_name"] = class_name - flops.append(flops_stats) - - if hasattr(module, "weight") and module.weight is not None: - w = module.weight - param_stats = get_param_stats(w) - param_stats["name"] = name + "-w" - params.append(param_stats) - - if hasattr(module, "bias") and module.bias is not None: - b = module.bias - param_stats = get_param_stats(b) - param_stats["name"] = name + "-b" - params.append(param_stats) - - if not isinstance(outputs, tuple) or not isinstance(outputs, list): - output = outputs - else: - output = outputs[0] - activation_stats = get_activation_stats(output) - activation_stats["name"] = name - activation_stats["class_name"] = class_name - activations.append(activation_stats) - - # multiple inputs to the network - if not isinstance(input_shapes[0], tuple): - input_shapes = [input_shapes] + if cal_flops: + flops_stats = get_op_stats(module, inputs, outputs) + if flops_stats is not None: + flops_stats["name"] = name + flops_stats["class_name"] = class_name + flops.append(flops_stats) + + if cal_params: + if hasattr(module, "weight") and module.weight is not None: + w = module.weight + param_stats = get_param_stats(w.numpy()) + param_stats["name"] = name + "-w" + params.append(param_stats) + + if hasattr(module, "bias") and module.bias is not None: + b = module.bias + param_stats = get_param_stats(b.numpy()) + param_stats["name"] = name + "-b" + params.append(param_stats) + + if cal_activations: + if not isinstance(outputs, (tuple, list)): + output = outputs.numpy() + else: + output = outputs[0].numpy() + activation_stats = get_activation_stats(output, has_inputs) + activation_stats["name"] = name + activation_stats["class_name"] = class_name + activations.append(activation_stats) params = [] flops = [] @@ -466,7 +495,6 @@ def module_stats( module.register_forward_hook(partial(module_stats_hook, name=name)) ) - inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] with set_module_mode_safe(model, training=False) as model: model(*inputs) @@ -481,29 +509,37 @@ def module_stats( total_param_dims, total_param_size, total_act_dims, - total_param_size, + total_act_size, ) = (0, 0, 0, 0, 0) - total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) - extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") - extra_info["total_param_size"] = sizeof_fmt(total_param_size) - if log_params: - print_param_stats(params) - - total_flops, flops = sum_op_stats(flops, bar_length_max) - extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") - if log_flops: - print_op_stats(flops) - - total_act_dims, total_act_size, activations = sum_activations_stats( - activations, bar_length_max - ) - extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") - extra_info["total_act_size"] = sizeof_fmt(total_act_size) - if log_activations: - print_activations_stats(activations) - - extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) + if cal_params: + total_param_dims, total_param_size, params = sum_param_stats( + params, bar_length_max + ) + extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") + extra_info["total_param_size"] = sizeof_fmt(total_param_size) + if logging_to_stdout: + print_param_stats(params) + + if cal_flops: + total_flops, flops = sum_op_stats(flops, bar_length_max) + extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") + if logging_to_stdout: + print_op_stats(flops) + + if cal_activations: + total_act_dims, total_act_size, activations = sum_activations_stats( + activations, bar_length_max + ) + extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") + extra_info["total_act_size"] = sizeof_fmt(total_act_size) + if logging_to_stdout: + print_activations_stats(activations, has_inputs) + + if cal_flops and cal_params: + extra_info["flops/param_size"] = "{:3.3f}".format( + total_flops / total_param_size + ) print_summary(**extra_info) diff --git a/imperative/python/test/unit/utils/test_module_stats.py b/imperative/python/test/unit/utils/test_module_stats.py index 2a5706890..2c73596d4 100644 --- a/imperative/python/test/unit/utils/test_module_stats.py +++ b/imperative/python/test/unit/utils/test_module_stats.py @@ -18,11 +18,15 @@ from megengine.utils.module_stats import module_stats def test_module_stats(): net = ResNet(BasicBlock, [2, 2, 2, 2]) input_shape = (1, 3, 224, 224) - total_stats, stats_details = module_stats(net, input_shape) - - x1 = mge.tensor(np.zeros((1, 3, 224, 224))) - gt_flops, gt_acts = net.get_stats(x1) + total_stats, stats_details = module_stats(net, input_shapes=input_shape) + x1 = np.random.random((1, 3, 224, 224)).astype("float32") + gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) + assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( + gt_flops, + gt_acts, + ) + total_stats, stats_details = module_stats(net, inputs=x1) assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( gt_flops, gt_acts, -- GitLab