提交 dac2b9e7 编写于 作者: M Megvii Engine Team

feat(mge/tools): optimize statistical tools

BREAKING CHANGE:

GitOrigin-RevId: cd2a1acd1128e482f2cb0963e4e7e1fe0cc47b14
上级 7b68bf77
...@@ -12,6 +12,7 @@ import re ...@@ -12,6 +12,7 @@ import re
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from tqdm import tqdm
from megengine.core.tensor.dtype import is_quantize from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
...@@ -37,10 +38,13 @@ logger = get_logger(__name__) ...@@ -37,10 +38,13 @@ logger = get_logger(__name__)
def visualize( def visualize(
model_path: str, model_path: str,
log_path: str, log_path: str,
input: np.ndarray = None,
inp_dict: dict = None,
cal_params: bool = True,
cal_flops: bool = True,
cal_activations: bool = True,
logging_to_stdout: bool = True,
bar_length_max: int = 20, bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
log_activations: bool = True,
): ):
r""" r"""
Load megengine dumped model and visualize graph structure with tensorboard log files. Load megengine dumped model and visualize graph structure with tensorboard log files.
...@@ -48,10 +52,14 @@ def visualize( ...@@ -48,10 +52,14 @@ def visualize(
:param model_path: dir path for megengine dumped model. :param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log. :param log_path: dir path for tensorboard graph log.
:param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input.
:param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used.
:param cal_params: whether calculate and record params size.
:param cal_flops: whether calculate and record op flops.
:param cal_activations: whether calculate and record op activations.
:param logging_to_stdout: whether print all calculated statistic details.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
""" """
if log_path: if log_path:
try: try:
...@@ -78,6 +86,27 @@ def visualize( ...@@ -78,6 +86,27 @@ def visualize(
enable_receptive_field() enable_receptive_field()
graph = Network.load(model_path) graph = Network.load(model_path)
graph.reset_batch_size(1)
has_input = False
if input is not None or inp_dict is not None:
has_input = True
repl_dict = {}
inp_vars = graph.input_vars
if inp_dict is not None:
assert len(inp_dict) == len(
inp_vars
), "Inputs are not sufficient for calculation."
for v in inp_vars:
new_input = graph.make_const(inp_dict[v.name], name=v.name)
repl_dict[v] = new_input
else:
assert len(inp_vars) == 1, "The graph needs more than one input."
inp_var = inp_vars[0]
repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
graph.replace_vars(repl_dict=repl_dict)
graph._compile()
def process_name(name): def process_name(name):
# nodes that start with point or contain float const will lead to display bug # nodes that start with point or contain float const will lead to display bug
...@@ -93,7 +122,7 @@ def visualize( ...@@ -93,7 +122,7 @@ def visualize(
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
for node in graph.all_oprs: for node in tqdm(graph.all_oprs):
if hasattr(node, "output_idx"): if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx] node_oup = node.outputs[node.output_idx]
else: else:
...@@ -123,31 +152,35 @@ def visualize( ...@@ -123,31 +152,35 @@ def visualize(
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
} }
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
acts = get_activation_stats(node_oup) if cal_flops:
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
if cal_activations:
acts = get_activation_stats(node_oup.numpy(), has_input=has_input)
acts["name"] = node.name acts["name"] = node.name
acts["class_name"] = node.type acts["class_name"] = node.type
activations_list.append(acts) activations_list.append(acts)
if node.type == "ImmutableTensor": if cal_params:
param_stats = get_param_stats(node_oup) if node.type == "ImmutableTensor":
# add tensor size attr param_stats = get_param_stats(node.numpy())
if log_path: # add tensor size attr
attr["size"] = AttrValue( if log_path:
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") attr["size"] = AttrValue(
) s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
param_stats["name"] = node.name )
params_list.append(param_stats) param_stats["name"] = node.name
params_list.append(param_stats)
if log_path: if log_path:
node_list.append( node_list.append(
...@@ -169,31 +202,37 @@ def visualize( ...@@ -169,31 +202,37 @@ def visualize(
total_param_dims, total_param_dims,
total_param_size, total_param_size,
total_act_dims, total_act_dims,
total_param_size, total_act_size,
) = (0, 0, 0, 0, 0) ) = (0, 0, 0, 0, 0)
total_param_dims, total_param_size, params = sum_param_stats( if cal_params:
params_list, bar_length_max total_param_dims, total_param_size, params_list = sum_param_stats(
) params_list, bar_length_max
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") )
extra_info["total_param_size"] = sizeof_fmt(total_param_size) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
if log_params: extra_info["total_param_size"] = sizeof_fmt(total_param_size)
print_param_stats(params) if logging_to_stdout:
print_param_stats(params_list)
total_flops, flops = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_flops:
print_op_stats(flops)
total_act_dims, total_act_size, activations = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if log_activations:
print_activations_stats(activations)
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) if cal_flops:
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if logging_to_stdout:
print_op_stats(flops_list)
if cal_activations:
total_act_dims, total_act_size, activations_list = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout:
print_activations_stats(activations_list, has_input=has_input)
if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
if log_path: if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
...@@ -211,7 +250,9 @@ def visualize( ...@@ -211,7 +250,9 @@ def visualize(
total_stats( total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size, param_size=total_param_size, flops=total_flops, act_size=total_act_size,
), ),
stats_details(params=params, flops=flops, activations=activations), stats_details(
params=params_list, flops=flops_list, activations=activations_list
),
) )
...@@ -229,12 +270,24 @@ def main(): ...@@ -229,12 +270,24 @@ def main():
help="size of bar indicating max flops or parameter size in net stats.", help="size of bar indicating max flops or parameter size in net stats.",
) )
parser.add_argument( parser.add_argument(
"--log_params", "--cal_params",
action="store_true",
help="whether calculate and record params size.",
)
parser.add_argument(
"--cal_flops",
action="store_true",
help="whether calculate and record op flops.",
)
parser.add_argument(
"--cal_activations",
action="store_true", action="store_true",
help="whether print and record params size.", help="whether calculate and record op activations.",
) )
parser.add_argument( parser.add_argument(
"--log_flops", action="store_true", help="whether print and record op flops.", "--logging_to_stdout",
action="store_true",
help="whether print all calculated statistic details.",
) )
parser.add_argument( parser.add_argument(
"--all", "--all",
...@@ -243,8 +296,10 @@ def main(): ...@@ -243,8 +296,10 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
if args.all: if args.all:
args.log_params = True args.cal_params = True
args.log_flops = True args.cal_flops = True
args.cal_activations = True
args.logging_to_stdout = True
if not args.log_path: if not args.log_path:
args.log_path = "./log" args.log_path = "./log"
kwargs = vars(args) kwargs = vars(args)
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import namedtuple from collections import Iterable, namedtuple
from functools import partial from functools import partial
from typing import Iterable
import numpy as np import numpy as np
import tabulate import tabulate
...@@ -19,6 +20,7 @@ from megengine import Tensor ...@@ -19,6 +20,7 @@ from megengine import Tensor
from megengine import functional as F from megengine import functional as F
from megengine.core.tensor.dtype import get_dtype_bit from megengine.core.tensor.dtype import get_dtype_bit
from megengine.functional.tensor import zeros from megengine.functional.tensor import zeros
from megengine.tensor import Tensor
from .module_utils import set_module_mode_safe from .module_utils import set_module_mode_safe
...@@ -335,21 +337,23 @@ def print_param_stats(params): ...@@ -335,21 +337,23 @@ def print_param_stats(params):
) )
def get_activation_stats(output: Tensor): def get_activation_stats(output: np.ndarray, has_input=False):
out_shape = output.shape out_shape = output.shape
activations_dtype = np.dtype(output.dtype) activations_dtype = np.dtype(output.dtype)
nbits = get_dtype_bit(activations_dtype.name) nbits = get_dtype_bit(activations_dtype.name)
act_dim = np.prod(out_shape) act_dim = np.prod(out_shape)
act_size = act_dim * nbits // 8 act_size = act_dim * nbits // 8
return { activation_stats = {
"dtype": activations_dtype, "dtype": activations_dtype,
"shape": out_shape, "shape": out_shape,
"act_dim": act_dim, "act_dim": act_dim,
"mean": "{:.3g}".format(_mean(output)),
"std": "{:.3g}".format(_std(output)),
"nbits": nbits, "nbits": nbits,
"size": act_size, "size": act_size,
} }
if has_input:
activation_stats["mean"] = "{:.3g}".format(output.mean())
activation_stats["std"] = "{:.3g}".format(output.std())
return activation_stats
def sum_activations_stats(activations, bar_length_max=20): def sum_activations_stats(activations, bar_length_max=20):
...@@ -373,14 +377,12 @@ def sum_activations_stats(activations, bar_length_max=20): ...@@ -373,14 +377,12 @@ def sum_activations_stats(activations, bar_length_max=20):
return total_act_dims, total_act_size, activations return total_act_dims, total_act_size, activations
def print_activations_stats(activations): def print_activations_stats(activations, has_input=False):
header = [ header = [
"name", "name",
"class_name", "class_name",
"dtype", "dtype",
"shape", "shape",
"mean",
"std",
"nbits", "nbits",
"act_dim", "act_dim",
"size", "size",
...@@ -388,6 +390,9 @@ def print_activations_stats(activations): ...@@ -388,6 +390,9 @@ def print_activations_stats(activations):
"percentage", "percentage",
"size_bar", "size_bar",
] ]
if has_input:
header.insert(4, "mean")
header.insert(5, "std")
logger.info( logger.info(
"activations stats: \n" "activations stats: \n"
+ tabulate.tabulate(dict2table(activations, header=header)) + tabulate.tabulate(dict2table(activations, header=header))
...@@ -402,56 +407,80 @@ def print_summary(**kwargs): ...@@ -402,56 +407,80 @@ def print_summary(**kwargs):
def module_stats( def module_stats(
model: m.Module, model: m.Module,
input_shapes: list, inputs: Iterable[np.ndarray] = None,
input_shapes: list = None,
cal_params: bool = True,
cal_flops: bool = True,
cal_activations: bool = True,
logging_to_stdout: bool = True,
bar_length_max: int = 20, bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
log_activations: bool = True,
): ):
r""" r"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
:param model: model that need to get stats info. :param model: model that need to get stats info.
:param input_shapes: shapes of inputs for running model and calculating stats. :param inputs: user defined input data for running model and calculating stats, alternative with input_shapes.
:param input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs.
:param cal_params: whether calculate and record params size.
:param cal_flops: whether calculate and record op flops.
:param cal_activations: whether calculate and record op activations.
:param logging_to_stdout: whether print all calculated statistic details.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
""" """
has_inputs = False
if inputs is not None:
has_inputs = True
if not isinstance(inputs, (tuple, list)):
inputs = [inputs]
inputs = [Tensor(input, dtype=np.float32) for input in inputs]
else:
if input_shapes:
if not isinstance(input_shapes[0], tuple):
input_shapes = [input_shapes]
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes]
else:
logger.error(
"Inputs or input_shapes is required for running model and calculating stats.",
exc_info=True,
)
return
if not cal_activations:
log_activations = False
disable_receptive_field() disable_receptive_field()
def module_stats_hook(module, inputs, outputs, name=""): def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_stats = get_op_stats(module, inputs, outputs) if cal_flops:
if flops_stats is not None: flops_stats = get_op_stats(module, inputs, outputs)
flops_stats["name"] = name if flops_stats is not None:
flops_stats["class_name"] = class_name flops_stats["name"] = name
flops.append(flops_stats) flops_stats["class_name"] = class_name
flops.append(flops_stats)
if hasattr(module, "weight") and module.weight is not None:
w = module.weight if cal_params:
param_stats = get_param_stats(w) if hasattr(module, "weight") and module.weight is not None:
param_stats["name"] = name + "-w" w = module.weight
params.append(param_stats) param_stats = get_param_stats(w.numpy())
param_stats["name"] = name + "-w"
if hasattr(module, "bias") and module.bias is not None: params.append(param_stats)
b = module.bias
param_stats = get_param_stats(b) if hasattr(module, "bias") and module.bias is not None:
param_stats["name"] = name + "-b" b = module.bias
params.append(param_stats) param_stats = get_param_stats(b.numpy())
param_stats["name"] = name + "-b"
if not isinstance(outputs, tuple) or not isinstance(outputs, list): params.append(param_stats)
output = outputs
else: if cal_activations:
output = outputs[0] if not isinstance(outputs, (tuple, list)):
activation_stats = get_activation_stats(output) output = outputs.numpy()
activation_stats["name"] = name else:
activation_stats["class_name"] = class_name output = outputs[0].numpy()
activations.append(activation_stats) activation_stats = get_activation_stats(output, has_inputs)
activation_stats["name"] = name
# multiple inputs to the network activation_stats["class_name"] = class_name
if not isinstance(input_shapes[0], tuple): activations.append(activation_stats)
input_shapes = [input_shapes]
params = [] params = []
flops = [] flops = []
...@@ -466,7 +495,6 @@ def module_stats( ...@@ -466,7 +495,6 @@ def module_stats(
module.register_forward_hook(partial(module_stats_hook, name=name)) module.register_forward_hook(partial(module_stats_hook, name=name))
) )
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes]
with set_module_mode_safe(model, training=False) as model: with set_module_mode_safe(model, training=False) as model:
model(*inputs) model(*inputs)
...@@ -481,29 +509,37 @@ def module_stats( ...@@ -481,29 +509,37 @@ def module_stats(
total_param_dims, total_param_dims,
total_param_size, total_param_size,
total_act_dims, total_act_dims,
total_param_size, total_act_size,
) = (0, 0, 0, 0, 0) ) = (0, 0, 0, 0, 0)
total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) if cal_params:
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") total_param_dims, total_param_size, params = sum_param_stats(
extra_info["total_param_size"] = sizeof_fmt(total_param_size) params, bar_length_max
if log_params: )
print_param_stats(params) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
total_flops, flops = sum_op_stats(flops, bar_length_max) if logging_to_stdout:
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") print_param_stats(params)
if log_flops:
print_op_stats(flops) if cal_flops:
total_flops, flops = sum_op_stats(flops, bar_length_max)
total_act_dims, total_act_size, activations = sum_activations_stats( extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
activations, bar_length_max if logging_to_stdout:
) print_op_stats(flops)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size) if cal_activations:
if log_activations: total_act_dims, total_act_size, activations = sum_activations_stats(
print_activations_stats(activations) activations, bar_length_max
)
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout:
print_activations_stats(activations, has_inputs)
if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
print_summary(**extra_info) print_summary(**extra_info)
......
...@@ -18,11 +18,15 @@ from megengine.utils.module_stats import module_stats ...@@ -18,11 +18,15 @@ from megengine.utils.module_stats import module_stats
def test_module_stats(): def test_module_stats():
net = ResNet(BasicBlock, [2, 2, 2, 2]) net = ResNet(BasicBlock, [2, 2, 2, 2])
input_shape = (1, 3, 224, 224) input_shape = (1, 3, 224, 224)
total_stats, stats_details = module_stats(net, input_shape) total_stats, stats_details = module_stats(net, input_shapes=input_shape)
x1 = np.random.random((1, 3, 224, 224)).astype("float32")
x1 = mge.tensor(np.zeros((1, 3, 224, 224))) gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
gt_flops, gt_acts = net.get_stats(x1) assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
gt_flops,
gt_acts,
)
total_stats, stats_details = module_stats(net, inputs=x1)
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
gt_flops, gt_flops,
gt_acts, gt_acts,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册