Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6bb9a255
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6bb9a255
编写于
4月 08, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/tools): fix module stats' receptive field bug for Module
GitOrigin-RevId: b4713638304205c94927d6802e858633343e9d27
上级
acf28603
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
35 addition
and
16 deletion
+35
-16
imperative/python/megengine/tools/network_visualize.py
imperative/python/megengine/tools/network_visualize.py
+9
-6
imperative/python/megengine/utils/module_stats.py
imperative/python/megengine/utils/module_stats.py
+26
-10
未找到文件。
imperative/python/megengine/tools/network_visualize.py
浏览文件 @
6bb9a255
...
...
@@ -15,10 +15,11 @@ import numpy as np
from
megengine.core.tensor.dtype
import
is_quantize
from
megengine.logger
import
_imperative_rt_logger
,
get_logger
,
set_mgb_log_level
from
megengine.utils.module_stats
import
(
get_flops_stats
,
enable_receptive_field
,
get_op_stats
,
get_param_stats
,
print_
flops
_stats
,
print_param
s
_stats
,
print_
op
_stats
,
print_param_stats
,
print_summary
,
sizeof_fmt
,
)
...
...
@@ -68,6 +69,8 @@ def visualize(
# FIXME: remove this after resolving "span dist too large" warning
old_level
=
set_mgb_log_level
(
logging
.
ERROR
)
enable_receptive_field
()
graph
=
Network
.
load
(
model_path
)
def
process_name
(
name
):
...
...
@@ -110,7 +113,7 @@ def visualize(
"params"
:
AttrValue
(
s
=
str
(
node
.
params
).
encode
(
encoding
=
"utf-8"
)),
"dtype"
:
AttrValue
(
s
=
str
(
node_oup
.
dtype
).
encode
(
encoding
=
"utf-8"
)),
}
flops_stats
=
get_
flops
_stats
(
node
,
node
.
inputs
,
node
.
outputs
)
flops_stats
=
get_
op
_stats
(
node
,
node
.
inputs
,
node
.
outputs
)
if
flops_stats
is
not
None
:
# add op flops attr
if
log_path
and
hasattr
(
flops_stats
,
"flops_num"
):
...
...
@@ -148,13 +151,13 @@ def visualize(
total_flops
,
total_param_dims
,
total_param_size
=
0
,
0
,
0
if
log_params
:
total_param_dims
,
total_param_size
=
print_param
s
_stats
(
total_param_dims
,
total_param_size
=
print_param_stats
(
params_list
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_flops
:
total_flops
=
print_
flops
_stats
(
flops_list
,
bar_length_max
)
total_flops
=
print_
op
_stats
(
flops_list
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_params
and
log_flops
:
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
...
...
imperative/python/megengine/utils/module_stats.py
浏览文件 @
6bb9a255
...
...
@@ -31,6 +31,8 @@ _calc_receptive_field_dict = {}
def
_receptive_field_fallback
(
module
,
inputs
,
outputs
):
if
not
_receptive_field_enabled
:
return
assert
not
hasattr
(
module
,
"_rf"
)
assert
not
hasattr
(
module
,
"_stride"
)
if
len
(
inputs
)
==
0
:
...
...
@@ -54,6 +56,8 @@ _iter_list = [
),
]
_receptive_field_enabled
=
False
def
_register_dict
(
*
modules
,
dict
=
None
):
def
callback
(
impl
):
...
...
@@ -72,6 +76,16 @@ def register_receptive_field(*modules):
return
_register_dict
(
*
modules
,
dict
=
_calc_receptive_field_dict
)
def
enable_receptive_field
():
global
_receptive_field_enabled
_receptive_field_enabled
=
True
def
disable_receptive_field
():
global
_receptive_field_enabled
_receptive_field_enabled
=
False
@
register_flops
(
m
.
Conv1d
,
m
.
Conv2d
,
m
.
Conv3d
,
)
...
...
@@ -144,16 +158,16 @@ def preprocess_receptive_field(module, inputs, outputs):
# TODO: support other dimensions
pre_rf
=
(
max
(
getattr
(
i
.
owner
,
"_rf"
,
(
1
,
1
))[
0
]
for
i
in
inputs
),
max
(
i
.
owner
.
_rf
[
1
]
for
i
in
inputs
),
max
(
getattr
(
i
.
owner
,
"_rf"
,
(
1
,
1
))
[
1
]
for
i
in
inputs
),
)
pre_stride
=
(
max
(
getattr
(
i
.
owner
,
"_stride"
,
(
1
,
1
))[
0
]
for
i
in
inputs
),
max
(
i
.
owner
.
_stride
[
1
]
for
i
in
inputs
),
max
(
getattr
(
i
.
owner
,
"_stride"
,
(
1
,
1
))
[
1
]
for
i
in
inputs
),
)
return
pre_rf
,
pre_stride
def
get_
flops
_stats
(
module
,
inputs
,
outputs
):
def
get_
op
_stats
(
module
,
inputs
,
outputs
):
rst
=
{
"input_shapes"
:
[
i
.
shape
for
i
in
inputs
],
"output_shapes"
:
[
o
.
shape
for
o
in
outputs
],
...
...
@@ -184,7 +198,7 @@ def get_flops_stats(module, inputs, outputs):
return
def
print_
flops
_stats
(
flops
,
bar_length_max
=
20
):
def
print_
op
_stats
(
flops
,
bar_length_max
=
20
):
max_flops_num
=
max
([
i
[
"flops_num"
]
for
i
in
flops
]
+
[
0
])
total_flops_num
=
0
for
d
in
flops
:
...
...
@@ -203,13 +217,14 @@ def print_flops_stats(flops, bar_length_max=20):
"class_name"
,
"input_shapes"
,
"output_shapes"
,
"receptive_field"
,
"stride"
,
"flops"
,
"flops_cum"
,
"percentage"
,
"bar"
,
]
if
_receptive_field_enabled
:
header
.
insert
(
4
,
"receptive_field"
)
header
.
insert
(
5
,
"stride"
)
total_flops_str
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
total_var_size
=
sum
(
...
...
@@ -240,7 +255,7 @@ def get_param_stats(param: np.ndarray):
}
def
print_param
s
_stats
(
params
,
bar_length_max
=
20
):
def
print_param_stats
(
params
,
bar_length_max
=
20
):
max_size
=
max
([
d
[
"size"
]
for
d
in
params
]
+
[
0
])
total_param_dims
,
total_param_size
=
0
,
0
for
d
in
params
:
...
...
@@ -302,11 +317,12 @@ def module_stats(
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
disable_receptive_field
()
def
module_stats_hook
(
module
,
inputs
,
outputs
,
name
=
""
):
class_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
].
split
(
"'"
)[
0
]
flops_stats
=
get_
flops
_stats
(
module
,
inputs
,
outputs
)
flops_stats
=
get_
op
_stats
(
module
,
inputs
,
outputs
)
if
flops_stats
is
not
None
:
flops_stats
[
"name"
]
=
name
flops_stats
[
"class_name"
]
=
class_name
...
...
@@ -349,11 +365,11 @@ def module_stats(
}
total_flops
,
total_param_dims
,
total_param_size
=
0
,
0
,
0
if
log_params
:
total_param_dims
,
total_param_size
=
print_param
s
_stats
(
params
,
bar_length_max
)
total_param_dims
,
total_param_size
=
print_param_stats
(
params
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_flops
:
total_flops
=
print_
flops
_stats
(
flops
,
bar_length_max
)
total_flops
=
print_
op
_stats
(
flops
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_params
and
log_flops
:
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录