提交 2df84754 编写于 作者: M Megvii Engine Team

fix(mge/tools): fix node display bug in tensorboard

GitOrigin-RevId: c997d6cccbfbdeaf2d24d6115650b1fee4bc0763
上级 13481fd2
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse
import json
import logging import logging
import re
import numpy as np import numpy as np
...@@ -71,7 +71,10 @@ def visualize( ...@@ -71,7 +71,10 @@ def visualize(
graph = Network.load(model_path) graph = Network.load(model_path)
def process_name(name): def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8") # nodes that start with point or contain float const will lead to display bug
if not re.match(r"^[+-]?\d*\.\d*", name):
name = name.replace(".", "/")
return name.encode(encoding="utf-8")
summary = [["item", "value"]] summary = [["item", "value"]]
node_list = [] node_list = []
...@@ -128,10 +131,6 @@ def visualize( ...@@ -128,10 +131,6 @@ def visualize(
param_stats["name"] = node.name param_stats["name"] = node.name
params_list.append(param_stats) params_list.append(param_stats)
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
if log_path: if log_path:
node_list.append( node_list.append(
NodeDef( NodeDef(
......
...@@ -230,6 +230,7 @@ def get_param_stats(param: np.ndarray): ...@@ -230,6 +230,7 @@ def get_param_stats(param: np.ndarray):
param_dim = np.prod(param.shape) param_dim = np.prod(param.shape)
param_size = param_dim * nbits // 8 param_size = param_dim * nbits // 8
return { return {
"dtype": param.dtype,
"shape": shape, "shape": shape,
"mean": "{:.3g}".format(param.mean()), "mean": "{:.3g}".format(param.mean()),
"std": "{:.3g}".format(param.std()), "std": "{:.3g}".format(param.std()),
...@@ -260,6 +261,7 @@ def print_params_stats(params, bar_length_max=20): ...@@ -260,6 +261,7 @@ def print_params_stats(params, bar_length_max=20):
header = [ header = [
"name", "name",
"dtype",
"shape", "shape",
"mean", "mean",
"std", "std",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册