提交 dac2b9e7 编写于 作者: M Megvii Engine Team

feat(mge/tools): optimize statistical tools

BREAKING CHANGE:

GitOrigin-RevId: cd2a1acd1128e482f2cb0963e4e7e1fe0cc47b14
上级 7b68bf77
......@@ -12,6 +12,7 @@ import re
from collections import namedtuple
import numpy as np
from tqdm import tqdm
from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
......@@ -37,10 +38,13 @@ logger = get_logger(__name__)
def visualize(
model_path: str,
log_path: str,
input: np.ndarray = None,
inp_dict: dict = None,
cal_params: bool = True,
cal_flops: bool = True,
cal_activations: bool = True,
logging_to_stdout: bool = True,
bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
log_activations: bool = True,
):
r"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
......@@ -48,10 +52,14 @@ def visualize(
:param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log.
:param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input.
:param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used.
:param cal_params: whether calculate and record params size.
:param cal_flops: whether calculate and record op flops.
:param cal_activations: whether calculate and record op activations.
:param logging_to_stdout: whether print all calculated statistic details.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
"""
if log_path:
try:
......@@ -78,6 +86,27 @@ def visualize(
enable_receptive_field()
graph = Network.load(model_path)
graph.reset_batch_size(1)
has_input = False
if input is not None or inp_dict is not None:
has_input = True
repl_dict = {}
inp_vars = graph.input_vars
if inp_dict is not None:
assert len(inp_dict) == len(
inp_vars
), "Inputs are not sufficient for calculation."
for v in inp_vars:
new_input = graph.make_const(inp_dict[v.name], name=v.name)
repl_dict[v] = new_input
else:
assert len(inp_vars) == 1, "The graph needs more than one input."
inp_var = inp_vars[0]
repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
graph.replace_vars(repl_dict=repl_dict)
graph._compile()
def process_name(name):
# nodes that start with point or contain float const will lead to display bug
......@@ -93,7 +122,7 @@ def visualize(
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
for node in graph.all_oprs:
for node in tqdm(graph.all_oprs):
if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx]
else:
......@@ -123,31 +152,35 @@ def visualize(
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
acts = get_activation_stats(node_oup)
if cal_flops:
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
if cal_activations:
acts = get_activation_stats(node_oup.numpy(), has_input=has_input)
acts["name"] = node.name
acts["class_name"] = node.type
activations_list.append(acts)
if node.type == "ImmutableTensor":
param_stats = get_param_stats(node_oup)
# add tensor size attr
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
)
param_stats["name"] = node.name
params_list.append(param_stats)
if cal_params:
if node.type == "ImmutableTensor":
param_stats = get_param_stats(node.numpy())
# add tensor size attr
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
)
param_stats["name"] = node.name
params_list.append(param_stats)
if log_path:
node_list.append(
......@@ -169,31 +202,37 @@ def visualize(
total_param_dims,
total_param_size,
total_act_dims,
total_param_size,
total_act_size,
) = (0, 0, 0, 0, 0)
total_param_dims, total_param_size, params = sum_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_params:
print_param_stats(params)
total_flops, flops = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_flops:
print_op_stats(flops)
total_act_dims, total_act_size, activations = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if log_activations:
print_activations_stats(activations)
if cal_params:
total_param_dims, total_param_size, params_list = sum_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if logging_to_stdout:
print_param_stats(params_list)
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size)
if cal_flops:
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if logging_to_stdout:
print_op_stats(flops_list)
if cal_activations:
total_act_dims, total_act_size, activations_list = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout:
print_activations_stats(activations_list, has_input=has_input)
if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
......@@ -211,7 +250,9 @@ def visualize(
total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size,
),
stats_details(params=params, flops=flops, activations=activations),
stats_details(
params=params_list, flops=flops_list, activations=activations_list
),
)
......@@ -229,12 +270,24 @@ def main():
help="size of bar indicating max flops or parameter size in net stats.",
)
parser.add_argument(
"--log_params",
"--cal_params",
action="store_true",
help="whether calculate and record params size.",
)
parser.add_argument(
"--cal_flops",
action="store_true",
help="whether calculate and record op flops.",
)
parser.add_argument(
"--cal_activations",
action="store_true",
help="whether print and record params size.",
help="whether calculate and record op activations.",
)
parser.add_argument(
"--log_flops", action="store_true", help="whether print and record op flops.",
"--logging_to_stdout",
action="store_true",
help="whether print all calculated statistic details.",
)
parser.add_argument(
"--all",
......@@ -243,8 +296,10 @@ def main():
)
args = parser.parse_args()
if args.all:
args.log_params = True
args.log_flops = True
args.cal_params = True
args.cal_flops = True
args.cal_activations = True
args.logging_to_stdout = True
if not args.log_path:
args.log_path = "./log"
kwargs = vars(args)
......
......@@ -5,8 +5,9 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import namedtuple
from collections import Iterable, namedtuple
from functools import partial
from typing import Iterable
import numpy as np
import tabulate
......@@ -19,6 +20,7 @@ from megengine import Tensor
from megengine import functional as F
from megengine.core.tensor.dtype import get_dtype_bit
from megengine.functional.tensor import zeros
from megengine.tensor import Tensor
from .module_utils import set_module_mode_safe
......@@ -335,21 +337,23 @@ def print_param_stats(params):
)
def get_activation_stats(output: Tensor):
def get_activation_stats(output: np.ndarray, has_input=False):
out_shape = output.shape
activations_dtype = np.dtype(output.dtype)
nbits = get_dtype_bit(activations_dtype.name)
act_dim = np.prod(out_shape)
act_size = act_dim * nbits // 8
return {
activation_stats = {
"dtype": activations_dtype,
"shape": out_shape,
"act_dim": act_dim,
"mean": "{:.3g}".format(_mean(output)),
"std": "{:.3g}".format(_std(output)),
"nbits": nbits,
"size": act_size,
}
if has_input:
activation_stats["mean"] = "{:.3g}".format(output.mean())
activation_stats["std"] = "{:.3g}".format(output.std())
return activation_stats
def sum_activations_stats(activations, bar_length_max=20):
......@@ -373,14 +377,12 @@ def sum_activations_stats(activations, bar_length_max=20):
return total_act_dims, total_act_size, activations
def print_activations_stats(activations):
def print_activations_stats(activations, has_input=False):
header = [
"name",
"class_name",
"dtype",
"shape",
"mean",
"std",
"nbits",
"act_dim",
"size",
......@@ -388,6 +390,9 @@ def print_activations_stats(activations):
"percentage",
"size_bar",
]
if has_input:
header.insert(4, "mean")
header.insert(5, "std")
logger.info(
"activations stats: \n"
+ tabulate.tabulate(dict2table(activations, header=header))
......@@ -402,56 +407,80 @@ def print_summary(**kwargs):
def module_stats(
model: m.Module,
input_shapes: list,
inputs: Iterable[np.ndarray] = None,
input_shapes: list = None,
cal_params: bool = True,
cal_flops: bool = True,
cal_activations: bool = True,
logging_to_stdout: bool = True,
bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
log_activations: bool = True,
):
r"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
:param model: model that need to get stats info.
:param input_shapes: shapes of inputs for running model and calculating stats.
:param inputs: user defined input data for running model and calculating stats, alternative with input_shapes.
:param input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs.
:param cal_params: whether calculate and record params size.
:param cal_flops: whether calculate and record op flops.
:param cal_activations: whether calculate and record op activations.
:param logging_to_stdout: whether print all calculated statistic details.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
"""
has_inputs = False
if inputs is not None:
has_inputs = True
if not isinstance(inputs, (tuple, list)):
inputs = [inputs]
inputs = [Tensor(input, dtype=np.float32) for input in inputs]
else:
if input_shapes:
if not isinstance(input_shapes[0], tuple):
input_shapes = [input_shapes]
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes]
else:
logger.error(
"Inputs or input_shapes is required for running model and calculating stats.",
exc_info=True,
)
return
if not cal_activations:
log_activations = False
disable_receptive_field()
def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_stats = get_op_stats(module, inputs, outputs)
if flops_stats is not None:
flops_stats["name"] = name
flops_stats["class_name"] = class_name
flops.append(flops_stats)
if hasattr(module, "weight") and module.weight is not None:
w = module.weight
param_stats = get_param_stats(w)
param_stats["name"] = name + "-w"
params.append(param_stats)
if hasattr(module, "bias") and module.bias is not None:
b = module.bias
param_stats = get_param_stats(b)
param_stats["name"] = name + "-b"
params.append(param_stats)
if not isinstance(outputs, tuple) or not isinstance(outputs, list):
output = outputs
else:
output = outputs[0]
activation_stats = get_activation_stats(output)
activation_stats["name"] = name
activation_stats["class_name"] = class_name
activations.append(activation_stats)
# multiple inputs to the network
if not isinstance(input_shapes[0], tuple):
input_shapes = [input_shapes]
if cal_flops:
flops_stats = get_op_stats(module, inputs, outputs)
if flops_stats is not None:
flops_stats["name"] = name
flops_stats["class_name"] = class_name
flops.append(flops_stats)
if cal_params:
if hasattr(module, "weight") and module.weight is not None:
w = module.weight
param_stats = get_param_stats(w.numpy())
param_stats["name"] = name + "-w"
params.append(param_stats)
if hasattr(module, "bias") and module.bias is not None:
b = module.bias
param_stats = get_param_stats(b.numpy())
param_stats["name"] = name + "-b"
params.append(param_stats)
if cal_activations:
if not isinstance(outputs, (tuple, list)):
output = outputs.numpy()
else:
output = outputs[0].numpy()
activation_stats = get_activation_stats(output, has_inputs)
activation_stats["name"] = name
activation_stats["class_name"] = class_name
activations.append(activation_stats)
params = []
flops = []
......@@ -466,7 +495,6 @@ def module_stats(
module.register_forward_hook(partial(module_stats_hook, name=name))
)
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes]
with set_module_mode_safe(model, training=False) as model:
model(*inputs)
......@@ -481,29 +509,37 @@ def module_stats(
total_param_dims,
total_param_size,
total_act_dims,
total_param_size,
total_act_size,
) = (0, 0, 0, 0, 0)
total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_params:
print_param_stats(params)
total_flops, flops = sum_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_flops:
print_op_stats(flops)
total_act_dims, total_act_size, activations = sum_activations_stats(
activations, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if log_activations:
print_activations_stats(activations)
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size)
if cal_params:
total_param_dims, total_param_size, params = sum_param_stats(
params, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if logging_to_stdout:
print_param_stats(params)
if cal_flops:
total_flops, flops = sum_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if logging_to_stdout:
print_op_stats(flops)
if cal_activations:
total_act_dims, total_act_size, activations = sum_activations_stats(
activations, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout:
print_activations_stats(activations, has_inputs)
if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
print_summary(**extra_info)
......
......@@ -18,11 +18,15 @@ from megengine.utils.module_stats import module_stats
def test_module_stats():
net = ResNet(BasicBlock, [2, 2, 2, 2])
input_shape = (1, 3, 224, 224)
total_stats, stats_details = module_stats(net, input_shape)
x1 = mge.tensor(np.zeros((1, 3, 224, 224)))
gt_flops, gt_acts = net.get_stats(x1)
total_stats, stats_details = module_stats(net, input_shapes=input_shape)
x1 = np.random.random((1, 3, 224, 224)).astype("float32")
gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
gt_flops,
gt_acts,
)
total_stats, stats_details = module_stats(net, inputs=x1)
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
gt_flops,
gt_acts,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册