Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
bc9f9cd4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
bc9f9cd4
编写于
5月 18, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/module): add linear fuse bn and relu support
GitOrigin-RevId: c342f687974f9b47304c26ac39a59b74caaf7f70
上级
f0a3ab97
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
581 addition
and
11 deletion
+581
-11
imperative/python/megengine/module/__init__.py
imperative/python/megengine/module/__init__.py
+2
-1
imperative/python/megengine/module/linear.py
imperative/python/megengine/module/linear.py
+12
-3
imperative/python/megengine/module/linear_bn.py
imperative/python/megengine/module/linear_bn.py
+50
-0
imperative/python/megengine/module/qat/__init__.py
imperative/python/megengine/module/qat/__init__.py
+2
-1
imperative/python/megengine/module/qat/linear.py
imperative/python/megengine/module/qat/linear.py
+17
-1
imperative/python/megengine/module/qat/linear_bn.py
imperative/python/megengine/module/qat/linear_bn.py
+164
-0
imperative/python/megengine/module/quantized/__init__.py
imperative/python/megengine/module/quantized/__init__.py
+2
-1
imperative/python/megengine/module/quantized/linear.py
imperative/python/megengine/module/quantized/linear.py
+21
-2
imperative/python/megengine/module/quantized/linear_bn.py
imperative/python/megengine/module/quantized/linear_bn.py
+40
-0
imperative/python/megengine/utils/bn_fusion.py
imperative/python/megengine/utils/bn_fusion.py
+96
-1
imperative/python/test/unit/module/test_qat.py
imperative/python/test/unit/module/test_qat.py
+128
-0
imperative/python/test/unit/quantization/test_quantize.py
imperative/python/test/unit/quantization/test_quantize.py
+47
-1
未找到文件。
imperative/python/megengine/module/__init__.py
浏览文件 @
bc9f9cd4
...
...
@@ -24,7 +24,8 @@ from .dropout import Dropout
from
.elemwise
import
Elemwise
from
.embedding
import
Embedding
from
.identity
import
Identity
from
.linear
import
Linear
from
.linear
import
Linear
,
LinearRelu
from
.linear_bn
import
LinearBn1d
,
LinearBnRelu1d
from
.lrn
import
LocalResponseNorm
from
.module
import
Module
from
.multiheadattn
import
MultiHeadAttention
...
...
imperative/python/megengine/module/linear.py
浏览文件 @
bc9f9cd4
import
numpy
as
np
from
..functional.nn
import
linear
from
..functional.nn
import
linear
,
relu
from
..tensor
import
Parameter
from
.
import
init
from
.module
import
Module
...
...
@@ -62,13 +62,22 @@ class Linear(Module):
if
self
.
bias
is
not
None
:
init
.
zeros_
(
self
.
bias
)
def
_
calc_linear
(
self
,
x
,
weight
,
bias
):
def
calc_linear
(
self
,
x
,
weight
,
bias
):
return
linear
(
x
,
weight
,
bias
,
compute_mode
=
self
.
compute_mode
)
def
forward
(
self
,
x
):
return
self
.
_
calc_linear
(
x
,
self
.
weight
,
self
.
bias
)
return
self
.
calc_linear
(
x
,
self
.
weight
,
self
.
bias
)
def
_module_info_string
(
self
)
->
str
:
return
"in_features={}, out_features={}, bias={}"
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
)
class
LinearRelu
(
Linear
):
r
"""A fused :class:`~.Module` including :class:`~.module.Linear` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearRelu` using :func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
relu
(
self
.
calc_linear
(
inp
,
self
.
weight
,
self
.
bias
))
imperative/python/megengine/module/linear_bn.py
0 → 100644
浏览文件 @
bc9f9cd4
import
numpy
as
np
from
..functional
import
relu
from
..tensor
import
Parameter
from
.batchnorm
import
BatchNorm1d
from
.linear
import
Linear
from
.module
import
Module
class
_LinearBnActivation1d
(
Module
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
compute_mode
:
str
=
"default"
,
eps
=
1e-5
,
momentum
=
0.9
,
affine
=
True
,
track_running_stats
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
out_features
=
out_features
self
.
in_features
=
in_features
self
.
bias
=
None
if
bias
:
b_shape
=
(
out_features
,)
self
.
bias
=
Parameter
(
np
.
zeros
(
b_shape
,
dtype
=
np
.
float32
))
self
.
linear
=
Linear
(
in_features
,
out_features
,
bias
,
compute_mode
,
**
kwargs
,)
self
.
bn
=
BatchNorm1d
(
out_features
,
eps
,
momentum
,
affine
,
track_running_stats
)
class
LinearBn1d
(
_LinearBnActivation1d
):
r
"""A fused :class:`~.Module` including :class:`~.module.Linear` and :class:`~.module.BatchNorm1d`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearBn1d` using
:func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
self
.
bn
(
self
.
linear
(
inp
))
class
LinearBnRelu1d
(
_LinearBnActivation1d
):
r
"""A fused :class:`~.Module` including :class:`~.module.Linear`, :class:`~.module.BatchNorm1d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearBnRelu1d` using :func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
relu
(
self
.
bn
(
self
.
linear
(
inp
)))
imperative/python/megengine/module/qat/__init__.py
浏览文件 @
bc9f9cd4
...
...
@@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.linear
import
Linear
,
LinearRelu
from
.linear_bn
import
LinearBn1d
,
LinearBnRelu1d
from
.module
import
QATModule
from
.quant_dequant
import
DequantStub
,
QuantStub
imperative/python/megengine/module/qat/linear.py
浏览文件 @
bc9f9cd4
from
...
import
functional
as
F
from
..
import
linear
as
Float
from
.module
import
QATModule
...
...
@@ -13,10 +14,16 @@ class Linear(Float.Linear, QATModule):
Default: True
"""
def
calc_linear_qat
(
self
,
inp
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
b_qat
=
self
.
apply_quant_bias
(
self
.
bias
,
inp
,
w_qat
)
linear
=
self
.
calc_linear
(
inp
,
w_qat
,
b_qat
)
return
linear
def
forward
(
self
,
inp
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
b_qat
=
self
.
apply_quant_bias
(
self
.
bias
,
inp
,
w_qat
)
return
self
.
apply_quant_activation
(
self
.
_
calc_linear
(
inp
,
w_qat
,
b_qat
))
return
self
.
apply_quant_activation
(
self
.
calc_linear
(
inp
,
w_qat
,
b_qat
))
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
Linear
):
...
...
@@ -30,3 +37,12 @@ class Linear(Float.Linear, QATModule):
qmod
.
weight
=
float_module
.
weight
qmod
.
bias
=
float_module
.
bias
return
qmod
class
LinearRelu
(
Linear
):
r
"""A :class:`~.QATModule` include :class:`~.module.Linear` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
F
.
relu
(
self
.
calc_linear_qat
(
inp
)))
imperative/python/megengine/module/qat/linear_bn.py
0 → 100644
浏览文件 @
bc9f9cd4
from
...functional
import
linear
,
ones
,
relu
,
sqrt
,
sum
,
zeros
from
..
import
linear_bn
as
Float
from
.module
import
QATModule
class
_LinearBnActivation1d
(
Float
.
_LinearBnActivation1d
,
QATModule
):
def
get_batch_mean_var
(
self
,
inp
):
def
_sum_channel
(
inp
,
axis
=
0
,
keepdims
=
True
):
if
isinstance
(
axis
,
int
):
out
=
sum
(
inp
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
isinstance
(
axis
,
tuple
):
for
idx
,
elem
in
enumerate
(
axis
):
out
=
sum
(
inp
if
idx
==
0
else
out
,
axis
=
elem
,
keepdims
=
keepdims
)
return
out
sum1
=
_sum_channel
(
inp
,
(
0
,
2
,
3
))
sum2
=
_sum_channel
(
inp
**
2
,
(
0
,
2
,
3
))
reduce_size
=
inp
.
size
/
inp
.
shape
[
1
]
batch_mean
=
sum1
/
reduce_size
batch_var
=
(
sum2
-
sum1
**
2
/
reduce_size
)
/
reduce_size
return
batch_mean
,
batch_var
def
fold_weight_bias
(
self
,
bn_mean
,
bn_var
):
weight_shape
=
[
1
]
*
len
(
self
.
linear
.
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
self
.
linear
.
weight
.
shape
)
bias_shape
[
1
]
=
-
1
# get fold bn linear param
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
-
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
-
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
bn_mean
=
bn_mean
.
reshape
(
-
1
)
if
bn_var
is
None
:
bn_var
=
ones
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
bn_var
=
bn_var
.
reshape
(
-
1
)
linear_bias
=
self
.
linear
.
bias
if
linear_bias
is
None
:
linear_bias
=
zeros
(
beta
.
shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
scale_factor
=
gamma
*
bn_istd
w_fold
=
self
.
linear
.
weight
*
scale_factor
.
reshape
(
weight_shape
)
w_fold
=
self
.
apply_quant_weight
(
w_fold
)
b_fold
=
beta
+
gamma
*
(
linear_bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
update_running_mean_and_running_var
(
self
,
bn_mean
,
bn_var
,
num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean
=
bn_mean
.
detach
()
bn_var
=
(
bn_var
.
detach
()
*
num_elements_per_channel
/
(
num_elements_per_channel
-
1
)
)
exponential_average_factor
=
1
-
self
.
bn
.
momentum
self
.
bn
.
running_mean
*=
self
.
bn
.
momentum
self
.
bn
.
running_mean
+=
exponential_average_factor
*
bn_mean
self
.
bn
.
running_var
*=
self
.
bn
.
momentum
self
.
bn
.
running_var
+=
exponential_average_factor
*
bn_var
def
calc_linear_bn_qat
(
self
,
inp
,
approx
=
True
):
if
self
.
training
and
not
approx
:
linear
=
self
.
linear
(
inp
)
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
linear
)
num_elements_per_channel
=
linear
.
size
/
linear
.
shape
[
1
]
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
else
:
bn_mean
,
bn_var
=
self
.
bn
.
running_mean
,
self
.
bn
.
running_var
bn_mean
,
bn_var
=
(
self
.
bn
.
running_mean
.
reshape
(
-
1
),
self
.
bn
.
running_var
.
reshape
(
-
1
),
)
weight_shape
=
[
1
]
*
len
(
self
.
linear
.
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
self
.
linear
.
weight
.
shape
)
bias_shape
[
1
]
=
-
1
# get gamma and beta in BatchNorm
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
-
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
-
1
)
# linear_bias
linear_bias
=
self
.
linear
.
bias
if
linear_bias
is
None
:
linear_bias
=
zeros
(
beta
.
shape
,
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
scale_factor
=
gamma
*
bn_istd
w_fold
=
self
.
linear
.
weight
*
scale_factor
.
reshape
(
weight_shape
)
b_fold
=
None
if
not
(
self
.
training
and
approx
):
b_fold
=
beta
+
gamma
*
(
linear_bias
-
bn_mean
)
*
bn_istd
w_qat
=
self
.
apply_quant_weight
(
w_fold
)
b_qat
=
self
.
apply_quant_bias
(
b_fold
,
inp
,
w_qat
)
linear
=
self
.
linear
.
calc_linear
(
inp
,
w_qat
,
b_qat
)
if
not
(
self
.
training
and
approx
):
return
linear
# rescale linear to get original linear output
orig_linear
=
linear
/
scale_factor
.
reshape
(
bias_shape
)
if
self
.
linear
.
bias
is
not
None
:
orig_linear
=
orig_linear
+
self
.
linear
.
bias
.
reshape
(
bias_shape
)
# calculate batch norm
linear
=
self
.
bn
(
orig_linear
)
return
linear
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
_LinearBnActivation1d
):
qat_module
=
cls
(
float_module
.
linear
.
in_features
,
float_module
.
linear
.
out_features
,
float_module
.
linear
.
bias
is
not
None
,
float_module
.
linear
.
compute_mode
,
float_module
.
bn
.
eps
,
float_module
.
bn
.
momentum
,
float_module
.
bn
.
affine
,
float_module
.
bn
.
track_running_stats
,
name
=
float_module
.
name
,
)
qat_module
.
linear
.
weight
=
float_module
.
linear
.
weight
qat_module
.
linear
.
bias
=
float_module
.
linear
.
bias
qat_module
.
bn
=
float_module
.
bn
return
qat_module
class
LinearBn1d
(
_LinearBnActivation1d
):
r
"""A fused :class:`~.QATModule` including :class:`~.module.Linear` and :class:`~.module.BatchNorm1d` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
self
.
calc_linear_bn_qat
(
inp
))
class
LinearBnRelu1d
(
_LinearBnActivation1d
):
r
"""A fused :class:`~.QATModule` including :class:`~.module.Linear`, :class:`~.module.BatchNorm1d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
relu
(
self
.
calc_linear_bn_qat
(
inp
)))
imperative/python/megengine/module/quantized/__init__.py
浏览文件 @
bc9f9cd4
...
...
@@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.linear
import
Linear
,
LinearRelu
from
.linear_bn
import
LinearBn1d
,
LinearBnRelu1d
from
.module
import
QuantizedModule
from
.quant_dequant
import
DequantStub
,
QuantStub
imperative/python/megengine/module/quantized/linear.py
浏览文件 @
bc9f9cd4
import
numpy
as
np
from
...
import
functional
as
F
from
...
import
module
as
Float
from
...core.tensor
import
dtype
from
...tensor
import
Parameter
from
..qat
import
linear
as
QAT
...
...
@@ -16,20 +17,30 @@ class Linear(QuantizedModule):
self
.
bias
=
None
self
.
output_dtype
=
dtype
def
forward
(
self
,
inp
):
def
calc_linear_quantized
(
self
,
inp
,
nonlinear_mode
=
"identity"
):
if
self
.
training
:
raise
ValueError
(
"quantized module only support inference."
)
assert
nonlinear_mode
in
[
"identity"
,
"relu"
]
inp_scale
=
dtype
.
get_scale
(
inp
.
dtype
)
w_scale
=
dtype
.
get_scale
(
self
.
weight
.
dtype
)
bias_dtype
=
dtype
.
qint32
(
inp_scale
*
w_scale
)
ret
=
F
.
nn
.
linear
(
ret
=
F
.
linear
(
inp
,
self
.
weight
,
None
if
self
.
bias
is
None
else
self
.
bias
.
astype
(
bias_dtype
),
)
ret
=
ret
if
self
.
output_dtype
is
None
else
ret
.
astype
(
self
.
output_dtype
)
if
nonlinear_mode
==
"relu"
:
ret
=
F
.
relu
(
ret
)
return
ret
def
forward
(
self
,
inp
):
return
self
.
calc_linear_quantized
(
inp
)
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
Linear
):
r
"""
...
...
@@ -38,8 +49,16 @@ class Linear(QuantizedModule):
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qmod
=
cls
(
dtype
=
output_dtype
,
name
=
qat_module
.
name
)
qmod
.
name
=
qat_module
.
name
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
qmod
.
weight
=
Parameter
(
weight
.
numpy
(),
name
=
qat_module
.
weight
.
name
)
if
qat_module
.
bias
is
not
None
:
qmod
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
(),
name
=
qat_module
.
bias
.
name
)
return
qmod
class
LinearRelu
(
Linear
):
r
"""Quantized version of :class:`~.qat.LinearRelu`."""
def
forward
(
self
,
inp
):
return
self
.
calc_linear_quantized
(
inp
,
nonlinear_mode
=
"relu"
)
imperative/python/megengine/module/quantized/linear_bn.py
0 → 100644
浏览文件 @
bc9f9cd4
from
...tensor
import
Parameter
from
..qat
import
linear_bn
as
QAT
from
.linear
import
Linear
class
_LinearBnActivation1d
(
Linear
):
r
"""Applies a Linear over a quantized input tensor, used for inference only.
"""
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
_LinearBnActivation1d
):
r
"""
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qlinear
=
cls
(
dtype
=
output_dtype
,
name
=
qat_module
.
name
,)
w_fold
,
b_fold
=
qat_module
.
fold_weight_bias
(
qat_module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
qat_module
.
get_weight_dtype
())
qlinear
.
weight
=
Parameter
(
weight
.
numpy
(),
name
=
qat_module
.
linear
.
weight
.
name
)
qlinear
.
bias
=
Parameter
(
b_fold
.
numpy
())
if
qat_module
.
linear
.
bias
is
not
None
:
qlinear
.
bias
.
name
=
qat_module
.
linear
.
bias
.
name
return
qlinear
class
LinearBn1d
(
_LinearBnActivation1d
):
r
"""Quantized version of :class:`~.qat.LinearBn1d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_linear_quantized
(
inp
,
nonlinear_mode
=
"identity"
)
class
LinearBnRelu1d
(
_LinearBnActivation1d
):
r
"""Quantized version of :class:`~.qat.LinearBnRelu1d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_linear_quantized
(
inp
,
nonlinear_mode
=
"relu"
)
imperative/python/megengine/utils/bn_fusion.py
浏览文件 @
bc9f9cd4
...
...
@@ -2,6 +2,7 @@ from copy import deepcopy
from
..functional
import
ones
,
sqrt
,
zeros
from
..module
import
(
BatchNorm1d
,
BatchNorm2d
,
Conv2d
,
ConvBn2d
,
...
...
@@ -11,6 +12,10 @@ from ..module import (
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
,
ConvTransposeRelu2d
,
Linear
,
LinearBn1d
,
LinearBnRelu1d
,
LinearRelu
,
ReLU
,
)
from
..tensor
import
Parameter
...
...
@@ -26,10 +31,15 @@ _MAP_TO_FUSED_MODULE = {
(
ConvTranspose2d
,
BatchNorm2d
,
False
):
ConvTranspose2d
,
(
ConvTranspose2d
,
BatchNorm2d
,
True
):
ConvTransposeBn2d
,
(
ConvTranspose2d
,
ReLU
):
ConvTransposeRelu2d
,
(
Linear
,
BatchNorm1d
,
ReLU
,
False
):
LinearRelu
,
(
Linear
,
BatchNorm1d
,
ReLU
,
True
):
LinearBnRelu1d
,
(
Linear
,
BatchNorm1d
,
False
):
Linear
,
(
Linear
,
BatchNorm1d
,
True
):
LinearBn1d
,
(
Linear
,
ReLU
):
LinearRelu
,
}
def
fold
_weight_bias
(
def
_fold_conv_bn
_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
,
transpose
=
False
):
shape
=
(
-
1
,
1
,
1
,
1
)
...
...
@@ -76,6 +86,57 @@ def fold_weight_bias(
return
w_fold
,
b_fold
def
_fold_linear_bn_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
):
bn_mean
,
bn_var
=
bn_mean
.
reshape
(
-
1
),
bn_var
.
reshape
(
-
1
)
weight_shape
=
[
1
]
*
len
(
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
weight
.
shape
)
bias_shape
[
1
]
=
-
1
out_features
=
weight
.
shape
[
0
]
if
gamma
is
None
:
gamma
=
ones
((
out_features
,),
dtype
=
"float32"
)
else
:
gamma
=
gamma
.
reshape
(
-
1
)
if
beta
is
None
:
beta
=
zeros
((
out_features
,),
dtype
=
"float32"
)
else
:
beta
=
beta
.
reshape
(
-
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
out_features
,),
dtype
=
"float32"
)
else
:
bn_mean
=
bn_mean
.
reshape
(
-
1
)
if
bn_var
is
None
:
bn_var
=
ones
((
out_features
,),
dtype
=
"float32"
)
else
:
bn_var
=
bn_var
.
reshape
(
-
1
)
if
bias
is
None
:
bias
=
zeros
((
beta
.
shape
),
dtype
=
"float32"
)
else
:
bias
=
bias
.
reshape
(
-
1
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
eps
)
scale_factor
=
gamma
*
bn_istd
w_fold
=
weight
*
scale_factor
.
reshape
(
*
weight_shape
)
b_fold
=
beta
+
gamma
*
(
bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
fold_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
,
transpose
=
False
):
if
weight
.
ndim
!=
2
:
return
_fold_conv_bn_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
,
transpose
)
return
_fold_linear_bn_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
)
def
fuse_conv_bn_relu_module
(
conv
:
Conv2d
,
bn
:
BatchNorm2d
,
relu
:
ReLU
):
module_key
=
tuple
([
type
(
m
)
for
m
in
[
conv
,
bn
,
relu
]
if
m
])
if
bn
:
...
...
@@ -137,3 +198,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module
.
bn
=
deepcopy
(
bn
)
new_conv
.
training
=
conv
.
training
return
module
def
fuse_linear_bn_relu_module
(
linear
:
Linear
,
bn
:
BatchNorm1d
,
relu
:
ReLU
):
module_key
=
tuple
([
type
(
m
)
for
m
in
[
linear
,
bn
,
relu
]
if
m
])
if
bn
:
assert
(
linear
.
training
==
bn
.
training
),
"Linear and BN both must be in the same mode (train or eval)."
assert
(
bn
.
num_features
==
linear
.
out_features
),
"Output channel of Linear must match num_features of BatchNorm1d"
module_key
=
module_key
+
(
linear
.
training
,)
module
=
_MAP_TO_FUSED_MODULE
[
module_key
](
in_features
=
linear
.
in_features
,
out_features
=
linear
.
out_features
,
bias
=
linear
.
bias
is
not
None
,
compute_mode
=
linear
.
compute_mode
,
name
=
linear
.
name
,
)
new_linear
=
module
if
bn
is
None
or
not
linear
.
training
else
module
.
linear
weight
,
bias
=
linear
.
weight
,
linear
.
bias
if
not
linear
.
training
and
bn
is
not
None
:
weight
,
bias
=
fold_weight_bias
(
weight
,
bias
,
bn
.
weight
,
bn
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
bn
.
eps
,
)
new_linear
.
weight
=
Parameter
(
weight
)
if
bias
is
not
None
:
new_linear
.
bias
=
Parameter
(
bias
)
if
bn
is
not
None
and
linear
.
training
:
module
.
bn
=
deepcopy
(
bn
)
new_linear
.
training
=
linear
.
training
return
module
imperative/python/test/unit/module/test_qat.py
浏览文件 @
bc9f9cd4
...
...
@@ -19,6 +19,10 @@ from megengine.module import (
ConvTransposeBn2d
,
ConvTransposeRelu2d
,
DequantStub
,
Linear
,
LinearBn1d
,
LinearBnRelu1d
,
LinearRelu
,
Module
,
QuantStub
,
)
...
...
@@ -330,3 +334,127 @@ def test_qat_conv_transpose2d():
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
)
def
test_qat_linearbn1d
():
in_features
=
10
out_features
=
5
class
TestNet
(
Module
):
def
__init__
(
self
,
bias
):
super
().
__init__
()
self
.
quant
=
QuantStub
()
self
.
dequant
=
DequantStub
()
self
.
linear_bn
=
LinearBn1d
(
in_features
,
out_features
,
bias
=
bias
,)
def
forward
(
self
,
inp
):
out
=
self
.
quant
(
inp
)
out
=
self
.
linear_bn
(
out
)
out
=
self
.
dequant
(
out
)
return
out
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_features
).
astype
(
np
.
float32
))
for
bias
in
[
True
,
False
]:
net
=
TestNet
(
bias
)
net
.
train
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
)
disable_fake_quant
(
qat_net
)
normal_outputs
=
net
(
inputs
)
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
np
.
testing
.
assert_allclose
(
net
.
linear_bn
.
bn
.
running_mean
.
numpy
(),
qat_net
.
linear_bn
.
bn
.
running_mean
.
numpy
(),
atol
=
5e-8
,
)
np
.
testing
.
assert_allclose
(
net
.
linear_bn
.
bn
.
running_var
.
numpy
(),
qat_net
.
linear_bn
.
bn
.
running_var
.
numpy
(),
atol
=
5e-7
,
)
net
.
eval
()
normal_outputs
=
net
(
inputs
)
qat_net
.
eval
()
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
def
test_qat_linear_relu
():
in_features
=
10
out_features
=
5
class
TestNet
(
Module
):
def
__init__
(
self
,
bias
):
super
().
__init__
()
self
.
quant
=
QuantStub
()
self
.
dequant
=
DequantStub
()
self
.
linear_relu
=
LinearRelu
(
in_features
,
out_features
,
bias
=
bias
,)
def
forward
(
self
,
inp
):
out
=
self
.
quant
(
inp
)
out
=
self
.
linear_relu
(
out
)
out
=
self
.
dequant
(
out
)
return
out
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_features
).
astype
(
np
.
float32
))
for
bias
in
[
True
,
False
]:
net
=
TestNet
(
bias
)
net
.
train
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
)
disable_fake_quant
(
qat_net
)
normal_outputs
=
net
(
inputs
)
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
net
.
eval
()
normal_outputs
=
net
(
inputs
)
qat_net
.
eval
()
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
def
test_qat_linear_bn_relu
():
in_features
=
10
out_features
=
5
class
TestNet
(
Module
):
def
__init__
(
self
,
bias
):
super
().
__init__
()
self
.
quant
=
QuantStub
()
self
.
dequant
=
DequantStub
()
self
.
linear_bn_relu
=
LinearBnRelu1d
(
in_features
,
out_features
,
bias
=
bias
,)
def
forward
(
self
,
inp
):
out
=
self
.
quant
(
inp
)
out
=
self
.
linear_bn_relu
(
out
)
out
=
self
.
dequant
(
out
)
return
out
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_features
).
astype
(
np
.
float32
))
for
bias
in
[
True
,
False
]:
net
=
TestNet
(
bias
)
net
.
train
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
)
disable_fake_quant
(
qat_net
)
normal_outputs
=
net
(
inputs
)
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
net
.
eval
()
normal_outputs
=
net
(
inputs
)
qat_net
.
eval
()
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
imperative/python/test/unit/quantization/test_quantize.py
浏览文件 @
bc9f9cd4
...
...
@@ -5,11 +5,14 @@ from megengine import Parameter, Tensor
from
megengine
import
module
as
Float
from
megengine.functional
import
ones
,
zeros
from
megengine.module
import
(
BatchNorm1d
,
BatchNorm2d
,
Conv2d
,
ConvBn2d
,
ConvTranspose2d
,
ConvTransposeBn2d
,
Linear
,
LinearBn1d
,
ReLU
,
)
from
megengine.module
import
qat
as
QAT
...
...
@@ -33,7 +36,10 @@ from megengine.quantization.quantize import (
quantize_qat
,
reset_qconfig
,
)
from
megengine.utils.bn_fusion
import
fuse_conv_bn_relu_module
from
megengine.utils.bn_fusion
import
(
fuse_conv_bn_relu_module
,
fuse_linear_bn_relu_module
,
)
class
FloatNet
(
Float
.
Module
):
...
...
@@ -383,3 +389,43 @@ def test_ConvTransposeBn2d_fold_weight_bias():
np
.
testing
.
assert_allclose
(
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
)
def
test_LinearBn1d_fold_weight_bias
():
in_features
=
10
out_features
=
5
linear
=
Linear
(
in_features
,
out_features
)
bn
=
BatchNorm1d
(
out_features
)
relu
=
ReLU
()
fused_linear
=
fuse_linear_bn_relu_module
(
linear
,
bn
,
relu
)
bn
.
eval
()
fused_linear
.
eval
()
inputs
=
Tensor
(
np
.
random
.
randn
(
4
,
in_features
).
astype
(
np
.
float32
))
expected_result
=
relu
(
bn
(
linear
(
inputs
)))
actual_result
=
fused_linear
(
inputs
)
np
.
testing
.
assert_allclose
(
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
)
linear
.
eval
()
bn
.
eval
()
relu
.
eval
()
fused_linear
=
fuse_linear_bn_relu_module
(
linear
,
bn
,
relu
)
fused_linear
.
eval
()
expected_result
=
relu
(
linear
(
inputs
))
actual_result
=
fused_linear
(
inputs
)
np
.
testing
.
assert_allclose
(
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
)
linear
.
train
()
bn
.
train
()
fused_linear
=
fuse_linear_bn_relu_module
(
linear
,
bn
,
None
)
fused_linear
.
train
()
expected_result
=
bn
(
linear
(
inputs
))
actual_result
=
fused_linear
(
inputs
)
np
.
testing
.
assert_allclose
(
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录