From 2df847544db6f3fb544c5c57302a15ad429a519d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 31 Mar 2021 17:32:24 +0800 Subject: [PATCH] fix(mge/tools): fix node display bug in tensorboard GitOrigin-RevId: c997d6cccbfbdeaf2d24d6115650b1fee4bc0763 --- .../python/megengine/tools/network_visualize.py | 11 +++++------ imperative/python/megengine/utils/module_stats.py | 2 ++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index e3ef5e3e..c3bc8b4a 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -7,8 +7,8 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse -import json import logging +import re import numpy as np @@ -71,7 +71,10 @@ def visualize( graph = Network.load(model_path) def process_name(name): - return name.replace(".", "/").encode(encoding="utf-8") + # nodes that start with point or contain float const will lead to display bug + if not re.match(r"^[+-]?\d*\.\d*", name): + name = name.replace(".", "/") + return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] @@ -128,10 +131,6 @@ def visualize( param_stats["name"] = node.name params_list.append(param_stats) - # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug - if not len(node.name.split(".")) > 2 and not node in graph.input_vars: - continue - if log_path: node_list.append( NodeDef( diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 5e94308a..fa7813ad 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -230,6 +230,7 @@ def get_param_stats(param: np.ndarray): param_dim = np.prod(param.shape) param_size = param_dim * nbits // 8 return { + "dtype": param.dtype, "shape": shape, "mean": "{:.3g}".format(param.mean()), "std": "{:.3g}".format(param.std()), @@ -260,6 +261,7 @@ def print_params_stats(params, bar_length_max=20): header = [ "name", + "dtype", "shape", "mean", "std", -- GitLab