Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
da7f250c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
da7f250c
编写于
8月 18, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/module): deconv fuse bn and relu
GitOrigin-RevId: 5619b397a4686edec3f98f02c66cf3e70b197092
上级
dbd94839
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
463 addition
and
23 deletion
+463
-23
imperative/python/megengine/module/__init__.py
imperative/python/megengine/module/__init__.py
+2
-0
imperative/python/megengine/module/conv.py
imperative/python/megengine/module/conv.py
+9
-0
imperative/python/megengine/module/conv_transpose_bn.py
imperative/python/megengine/module/conv_transpose_bn.py
+62
-0
imperative/python/megengine/module/qat/__init__.py
imperative/python/megengine/module/qat/__init__.py
+2
-1
imperative/python/megengine/module/qat/conv.py
imperative/python/megengine/module/qat/conv.py
+11
-2
imperative/python/megengine/module/qat/conv_transpose_bn.py
imperative/python/megengine/module/qat/conv_transpose_bn.py
+163
-0
imperative/python/megengine/module/quantized/__init__.py
imperative/python/megengine/module/quantized/__init__.py
+2
-1
imperative/python/megengine/module/quantized/conv.py
imperative/python/megengine/module/quantized/conv.py
+17
-6
imperative/python/megengine/module/quantized/conv_transpose_bn.py
...ve/python/megengine/module/quantized/conv_transpose_bn.py
+53
-0
imperative/python/megengine/utils/bn_fusion.py
imperative/python/megengine/utils/bn_fusion.py
+84
-10
imperative/python/test/unit/module/test_qat.py
imperative/python/test/unit/module/test_qat.py
+58
-3
未找到文件。
imperative/python/megengine/module/__init__.py
浏览文件 @
da7f250c
...
...
@@ -12,11 +12,13 @@ from .conv import (
ConvRelu2d
,
ConvTranspose2d
,
ConvTranspose3d
,
ConvTransposeRelu2d
,
DeformableConv2d
,
LocalConv2d
,
RegionRestrictedConv
,
)
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.deformable_psroi_pooling
import
DeformablePSROIPooling
from
.dropout
import
Dropout
from
.elemwise
import
Elemwise
...
...
imperative/python/megengine/module/conv.py
浏览文件 @
da7f250c
...
...
@@ -773,6 +773,15 @@ class ConvRelu2d(Conv2d):
return
relu
(
self
.
calc_conv
(
inp
,
self
.
weight
,
self
.
bias
))
class
ConvTransposeRelu2d
(
ConvTranspose2d
):
r
"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeRelu2d` using :func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
relu
(
self
.
calc_conv_transpose2d
(
inp
,
self
.
weight
,
self
.
bias
))
class
DeformableConv2d
(
_ConvNd
):
r
"""Deformable Convolution.
...
...
imperative/python/megengine/module/conv_transpose_bn.py
0 → 100644
浏览文件 @
da7f250c
from
typing
import
Tuple
,
Union
from
..functional
import
relu
from
.batchnorm
import
BatchNorm2d
from
.conv
import
ConvTranspose2d
from
.module
import
Module
class
_ConvTransposeBnActivation2d
(
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
stride
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
1
,
padding
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
0
,
output_padding
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
0
,
dilation
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
conv_mode
:
str
=
"cross_correlation"
,
compute_mode
:
str
=
"default"
,
eps
=
1e-5
,
momentum
=
0.9
,
affine
=
True
,
track_running_stats
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
conv_transpose2d
=
ConvTranspose2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
output_padding
,
dilation
,
groups
,
bias
,
conv_mode
,
compute_mode
,
**
kwargs
,
)
self
.
bn
=
BatchNorm2d
(
out_channels
,
eps
,
momentum
,
affine
,
track_running_stats
)
class
ConvTransposeBn2d
(
_ConvTransposeBnActivation2d
):
r
"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBn2d` using:func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
self
.
bn
(
self
.
conv_transpose2d
(
inp
))
class
ConvTransposeBnRelu2d
(
_ConvTransposeBnActivation2d
):
r
"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBnRelu2d` using :func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
relu
(
self
.
bn
(
self
.
conv_transpose2d
(
inp
)))
imperative/python/megengine/module/qat/__init__.py
浏览文件 @
da7f250c
from
.batch_matmul_activation
import
BatchMatMulActivation
from
.concat
import
Concat
from
.conv
import
Conv2d
,
ConvRelu2d
,
ConvTranspose2d
from
.conv
import
Conv2d
,
ConvRelu2d
,
ConvTranspose2d
,
ConvTransposeRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.module
import
QATModule
...
...
imperative/python/megengine/module/qat/conv.py
浏览文件 @
da7f250c
...
...
@@ -59,8 +59,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule):
def
calc_conv_transpose2d_qat
(
self
,
inp
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
b_qat
=
self
.
apply_quant_bias
(
self
.
bias
,
inp
,
w_qat
)
conv
=
self
.
calc_conv_transpose2d
(
inp
,
w_qat
,
b_qat
)
return
conv
conv
_transpose2d
=
self
.
calc_conv_transpose2d
(
inp
,
w_qat
,
b_qat
)
return
conv
_transpose2d
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
ConvTranspose2d
):
...
...
@@ -88,3 +88,12 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule):
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
self
.
calc_conv_transpose2d_qat
(
inp
))
class
ConvTransposeRelu2d
(
ConvTranspose2d
):
r
"""A :class:`~.QATModule` include :class:`~.module.ConvTranspose2d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
F
.
relu
(
self
.
calc_conv_transpose2d_qat
(
inp
)))
imperative/python/megengine/module/qat/conv_transpose_bn.py
0 → 100644
浏览文件 @
da7f250c
from
...functional
import
ones
,
relu
,
sqrt
,
sum
,
zeros
from
..
import
conv_transpose_bn
as
Float
from
.module
import
QATModule
class
_ConvTransposeBnActivation2d
(
Float
.
_ConvTransposeBnActivation2d
,
QATModule
):
def
get_batch_mean_var
(
self
,
inp
):
def
_sum_channel
(
inp
,
axis
=
0
,
keepdims
=
True
):
if
isinstance
(
axis
,
int
):
out
=
sum
(
inp
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
isinstance
(
axis
,
tuple
):
for
idx
,
elem
in
enumerate
(
axis
):
out
=
sum
(
inp
if
idx
==
0
else
out
,
axis
=
elem
,
keepdims
=
keepdims
)
return
out
sum1
=
_sum_channel
(
inp
,
(
0
,
2
,
3
))
sum2
=
_sum_channel
(
inp
**
2
,
(
0
,
2
,
3
))
reduce_size
=
inp
.
size
/
inp
.
shape
[
1
]
batch_mean
=
sum1
/
reduce_size
batch_var
=
(
sum2
-
sum1
**
2
/
reduce_size
)
/
reduce_size
return
batch_mean
,
batch_var
def
fold_weight_bias
(
self
,
bn_mean
,
bn_var
):
# get fold bn conv_transpose2d param
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
if
bn_var
is
None
:
bn_var
=
ones
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
conv_transpose2d_bias
=
self
.
conv_transpose2d
.
bias
if
conv_transpose2d_bias
is
None
:
conv_transpose2d_bias
=
zeros
(
self
.
conv_transpose2d
.
_infer_bias_shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
scale_factor
=
gamma
*
bn_istd
if
self
.
conv_transpose2d
.
groups
==
1
:
w_fold
=
self
.
conv_transpose2d
.
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
else
:
w_fold
=
self
.
conv_transpose2d
.
weight
*
scale_factor
.
reshape
(
self
.
conv_transpose2d
.
groups
,
-
1
,
1
,
1
,
1
)
w_fold
=
self
.
apply_quant_weight
(
w_fold
)
b_fold
=
beta
+
gamma
*
(
conv_transpose2d_bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
update_running_mean_and_running_var
(
self
,
bn_mean
,
bn_var
,
num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean
=
bn_mean
.
detach
()
bn_var
=
(
bn_var
.
detach
()
*
num_elements_per_channel
/
(
num_elements_per_channel
-
1
)
)
exponential_average_factor
=
1
-
self
.
bn
.
momentum
self
.
bn
.
running_mean
*=
self
.
bn
.
momentum
self
.
bn
.
running_mean
+=
exponential_average_factor
*
bn_mean
self
.
bn
.
running_var
*=
self
.
bn
.
momentum
self
.
bn
.
running_var
+=
exponential_average_factor
*
bn_var
def
calc_conv_transpose2d_bn_qat
(
self
,
inp
,
approx
=
True
):
if
self
.
training
and
not
approx
:
conv_transpose2d
=
self
.
conv_transpose2d
(
inp
)
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
conv_transpose2d
)
num_elements_per_channel
=
conv_transpose2d
.
size
/
conv_transpose2d
.
shape
[
1
]
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
else
:
bn_mean
,
bn_var
=
self
.
bn
.
running_mean
,
self
.
bn
.
running_var
# get gamma and beta in BatchNorm
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
# conv_transpose2d_bias
conv_transpose2d_bias
=
self
.
conv_transpose2d
.
bias
if
conv_transpose2d_bias
is
None
:
conv_transpose2d_bias
=
zeros
(
self
.
conv_transpose2d
.
_infer_bias_shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
scale_factor
=
gamma
*
bn_istd
if
self
.
conv_transpose2d
.
groups
==
1
:
w_fold
=
self
.
conv_transpose2d
.
weight
*
scale_factor
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
w_fold
=
self
.
conv_transpose2d
.
weight
*
scale_factor
.
reshape
(
self
.
conv_transpose2d
.
groups
,
1
,
-
1
,
1
,
1
)
b_fold
=
None
if
not
(
self
.
training
and
approx
):
b_fold
=
beta
+
gamma
*
(
conv_transpose2d_bias
-
bn_mean
)
*
bn_istd
w_qat
=
self
.
apply_quant_weight
(
w_fold
)
b_qat
=
self
.
apply_quant_bias
(
b_fold
,
inp
,
w_qat
)
conv_transpose2d
=
self
.
conv_transpose2d
.
calc_conv_transpose2d
(
inp
,
w_qat
,
b_qat
)
if
not
(
self
.
training
and
approx
):
return
conv_transpose2d
# rescale conv_transpose2d to get original conv_transpose2d output
orig_conv_transpose2d
=
conv_transpose2d
/
scale_factor
.
reshape
(
1
,
-
1
,
1
,
1
)
if
self
.
conv_transpose2d
.
bias
is
not
None
:
orig_conv_transpose2d
=
orig_conv_transpose2d
+
self
.
conv_transpose2d
.
bias
# calculate batch norm
conv_transpose2d
=
self
.
bn
(
orig_conv_transpose2d
)
return
conv_transpose2d
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
_ConvTransposeBnActivation2d
):
qat_module
=
cls
(
float_module
.
conv_transpose2d
.
in_channels
,
float_module
.
conv_transpose2d
.
out_channels
,
float_module
.
conv_transpose2d
.
kernel_size
,
float_module
.
conv_transpose2d
.
stride
,
float_module
.
conv_transpose2d
.
padding
,
float_module
.
conv_transpose2d
.
output_padding
,
float_module
.
conv_transpose2d
.
dilation
,
float_module
.
conv_transpose2d
.
groups
,
float_module
.
conv_transpose2d
.
bias
is
not
None
,
float_module
.
conv_transpose2d
.
conv_mode
,
float_module
.
conv_transpose2d
.
compute_mode
,
name
=
float_module
.
name
,
)
qat_module
.
conv_transpose2d
.
weight
=
float_module
.
conv_transpose2d
.
weight
qat_module
.
conv_transpose2d
.
bias
=
float_module
.
conv_transpose2d
.
bias
qat_module
.
bn
=
float_module
.
bn
return
qat_module
class
ConvTransposeBn2d
(
_ConvTransposeBnActivation2d
):
r
"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
self
.
calc_conv_transpose2d_bn_qat
(
inp
))
class
ConvTransposeBnRelu2d
(
_ConvTransposeBnActivation2d
):
r
"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
relu
(
self
.
calc_conv_transpose2d_bn_qat
(
inp
)))
imperative/python/megengine/module/quantized/__init__.py
浏览文件 @
da7f250c
from
.batch_matmul_activation
import
BatchMatMulActivation
from
.concat
import
Concat
from
.conv
import
Conv2d
,
ConvRelu2d
,
ConvTranspose2d
from
.conv
import
Conv2d
,
ConvRelu2d
,
ConvTranspose2d
,
ConvTransposeRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.module
import
QuantizedModule
...
...
imperative/python/megengine/module/quantized/conv.py
浏览文件 @
da7f250c
...
...
@@ -178,7 +178,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
:class:`~.QATModule` instance.
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qconv
=
cls
(
qconv
_transpose2d
=
cls
(
qat_module
.
in_channels
,
qat_module
.
out_channels
,
qat_module
.
kernel_size
,
...
...
@@ -194,15 +194,19 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
name
=
qat_module
.
name
,
)
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
(),
name
=
qat_module
.
weight
.
name
)
qconv
.
bias
=
(
qconv_transpose2d
.
weight
=
Parameter
(
weight
.
numpy
(),
name
=
qat_module
.
weight
.
name
)
qconv_transpose2d
.
bias
=
(
Parameter
(
qat_module
.
bias
.
numpy
(),
name
=
qat_module
.
bias
.
name
)
if
qat_module
.
bias
is
not
None
else
None
)
return
qconv
return
qconv_transpose2d
def
calc_conv_transpose2d_quantized
(
self
,
inp
,
nonlinear_mode
):
assert
nonlinear_mode
==
"identity"
,
"nonlinear_mode shoule be 'identity'"
def
calc_conv_transpose2d_quantized
(
self
,
inp
):
if
self
.
bias
is
not
None
:
inp_scale
=
dtype
.
get_scale
(
inp
.
dtype
)
w_scale
=
dtype
.
get_scale
(
self
.
weight
.
dtype
)
...
...
@@ -225,4 +229,11 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
)
def
forward
(
self
,
inp
):
return
self
.
calc_conv_transpose2d_quantized
(
inp
)
return
self
.
calc_conv_transpose2d_quantized
(
inp
,
nonlinear_mode
=
"identity"
)
class
ConvTransposeRelu2d
(
ConvTranspose2d
):
r
"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_conv_transpose2d_quantized
(
inp
,
nonlinear_mode
=
"relu"
)
imperative/python/megengine/module/quantized/conv_transpose_bn.py
0 → 100644
浏览文件 @
da7f250c
from
...tensor
import
Parameter
from
..qat
import
conv_transpose_bn
as
QAT
from
.conv
import
ConvTranspose2d
class
_ConvTransposeBnActivation2d
(
ConvTranspose2d
):
r
"""Applies a 2D deconvolution over a quantized input tensor, used for inference only.
"""
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
_ConvTransposeBnActivation2d
):
r
"""
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qconv_transpose2d
=
cls
(
qat_module
.
conv_transpose2d
.
in_channels
,
qat_module
.
conv_transpose2d
.
out_channels
,
qat_module
.
conv_transpose2d
.
kernel_size
,
qat_module
.
conv_transpose2d
.
stride
,
qat_module
.
conv_transpose2d
.
padding
,
qat_module
.
conv_transpose2d
.
output_padding
,
qat_module
.
conv_transpose2d
.
dilation
,
qat_module
.
conv_transpose2d
.
groups
,
dtype
=
output_dtype
,
name
=
qat_module
.
name
,
)
w_fold
,
b_fold
=
qat_module
.
fold_weight_bias
(
qat_module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
qat_module
.
get_weight_dtype
())
qconv_transpose2d
.
weight
=
Parameter
(
weight
.
numpy
(),
name
=
qat_module
.
conv_transpose2d
.
weight
.
name
)
qconv_transpose2d
.
bias
=
Parameter
(
b_fold
.
numpy
())
if
qat_module
.
conv_transpose2d
.
bias
is
not
None
:
qconv_transpose2d
.
bias
.
name
=
qat_module
.
conv_transpose2d
.
bias
.
name
return
qconv_transpose2d
class
ConvTransposeBn2d
(
_ConvTransposeBnActivation2d
):
r
"""Quantized version of :class:`~.qat.ConvTransposeBn2d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_conv_transpose2d_quantized
(
inp
,
nonlinear_mode
=
"identity"
)
class
ConvTransposeBnRelu2d
(
_ConvTransposeBnActivation2d
):
r
"""Quantized version of :class:`~.qat.ConvTransposeBnRelu2d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_conv_transpose2d_quantized
(
inp
,
nonlinear_mode
=
"relu"
)
imperative/python/megengine/utils/bn_fusion.py
浏览文件 @
da7f250c
from
copy
import
deepcopy
from
..functional
import
ones
,
sqrt
,
zeros
from
..module
import
BatchNorm2d
,
Conv2d
,
ConvBn2d
,
ConvBnRelu2d
,
ConvRelu2d
,
ReLU
from
..module
import
(
BatchNorm2d
,
Conv2d
,
ConvBn2d
,
ConvBnRelu2d
,
ConvRelu2d
,
ConvTranspose2d
,
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
,
ConvTransposeRelu2d
,
ReLU
,
)
from
..tensor
import
Parameter
_MAP_TO_FUSED_MODULE
=
{
(
Conv2d
,
BatchNorm2d
,
ReLU
,
False
):
ConvRelu2d
,
(
Conv2d
,
BatchNorm2d
,
ReLU
,
True
):
ConvBnRelu2d
,
(
ConvTranspose2d
,
BatchNorm2d
,
ReLU
,
False
):
ConvTransposeRelu2d
,
(
ConvTranspose2d
,
BatchNorm2d
,
ReLU
,
True
):
ConvTransposeBnRelu2d
,
(
Conv2d
,
BatchNorm2d
,
False
):
Conv2d
,
(
Conv2d
,
BatchNorm2d
,
True
):
ConvBn2d
,
(
Conv2d
,
ReLU
):
ConvRelu2d
,
(
ConvTranspose2d
,
BatchNorm2d
,
False
):
ConvTranspose2d
,
(
ConvTranspose2d
,
BatchNorm2d
,
True
):
ConvTransposeBn2d
,
(
ConvTranspose2d
,
ReLU
):
ConvTransposeRelu2d
,
}
def
fold_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
):
# get fold bn conv param
def
fold_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
,
transpose
=
False
):
shape
=
(
1
,
-
1
,
1
,
1
)
if
transpose
:
shape
=
(
-
1
,
1
,
1
,
1
)
kernel_shape
=
weight
.
shape
if
len
(
kernel_shape
)
==
5
:
groups
,
num_features
=
kernel_shape
[
0
],
kernel_shape
[
1
]
else
:
groups
,
num_features
=
1
,
kernel_shape
[
0
]
out_channels
=
groups
*
num_features
if
gamma
is
None
:
gamma
=
ones
((
num_features
),
dtype
=
"float32"
)
gamma
=
ones
((
out_channels
,
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
if
beta
is
None
:
beta
=
zeros
((
num_features
),
dtype
=
"float32"
)
beta
=
zeros
((
out_channels
,
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
1
,
num_feature
s
,
1
,
1
),
dtype
=
"float32"
)
bn_mean
=
zeros
((
1
,
out_channel
s
,
1
,
1
),
dtype
=
"float32"
)
if
bn_var
is
None
:
bn_var
=
ones
((
1
,
num_feature
s
,
1
,
1
),
dtype
=
"float32"
)
bn_var
=
ones
((
1
,
out_channel
s
,
1
,
1
),
dtype
=
"float32"
)
if
bias
is
None
:
bias
=
zeros
((
1
,
num_feature
s
,
1
,
1
),
dtype
=
"float32"
)
bias
=
zeros
((
1
,
out_channel
s
,
1
,
1
),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
eps
)
scale_factor
=
gamma
*
bn_istd
if
groups
==
1
:
w_fold
=
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
w_fold
=
weight
*
scale_factor
.
reshape
(
*
shape
)
else
:
w_fold
=
weight
*
scale_factor
.
reshape
(
groups
,
-
1
,
1
,
1
,
1
)
w_fold
=
weight
*
scale_factor
.
reshape
(
groups
,
*
shape
)
b_fold
=
beta
+
gamma
*
(
bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
...
...
@@ -84,3 +106,55 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module
.
bn
=
deepcopy
(
bn
)
new_conv
.
training
=
conv
.
training
return
module
def
fuse_conv_transpose2d_bn_relu_module
(
conv_transpose2d
:
ConvTranspose2d
,
bn
:
BatchNorm2d
,
relu
:
ReLU
):
module_key
=
tuple
([
type
(
m
)
for
m
in
[
conv_transpose2d
,
bn
,
relu
]
if
m
])
if
bn
:
assert
(
conv_transpose2d
.
training
==
bn
.
training
),
"ConvTranspose2d and BN both must be in the same mode (train or eval)."
assert
(
bn
.
num_features
==
conv_transpose2d
.
out_channels
),
"Output channel of ConvTranspose2d must match num_features of BatchNorm2d"
module_key
=
module_key
+
(
conv_transpose2d
.
training
,)
module
=
_MAP_TO_FUSED_MODULE
[
module_key
](
in_channels
=
conv_transpose2d
.
in_channels
,
out_channels
=
conv_transpose2d
.
out_channels
,
kernel_size
=
conv_transpose2d
.
kernel_size
,
stride
=
conv_transpose2d
.
stride
,
padding
=
conv_transpose2d
.
padding
,
output_padding
=
conv_transpose2d
.
output_padding
,
dilation
=
conv_transpose2d
.
dilation
,
groups
=
conv_transpose2d
.
groups
,
bias
=
conv_transpose2d
.
bias
is
not
None
,
conv_mode
=
conv_transpose2d
.
conv_mode
,
compute_mode
=
conv_transpose2d
.
compute_mode
,
name
=
conv_transpose2d
.
name
,
)
new_conv_transpose2d
=
(
module
if
bn
is
None
or
not
conv_transpose2d
.
training
else
module
.
conv_transpose2d
)
weight
,
bias
=
conv_transpose2d
.
weight
,
conv_transpose2d
.
bias
if
not
conv_transpose2d
.
training
and
bn
is
not
None
:
weight
,
bias
=
fold_weight_bias
(
weight
,
bias
,
bn
.
weight
,
bn
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
bn
.
eps
,
transpose
=
False
,
)
new_conv_transpose2d
.
weight
=
Parameter
(
weight
)
if
bias
is
not
None
:
new_conv_transpose2d
.
bias
=
Parameter
(
bias
)
if
bn
is
not
None
and
conv_transpose2d
.
training
:
module
.
bn
=
deepcopy
(
bn
)
new_conv_transpose2d
.
training
=
conv_transpose2d
.
training
return
module
imperative/python/test/unit/module/test_qat.py
浏览文件 @
da7f250c
...
...
@@ -5,7 +5,9 @@ import numpy as np
import
pytest
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine
import
jit
,
tensor
from
megengine
import
jit
from
megengine
import
module
as
M
from
megengine
import
tensor
from
megengine.device
import
get_device_count
from
megengine.functional
import
expand_dims
from
megengine.module
import
(
...
...
@@ -14,6 +16,8 @@ from megengine.module import (
ConvBn2d
,
ConvRelu2d
,
ConvTranspose2d
,
ConvTransposeBn2d
,
ConvTransposeRelu2d
,
DequantStub
,
Module
,
QuantStub
,
...
...
@@ -34,6 +38,49 @@ def test_qat_convbn2d():
module
=
ConvBn2d
(
in_channels
,
out_channels
,
kernel_size
,
groups
=
groups
,
bias
=
bias
)
M
.
init
.
normal_
(
module
.
bn
.
weight
)
M
.
init
.
normal_
(
module
.
bn
.
bias
)
module
.
train
()
qat_module
=
quantize_qat
(
module
,
inplace
=
False
)
disable_fake_quant
(
qat_module
)
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_channels
,
32
,
32
).
astype
(
np
.
float32
))
normal_outputs
=
module
(
inputs
)
qat_outputs
=
qat_module
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
5e-6
)
np
.
testing
.
assert_allclose
(
module
.
bn
.
running_mean
.
numpy
(),
qat_module
.
bn
.
running_mean
.
numpy
(),
atol
=
5e-8
,
)
np
.
testing
.
assert_allclose
(
module
.
bn
.
running_var
.
numpy
(),
qat_module
.
bn
.
running_var
.
numpy
(),
atol
=
5e-7
,
)
module
.
eval
()
normal_outputs
=
module
(
inputs
)
qat_module
.
eval
()
qat_outputs
=
qat_module
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
5e-6
)
def
test_qat_convtransposebn2d
():
in_channels
=
32
out_channels
=
64
kernel_size
=
3
for
groups
,
bias
in
product
([
1
,
4
],
[
True
,
False
]):
module
=
ConvTransposeBn2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
output_padding
=
0
,
groups
=
groups
,
bias
=
bias
,
)
M
.
init
.
normal_
(
module
.
bn
.
weight
)
M
.
init
.
normal_
(
module
.
bn
.
bias
)
module
.
train
()
qat_module
=
quantize_qat
(
module
,
inplace
=
False
)
disable_fake_quant
(
qat_module
)
...
...
@@ -235,10 +282,14 @@ def test_qat_conv_transpose2d():
self
.
conv
=
ConvTranspose2d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
bias
)
self
.
conv_transpose2d_relu
=
ConvTransposeRelu2d
(
out_channels
,
in_channels
,
kernel_size
,
bias
=
bias
)
def
forward
(
self
,
inp
):
out
=
self
.
quant
(
inp
)
out
=
self
.
conv
(
out
)
out
=
self
.
conv_transpose2d_relu
(
out
)
out
=
self
.
dequant
(
out
)
return
out
...
...
@@ -250,10 +301,14 @@ def test_qat_conv_transpose2d():
disable_fake_quant
(
qat_net
)
normal_outputs
=
net
(
inputs
)
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
())
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
)
net
.
eval
()
normal_outputs
=
net
(
inputs
)
qat_net
.
eval
()
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
())
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录