提交 4cd4a38a 编写于 作者: M Megvii Engine Team

fix(mge/tools): fix network_visualize for op without out shapes

GitOrigin-RevId: fdde52c214a78531d9939938af3170c564bdcf4e
上级 7badcb72
...@@ -17,6 +17,9 @@ from .._imperative_rt.common import ( ...@@ -17,6 +17,9 @@ from .._imperative_rt.common import (
def get_dtype_bit(dtype_name: str): def get_dtype_bit(dtype_name: str):
special_cases = {"bool": 1}
if dtype_name in special_cases:
return special_cases[dtype_name]
numbers = re.findall(r"\d+", dtype_name) numbers = re.findall(r"\d+", dtype_name)
assert len(numbers) == 1, "Unsupport dtype name with more than one number." assert len(numbers) == 1, "Unsupport dtype name with more than one number."
return int(numbers[0]) return int(numbers[0])
......
...@@ -129,6 +129,7 @@ def visualize( ...@@ -129,6 +129,7 @@ def visualize(
) )
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
disable_stats = False
for node in tqdm(graph.all_oprs): for node in tqdm(graph.all_oprs):
if hasattr(node, "output_idx"): if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx] node_oup = node.outputs[node.output_idx]
...@@ -145,7 +146,11 @@ def visualize( ...@@ -145,7 +146,11 @@ def visualize(
if log_path: if log_path:
# detail format see tensorboard/compat/proto/attr_value.proto # detail format see tensorboard/compat/proto/attr_value.proto
attr = { attr = {
"_output_shapes": AttrValue( "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
if node_oup.shape:
attr["_output_shapes"] = AttrValue(
list=AttrValue.ListValue( list=AttrValue.ListValue(
shape=[ shape=[
TensorShapeProto( TensorShapeProto(
...@@ -155,39 +160,42 @@ def visualize( ...@@ -155,39 +160,42 @@ def visualize(
) )
] ]
) )
), )
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), else:
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), disable_stats = True
} logger.warning(
f"OpNode {node.name} do not has shape attr, would not calculate flops/params/activations for this net."
)
if cal_flops: if not disable_stats:
flops_stats = get_op_stats(node, node.inputs, node.outputs) if cal_flops:
if flops_stats is not None: flops_stats = get_op_stats(node, node.inputs, node.outputs)
# add op flops attr if flops_stats is not None:
if log_path and hasattr(flops_stats, "flops_num"): # add op flops attr
attr["flops"] = AttrValue( if log_path and hasattr(flops_stats, "flops_num"):
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") attr["flops"] = AttrValue(
) s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
flops_stats["name"] = node.name )
flops_stats["class_name"] = node.type flops_stats["name"] = node.name
flops_list.append(flops_stats) flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
if cal_activations: if cal_activations:
acts = get_activation_stats(node_oup, has_input=has_input) acts = get_activation_stats(node_oup, has_input=has_input)
acts["name"] = node.name acts["name"] = node.name
acts["class_name"] = node.type acts["class_name"] = node.type
activations_list.append(acts) activations_list.append(acts)
if cal_params: if cal_params:
if node.type == "ImmutableTensor": if node.type == "ImmutableTensor":
param_stats = get_param_stats(node_oup) param_stats = get_param_stats(node_oup)
# add tensor size attr # add tensor size attr
if log_path: if log_path:
attr["size"] = AttrValue( attr["size"] = AttrValue(
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
) )
param_stats["name"] = node.name param_stats["name"] = node.name
params_list.append(param_stats) params_list.append(param_stats)
if log_path: if log_path:
node_list.append( node_list.append(
...@@ -212,34 +220,37 @@ def visualize( ...@@ -212,34 +220,37 @@ def visualize(
total_act_size, total_act_size,
) = (0, 0, 0, 0, 0) ) = (0, 0, 0, 0, 0)
if cal_params: if not disable_stats:
total_param_dims, total_param_size, params_list = sum_param_stats( if cal_params:
params_list, bar_length_max total_param_dims, total_param_size, params_list = sum_param_stats(
) params_list, bar_length_max
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") )
extra_info["total_param_size"] = sizeof_fmt(total_param_size) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
if logging_to_stdout: extra_info["total_param_size"] = sizeof_fmt(total_param_size)
print_param_stats(params_list) if logging_to_stdout:
print_param_stats(params_list)
if cal_flops: if cal_flops:
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if logging_to_stdout: if logging_to_stdout:
print_op_stats(flops_list) print_op_stats(flops_list)
if cal_activations: if cal_activations:
total_act_dims, total_act_size, activations_list = sum_activations_stats( total_act_dims, total_act_size, activations_list = sum_activations_stats(
activations_list, bar_length_max activations_list, bar_length_max
) )
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
extra_info["total_act_size"] = sizeof_fmt(total_act_size) extra_info["total_act_size"] = sizeof_fmt(total_act_size)
if logging_to_stdout: if logging_to_stdout:
print_activations_stats(activations_list, has_input=has_input) print_activations_stats(activations_list, has_input=has_input)
if cal_flops and cal_params: if cal_flops and cal_params:
extra_info["flops/param_size"] = "{:3.3f}".format( extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size total_flops / total_param_size
) )
print_summary(**extra_info)
if log_path: if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
...@@ -251,8 +262,6 @@ def visualize( ...@@ -251,8 +262,6 @@ def visualize(
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats)) writer._get_file_writer().add_graph((graph_def, stepstats))
print_summary(**extra_info)
return ( return (
total_stats( total_stats(
param_size=total_param_size, param_size=total_param_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册