提交 4cd4a38a 编写于 作者: M Megvii Engine Team

fix(mge/tools): fix network_visualize for op without out shapes

GitOrigin-RevId: fdde52c214a78531d9939938af3170c564bdcf4e
上级 7badcb72
...@@ -17,6 +17,9 @@ from .._imperative_rt.common import ( ...@@ -17,6 +17,9 @@ from .._imperative_rt.common import (
def get_dtype_bit(dtype_name: str): def get_dtype_bit(dtype_name: str):
special_cases = {"bool": 1}
if dtype_name in special_cases:
return special_cases[dtype_name]
numbers = re.findall(r"\d+", dtype_name) numbers = re.findall(r"\d+", dtype_name)
assert len(numbers) == 1, "Unsupport dtype name with more than one number." assert len(numbers) == 1, "Unsupport dtype name with more than one number."
return int(numbers[0]) return int(numbers[0])
......
...@@ -129,6 +129,7 @@ def visualize( ...@@ -129,6 +129,7 @@ def visualize(
) )
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
disable_stats = False
for node in tqdm(graph.all_oprs): for node in tqdm(graph.all_oprs):
if hasattr(node, "output_idx"): if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx] node_oup = node.outputs[node.output_idx]
...@@ -145,7 +146,11 @@ def visualize( ...@@ -145,7 +146,11 @@ def visualize(
if log_path: if log_path:
# detail format see tensorboard/compat/proto/attr_value.proto # detail format see tensorboard/compat/proto/attr_value.proto
attr = { attr = {
"_output_shapes": AttrValue( "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
if node_oup.shape:
attr["_output_shapes"] = AttrValue(
list=AttrValue.ListValue( list=AttrValue.ListValue(
shape=[ shape=[
TensorShapeProto( TensorShapeProto(
...@@ -155,11 +160,14 @@ def visualize( ...@@ -155,11 +160,14 @@ def visualize(
) )
] ]
) )
), )
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), else:
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), disable_stats = True
} logger.warning(
f"OpNode {node.name} do not has shape attr, would not calculate flops/params/activations for this net."
)
if not disable_stats:
if cal_flops: if cal_flops:
flops_stats = get_op_stats(node, node.inputs, node.outputs) flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None: if flops_stats is not None:
...@@ -212,6 +220,7 @@ def visualize( ...@@ -212,6 +220,7 @@ def visualize(
total_act_size, total_act_size,
) = (0, 0, 0, 0, 0) ) = (0, 0, 0, 0, 0)
if not disable_stats:
if cal_params: if cal_params:
total_param_dims, total_param_size, params_list = sum_param_stats( total_param_dims, total_param_size, params_list = sum_param_stats(
params_list, bar_length_max params_list, bar_length_max
...@@ -241,6 +250,8 @@ def visualize( ...@@ -241,6 +250,8 @@ def visualize(
total_flops / total_param_size total_flops / total_param_size
) )
print_summary(**extra_info)
if log_path: if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
...@@ -251,8 +262,6 @@ def visualize( ...@@ -251,8 +262,6 @@ def visualize(
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats)) writer._get_file_writer().add_graph((graph_def, stepstats))
print_summary(**extra_info)
return ( return (
total_stats( total_stats(
param_size=total_param_size, param_size=total_param_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册