Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4cd4a38a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4cd4a38a
编写于
7月 05, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/tools): fix network_visualize for op without out shapes
GitOrigin-RevId: fdde52c214a78531d9939938af3170c564bdcf4e
上级
7badcb72
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
70 addition
and
58 deletion
+70
-58
imperative/python/megengine/core/tensor/dtype.py
imperative/python/megengine/core/tensor/dtype.py
+3
-0
imperative/python/megengine/tools/network_visualize.py
imperative/python/megengine/tools/network_visualize.py
+67
-58
未找到文件。
imperative/python/megengine/core/tensor/dtype.py
浏览文件 @
4cd4a38a
...
...
@@ -17,6 +17,9 @@ from .._imperative_rt.common import (
def
get_dtype_bit
(
dtype_name
:
str
):
special_cases
=
{
"bool"
:
1
}
if
dtype_name
in
special_cases
:
return
special_cases
[
dtype_name
]
numbers
=
re
.
findall
(
r
"\d+"
,
dtype_name
)
assert
len
(
numbers
)
==
1
,
"Unsupport dtype name with more than one number."
return
int
(
numbers
[
0
])
...
...
imperative/python/megengine/tools/network_visualize.py
浏览文件 @
4cd4a38a
...
...
@@ -129,6 +129,7 @@ def visualize(
)
stats_details
=
namedtuple
(
"module_stats"
,
[
"params"
,
"flops"
,
"activations"
])
disable_stats
=
False
for
node
in
tqdm
(
graph
.
all_oprs
):
if
hasattr
(
node
,
"output_idx"
):
node_oup
=
node
.
outputs
[
node
.
output_idx
]
...
...
@@ -145,7 +146,11 @@ def visualize(
if
log_path
:
# detail format see tensorboard/compat/proto/attr_value.proto
attr
=
{
"_output_shapes"
:
AttrValue
(
"params"
:
AttrValue
(
s
=
str
(
node
.
params
).
encode
(
encoding
=
"utf-8"
)),
"dtype"
:
AttrValue
(
s
=
str
(
node_oup
.
dtype
).
encode
(
encoding
=
"utf-8"
)),
}
if
node_oup
.
shape
:
attr
[
"_output_shapes"
]
=
AttrValue
(
list
=
AttrValue
.
ListValue
(
shape
=
[
TensorShapeProto
(
...
...
@@ -155,39 +160,42 @@ def visualize(
)
]
)
),
"params"
:
AttrValue
(
s
=
str
(
node
.
params
).
encode
(
encoding
=
"utf-8"
)),
"dtype"
:
AttrValue
(
s
=
str
(
node_oup
.
dtype
).
encode
(
encoding
=
"utf-8"
)),
}
)
else
:
disable_stats
=
True
logger
.
warning
(
f
"OpNode
{
node
.
name
}
do not has shape attr, would not calculate flops/params/activations for this net."
)
if
cal_flops
:
flops_stats
=
get_op_stats
(
node
,
node
.
inputs
,
node
.
outputs
)
if
flops_stats
is
not
None
:
# add op flops attr
if
log_path
and
hasattr
(
flops_stats
,
"flops_num"
):
attr
[
"flops"
]
=
AttrValue
(
s
=
sizeof_fmt
(
flops_stats
[
"flops"
]).
encode
(
encoding
=
"utf-8"
)
)
flops_stats
[
"name"
]
=
node
.
name
flops_stats
[
"class_name"
]
=
node
.
type
flops_list
.
append
(
flops_stats
)
if
not
disable_stats
:
if
cal_flops
:
flops_stats
=
get_op_stats
(
node
,
node
.
inputs
,
node
.
outputs
)
if
flops_stats
is
not
None
:
# add op flops attr
if
log_path
and
hasattr
(
flops_stats
,
"flops_num"
):
attr
[
"flops"
]
=
AttrValue
(
s
=
sizeof_fmt
(
flops_stats
[
"flops"
]).
encode
(
encoding
=
"utf-8"
)
)
flops_stats
[
"name"
]
=
node
.
name
flops_stats
[
"class_name"
]
=
node
.
type
flops_list
.
append
(
flops_stats
)
if
cal_activations
:
acts
=
get_activation_stats
(
node_oup
,
has_input
=
has_input
)
acts
[
"name"
]
=
node
.
name
acts
[
"class_name"
]
=
node
.
type
activations_list
.
append
(
acts
)
if
cal_activations
:
acts
=
get_activation_stats
(
node_oup
,
has_input
=
has_input
)
acts
[
"name"
]
=
node
.
name
acts
[
"class_name"
]
=
node
.
type
activations_list
.
append
(
acts
)
if
cal_params
:
if
node
.
type
==
"ImmutableTensor"
:
param_stats
=
get_param_stats
(
node_oup
)
# add tensor size attr
if
log_path
:
attr
[
"size"
]
=
AttrValue
(
s
=
sizeof_fmt
(
param_stats
[
"size"
]).
encode
(
encoding
=
"utf-8"
)
)
param_stats
[
"name"
]
=
node
.
name
params_list
.
append
(
param_stats
)
if
cal_params
:
if
node
.
type
==
"ImmutableTensor"
:
param_stats
=
get_param_stats
(
node_oup
)
# add tensor size attr
if
log_path
:
attr
[
"size"
]
=
AttrValue
(
s
=
sizeof_fmt
(
param_stats
[
"size"
]).
encode
(
encoding
=
"utf-8"
)
)
param_stats
[
"name"
]
=
node
.
name
params_list
.
append
(
param_stats
)
if
log_path
:
node_list
.
append
(
...
...
@@ -212,34 +220,37 @@ def visualize(
total_act_size
,
)
=
(
0
,
0
,
0
,
0
,
0
)
if
cal_params
:
total_param_dims
,
total_param_size
,
params_list
=
sum_param_stats
(
params_list
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
,
suffix
=
""
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
logging_to_stdout
:
print_param_stats
(
params_list
)
if
not
disable_stats
:
if
cal_params
:
total_param_dims
,
total_param_size
,
params_list
=
sum_param_stats
(
params_list
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
,
suffix
=
""
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
logging_to_stdout
:
print_param_stats
(
params_list
)
if
cal_flops
:
total_flops
,
flops_list
=
sum_op_stats
(
flops_list
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
logging_to_stdout
:
print_op_stats
(
flops_list
)
if
cal_flops
:
total_flops
,
flops_list
=
sum_op_stats
(
flops_list
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
logging_to_stdout
:
print_op_stats
(
flops_list
)
if
cal_activations
:
total_act_dims
,
total_act_size
,
activations_list
=
sum_activations_stats
(
activations_list
,
bar_length_max
)
extra_info
[
"total_act_dims"
]
=
sizeof_fmt
(
total_act_dims
,
suffix
=
""
)
extra_info
[
"total_act_size"
]
=
sizeof_fmt
(
total_act_size
)
if
logging_to_stdout
:
print_activations_stats
(
activations_list
,
has_input
=
has_input
)
if
cal_activations
:
total_act_dims
,
total_act_size
,
activations_list
=
sum_activations_stats
(
activations_list
,
bar_length_max
)
extra_info
[
"total_act_dims"
]
=
sizeof_fmt
(
total_act_dims
,
suffix
=
""
)
extra_info
[
"total_act_size"
]
=
sizeof_fmt
(
total_act_size
)
if
logging_to_stdout
:
print_activations_stats
(
activations_list
,
has_input
=
has_input
)
if
cal_flops
and
cal_params
:
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
if
cal_flops
and
cal_params
:
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
print_summary
(
**
extra_info
)
if
log_path
:
graph_def
=
GraphDef
(
node
=
node_list
,
versions
=
VersionDef
(
producer
=
22
))
...
...
@@ -251,8 +262,6 @@ def visualize(
writer
=
SummaryWriter
(
log_path
)
writer
.
_get_file_writer
().
add_graph
((
graph_def
,
stepstats
))
print_summary
(
**
extra_info
)
return
(
total_stats
(
param_size
=
total_param_size
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录