提交 c45f1eb2 编写于 作者: M Megvii Engine Team

fix(mge/tools): improve `module_visualize` result's robustness and beauty

GitOrigin-RevId: ef7b57377619fabcf50d4a48235b8d196659f1d4
上级 786c36ff
...@@ -7,11 +7,12 @@ ...@@ -7,11 +7,12 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse
import logging
import numpy as np import numpy as np
from megengine.core.tensor.dtype import is_quantize from megengine.core.tensor.dtype import is_quantize
from megengine.logger import get_logger from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import ( from megengine.utils.module_stats import (
print_flops_stats, print_flops_stats,
print_params_stats, print_params_stats,
...@@ -58,6 +59,8 @@ def visualize( ...@@ -58,6 +59,8 @@ def visualize(
"TensorBoard and TensorboardX are required for visualize.", exc_info=True "TensorBoard and TensorboardX are required for visualize.", exc_info=True
) )
return return
# FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR)
graph = Network.load(model_path) graph = Network.load(model_path)
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
...@@ -126,6 +129,9 @@ def visualize( ...@@ -126,6 +129,9 @@ def visualize(
std="{:.2g}".format(node.numpy().std()), std="{:.2g}".format(node.numpy().std()),
) )
) )
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
node_list.append( node_list.append(
NodeDef( NodeDef(
name=process_name(node.name), op=node.type, input=inp_list, attr=attr, name=process_name(node.name), op=node.type, input=inp_list, attr=attr,
...@@ -145,6 +151,10 @@ def visualize( ...@@ -145,6 +151,10 @@ def visualize(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
) )
writer._get_file_writer().add_graph((graph_def, stepstats)) writer._get_file_writer().add_graph((graph_def, stepstats))
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)
return total_params, total_flops return total_params, total_flops
......
...@@ -135,7 +135,9 @@ def print_flops_stats(flops, bar_length_max=20): ...@@ -135,7 +135,9 @@ def print_flops_stats(flops, bar_length_max=20):
] ]
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops
)
flops.append( flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册