diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index e3ef5e3ed2b60f1f742a929f0a95d69519b77e89..c3bc8b4ae6505140f6abd771a42cd140a3f518b7 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -7,8 +7,8 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse -import json import logging +import re import numpy as np @@ -71,7 +71,10 @@ def visualize( graph = Network.load(model_path) def process_name(name): - return name.replace(".", "/").encode(encoding="utf-8") + # nodes that start with point or contain float const will lead to display bug + if not re.match(r"^[+-]?\d*\.\d*", name): + name = name.replace(".", "/") + return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] @@ -128,10 +131,6 @@ def visualize( param_stats["name"] = node.name params_list.append(param_stats) - # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug - if not len(node.name.split(".")) > 2 and not node in graph.input_vars: - continue - if log_path: node_list.append( NodeDef( diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 5e94308a0b26e899c1db4252d3b7d7fac2ba78e9..fa7813adc6d3a5576018d4307b042b21940ab44d 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -230,6 +230,7 @@ def get_param_stats(param: np.ndarray): param_dim = np.prod(param.shape) param_size = param_dim * nbits // 8 return { + "dtype": param.dtype, "shape": shape, "mean": "{:.3g}".format(param.mean()), "std": "{:.3g}".format(param.std()), @@ -260,6 +261,7 @@ def print_params_stats(params, bar_length_max=20): header = [ "name", + "dtype", "shape", "mean", "std",