提交 4cd4a38a 编写于 作者: M Megvii Engine Team

fix(mge/tools): fix network_visualize for op without out shapes

GitOrigin-RevId: fdde52c214a78531d9939938af3170c564bdcf4e
上级 7badcb72
......@@ -17,6 +17,9 @@ from .._imperative_rt.common import (
def get_dtype_bit(dtype_name: str):
special_cases = {"bool": 1}
if dtype_name in special_cases:
return special_cases[dtype_name]
numbers = re.findall(r"\d+", dtype_name)
assert len(numbers) == 1, "Unsupport dtype name with more than one number."
return int(numbers[0])
......
......@@ -129,6 +129,7 @@ def visualize(
)
stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
disable_stats = False
for node in tqdm(graph.all_oprs):
if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx]
......@@ -145,7 +146,11 @@ def visualize(
if log_path:
# detail format see tensorboard/compat/proto/attr_value.proto
attr = {
"_output_shapes": AttrValue(
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
if node_oup.shape:
attr["_output_shapes"] = AttrValue(
list=AttrValue.ListValue(
shape=[
TensorShapeProto(
......@@ -155,39 +160,42 @@ def visualize(
)
]
)
),
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
)
else:
disable_stats = True
logger.warning(
f"OpNode {node.name} do not has shape attr, would not calculate flops/params/activations for this net."
)
if cal_flops:
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
if not disable_stats:
if cal_flops:
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
if cal_activations:
acts = get_activation_stats(node_oup, has_input=has_input)
acts["name"] = node.name
acts["class_name"] = node.type
activations_list.append(acts)
if cal_activations:
acts = get_activation_stats(node_oup, has_input=has_input)
acts["name"] = node.name
acts["class_name"] = node.type
activations_list.append(acts)
if cal_params:
if node.type == "ImmutableTensor":
param_stats = get_param_stats(node_oup)
# add tensor size attr
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
)
param_stats["name"] = node.name
params_list.append(param_stats)
if cal_params:
if node.type == "ImmutableTensor":
param_stats = get_param_stats(node_oup)
# add tensor size attr
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
)
param_stats["name"] = node.name
params_list.append(param_stats)
if log_path:
node_list.append(
......@@ -212,34 +220,37 @@ def visualize(
total_act_size,
) = (0, 0, 0, 0, 0)
if cal_params:
total_param_dims, total_param_size, params_list = sum_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if logging_to_stdout:
print_param_stats(params_list)
if not disable_stats:
if cal_params:
total_param_dims, total_param_size, params_list = sum_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if logging_to_stdout:
print_param_stats(params_list)
if cal_flops:
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if logging_to_stdout:
print_op_stats(flops_list)
if cal_flops:
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if logging_to_stdout:
print_op_stats(flops_list)
if cal_activations:
total_act_dims, total_act_size, activations_list = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout:
print_activations_stats(activations_list, has_input=has_input)
if cal_activations:
total_act_dims, total_act_size, activations_list = sum_activations_stats(
activations_list, bar_length_max
)
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout:
print_activations_stats(activations_list, has_input=has_input)
if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
print_summary(**extra_info)
if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
......@@ -251,8 +262,6 @@ def visualize(
writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats))
print_summary(**extra_info)
return (
total_stats(
param_size=total_param_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册