Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c3a1ac3d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c3a1ac3d
编写于
4月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/tools): module_status-add-functions
BREAKING CHANGE: GitOrigin-RevId: ced3da3a129713c652d93b73756b93273bf1cc9b
上级
05e4c826
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
623 addition
and
77 deletion
+623
-77
imperative/python/megengine/tools/network_visualize.py
imperative/python/megengine/tools/network_visualize.py
+51
-13
imperative/python/megengine/utils/module_stats.py
imperative/python/megengine/utils/module_stats.py
+169
-64
imperative/python/megengine/utils/module_utils.py
imperative/python/megengine/utils/module_utils.py
+26
-0
imperative/python/test/unit/utils/test_module_stats.py
imperative/python/test/unit/utils/test_module_stats.py
+377
-0
未找到文件。
imperative/python/megengine/tools/network_visualize.py
浏览文件 @
c3a1ac3d
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
import
argparse
import
argparse
import
logging
import
logging
import
re
import
re
from
collections
import
namedtuple
import
numpy
as
np
import
numpy
as
np
...
@@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize
...
@@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize
from
megengine.logger
import
_imperative_rt_logger
,
get_logger
,
set_mgb_log_level
from
megengine.logger
import
_imperative_rt_logger
,
get_logger
,
set_mgb_log_level
from
megengine.utils.module_stats
import
(
from
megengine.utils.module_stats
import
(
enable_receptive_field
,
enable_receptive_field
,
get_activation_stats
,
get_op_stats
,
get_op_stats
,
get_param_stats
,
get_param_stats
,
print_activations_stats
,
print_op_stats
,
print_op_stats
,
print_param_stats
,
print_param_stats
,
print_summary
,
print_summary
,
sizeof_fmt
,
sizeof_fmt
,
sum_activations_stats
,
sum_op_stats
,
sum_param_stats
,
)
)
from
megengine.utils.network
import
Network
from
megengine.utils.network
import
Network
...
@@ -34,6 +40,7 @@ def visualize(
...
@@ -34,6 +40,7 @@ def visualize(
bar_length_max
:
int
=
20
,
bar_length_max
:
int
=
20
,
log_params
:
bool
=
True
,
log_params
:
bool
=
True
,
log_flops
:
bool
=
True
,
log_flops
:
bool
=
True
,
log_activations
:
bool
=
True
,
):
):
r
"""
r
"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
Load megengine dumped model and visualize graph structure with tensorboard log files.
...
@@ -44,6 +51,7 @@ def visualize(
...
@@ -44,6 +51,7 @@ def visualize(
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
"""
"""
if
log_path
:
if
log_path
:
try
:
try
:
...
@@ -83,6 +91,10 @@ def visualize(
...
@@ -83,6 +91,10 @@ def visualize(
node_list
=
[]
node_list
=
[]
flops_list
=
[]
flops_list
=
[]
params_list
=
[]
params_list
=
[]
activations_list
=
[]
total_stats
=
namedtuple
(
"total_stats"
,
[
"param_size"
,
"flops"
,
"act_size"
])
stats_details
=
namedtuple
(
"module_stats"
,
[
"params"
,
"flops"
,
"activations"
])
for
node
in
graph
.
all_oprs
:
for
node
in
graph
.
all_oprs
:
if
hasattr
(
node
,
"output_idx"
):
if
hasattr
(
node
,
"output_idx"
):
node_oup
=
node
.
outputs
[
node
.
output_idx
]
node_oup
=
node
.
outputs
[
node
.
output_idx
]
...
@@ -124,6 +136,11 @@ def visualize(
...
@@ -124,6 +136,11 @@ def visualize(
flops_stats
[
"class_name"
]
=
node
.
type
flops_stats
[
"class_name"
]
=
node
.
type
flops_list
.
append
(
flops_stats
)
flops_list
.
append
(
flops_stats
)
acts
=
get_activation_stats
(
node_oup
.
numpy
())
acts
[
"name"
]
=
node
.
name
acts
[
"class_name"
]
=
node
.
type
activations_list
.
append
(
acts
)
if
node
.
type
==
"ImmutableTensor"
:
if
node
.
type
==
"ImmutableTensor"
:
param_stats
=
get_param_stats
(
node
.
numpy
())
param_stats
=
get_param_stats
(
node
.
numpy
())
# add tensor size attr
# add tensor size attr
...
@@ -149,20 +166,36 @@ def visualize(
...
@@ -149,20 +166,36 @@ def visualize(
"#params"
:
len
(
params_list
),
"#params"
:
len
(
params_list
),
}
}
total_flops
,
total_param_dims
,
total_param_size
=
0
,
0
,
0
(
total_flops
,
total_param_dims
,
total_param_size
,
total_act_dims
,
total_param_size
,
)
=
(
0
,
0
,
0
,
0
,
0
)
total_param_dims
,
total_param_size
,
params
=
sum_param_stats
(
params_list
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
,
suffix
=
""
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_params
:
if
log_params
:
total_param_dims
,
total_param_size
=
print_param_stats
(
print_param_stats
(
params
)
params_list
,
bar_length_max
)
total_flops
,
flops
=
sum_op_stats
(
flops_list
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_flops
:
if
log_flops
:
total_flops
=
print_op_stats
(
flops_list
,
bar_length_max
)
print_op_stats
(
flops
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_params
and
log_flops
:
total_act_dims
,
total_act_size
,
activations
=
sum_activations_stats
(
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
activations_list
,
bar_length_max
total_flops
/
total_param_size
)
)
extra_info
[
"total_act_dims"
]
=
sizeof_fmt
(
total_act_dims
,
suffix
=
""
)
extra_info
[
"total_act_size"
]
=
sizeof_fmt
(
total_act_size
)
if
log_activations
:
print_activations_stats
(
activations
)
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
if
log_path
:
if
log_path
:
graph_def
=
GraphDef
(
node
=
node_list
,
versions
=
VersionDef
(
producer
=
22
))
graph_def
=
GraphDef
(
node
=
node_list
,
versions
=
VersionDef
(
producer
=
22
))
...
@@ -179,7 +212,12 @@ def visualize(
...
@@ -179,7 +212,12 @@ def visualize(
# FIXME: remove this after resolving "span dist too large" warning
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger
.
set_log_level
(
old_level
)
_imperative_rt_logger
.
set_log_level
(
old_level
)
return
total_param_size
,
total_flops
return
(
total_stats
(
param_size
=
total_param_size
,
flops
=
total_flops
,
act_size
=
total_act_size
,
),
stats_details
(
params
=
params
,
flops
=
flops
,
activations
=
activations
),
)
def
main
():
def
main
():
...
...
imperative/python/megengine/utils/module_stats.py
浏览文件 @
c3a1ac3d
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
contextlib
from
collections
import
namedtuple
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
...
@@ -18,6 +18,8 @@ import megengine.module.quantized as qm
...
@@ -18,6 +18,8 @@ import megengine.module.quantized as qm
from
megengine.core.tensor.dtype
import
get_dtype_bit
from
megengine.core.tensor.dtype
import
get_dtype_bit
from
megengine.functional.tensor
import
zeros
from
megengine.functional.tensor
import
zeros
from
.module_utils
import
set_module_mode_safe
try
:
try
:
mge
.
logger
.
MegEngineLogFormatter
.
max_lines
=
float
(
"inf"
)
mge
.
logger
.
MegEngineLogFormatter
.
max_lines
=
float
(
"inf"
)
except
AttributeError
as
e
:
except
AttributeError
as
e
:
...
@@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs):
...
@@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs):
)
)
@
register_flops
(
m
.
batchnorm
.
_BatchNorm
,
m
.
SyncBatchNorm
,
m
.
GroupNorm
,
m
.
LayerNorm
,
m
.
InstanceNorm
,
)
def
flops_norm
(
module
:
m
.
Linear
,
inputs
,
outputs
):
return
np
.
prod
(
inputs
[
0
].
shape
)
*
7
@
register_flops
(
m
.
AvgPool2d
,
m
.
MaxPool2d
)
def
flops_pool
(
module
:
m
.
AvgPool2d
,
inputs
,
outputs
):
return
np
.
prod
(
outputs
[
0
].
shape
)
*
(
module
.
kernel_size
**
2
)
@
register_flops
(
m
.
AdaptiveAvgPool2d
,
m
.
AdaptiveMaxPool2d
)
def
flops_adaptivePool
(
module
:
m
.
AdaptiveAvgPool2d
,
inputs
,
outputs
):
stride_h
=
np
.
floor
(
inputs
[
0
].
shape
[
2
]
/
(
inputs
[
0
].
shape
[
2
]
-
1
))
kernel_h
=
inputs
[
0
].
shape
[
2
]
-
(
inputs
[
0
].
shape
[
2
]
-
1
)
*
stride_h
stride_w
=
np
.
floor
(
inputs
[
0
].
shape
[
3
]
/
(
inputs
[
0
].
shape
[
3
]
-
1
))
kernel_w
=
inputs
[
0
].
shape
[
3
]
-
(
inputs
[
0
].
shape
[
3
]
-
1
)
*
stride_w
return
np
.
prod
(
outputs
[
0
].
shape
)
*
kernel_h
*
kernel_w
@
register_flops
(
m
.
Linear
)
@
register_flops
(
m
.
Linear
)
def
flops_linear
(
module
:
m
.
Linear
,
inputs
,
outputs
):
def
flops_linear
(
module
:
m
.
Linear
,
inputs
,
outputs
):
bias
=
module
.
out_features
if
module
.
bias
is
not
None
else
0
bias
=
module
.
out_features
if
module
.
bias
is
not
None
else
0
...
@@ -120,6 +143,12 @@ hook_modules = (
...
@@ -120,6 +143,12 @@ hook_modules = (
m
.
conv
.
_ConvNd
,
m
.
conv
.
_ConvNd
,
m
.
Linear
,
m
.
Linear
,
m
.
BatchMatMulActivation
,
m
.
BatchMatMulActivation
,
m
.
batchnorm
.
_BatchNorm
,
m
.
LayerNorm
,
m
.
GroupNorm
,
m
.
InstanceNorm
,
m
.
pooling
.
_PoolNd
,
m
.
adaptive_pooling
.
_AdaptivePoolNd
,
)
)
...
@@ -137,12 +166,16 @@ def dict2table(list_of_dict, header):
...
@@ -137,12 +166,16 @@ def dict2table(list_of_dict, header):
def
sizeof_fmt
(
num
,
suffix
=
"B"
):
def
sizeof_fmt
(
num
,
suffix
=
"B"
):
for
unit
in
[
""
,
"Ki"
,
"Mi"
,
"Gi"
,
"Ti"
,
"Pi"
,
"Ei"
,
"Zi"
]:
if
suffix
==
"B"
:
if
abs
(
num
)
<
1024.0
:
scale
=
1024.0
units
=
[
""
,
"Ki"
,
"Mi"
,
"Gi"
,
"Ti"
,
"Pi"
,
"Ei"
,
"Zi"
,
"Yi"
]
else
:
scale
=
1000.0
units
=
[
""
,
"K"
,
"M"
,
"G"
,
"T"
,
"P"
,
"E"
,
"Z"
,
"Y"
]
for
unit
in
units
:
if
abs
(
num
)
<
scale
or
unit
==
units
[
-
1
]:
return
"{:3.3f} {}{}"
.
format
(
num
,
unit
,
suffix
)
return
"{:3.3f} {}{}"
.
format
(
num
,
unit
,
suffix
)
num
/=
1024.0
num
/=
scale
sign_str
=
"-"
if
num
<
0
else
""
return
"{}{:.1f} {}{}"
.
format
(
sign_str
,
num
,
"Yi"
,
suffix
)
def
preprocess_receptive_field
(
module
,
inputs
,
outputs
):
def
preprocess_receptive_field
(
module
,
inputs
,
outputs
):
...
@@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs):
...
@@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs):
def
get_op_stats
(
module
,
inputs
,
outputs
):
def
get_op_stats
(
module
,
inputs
,
outputs
):
if
not
isinstance
(
outputs
,
tuple
)
and
not
isinstance
(
outputs
,
list
):
outputs
=
(
outputs
,)
rst
=
{
rst
=
{
"input_shapes"
:
[
i
.
shape
for
i
in
inputs
],
"input_shapes"
:
[
i
.
shape
for
i
in
inputs
],
"output_shapes"
:
[
o
.
shape
for
o
in
outputs
],
"output_shapes"
:
[
o
.
shape
for
o
in
outputs
],
...
@@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs):
...
@@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs):
return
return
def
print
_op_stats
(
flops
,
bar_length_max
=
20
):
def
sum
_op_stats
(
flops
,
bar_length_max
=
20
):
max_flops_num
=
max
([
i
[
"flops_num"
]
for
i
in
flops
]
+
[
0
])
max_flops_num
=
max
([
i
[
"flops_num"
]
for
i
in
flops
]
+
[
0
])
total_flops_num
=
0
total_flops_num
=
0
for
d
in
flops
:
for
d
in
flops
:
...
@@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20):
...
@@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20):
d
[
"bar"
]
=
"#"
*
bar_length
d
[
"bar"
]
=
"#"
*
bar_length
d
[
"flops"
]
=
sizeof_fmt
(
d
[
"flops_num"
],
suffix
=
"OPs"
)
d
[
"flops"
]
=
sizeof_fmt
(
d
[
"flops_num"
],
suffix
=
"OPs"
)
total_flops_str
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
total_var_size
=
sum
(
sum
(
s
[
1
]
if
len
(
s
)
>
1
else
0
for
s
in
d
[
"output_shapes"
])
for
d
in
flops
)
flops
.
append
(
dict
(
name
=
"total"
,
flops
=
total_flops_str
,
output_shapes
=
total_var_size
)
)
return
total_flops_num
,
flops
def
print_op_stats
(
flops
):
header
=
[
header
=
[
"name"
,
"name"
,
"class_name"
,
"class_name"
,
...
@@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20):
...
@@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20):
if
_receptive_field_enabled
:
if
_receptive_field_enabled
:
header
.
insert
(
4
,
"receptive_field"
)
header
.
insert
(
4
,
"receptive_field"
)
header
.
insert
(
5
,
"stride"
)
header
.
insert
(
5
,
"stride"
)
total_flops_str
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
total_var_size
=
sum
(
sum
(
s
[
1
]
if
len
(
s
)
>
1
else
0
for
s
in
d
[
"output_shapes"
])
for
d
in
flops
)
flops
.
append
(
dict
(
name
=
"total"
,
flops
=
total_flops_str
,
output_shapes
=
total_var_size
)
)
logger
.
info
(
"flops stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
flops
,
header
=
header
)))
logger
.
info
(
"flops stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
flops
,
header
=
header
)))
return
total_flops_num
def
get_param_stats
(
param
:
np
.
ndarray
):
def
get_param_stats
(
param
:
np
.
ndarray
):
nbits
=
get_dtype_bit
(
param
.
dtype
.
name
)
nbits
=
get_dtype_bit
(
param
.
dtype
.
name
)
...
@@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray):
...
@@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray):
}
}
def
print
_param_stats
(
params
,
bar_length_max
=
20
):
def
sum
_param_stats
(
params
,
bar_length_max
=
20
):
max_size
=
max
([
d
[
"size"
]
for
d
in
params
]
+
[
0
])
max_size
=
max
([
d
[
"size"
]
for
d
in
params
]
+
[
0
])
total_param_dims
,
total_param_size
=
0
,
0
total_param_dims
,
total_param_size
=
0
,
0
for
d
in
params
:
for
d
in
params
:
...
@@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20):
...
@@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20):
param_size
=
sizeof_fmt
(
total_param_size
)
param_size
=
sizeof_fmt
(
total_param_size
)
params
.
append
(
dict
(
name
=
"total"
,
param_dim
=
total_param_dims
,
size
=
param_size
,))
params
.
append
(
dict
(
name
=
"total"
,
param_dim
=
total_param_dims
,
size
=
param_size
,))
return
total_param_dims
,
total_param_size
,
params
def
print_param_stats
(
params
):
header
=
[
header
=
[
"name"
,
"name"
,
"dtype"
,
"dtype"
,
...
@@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20):
...
@@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20):
"mean"
,
"mean"
,
"std"
,
"std"
,
"param_dim"
,
"param_dim"
,
"bits"
,
"
n
bits"
,
"size"
,
"size"
,
"size_cum"
,
"size_cum"
,
"percentage"
,
"percentage"
,
"size_bar"
,
"size_bar"
,
]
]
logger
.
info
(
logger
.
info
(
"param stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
params
,
header
=
header
))
"param stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
params
,
header
=
header
))
)
)
return
total_param_dims
,
total_param_size
def
get_activation_stats
(
output
:
np
.
ndarray
):
out_shape
=
output
.
shape
activations_dtype
=
output
.
dtype
nbits
=
get_dtype_bit
(
activations_dtype
.
name
)
act_dim
=
np
.
prod
(
out_shape
)
act_size
=
act_dim
*
nbits
//
8
return
{
"dtype"
:
activations_dtype
,
"shape"
:
out_shape
,
"act_dim"
:
act_dim
,
"mean"
:
"{:.3g}"
.
format
(
output
.
mean
()),
"std"
:
"{:.3g}"
.
format
(
output
.
std
()),
"nbits"
:
nbits
,
"size"
:
act_size
,
}
def
sum_activations_stats
(
activations
,
bar_length_max
=
20
):
max_act_size
=
max
([
i
[
"size"
]
for
i
in
activations
]
+
[
0
])
total_act_dims
,
total_act_size
=
0
,
0
for
d
in
activations
:
total_act_size
+=
int
(
d
[
"size"
])
total_act_dims
+=
int
(
d
[
"act_dim"
])
d
[
"size_cum"
]
=
sizeof_fmt
(
total_act_size
)
for
d
in
activations
:
ratio
=
d
[
"ratio"
]
=
d
[
"size"
]
/
total_act_size
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
ratio
*
100
)
bar_length
=
int
(
d
[
"size"
]
/
max_act_size
*
bar_length_max
)
d
[
"size_bar"
]
=
"#"
*
bar_length
d
[
"size"
]
=
sizeof_fmt
(
d
[
"size"
])
act_size
=
sizeof_fmt
(
total_act_size
)
activations
.
append
(
dict
(
name
=
"total"
,
act_dim
=
total_act_dims
,
size
=
act_size
,))
return
total_act_dims
,
total_act_size
,
activations
def
print_activations_stats
(
activations
):
header
=
[
"name"
,
"class_name"
,
"dtype"
,
"shape"
,
"mean"
,
"std"
,
"nbits"
,
"act_dim"
,
"size"
,
"size_cum"
,
"percentage"
,
"size_bar"
,
]
logger
.
info
(
"activations stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
activations
,
header
=
header
))
)
def
print_summary
(
**
kwargs
):
def
print_summary
(
**
kwargs
):
...
@@ -294,25 +390,26 @@ def print_summary(**kwargs):
...
@@ -294,25 +390,26 @@ def print_summary(**kwargs):
def
module_stats
(
def
module_stats
(
model
:
m
.
Module
,
model
:
m
.
Module
,
input_s
ize
:
in
t
,
input_s
hapes
:
lis
t
,
bar_length_max
:
int
=
20
,
bar_length_max
:
int
=
20
,
log_params
:
bool
=
True
,
log_params
:
bool
=
True
,
log_flops
:
bool
=
True
,
log_flops
:
bool
=
True
,
log_activations
:
bool
=
True
,
):
):
r
"""
r
"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
:param model: model that need to get stats info.
:param model: model that need to get stats info.
:param input_s
ize: size of input
for running model and calculating stats.
:param input_s
hapes: shapes of inputs
for running model and calculating stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
:param log_flops: whether print and record op flops.
:param log_activations: whether print and record op activations.
"""
"""
disable_receptive_field
()
disable_receptive_field
()
def
module_stats_hook
(
module
,
inputs
,
outputs
,
name
=
""
):
def
module_stats_hook
(
module
,
inputs
,
outputs
,
name
=
""
):
class_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
].
split
(
"'"
)[
0
]
class_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
].
split
(
"'"
)[
0
]
flops_stats
=
get_op_stats
(
module
,
inputs
,
outputs
)
flops_stats
=
get_op_stats
(
module
,
inputs
,
outputs
)
if
flops_stats
is
not
None
:
if
flops_stats
is
not
None
:
flops_stats
[
"name"
]
=
name
flops_stats
[
"name"
]
=
name
...
@@ -331,38 +428,25 @@ def module_stats(
...
@@ -331,38 +428,25 @@ def module_stats(
param_stats
[
"name"
]
=
name
+
"-b"
param_stats
[
"name"
]
=
name
+
"-b"
params
.
append
(
param_stats
)
params
.
append
(
param_stats
)
@
contextlib
.
contextmanager
if
not
isinstance
(
outputs
,
tuple
)
or
not
isinstance
(
outputs
,
list
):
def
adjust_stats
(
module
,
training
=
False
):
output
=
outputs
.
numpy
()
"""Adjust module to training/eval mode temporarily.
else
:
output
=
outputs
[
0
].
numpy
()
Args:
activation_stats
=
get_activation_stats
(
output
)
module (M.Module): used module.
activation_stats
[
"name"
]
=
name
training (bool): training mode. True for train mode, False fro eval mode.
activation_stats
[
"class_name"
]
=
class_name
"""
activations
.
append
(
activation_stats
)
def
recursive_backup_stats
(
module
,
mode
):
for
m
in
module
.
modules
():
# save prev status to _prev_training
m
.
_prev_training
=
m
.
training
m
.
train
(
mode
,
recursive
=
False
)
def
recursive_recover_stats
(
module
):
for
m
in
module
.
modules
():
# recover prev status and delete attribute
m
.
training
=
m
.
_prev_training
delattr
(
m
,
"_prev_training"
)
recursive_backup_stats
(
module
,
mode
=
training
)
yield
module
recursive_recover_stats
(
module
)
# multiple inputs to the network
# multiple inputs to the network
if
not
isinstance
(
input_s
ize
[
0
],
tuple
):
if
not
isinstance
(
input_s
hapes
[
0
],
tuple
):
input_s
ize
=
[
input_size
]
input_s
hapes
=
[
input_shapes
]
params
=
[]
params
=
[]
flops
=
[]
flops
=
[]
hooks
=
[]
hooks
=
[]
activations
=
[]
total_stats
=
namedtuple
(
"total_stats"
,
[
"param_size"
,
"flops"
,
"act_size"
])
stats_details
=
namedtuple
(
"module_stats"
,
[
"params"
,
"flops"
,
"activations"
])
for
(
name
,
module
)
in
model
.
named_modules
():
for
(
name
,
module
)
in
model
.
named_modules
():
if
isinstance
(
module
,
hook_modules
):
if
isinstance
(
module
,
hook_modules
):
...
@@ -370,8 +454,8 @@ def module_stats(
...
@@ -370,8 +454,8 @@ def module_stats(
module
.
register_forward_hook
(
partial
(
module_stats_hook
,
name
=
name
))
module
.
register_forward_hook
(
partial
(
module_stats_hook
,
name
=
name
))
)
)
inputs
=
[
zeros
(
in_size
,
dtype
=
np
.
float32
)
for
in_size
in
input_s
ize
]
inputs
=
[
zeros
(
in_size
,
dtype
=
np
.
float32
)
for
in_size
in
input_s
hapes
]
with
adjust_stats
(
model
,
training
=
False
)
as
model
:
with
set_module_mode_safe
(
model
,
training
=
False
)
as
model
:
model
(
*
inputs
)
model
(
*
inputs
)
for
h
in
hooks
:
for
h
in
hooks
:
...
@@ -380,19 +464,40 @@ def module_stats(
...
@@ -380,19 +464,40 @@ def module_stats(
extra_info
=
{
extra_info
=
{
"#params"
:
len
(
params
),
"#params"
:
len
(
params
),
}
}
total_flops
,
total_param_dims
,
total_param_size
=
0
,
0
,
0
(
total_flops
,
total_param_dims
,
total_param_size
,
total_act_dims
,
total_param_size
,
)
=
(
0
,
0
,
0
,
0
,
0
)
total_param_dims
,
total_param_size
,
params
=
sum_param_stats
(
params
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
,
suffix
=
""
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_params
:
if
log_params
:
total_param_dims
,
total_param_size
=
print_param_stats
(
params
,
bar_length_max
)
print_param_stats
(
params
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
total_flops
,
flops
=
sum_op_stats
(
flops
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_flops
:
if
log_flops
:
total_flops
=
print_op_stats
(
flops
,
bar_length_max
)
print_op_stats
(
flops
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_params
and
log_flops
:
total_act_dims
,
total_act_size
,
activations
=
sum_activations_stats
(
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
activations
,
bar_length_max
total_flops
/
total_param_size
)
)
extra_info
[
"total_act_dims"
]
=
sizeof_fmt
(
total_act_dims
,
suffix
=
""
)
extra_info
[
"total_act_size"
]
=
sizeof_fmt
(
total_act_size
)
if
log_activations
:
print_activations_stats
(
activations
)
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
print_summary
(
**
extra_info
)
print_summary
(
**
extra_info
)
return
total_param_size
,
total_flops
return
(
total_stats
(
param_size
=
total_param_size
,
flops
=
total_flops
,
act_size
=
total_act_size
,
),
stats_details
(
params
=
params
,
flops
=
flops
,
activations
=
activations
),
)
imperative/python/megengine/utils/module_utils.py
浏览文件 @
c3a1ac3d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
contextlib
from
collections
import
Iterable
from
collections
import
Iterable
from
..module
import
Sequential
from
..module
import
Sequential
...
@@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value):
...
@@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value):
parent
[
key
]
=
value
parent
[
key
]
=
value
_access_structure
(
obj
,
key
,
callback
=
f
)
_access_structure
(
obj
,
key
,
callback
=
f
)
@
contextlib
.
contextmanager
def
set_module_mode_safe
(
module
:
Module
,
training
:
bool
=
False
,
):
"""Adjust module to training/eval mode temporarily.
:param module: used module.
:param training: training (bool): training mode. True for train mode, False fro eval mode.
"""
backup_stats
=
{}
def
recursive_backup_stats
(
module
,
mode
):
for
m
in
module
.
modules
():
backup_stats
[
m
]
=
m
.
training
m
.
train
(
mode
,
recursive
=
False
)
def
recursive_recover_stats
(
module
):
for
m
in
module
.
modules
():
m
.
training
=
backup_stats
.
pop
(
m
)
recursive_backup_stats
(
module
,
mode
=
training
)
yield
module
recursive_recover_stats
(
module
)
imperative/python/test/unit/utils/test_module_stats.py
0 → 100644
浏览文件 @
c3a1ac3d
import
math
from
copy
import
deepcopy
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.functional
as
F
import
megengine.hub
as
hub
import
megengine.module
as
M
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.utils.module_stats
import
module_stats
@
pytest
.
mark
.
skipif
(
use_symbolic_shape
(),
reason
=
"This test do not support symbolic shape."
,
)
def
test_module_stats
():
net
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
input_shape
=
(
1
,
3
,
224
,
224
)
total_stats
,
stats_details
=
module_stats
(
net
,
input_shape
)
x1
=
mge
.
tensor
(
np
.
zeros
((
1
,
3
,
224
,
224
)))
gt_flops
,
gt_acts
=
net
.
get_stats
(
x1
)
assert
(
total_stats
.
flops
,
stats_details
.
activations
[
-
1
][
"act_dim"
])
==
(
gt_flops
,
gt_acts
,
)
class
BasicBlock
(
M
.
Module
):
expansion
=
1
def
__init__
(
self
,
in_channels
,
channels
,
stride
=
1
,
groups
=
1
,
base_width
=
64
,
dilation
=
1
,
norm
=
M
.
BatchNorm2d
,
):
super
().
__init__
()
self
.
tmp_in_channels
=
in_channels
self
.
tmp_channels
=
channels
self
.
stride
=
stride
if
groups
!=
1
or
base_width
!=
64
:
raise
ValueError
(
"BasicBlock only supports groups=1 and base_width=64"
)
if
dilation
>
1
:
raise
NotImplementedError
(
"Dilation > 1 not supported in BasicBlock"
)
self
.
conv1
=
M
.
Conv2d
(
in_channels
,
channels
,
3
,
stride
,
padding
=
dilation
,
bias
=
False
)
self
.
bn1
=
norm
(
channels
)
self
.
conv2
=
M
.
Conv2d
(
channels
,
channels
,
3
,
1
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
norm
(
channels
)
self
.
downsample_id
=
M
.
Identity
()
self
.
downsample_conv
=
M
.
Conv2d
(
in_channels
,
channels
,
1
,
stride
,
bias
=
False
)
self
.
downsample_norm
=
norm
(
channels
)
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
if
self
.
tmp_in_channels
==
self
.
tmp_channels
and
self
.
stride
==
1
:
identity
=
self
.
downsample_id
(
identity
)
else
:
identity
=
self
.
downsample_conv
(
identity
)
identity
=
self
.
downsample_norm
(
identity
)
x
+=
identity
x
=
F
.
relu
(
x
)
return
x
def
get_stats
(
self
,
x
):
activations
,
flops
=
0
,
0
identity
=
x
in_x
=
deepcopy
(
x
)
x
=
self
.
conv1
(
x
)
tmp_flops
,
tmp_acts
=
cal_conv_stats
(
self
.
conv1
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
in_x
=
deepcopy
(
x
)
x
=
self
.
bn1
(
x
)
tmp_flops
,
tmp_acts
=
cal_norm_stats
(
self
.
bn1
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
=
F
.
relu
(
x
)
in_x
=
deepcopy
(
x
)
x
=
self
.
conv2
(
x
)
tmp_flops
,
tmp_acts
=
cal_conv_stats
(
self
.
conv2
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
in_x
=
deepcopy
(
x
)
x
=
self
.
bn2
(
x
)
tmp_flops
,
tmp_acts
=
cal_norm_stats
(
self
.
bn2
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
if
self
.
tmp_in_channels
==
self
.
tmp_channels
and
self
.
stride
==
1
:
identity
=
self
.
downsample_id
(
identity
)
else
:
in_x
=
deepcopy
(
identity
)
identity
=
self
.
downsample_conv
(
identity
)
tmp_flops
,
tmp_acts
=
cal_conv_stats
(
self
.
downsample_conv
,
in_x
,
identity
)
activations
+=
tmp_acts
flops
+=
tmp_flops
in_x
=
deepcopy
(
identity
)
identity
=
self
.
downsample_norm
(
identity
)
tmp_flops
,
tmp_acts
=
cal_norm_stats
(
self
.
downsample_norm
,
in_x
,
identity
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
+=
identity
x
=
F
.
relu
(
x
)
return
x
,
flops
,
activations
class
ResNet
(
M
.
Module
):
def
__init__
(
self
,
block
,
layers
=
[
2
,
2
,
2
,
2
],
num_classes
=
1000
,
zero_init_residual
=
False
,
groups
=
1
,
width_per_group
=
64
,
replace_stride_with_dilation
=
None
,
norm
=
M
.
BatchNorm2d
,
):
super
().
__init__
()
self
.
in_channels
=
64
self
.
dilation
=
1
if
replace_stride_with_dilation
is
None
:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation
=
[
False
,
False
,
False
]
if
len
(
replace_stride_with_dilation
)
!=
3
:
raise
ValueError
(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}"
.
format
(
replace_stride_with_dilation
)
)
self
.
groups
=
groups
self
.
base_width
=
width_per_group
self
.
conv1
=
M
.
Conv2d
(
3
,
self
.
in_channels
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
norm
(
self
.
in_channels
)
self
.
maxpool
=
M
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1_0
=
BasicBlock
(
self
.
in_channels
,
64
,
stride
=
1
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
self
.
dilation
,
norm
=
M
.
BatchNorm2d
,
)
self
.
layer1_1
=
BasicBlock
(
self
.
in_channels
,
64
,
stride
=
1
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
self
.
dilation
,
norm
=
M
.
BatchNorm2d
,
)
self
.
layer2_0
=
BasicBlock
(
64
,
128
,
stride
=
2
)
self
.
layer2_1
=
BasicBlock
(
128
,
128
)
self
.
layer3_0
=
BasicBlock
(
128
,
256
,
stride
=
2
)
self
.
layer3_1
=
BasicBlock
(
256
,
256
)
self
.
layer4_0
=
BasicBlock
(
256
,
512
,
stride
=
2
)
self
.
layer4_1
=
BasicBlock
(
512
,
512
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
],
norm
=
norm
)
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
2
,
stride
=
2
,
dilate
=
replace_stride_with_dilation
[
0
],
norm
=
norm
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
2
,
stride
=
2
,
dilate
=
replace_stride_with_dilation
[
1
],
norm
=
norm
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
2
,
stride
=
2
,
dilate
=
replace_stride_with_dilation
[
2
],
norm
=
norm
)
self
.
fc
=
M
.
Linear
(
512
,
num_classes
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
M
.
Conv2d
):
M
.
init
.
msra_normal_
(
m
.
weight
,
mode
=
"fan_out"
,
nonlinearity
=
"relu"
)
if
m
.
bias
is
not
None
:
fan_in
,
_
=
M
.
init
.
calculate_fan_in_and_fan_out
(
m
.
weight
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
M
.
init
.
uniform_
(
m
.
bias
,
-
bound
,
bound
)
elif
isinstance
(
m
,
M
.
BatchNorm2d
):
M
.
init
.
ones_
(
m
.
weight
)
M
.
init
.
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
M
.
Linear
):
M
.
init
.
msra_uniform_
(
m
.
weight
,
a
=
math
.
sqrt
(
5
))
if
m
.
bias
is
not
None
:
fan_in
,
_
=
M
.
init
.
calculate_fan_in_and_fan_out
(
m
.
weight
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
M
.
init
.
uniform_
(
m
.
bias
,
-
bound
,
bound
)
if
zero_init_residual
:
for
m
in
self
.
modules
():
M
.
init
.
zeros_
(
m
.
bn2
.
weight
)
def
_make_layer
(
self
,
block
,
channels
,
blocks
,
stride
=
1
,
dilate
=
False
,
norm
=
M
.
BatchNorm2d
):
previous_dilation
=
self
.
dilation
if
dilate
:
self
.
dilation
*=
stride
stride
=
1
layers
=
[]
layers
.
append
(
block
(
self
.
in_channels
,
channels
,
stride
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
previous_dilation
,
norm
=
norm
,
)
)
self
.
in_channels
=
channels
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
in_channels
,
channels
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
self
.
dilation
,
norm
=
norm
,
)
)
return
M
.
Sequential
(
*
layers
)
def
extract_features
(
self
,
x
):
outputs
=
{}
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
outputs
[
"stem"
]
=
x
x
=
self
.
layer1
(
x
)
outputs
[
"res2"
]
=
x
x
=
self
.
layer2
(
x
)
outputs
[
"res3"
]
=
x
x
=
self
.
layer3
(
x
)
outputs
[
"res4"
]
=
x
x
=
self
.
layer4
(
x
)
outputs
[
"res5"
]
=
x
return
outputs
def
forward
(
self
,
x
):
x
=
self
.
extract_features
(
x
)[
"res5"
]
x
=
F
.
avg_pool2d
(
x
,
7
)
x
=
F
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
return
x
def
get_stats
(
self
,
x
):
flops
,
activations
=
0
,
0
in_x
=
deepcopy
(
x
)
x
=
self
.
conv1
(
x
)
tmp_flops
,
tmp_acts
=
cal_conv_stats
(
self
.
conv1
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
in_x
=
deepcopy
(
x
)
x
=
self
.
bn1
(
x
)
tmp_flops
,
tmp_acts
=
cal_norm_stats
(
self
.
bn1
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
=
F
.
relu
(
x
)
in_x
=
deepcopy
(
x
)
x
=
self
.
maxpool
(
x
)
tmp_flops
,
tmp_acts
=
cal_pool_stats
(
self
.
maxpool
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer1_0
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer1_1
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer2_0
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer2_1
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer3_0
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer3_1
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer4_0
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
,
tmp_flops
,
tmp_acts
=
self
.
layer4_1
.
get_stats
(
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
x
=
F
.
avg_pool2d
(
x
,
7
)
x
=
F
.
flatten
(
x
,
1
)
in_x
=
deepcopy
(
x
)
x
=
self
.
fc
(
x
)
tmp_flops
,
tmp_acts
=
cal_linear_stats
(
self
.
fc
,
in_x
,
x
)
activations
+=
tmp_acts
flops
+=
tmp_flops
return
flops
,
activations
def
cal_conv_stats
(
module
,
input
,
output
):
bias
=
1
if
module
.
bias
is
not
None
else
0
flops
=
np
.
prod
(
output
[
0
].
shape
)
*
(
module
.
in_channels
//
module
.
groups
*
np
.
prod
(
module
.
kernel_size
)
+
bias
)
acts
=
np
.
prod
(
output
[
0
].
shape
)
return
flops
,
acts
def
cal_norm_stats
(
module
,
input
,
output
):
return
np
.
prod
(
input
[
0
].
shape
)
*
7
,
np
.
prod
(
output
[
0
].
shape
)
def
cal_linear_stats
(
module
,
inputs
,
outputs
):
bias
=
module
.
out_features
if
module
.
bias
is
not
None
else
0
return
(
np
.
prod
(
outputs
[
0
].
shape
)
*
module
.
in_features
+
bias
,
np
.
prod
(
outputs
[
0
].
shape
),
)
def
cal_pool_stats
(
module
,
inputs
,
outputs
):
return
(
np
.
prod
(
outputs
[
0
].
shape
)
*
(
module
.
kernel_size
**
2
),
np
.
prod
(
outputs
[
0
].
shape
),
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录