Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c3a1ac3d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c3a1ac3d
编写于
4月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/tools): module_status-add-functions
BREAKING CHANGE: GitOrigin-RevId: ced3da3a129713c652d93b73756b93273bf1cc9b
上级
05e4c826
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
623 addition
and
77 deletion
+623
-77
imperative/python/megengine/tools/network_visualize.py
imperative/python/megengine/tools/network_visualize.py
+51
-13
imperative/python/megengine/utils/module_stats.py
imperative/python/megengine/utils/module_stats.py
+169
-64
imperative/python/megengine/utils/module_utils.py
imperative/python/megengine/utils/module_utils.py
+26
-0
imperative/python/test/unit/utils/test_module_stats.py
imperative/python/test/unit/utils/test_module_stats.py
+377
-0
未找到文件。
imperative/python/megengine/tools/network_visualize.py
浏览文件 @
c3a1ac3d
...
...
@@ -9,6 +9,7 @@
import
argparse
import
logging
import
re
from
collections
import
namedtuple
import
numpy
as
np
...
...
@@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize
from
megengine.logger
import
_imperative_rt_logger
,
get_logger
,
set_mgb_log_level
from
megengine.utils.module_stats
import
(
enable_receptive_field
,
get_activation_stats
,
get_op_stats
,
get_param_stats
,
print_activations_stats
,
print_op_stats
,
print_param_stats
,
print_summary
,
sizeof_fmt
,
sum_activations_stats
,
sum_op_stats
,
sum_param_stats
,
)
from
megengine.utils.network
import
Network
...
...
@@ -34,6 +40,7 @@ def visualize(
bar_length_max
:
int
=
20
,
log_params
:
bool
=
True
,
log_flops
:
bool
=
True
,
log_activations
:
bool
=
True
,
):
r
"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
...
...
@@ -44,6 +51,7 @@ def visualize(
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
"""
if
log_path
:
try
:
...
...
@@ -83,6 +91,10 @@ def visualize(
node_list
=
[]
flops_list
=
[]
params_list
=
[]
activations_list
=
[]
total_stats
=
namedtuple
(
"total_stats"
,
[
"param_size"
,
"flops"
,
"act_size"
])
stats_details
=
namedtuple
(
"module_stats"
,
[
"params"
,
"flops"
,
"activations"
])
for
node
in
graph
.
all_oprs
:
if
hasattr
(
node
,
"output_idx"
):
node_oup
=
node
.
outputs
[
node
.
output_idx
]
...
...
@@ -124,6 +136,11 @@ def visualize(
flops_stats
[
"class_name"
]
=
node
.
type
flops_list
.
append
(
flops_stats
)
acts
=
get_activation_stats
(
node_oup
.
numpy
())
acts
[
"name"
]
=
node
.
name
acts
[
"class_name"
]
=
node
.
type
activations_list
.
append
(
acts
)
if
node
.
type
==
"ImmutableTensor"
:
param_stats
=
get_param_stats
(
node
.
numpy
())
# add tensor size attr
...
...
@@ -149,20 +166,36 @@ def visualize(
"#params"
:
len
(
params_list
),
}
total_flops
,
total_param_dims
,
total_param_size
=
0
,
0
,
0
(
total_flops
,
total_param_dims
,
total_param_size
,
total_act_dims
,
total_param_size
,
)
=
(
0
,
0
,
0
,
0
,
0
)
total_param_dims
,
total_param_size
,
params
=
sum_param_stats
(
params_list
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
,
suffix
=
""
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_params
:
total_param_dims
,
total_param_size
=
print_param_stats
(
params_list
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
print_param_stats
(
params
)
total_flops
,
flops
=
sum_op_stats
(
flops_list
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_flops
:
total_flops
=
print_op_stats
(
flops_list
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_params
and
log_flops
:
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
print_op_stats
(
flops
)
total_act_dims
,
total_act_size
,
activations
=
sum_activations_stats
(
activations_list
,
bar_length_max
)
extra_info
[
"total_act_dims"
]
=
sizeof_fmt
(
total_act_dims
,
suffix
=
""
)
extra_info
[
"total_act_size"
]
=
sizeof_fmt
(
total_act_size
)
if
log_activations
:
print_activations_stats
(
activations
)
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
if
log_path
:
graph_def
=
GraphDef
(
node
=
node_list
,
versions
=
VersionDef
(
producer
=
22
))
...
...
@@ -179,7 +212,12 @@ def visualize(
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger
.
set_log_level
(
old_level
)
return
total_param_size
,
total_flops
return
(
total_stats
(
param_size
=
total_param_size
,
flops
=
total_flops
,
act_size
=
total_act_size
,
),
stats_details
(
params
=
params
,
flops
=
flops
,
activations
=
activations
),
)
def
main
():
...
...
imperative/python/megengine/utils/module_stats.py
浏览文件 @
c3a1ac3d
...
...
@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
contextlib
from
collections
import
namedtuple
from
functools
import
partial
import
numpy
as
np
...
...
@@ -18,6 +18,8 @@ import megengine.module.quantized as qm
from
megengine.core.tensor.dtype
import
get_dtype_bit
from
megengine.functional.tensor
import
zeros
from
.module_utils
import
set_module_mode_safe
try
:
mge
.
logger
.
MegEngineLogFormatter
.
max_lines
=
float
(
"inf"
)
except
AttributeError
as
e
:
...
...
@@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs):
)
@
register_flops
(
m
.
batchnorm
.
_BatchNorm
,
m
.
SyncBatchNorm
,
m
.
GroupNorm
,
m
.
LayerNorm
,
m
.
InstanceNorm
,
)
def
flops_norm
(
module
:
m
.
Linear
,
inputs
,
outputs
):
return
np
.
prod
(
inputs
[
0
].
shape
)
*
7
@
register_flops
(
m
.
AvgPool2d
,
m
.
MaxPool2d
)
def
flops_pool
(
module
:
m
.
AvgPool2d
,
inputs
,
outputs
):
return
np
.
prod
(
outputs
[
0
].
shape
)
*
(
module
.
kernel_size
**
2
)
@
register_flops
(
m
.
AdaptiveAvgPool2d
,
m
.
AdaptiveMaxPool2d
)
def
flops_adaptivePool
(
module
:
m
.
AdaptiveAvgPool2d
,
inputs
,
outputs
):
stride_h
=
np
.
floor
(
inputs
[
0
].
shape
[
2
]
/
(
inputs
[
0
].
shape
[
2
]
-
1
))
kernel_h
=
inputs
[
0
].
shape
[
2
]
-
(
inputs
[
0
].
shape
[
2
]
-
1
)
*
stride_h
stride_w
=
np
.
floor
(
inputs
[
0
].
shape
[
3
]
/
(
inputs
[
0
].
shape
[
3
]
-
1
))
kernel_w
=
inputs
[
0
].
shape
[
3
]
-
(
inputs
[
0
].
shape
[
3
]
-
1
)
*
stride_w
return
np
.
prod
(
outputs
[
0
].
shape
)
*
kernel_h
*
kernel_w
@
register_flops
(
m
.
Linear
)
def
flops_linear
(
module
:
m
.
Linear
,
inputs
,
outputs
):
bias
=
module
.
out_features
if
module
.
bias
is
not
None
else
0
...
...
@@ -120,6 +143,12 @@ hook_modules = (
m
.
conv
.
_ConvNd
,
m
.
Linear
,
m
.
BatchMatMulActivation
,
m
.
batchnorm
.
_BatchNorm
,
m
.
LayerNorm
,
m
.
GroupNorm
,
m
.
InstanceNorm
,
m
.
pooling
.
_PoolNd
,
m
.
adaptive_pooling
.
_AdaptivePoolNd
,
)
...
...
@@ -137,12 +166,16 @@ def dict2table(list_of_dict, header):
def
sizeof_fmt
(
num
,
suffix
=
"B"
):
for
unit
in
[
""
,
"Ki"
,
"Mi"
,
"Gi"
,
"Ti"
,
"Pi"
,
"Ei"
,
"Zi"
]:
if
abs
(
num
)
<
1024.0
:
if
suffix
==
"B"
:
scale
=
1024.0
units
=
[
""
,
"Ki"
,
"Mi"
,
"Gi"
,
"Ti"
,
"Pi"
,
"Ei"
,
"Zi"
,
"Yi"
]
else
:
scale
=
1000.0
units
=
[
""
,
"K"
,
"M"
,
"G"
,
"T"
,
"P"
,
"E"
,
"Z"
,
"Y"
]
for
unit
in
units
:
if
abs
(
num
)
<
scale
or
unit
==
units
[
-
1
]:
return
"{:3.3f} {}{}"
.
format
(
num
,
unit
,
suffix
)
num
/=
1024.0
sign_str
=
"-"
if
num
<
0
else
""
return
"{}{:.1f} {}{}"
.
format
(
sign_str
,
num
,
"Yi"
,
suffix
)
num
/=
scale
def
preprocess_receptive_field
(
module
,
inputs
,
outputs
):
...
...
@@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs):
def
get_op_stats
(
module
,
inputs
,
outputs
):
if
not
isinstance
(
outputs
,
tuple
)
and
not
isinstance
(
outputs
,
list
):
outputs
=
(
outputs
,)
rst
=
{
"input_shapes"
:
[
i
.
shape
for
i
in
inputs
],
"output_shapes"
:
[
o
.
shape
for
o
in
outputs
],
...
...
@@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs):
return
def
print
_op_stats
(
flops
,
bar_length_max
=
20
):
def
sum
_op_stats
(
flops
,
bar_length_max
=
20
):
max_flops_num
=
max
([
i
[
"flops_num"
]
for
i
in
flops
]
+
[
0
])
total_flops_num
=
0
for
d
in
flops
:
...
...
@@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20):
d
[
"bar"
]
=
"#"
*
bar_length
d
[
"flops"
]
=
sizeof_fmt
(
d
[
"flops_num"
],
suffix
=
"OPs"
)
total_flops_str
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
total_var_size
=
sum
(
sum
(
s
[
1
]
if
len
(
s
)
>
1
else
0
for
s
in
d
[
"output_shapes"
])
for
d
in
flops
)
flops
.
append
(
dict
(
name
=
"total"
,
flops
=
total_flops_str
,
output_shapes
=
total_var_size
)
)
return
total_flops_num
,
flops
def
print_op_stats
(
flops
):
header
=
[
"name"
,
"class_name"
,
...
...
@@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20):
if
_receptive_field_enabled
:
header
.
insert
(
4
,
"receptive_field"
)
header
.
insert
(
5
,
"stride"
)
total_flops_str
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
total_var_size
=
sum
(
sum
(
s
[
1
]
if
len
(
s
)
>
1
else
0
for
s
in
d
[
"output_shapes"
])
for
d
in
flops
)
flops
.
append
(
dict
(
name
=
"total"
,
flops
=
total_flops_str
,
output_shapes
=
total_var_size
)
)
logger
.
info
(
"flops stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
flops
,
header
=
header
)))
return
total_flops_num
def
get_param_stats
(
param
:
np
.
ndarray
):
nbits
=
get_dtype_bit
(
param
.
dtype
.
name
)
...
...
@@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray):
}
def
print
_param_stats
(
params
,
bar_length_max
=
20
):
def
sum
_param_stats
(
params
,
bar_length_max
=
20
):
max_size
=
max
([
d
[
"size"
]
for
d
in
params
]
+
[
0
])
total_param_dims
,
total_param_size
=
0
,
0
for
d
in
params
:
...
...
@@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20):
param_size
=
sizeof_fmt
(
total_param_size
)
params
.
append
(
dict
(
name
=
"total"
,
param_dim
=
total_param_dims
,
size
=
param_size
,))
return
total_param_dims
,
total_param_size
,
params
def
print_param_stats
(
params
):
header
=
[
"name"
,
"dtype"
,
...
...
@@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20):
"mean"
,
"std"
,
"param_dim"
,
"bits"
,
"
n
bits"
,
"size"
,
"size_cum"
,
"percentage"
,
"size_bar"
,
]
logger
.
info
(
"param stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
params
,
header
=
header
))
)
return
total_param_dims
,
total_param_size
def
get_activation_stats
(
output
:
np
.
ndarray
):
out_shape
=
output
.
shape
activations_dtype
=
output
.
dtype
nbits
=
get_dtype_bit
(
activations_dtype
.
name
)
act_dim
=
np
.
prod
(
out_shape
)
act_size
=
act_dim
*
nbits
//
8
return
{
"dtype"
:
activations_dtype
,
"shape"
:
out_shape
,
"act_dim"
:
act_dim
,
"mean"
:
"{:.3g}"
.
format
(
output
.
mean
()),
"std"
:
"{:.3g}"
.
format
(
output
.
std
()),
"nbits"
:
nbits
,
"size"
:
act_size
,
}
def
sum_activations_stats
(
activations
,
bar_length_max
=
20
):
max_act_size
=
max
([
i
[
"size"
]
for
i
in
activations
]
+
[
0
])
total_act_dims
,
total_act_size
=
0
,
0
for
d
in
activations
:
total_act_size
+=
int
(
d
[
"size"
])
total_act_dims
+=
int
(
d
[
"act_dim"
])
d
[
"size_cum"
]
=
sizeof_fmt
(
total_act_size
)
for
d
in
activations
:
ratio
=
d
[
"ratio"
]
=
d
[
"size"
]
/
total_act_size
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
ratio
*
100
)
bar_length
=
int
(
d
[
"size"
]
/
max_act_size
*
bar_length_max
)
d
[
"size_bar"
]
=
"#"
*
bar_length
d
[
"size"
]
=
sizeof_fmt
(
d
[
"size"
])
act_size
=
sizeof_fmt
(
total_act_size
)
activations
.
append
(
dict
(
name
=
"total"
,
act_dim
=
total_act_dims
,
size
=
act_size
,))
return
total_act_dims
,
total_act_size
,
activations
def
print_activations_stats
(
activations
):
header
=
[
"name"
,
"class_name"
,
"dtype"
,
"shape"
,
"mean"
,
"std"
,
"nbits"
,
"act_dim"
,
"size"
,
"size_cum"
,
"percentage"
,
"size_bar"
,
]
logger
.
info
(
"activations stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
activations
,
header
=
header
))
)
def
print_summary
(
**
kwargs
):
...
...
@@ -294,25 +390,26 @@ def print_summary(**kwargs):
def
module_stats
(
model
:
m
.
Module
,
input_s
ize
:
in
t
,
input_s
hapes
:
lis
t
,
bar_length_max
:
int
=
20
,
log_params
:
bool
=
True
,
log_flops
:
bool
=
True
,
log_activations
:
bool
=
True
,
):
r
"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
:param model: model that need to get stats info.
:param input_s
ize: size of input
for running model and calculating stats.
:param input_s
hapes: shapes of inputs
for running model and calculating stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
"""
disable_receptive_field
()
def
module_stats_hook
(
module
,
inputs
,
outputs
,
name
=
""
):
class_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
].
split
(
"'"
)[
0
]
flops_stats
=
get_op_stats
(
module
,
inputs
,
outputs
)
if
flops_stats
is
not
None
:
flops_stats
[
"name"
]
=
name
...
...
@@ -331,38 +428,25 @@ def module_stats(
param_stats
[
"name"
]
=
name
+
"-b"
params
.
append
(
param_stats
)
@
contextlib
.
contextmanager
def
adjust_stats
(
module
,
training
=
False
):
"""Adjust module to training/eval mode temporarily.
Args:
module (M.Module): used module.
training (bool): training mode. True for train mode, False fro eval mode.
"""
def
recursive_backup_stats
(
module
,
mode
):
for
m
in
module
.
modules
():
# save prev status to _prev_training
m
.
_prev_training
=
m
.
training
m
.
train
(
mode
,
recursive
=
False
)
def
recursive_recover_stats
(
module
):
for
m
in
module
.
modules
():
# recover prev status and delete attribute
m
.
training
=
m
.
_prev_training
delattr
(
m
,
"_prev_training"
)
recursive_backup_stats
(
module
,
mode
=
training
)
yield
module
recursive_recover_stats
(
module
)
if
not
isinstance
(
outputs
,
tuple
)
or
not
isinstance
(
outputs
,
list
):
output
=
outputs
.
numpy
()
else
:
output
=
outputs
[
0
].
numpy
()
activation_stats
=
get_activation_stats
(
output
)
activation_stats
[
"name"
]
=
name
activation_stats
[
"class_name"
]
=
class_name
activations
.
append
(
activation_stats
)
# multiple inputs to the network
if
not
isinstance
(
input_s
ize
[
0
],
tuple
):
input_s
ize
=
[
input_size
]
if
not
isinstance
(
input_s
hapes
[
0
],
tuple
):
input_s
hapes
=
[
input_shapes
]
params
=
[]
flops
=
[]
hooks
=
[]
activations
=
[]
total_stats
=
namedtuple
(
"total_stats"
,
[
"param_size"
,
"flops"
,
"act_size"
])
stats_details
=
namedtuple
(
"module_stats"
,
[
"params"
,
"flops"
,
"activations"
])
for
(
name
,
module
)
in
model
.
named_modules
():
if
isinstance
(
module
,
hook_modules
):
...
...
@@ -370,8 +454,8 @@ def module_stats(
module
.
register_forward_hook
(
partial
(
module_stats_hook
,
name
=
name
))
)
inputs
=
[
zeros
(
in_size
,
dtype
=
np
.
float32
)
for
in_size
in
input_s
ize
]
with
adjust_stats
(
model
,
training
=
False
)
as
model
:
inputs
=
[
zeros
(
in_size
,
dtype
=
np
.
float32
)
for
in_size
in
input_s
hapes
]
with
set_module_mode_safe
(
model
,
training
=
False
)
as
model
:
model
(
*
inputs
)
for
h
in
hooks
:
...
...
@@ -380,19 +464,40 @@ def module_stats(
extra_info
=
{
"#params"
:
len
(
params
),
}
total_flops
,
total_param_dims
,
total_param_size
=
0
,
0
,
0
(
total_flops
,
total_param_dims
,
total_param_size
,
total_act_dims
,
total_param_size
,
)
=
(
0
,
0
,
0
,
0
,
0
)
total_param_dims
,
total_param_size
,
params
=
sum_param_stats
(
params
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
,
suffix
=
""
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_params
:
total_param_dims
,
total_param_size
=
print_param_stats
(
params
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
print_param_stats
(
params
)
total_flops
,
flops
=
sum_op_stats
(
flops
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_flops
:
total_flops
=
print_op_stats
(
flops
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_params
and
log_flops
:
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
print_op_stats
(
flops
)
total_act_dims
,
total_act_size
,
activations
=
sum_activations_stats
(
activations
,
bar_length_max
)
extra_info
[
"total_act_dims"
]
=
sizeof_fmt
(
total_act_dims
,
suffix
=
""
)
extra_info
[
"total_act_size"
]
=
sizeof_fmt
(
total_act_size
)
if
log_activations
:
print_activations_stats
(
activations
)
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
print_summary
(
**
extra_info
)
return
total_param_size
,
total_flops
return
(
total_stats
(
param_size
=
total_param_size
,
flops
=
total_flops
,
act_size
=
total_act_size
,
),
stats_details
(
params
=
params
,
flops
=
flops
,
activations
=
activations
),
)
imperative/python/megengine/utils/module_utils.py
浏览文件 @
c3a1ac3d
...
...
@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
contextlib
from
collections
import
Iterable
from
..module
import
Sequential
...
...
@@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value):
parent
[
key
]
=
value
_access_structure
(
obj
,
key
,
callback
=
f
)
@
contextlib
.
contextmanager
def
set_module_mode_safe
(
module
:
Module
,
training
:
bool
=
False
,
):
"""Adjust module to training/eval mode temporarily.
:param module: used module.
:param training: training (bool): training mode. True for train mode, False fro eval mode.
"""
backup_stats
=
{}
def
recursive_backup_stats
(
module
,
mode
):
for
m
in
module
.
modules
():
backup_stats
[
m
]
=
m
.
training
m
.
train
(
mode
,
recursive
=
False
)
def
recursive_recover_stats
(
module
):
for
m
in
module
.
modules
():
m
.
training
=
backup_stats
.
pop
(
m
)
recursive_backup_stats
(
module
,
mode
=
training
)
yield
module
recursive_recover_stats
(
module
)
imperative/python/test/unit/utils/test_module_stats.py
0 → 100644
浏览文件 @
c3a1ac3d
import
math
from
copy
import
deepcopy
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.functional
as
F
import
megengine.hub
as
hub
import
megengine.module
as
M
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.utils.module_stats
import
module_stats
@
pytest
.
mark
.
skipif
(
use_symbolic_shape
(),
reason
=
"This test do not support symbolic shape."
,
)
def
test_module_stats
():
net
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
input_shape
=
(
1
,
3
,
224
,
224
)
total_stats
,
stats_details
=
module_stats
(
net
,
input_shape
)
x1
=
mge
.
tensor
(
np
.
zeros
((
1
,
3
,
224
,
224
)))
gt_flops
,
gt_acts
=
net
.
get_stats
(
x1
)
assert
(
total_stats
.
flops
,
stats_details
.
activations
[
-
1
][
"act_dim"
])
==
(
gt_flops
,
gt_acts
,
)
class
BasicBlock
(
M
.
Module
):
expansion
=
1
def
__init__
(
self
,
in_channels
,
channels
,
stride
=
1
,
groups
=
1
,
base_width
=
64
,
dilation
=
1
,
norm
=
M
.
BatchNorm2d
,
):
super
().
__init__
()
self
.
tmp_in_channels
=
in_channels
self
.
tmp_channels
=
channels
self
.
stride
=
stride
if
groups
!=
1
or
base_width
!=
64
:
raise
ValueError
(
"BasicBlock only supports groups=1 and base_width=64"
)
if
dilation
>
1
:
raise
NotImplementedError
(
"Dilation > 1 not supported in BasicBlock"
)
self
.
conv1
=
M
.
Conv2d
(
in_channels
,
channels
,
3
,
stride
,
padding
=
dilation
,
bias
=
False
)
self
.
bn1
=
norm
(
channels
)
self
.
conv2
=
M
.
Conv2d
(
channels
,
channels
,
3
,
1
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
norm
(
channels
)
self
.
downsample_id
=
M
.
Identity
()
self
.
downsample_conv
=
M
.
Conv2d
(
in_channels
,
channels
,
1
,
stride
,
bias
=
False
)
self
.
downsample_norm
=
norm
(
channels
)
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
if
self
.
tmp_in_channels
==
self
.
tmp_channels
and
self
.
stride
==
1
:
identity
=
self
.
downsample_id
(
identity
)
else
:
identity
=
self
.
downsample_conv
(
identity
)
identity
=
self
.
downsample_norm
(
identity
)
x
+=
identity
x
=
F
.
relu
(
x
)
return
x
def
get_stats
(
self
,
x
):
activations
,
flops
=
0
,
0
identity
=
x
in_x
=
deepcopy
(
x
)
x
=
self
.
conv1
(
x
)
tmp_flops
,
tmp_acts
=
cal_conv_stats
(
self
.
conv1
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
in_x
=
deepcopy
(
x
)
x
=
self
.
bn1
(
x
)
tmp_flops
,
tmp_acts
=
cal_norm_stats
(
self
.
bn1
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
=
F
.
relu
(
x
)
in_x
=
deepcopy
(
x
)
x
=
self
.
conv2
(
x
)
tmp_flops
,
tmp_acts
=
cal_conv_stats
(
self
.
conv2
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
in_x
=
deepcopy
(
x
)
x
=
self
.
bn2
(
x
)
tmp_flops
,
tmp_acts
=
cal_norm_stats
(
self
.
bn2
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
if
self
.
tmp_in_channels
==
self
.
tmp_channels
and
self
.
stride
==
1
:
identity
=
self
.
downsample_id
(
identity
)
else
:
in_x
=
deepcopy
(
identity
)
identity
=
self
.
downsample_conv
(
identity
)
tmp_flops
,
tmp_acts
=
cal_conv_stats
(
self
.
downsample_conv
,
in_x
,
identity
)
activations
+=
tmp_acts
flops
+=
tmp_flops
in_x
=
deepcopy
(
identity
)
identity
=
self
.
downsample_norm
(
identity
)
tmp_flops
,
tmp_acts
=
cal_norm_stats
(
self
.
downsample_norm
,
in_x
,
identity
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
+=
identity
x
=
F
.
relu
(
x
)
return
x
,
flops
,
activations
class
ResNet
(
M
.
Module
):
def
__init__
(
self
,
block
,
layers
=
[
2
,
2
,
2
,
2
],
num_classes
=
1000
,
zero_init_residual
=
False
,
groups
=
1
,
width_per_group
=
64
,
replace_stride_with_dilation
=
None
,
norm
=
M
.
BatchNorm2d
,
):
super
().
__init__
()
self
.
in_channels
=
64
self
.
dilation
=
1
if
replace_stride_with_dilation
is
None
:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation
=
[
False
,
False
,
False
]
if
len
(
replace_stride_with_dilation
)
!=
3
:
raise
ValueError
(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}"
.
format
(
replace_stride_with_dilation
)
)
self
.
groups
=
groups
self
.
base_width
=
width_per_group
self
.
conv1
=
M
.
Conv2d
(
3
,
self
.
in_channels
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
norm
(
self
.
in_channels
)
self
.
maxpool
=
M
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1_0
=
BasicBlock
(
self
.
in_channels
,
64
,
stride
=
1
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
self
.
dilation
,
norm
=
M
.
BatchNorm2d
,
)
self
.
layer1_1
=
BasicBlock
(
self
.
in_channels
,
64
,
stride
=
1
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
self
.
dilation
,
norm
=
M
.
BatchNorm2d
,
)
self
.
layer2_0
=
BasicBlock
(
64
,
128
,
stride
=
2
)
self
.
layer2_1
=
BasicBlock
(
128
,
128
)
self
.
layer3_0
=
BasicBlock
(
128
,
256
,
stride
=
2
)
self
.
layer3_1
=
BasicBlock
(
256
,
256
)
self
.
layer4_0
=
BasicBlock
(
256
,
512
,
stride
=
2
)
self
.
layer4_1
=
BasicBlock
(
512
,
512
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
],
norm
=
norm
)
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
2
,
stride
=
2
,
dilate
=
replace_stride_with_dilation
[
0
],
norm
=
norm
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
2
,
stride
=
2
,
dilate
=
replace_stride_with_dilation
[
1
],
norm
=
norm
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
2
,
stride
=
2
,
dilate
=
replace_stride_with_dilation
[
2
],
norm
=
norm
)
self
.
fc
=
M
.
Linear
(
512
,
num_classes
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
M
.
Conv2d
):
M
.
init
.
msra_normal_
(
m
.
weight
,
mode
=
"fan_out"
,
nonlinearity
=
"relu"
)
if
m
.
bias
is
not
None
:
fan_in
,
_
=
M
.
init
.
calculate_fan_in_and_fan_out
(
m
.
weight
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
M
.
init
.
uniform_
(
m
.
bias
,
-
bound
,
bound
)
elif
isinstance
(
m
,
M
.
BatchNorm2d
):
M
.
init
.
ones_
(
m
.
weight
)
M
.
init
.
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
M
.
Linear
):
M
.
init
.
msra_uniform_
(
m
.
weight
,
a
=
math
.
sqrt
(
5
))
if
m
.
bias
is
not
None
:
fan_in
,
_
=
M
.
init
.
calculate_fan_in_and_fan_out
(
m
.
weight
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
M
.
init
.
uniform_
(
m
.
bias
,
-
bound
,
bound
)
if
zero_init_residual
:
for
m
in
self
.
modules
():
M
.
init
.
zeros_
(
m
.
bn2
.
weight
)
def
_make_layer
(
self
,
block
,
channels
,
blocks
,
stride
=
1
,
dilate
=
False
,
norm
=
M
.
BatchNorm2d
):
previous_dilation
=
self
.
dilation
if
dilate
:
self
.
dilation
*=
stride
stride
=
1
layers
=
[]
layers
.
append
(
block
(
self
.
in_channels
,
channels
,
stride
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
previous_dilation
,
norm
=
norm
,
)
)
self
.
in_channels
=
channels
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
in_channels
,
channels
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
self
.
dilation
,
norm
=
norm
,
)
)
return
M
.
Sequential
(
*
layers
)
def
extract_features
(
self
,
x
):
outputs
=
{}
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
outputs
[
"stem"
]
=
x
x
=
self
.
layer1
(
x
)
outputs
[
"res2"
]
=
x
x
=
self
.
layer2
(
x
)
outputs
[
"res3"
]
=
x
x
=
self
.
layer3
(
x
)
outputs
[
"res4"
]
=
x
x
=
self
.
layer4
(
x
)
outputs
[
"res5"
]
=
x
return
outputs
def
forward
(
self
,
x
):
x
=
self
.
extract_features
(
x
)[
"res5"
]
x
=
F
.
avg_pool2d
(
x
,
7
)
x
=
F
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
return
x
def
get_stats
(
self
,
x
):
flops
,
activations
=
0
,
0
in_x
=
deepcopy
(
x
)
x
=
self
.
conv1
(
x
)
tmp_flops
,
tmp_acts
=
cal_conv_stats
(
self
.
conv1
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
in_x
=
deepcopy
(
x
)
x
=
self
.
bn1
(
x
)
tmp_flops
,
tmp_acts
=
cal_norm_stats
(
self
.
bn1
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
=
F
.
relu
(
x
)
in_x
=
deepcopy
(
x
)
x
=
self
.
maxpool
(
x
)
tmp_flops
,
tmp_acts
=
cal_pool_stats
(
self
.
maxpool
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer1_0
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer1_1
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer2_0
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer2_1
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer3_0
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer3_1
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer4_0
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer4_1
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
=
F
.
avg_pool2d
(
x
,
7
)
x
=
F
.
flatten
(
x
,
1
)
in_x
=
deepcopy
(
x
)
x
=
self
.
fc
(
x
)
tmp_flops
,
tmp_acts
=
cal_linear_stats
(
self
.
fc
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
return
flops
,
activations
def
cal_conv_stats
(
module
,
input
,
output
):
bias
=
1
if
module
.
bias
is
not
None
else
0
flops
=
np
.
prod
(
output
[
0
].
shape
)
*
(
module
.
in_channels
//
module
.
groups
*
np
.
prod
(
module
.
kernel_size
)
+
bias
)
acts
=
np
.
prod
(
output
[
0
].
shape
)
return
flops
,
acts
def
cal_norm_stats
(
module
,
input
,
output
):
return
np
.
prod
(
input
[
0
].
shape
)
*
7
,
np
.
prod
(
output
[
0
].
shape
)
def
cal_linear_stats
(
module
,
inputs
,
outputs
):
bias
=
module
.
out_features
if
module
.
bias
is
not
None
else
0
return
(
np
.
prod
(
outputs
[
0
].
shape
)
*
module
.
in_features
+
bias
,
np
.
prod
(
outputs
[
0
].
shape
),
)
def
cal_pool_stats
(
module
,
inputs
,
outputs
):
return
(
np
.
prod
(
outputs
[
0
].
shape
)
*
(
module
.
kernel_size
**
2
),
np
.
prod
(
outputs
[
0
].
shape
),
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录