Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
980ebf2c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
980ebf2c
编写于
5月 16, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/module): add fused conv_bn qat approximate version
GitOrigin-RevId: 1b7284a5951229c8924cb880de41ebf58db19fea
上级
6972bfde
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
138 addition
and
50 deletion
+138
-50
python_module/megengine/module/conv_bn_relu.py
python_module/megengine/module/conv_bn_relu.py
+99
-50
python_module/test/unit/module/test_conv_bn_relu.py
python_module/test/unit/module/test_conv_bn_relu.py
+39
-0
未找到文件。
python_module/megengine/module/conv_bn_relu.py
浏览文件 @
980ebf2c
...
...
@@ -8,7 +8,7 @@
from
typing
import
Tuple
,
Union
from
..core
import
ones
,
zeros
from
..functional
import
flatten
,
relu
,
sqrt
,
sum
from
..functional
import
add_update
,
flatten
,
relu
,
sqrt
,
sum
,
zero_grad
from
.batchnorm
import
BatchNorm2d
from
.conv
import
Conv2d
from
.module
import
QATModule
...
...
@@ -31,7 +31,6 @@ class _ConvBn2d(QATModule):
momentum
=
0.9
,
affine
=
True
,
track_running_stats
=
True
,
freeze_bn
=
False
,
):
super
().
__init__
()
self
.
conv
=
Conv2d
(
...
...
@@ -47,28 +46,6 @@ class _ConvBn2d(QATModule):
compute_mode
,
)
self
.
bn
=
BatchNorm2d
(
out_channels
,
eps
,
momentum
,
affine
,
track_running_stats
)
self
.
freeze_bn
=
freeze_bn
def
update_bn_stats
(
self
):
self
.
freeze_bn
=
False
return
self
def
freeze_bn_stats
(
self
):
self
.
freeze_bn
=
True
return
self
def
get_bn_gamma_beta
(
self
):
if
self
.
bn
.
weight
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
else
:
gamma
=
self
.
bn
.
weight
if
self
.
bn
.
bias
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
else
:
beta
=
self
.
bn
.
bias
return
gamma
,
beta
def
get_batch_mean_var
(
self
,
inp
):
def
_sum_channel
(
inp
,
axis
=
0
,
keepdims
=
True
):
...
...
@@ -83,8 +60,7 @@ class _ConvBn2d(QATModule):
sum2
=
_sum_channel
(
inp
**
2
,
(
0
,
2
,
3
))
reduce_size
=
inp
.
shapeof
().
prod
()
/
inp
.
shapeof
(
1
)
batch_mean
=
sum1
/
reduce_size
batch_var
=
(
sum2
-
sum1
**
2
/
reduce_size
)
/
(
reduce_size
-
1
)
batch_var
=
(
sum2
-
sum1
**
2
/
reduce_size
)
/
reduce_size
return
batch_mean
,
batch_var
def
fold_weight_bias
(
self
,
bn_mean
,
bn_var
):
...
...
@@ -92,50 +68,123 @@ class _ConvBn2d(QATModule):
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma
,
beta
=
self
.
get_bn_gamma_beta
()
b
=
self
.
conv
.
bias
if
b
is
None
:
b
=
zeros
(
self
.
conv
.
_infer_bias_shape
(),
dtype
=
"float32"
)
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
if
bn_var
is
None
:
bn_var
=
ones
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
conv_bias
=
self
.
conv
.
bias
if
conv_bias
is
None
:
conv_bias
=
zeros
(
self
.
conv
.
_infer_bias_shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor
=
gamma
*
bn_istd
if
self
.
conv
.
groups
==
1
:
w_fold
=
(
self
.
conv
.
weight
*
gamma
.
reshape
(
-
1
,
1
,
1
,
1
)
*
bn_istd
.
reshape
(
-
1
,
1
,
1
,
1
)
)
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
else
:
w_fold
=
(
self
.
conv
.
weight
*
gamma
.
reshape
(
self
.
conv
.
groups
,
-
1
,
1
,
1
,
1
)
*
bn_istd
.
reshape
(
self
.
conv
.
groups
,
-
1
,
1
,
1
,
1
)
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
self
.
conv
.
groups
,
-
1
,
1
,
1
,
1
)
b_fold
=
flatten
(
beta
)
+
(
flatten
(
gamma
)
*
(
flatten
(
b
)
-
flatten
(
bn_mean
))
*
flatten
(
bn_istd
)
)
b_fold
=
b_fold
.
reshape
(
self
.
conv
.
_infer_bias_shape
())
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold
=
beta
+
gamma
*
(
conv_bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
calc_conv_bn_qat
(
self
,
inp
):
# TODO: use pytorch method as
conv
=
self
.
conv
(
inp
)
self
.
bn
(
conv
)
def
update_running_mean_and_running_var
(
self
,
bn_mean
,
bn_var
,
num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean
=
zero_grad
(
bn_mean
)
bn_var
=
(
zero_grad
(
bn_var
)
*
num_elements_per_channel
/
(
num_elements_per_channel
-
1
)
)
exponential_average_factor
=
1
-
self
.
bn
.
momentum
add_update
(
self
.
bn
.
running_mean
,
delta
=
bn_mean
,
alpha
=
1
-
exponential_average_factor
,
beta
=
exponential_average_factor
,
)
add_update
(
self
.
bn
.
running_var
,
delta
=
bn_var
,
alpha
=
1
-
exponential_average_factor
,
beta
=
exponential_average_factor
,
)
if
self
.
training
:
def
calc_conv_bn_qat
(
self
,
inp
,
approx
=
True
):
if
self
.
training
and
not
approx
:
conv
=
self
.
conv
(
inp
)
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
conv
)
num_elements_per_channel
=
conv
.
shapeof
().
prod
()
/
conv
.
shapeof
(
1
)
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
else
:
bn_mean
,
bn_var
=
self
.
bn
.
running_mean
,
self
.
bn
.
running_var
w_fold
,
b_fold
=
self
.
fold_weight_bias
(
bn_mean
,
bn_var
)
# get gamma and beta in BatchNorm
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
# conv_bias
conv_bias
=
self
.
conv
.
bias
if
conv_bias
is
None
:
conv_bias
=
zeros
(
self
.
conv
.
_infer_bias_shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor
=
gamma
*
bn_istd
if
self
.
conv
.
groups
==
1
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
else
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
self
.
conv
.
groups
,
-
1
,
1
,
1
,
1
)
b_fold
=
None
if
not
(
self
.
training
and
approx
):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold
=
beta
+
gamma
*
(
conv_bias
-
bn_mean
)
*
bn_istd
w_qat
=
self
.
apply_fakequant_with_observer
(
w_fold
,
self
.
weight_fake_quant
,
self
.
weight_observer
)
return
self
.
conv
.
calc_conv
(
inp
,
w_qat
,
b_fold
)
conv
=
self
.
conv
.
calc_conv
(
inp
,
w_qat
,
b_fold
)
if
not
(
self
.
training
and
approx
):
return
conv
# rescale conv to get original conv output
orig_conv
=
conv
/
scale_factor
.
reshape
(
1
,
-
1
,
1
,
1
)
if
self
.
conv
.
bias
is
not
None
:
orig_conv
=
orig_conv
+
self
.
conv
.
bias
# calculate batch norm
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
orig_conv
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
conv
=
gamma
*
bn_istd
*
(
orig_conv
-
bn_mean
)
+
beta
num_elements_per_channel
=
conv
.
shapeof
().
prod
()
/
conv
.
shapeof
(
1
)
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
return
conv
class
ConvBn2d
(
_ConvBn2d
):
...
...
python_module/test/unit/module/test_conv_bn_relu.py
0 → 100644
浏览文件 @
980ebf2c
import
copy
from
itertools
import
product
import
numpy
as
np
from
megengine
import
tensor
from
megengine.module
import
ConvBn2d
from
megengine.quantization
import
quantize_qat
from
megengine.quantization.quantize
import
disable_fake_quant
from
megengine.test
import
assertTensorClose
def
test_convbn2d
():
in_channels
=
32
out_channels
=
64
kernel_size
=
3
module
=
ConvBn2d
(
in_channels
,
out_channels
,
kernel_size
)
quantize_qat
(
module
)
for
groups
,
bias
in
product
([
1
,
4
],
[
True
,
False
]):
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_channels
,
32
,
32
).
astype
(
np
.
float32
))
module
.
train
()
qat_module
=
copy
.
deepcopy
(
module
)
disable_fake_quant
(
qat_module
)
normal_outputs
=
module
.
forward
(
inputs
)
qat_outputs
=
qat_module
.
forward_qat
(
inputs
)
assertTensorClose
(
normal_outputs
,
qat_outputs
,
max_err
=
5e-6
)
a
=
module
.
bn
.
running_mean
.
numpy
()
b
=
qat_module
.
bn
.
running_mean
.
numpy
()
assertTensorClose
(
module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_mean
,
max_err
=
5e-8
)
assertTensorClose
(
module
.
bn
.
running_var
,
qat_module
.
bn
.
running_var
,
max_err
=
5e-7
)
module
.
eval
()
normal_outputs
=
module
.
forward
(
inputs
)
qat_module
.
eval
()
qat_outputs
=
qat_module
.
forward_qat
(
inputs
)
assertTensorClose
(
normal_outputs
,
qat_outputs
,
max_err
=
5e-6
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录