Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
bc9f9cd4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
bc9f9cd4
编写于
5月 18, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/module): add linear fuse bn and relu support
GitOrigin-RevId: c342f687974f9b47304c26ac39a59b74caaf7f70
上级
f0a3ab97
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
581 addition
and
11 deletion
+581
-11
imperative/python/megengine/module/__init__.py
imperative/python/megengine/module/__init__.py
+2
-1
imperative/python/megengine/module/linear.py
imperative/python/megengine/module/linear.py
+12
-3
imperative/python/megengine/module/linear_bn.py
imperative/python/megengine/module/linear_bn.py
+50
-0
imperative/python/megengine/module/qat/__init__.py
imperative/python/megengine/module/qat/__init__.py
+2
-1
imperative/python/megengine/module/qat/linear.py
imperative/python/megengine/module/qat/linear.py
+17
-1
imperative/python/megengine/module/qat/linear_bn.py
imperative/python/megengine/module/qat/linear_bn.py
+164
-0
imperative/python/megengine/module/quantized/__init__.py
imperative/python/megengine/module/quantized/__init__.py
+2
-1
imperative/python/megengine/module/quantized/linear.py
imperative/python/megengine/module/quantized/linear.py
+21
-2
imperative/python/megengine/module/quantized/linear_bn.py
imperative/python/megengine/module/quantized/linear_bn.py
+40
-0
imperative/python/megengine/utils/bn_fusion.py
imperative/python/megengine/utils/bn_fusion.py
+96
-1
imperative/python/test/unit/module/test_qat.py
imperative/python/test/unit/module/test_qat.py
+128
-0
imperative/python/test/unit/quantization/test_quantize.py
imperative/python/test/unit/quantization/test_quantize.py
+47
-1
未找到文件。
imperative/python/megengine/module/__init__.py
浏览文件 @
bc9f9cd4
...
@@ -24,7 +24,8 @@ from .dropout import Dropout
...
@@ -24,7 +24,8 @@ from .dropout import Dropout
from
.elemwise
import
Elemwise
from
.elemwise
import
Elemwise
from
.embedding
import
Embedding
from
.embedding
import
Embedding
from
.identity
import
Identity
from
.identity
import
Identity
from
.linear
import
Linear
from
.linear
import
Linear
,
LinearRelu
from
.linear_bn
import
LinearBn1d
,
LinearBnRelu1d
from
.lrn
import
LocalResponseNorm
from
.lrn
import
LocalResponseNorm
from
.module
import
Module
from
.module
import
Module
from
.multiheadattn
import
MultiHeadAttention
from
.multiheadattn
import
MultiHeadAttention
...
...
imperative/python/megengine/module/linear.py
浏览文件 @
bc9f9cd4
import
numpy
as
np
import
numpy
as
np
from
..functional.nn
import
linear
from
..functional.nn
import
linear
,
relu
from
..tensor
import
Parameter
from
..tensor
import
Parameter
from
.
import
init
from
.
import
init
from
.module
import
Module
from
.module
import
Module
...
@@ -62,13 +62,22 @@ class Linear(Module):
...
@@ -62,13 +62,22 @@ class Linear(Module):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
init
.
zeros_
(
self
.
bias
)
init
.
zeros_
(
self
.
bias
)
def
_
calc_linear
(
self
,
x
,
weight
,
bias
):
def
calc_linear
(
self
,
x
,
weight
,
bias
):
return
linear
(
x
,
weight
,
bias
,
compute_mode
=
self
.
compute_mode
)
return
linear
(
x
,
weight
,
bias
,
compute_mode
=
self
.
compute_mode
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
_
calc_linear
(
x
,
self
.
weight
,
self
.
bias
)
return
self
.
calc_linear
(
x
,
self
.
weight
,
self
.
bias
)
def
_module_info_string
(
self
)
->
str
:
def
_module_info_string
(
self
)
->
str
:
return
"in_features={}, out_features={}, bias={}"
.
format
(
return
"in_features={}, out_features={}, bias={}"
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
)
)
class
LinearRelu
(
Linear
):
r
"""A fused :class:`~.Module` including :class:`~.module.Linear` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearRelu` using :func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
relu
(
self
.
calc_linear
(
inp
,
self
.
weight
,
self
.
bias
))
imperative/python/megengine/module/linear_bn.py
0 → 100644
浏览文件 @
bc9f9cd4
import
numpy
as
np
from
..functional
import
relu
from
..tensor
import
Parameter
from
.batchnorm
import
BatchNorm1d
from
.linear
import
Linear
from
.module
import
Module
class
_LinearBnActivation1d
(
Module
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
compute_mode
:
str
=
"default"
,
eps
=
1e-5
,
momentum
=
0.9
,
affine
=
True
,
track_running_stats
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
out_features
=
out_features
self
.
in_features
=
in_features
self
.
bias
=
None
if
bias
:
b_shape
=
(
out_features
,)
self
.
bias
=
Parameter
(
np
.
zeros
(
b_shape
,
dtype
=
np
.
float32
))
self
.
linear
=
Linear
(
in_features
,
out_features
,
bias
,
compute_mode
,
**
kwargs
,)
self
.
bn
=
BatchNorm1d
(
out_features
,
eps
,
momentum
,
affine
,
track_running_stats
)
class
LinearBn1d
(
_LinearBnActivation1d
):
r
"""A fused :class:`~.Module` including :class:`~.module.Linear` and :class:`~.module.BatchNorm1d`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearBn1d` using
:func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
self
.
bn
(
self
.
linear
(
inp
))
class
LinearBnRelu1d
(
_LinearBnActivation1d
):
r
"""A fused :class:`~.Module` including :class:`~.module.Linear`, :class:`~.module.BatchNorm1d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.LinearBnRelu1d` using :func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
relu
(
self
.
bn
(
self
.
linear
(
inp
)))
imperative/python/megengine/module/qat/__init__.py
浏览文件 @
bc9f9cd4
...
@@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
...
@@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.elemwise
import
Elemwise
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.linear
import
Linear
,
LinearRelu
from
.linear_bn
import
LinearBn1d
,
LinearBnRelu1d
from
.module
import
QATModule
from
.module
import
QATModule
from
.quant_dequant
import
DequantStub
,
QuantStub
from
.quant_dequant
import
DequantStub
,
QuantStub
imperative/python/megengine/module/qat/linear.py
浏览文件 @
bc9f9cd4
from
...
import
functional
as
F
from
..
import
linear
as
Float
from
..
import
linear
as
Float
from
.module
import
QATModule
from
.module
import
QATModule
...
@@ -13,10 +14,16 @@ class Linear(Float.Linear, QATModule):
...
@@ -13,10 +14,16 @@ class Linear(Float.Linear, QATModule):
Default: True
Default: True
"""
"""
def
calc_linear_qat
(
self
,
inp
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
b_qat
=
self
.
apply_quant_bias
(
self
.
bias
,
inp
,
w_qat
)
linear
=
self
.
calc_linear
(
inp
,
w_qat
,
b_qat
)
return
linear
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
b_qat
=
self
.
apply_quant_bias
(
self
.
bias
,
inp
,
w_qat
)
b_qat
=
self
.
apply_quant_bias
(
self
.
bias
,
inp
,
w_qat
)
return
self
.
apply_quant_activation
(
self
.
_
calc_linear
(
inp
,
w_qat
,
b_qat
))
return
self
.
apply_quant_activation
(
self
.
calc_linear
(
inp
,
w_qat
,
b_qat
))
@
classmethod
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
Linear
):
def
from_float_module
(
cls
,
float_module
:
Float
.
Linear
):
...
@@ -30,3 +37,12 @@ class Linear(Float.Linear, QATModule):
...
@@ -30,3 +37,12 @@ class Linear(Float.Linear, QATModule):
qmod
.
weight
=
float_module
.
weight
qmod
.
weight
=
float_module
.
weight
qmod
.
bias
=
float_module
.
bias
qmod
.
bias
=
float_module
.
bias
return
qmod
return
qmod
class
LinearRelu
(
Linear
):
r
"""A :class:`~.QATModule` include :class:`~.module.Linear` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
F
.
relu
(
self
.
calc_linear_qat
(
inp
)))
imperative/python/megengine/module/qat/linear_bn.py
0 → 100644
浏览文件 @
bc9f9cd4
from
...functional
import
linear
,
ones
,
relu
,
sqrt
,
sum
,
zeros
from
..
import
linear_bn
as
Float
from
.module
import
QATModule
class
_LinearBnActivation1d
(
Float
.
_LinearBnActivation1d
,
QATModule
):
def
get_batch_mean_var
(
self
,
inp
):
def
_sum_channel
(
inp
,
axis
=
0
,
keepdims
=
True
):
if
isinstance
(
axis
,
int
):
out
=
sum
(
inp
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
isinstance
(
axis
,
tuple
):
for
idx
,
elem
in
enumerate
(
axis
):
out
=
sum
(
inp
if
idx
==
0
else
out
,
axis
=
elem
,
keepdims
=
keepdims
)
return
out
sum1
=
_sum_channel
(
inp
,
(
0
,
2
,
3
))
sum2
=
_sum_channel
(
inp
**
2
,
(
0
,
2
,
3
))
reduce_size
=
inp
.
size
/
inp
.
shape
[
1
]
batch_mean
=
sum1
/
reduce_size
batch_var
=
(
sum2
-
sum1
**
2
/
reduce_size
)
/
reduce_size
return
batch_mean
,
batch_var
def
fold_weight_bias
(
self
,
bn_mean
,
bn_var
):
weight_shape
=
[
1
]
*
len
(
self
.
linear
.
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
self
.
linear
.
weight
.
shape
)
bias_shape
[
1
]
=
-
1
# get fold bn linear param
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
-
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
-
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
bn_mean
=
bn_mean
.
reshape
(
-
1
)
if
bn_var
is
None
:
bn_var
=
ones
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
bn_var
=
bn_var
.
reshape
(
-
1
)
linear_bias
=
self
.
linear
.
bias
if
linear_bias
is
None
:
linear_bias
=
zeros
(
beta
.
shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
scale_factor
=
gamma
*
bn_istd
w_fold
=
self
.
linear
.
weight
*
scale_factor
.
reshape
(
weight_shape
)
w_fold
=
self
.
apply_quant_weight
(
w_fold
)
b_fold
=
beta
+
gamma
*
(
linear_bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
update_running_mean_and_running_var
(
self
,
bn_mean
,
bn_var
,
num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean
=
bn_mean
.
detach
()
bn_var
=
(
bn_var
.
detach
()
*
num_elements_per_channel
/
(
num_elements_per_channel
-
1
)
)
exponential_average_factor
=
1
-
self
.
bn
.
momentum
self
.
bn
.
running_mean
*=
self
.
bn
.
momentum
self
.
bn
.
running_mean
+=
exponential_average_factor
*
bn_mean
self
.
bn
.
running_var
*=
self
.
bn
.
momentum
self
.
bn
.
running_var
+=
exponential_average_factor
*
bn_var
def
calc_linear_bn_qat
(
self
,
inp
,
approx
=
True
):
if
self
.
training
and
not
approx
:
linear
=
self
.
linear
(
inp
)
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
linear
)
num_elements_per_channel
=
linear
.
size
/
linear
.
shape
[
1
]
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
else
:
bn_mean
,
bn_var
=
self
.
bn
.
running_mean
,
self
.
bn
.
running_var
bn_mean
,
bn_var
=
(
self
.
bn
.
running_mean
.
reshape
(
-
1
),
self
.
bn
.
running_var
.
reshape
(
-
1
),
)
weight_shape
=
[
1
]
*
len
(
self
.
linear
.
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
self
.
linear
.
weight
.
shape
)
bias_shape
[
1
]
=
-
1
# get gamma and beta in BatchNorm
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
-
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
,),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
-
1
)
# linear_bias
linear_bias
=
self
.
linear
.
bias
if
linear_bias
is
None
:
linear_bias
=
zeros
(
beta
.
shape
,
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
scale_factor
=
gamma
*
bn_istd
w_fold
=
self
.
linear
.
weight
*
scale_factor
.
reshape
(
weight_shape
)
b_fold
=
None
if
not
(
self
.
training
and
approx
):
b_fold
=
beta
+
gamma
*
(
linear_bias
-
bn_mean
)
*
bn_istd
w_qat
=
self
.
apply_quant_weight
(
w_fold
)
b_qat
=
self
.
apply_quant_bias
(
b_fold
,
inp
,
w_qat
)
linear
=
self
.
linear
.
calc_linear
(
inp
,
w_qat
,
b_qat
)
if
not
(
self
.
training
and
approx
):
return
linear
# rescale linear to get original linear output
orig_linear
=
linear
/
scale_factor
.
reshape
(
bias_shape
)
if
self
.
linear
.
bias
is
not
None
:
orig_linear
=
orig_linear
+
self
.
linear
.
bias
.
reshape
(
bias_shape
)
# calculate batch norm
linear
=
self
.
bn
(
orig_linear
)
return
linear
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
_LinearBnActivation1d
):
qat_module
=
cls
(
float_module
.
linear
.
in_features
,
float_module
.
linear
.
out_features
,
float_module
.
linear
.
bias
is
not
None
,
float_module
.
linear
.
compute_mode
,
float_module
.
bn
.
eps
,
float_module
.
bn
.
momentum
,
float_module
.
bn
.
affine
,
float_module
.
bn
.
track_running_stats
,
name
=
float_module
.
name
,
)
qat_module
.
linear
.
weight
=
float_module
.
linear
.
weight
qat_module
.
linear
.
bias
=
float_module
.
linear
.
bias
qat_module
.
bn
=
float_module
.
bn
return
qat_module
class
LinearBn1d
(
_LinearBnActivation1d
):
r
"""A fused :class:`~.QATModule` including :class:`~.module.Linear` and :class:`~.module.BatchNorm1d` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
self
.
calc_linear_bn_qat
(
inp
))
class
LinearBnRelu1d
(
_LinearBnActivation1d
):
r
"""A fused :class:`~.QATModule` including :class:`~.module.Linear`, :class:`~.module.BatchNorm1d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
relu
(
self
.
calc_linear_bn_qat
(
inp
)))
imperative/python/megengine/module/quantized/__init__.py
浏览文件 @
bc9f9cd4
...
@@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
...
@@ -4,6 +4,7 @@ from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.conv_transpose_bn
import
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
from
.elemwise
import
Elemwise
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.linear
import
Linear
,
LinearRelu
from
.linear_bn
import
LinearBn1d
,
LinearBnRelu1d
from
.module
import
QuantizedModule
from
.module
import
QuantizedModule
from
.quant_dequant
import
DequantStub
,
QuantStub
from
.quant_dequant
import
DequantStub
,
QuantStub
imperative/python/megengine/module/quantized/linear.py
浏览文件 @
bc9f9cd4
import
numpy
as
np
import
numpy
as
np
from
...
import
functional
as
F
from
...
import
functional
as
F
from
...
import
module
as
Float
from
...core.tensor
import
dtype
from
...core.tensor
import
dtype
from
...tensor
import
Parameter
from
...tensor
import
Parameter
from
..qat
import
linear
as
QAT
from
..qat
import
linear
as
QAT
...
@@ -16,20 +17,30 @@ class Linear(QuantizedModule):
...
@@ -16,20 +17,30 @@ class Linear(QuantizedModule):
self
.
bias
=
None
self
.
bias
=
None
self
.
output_dtype
=
dtype
self
.
output_dtype
=
dtype
def
forward
(
self
,
inp
):
def
calc_linear_quantized
(
self
,
inp
,
nonlinear_mode
=
"identity"
):
if
self
.
training
:
if
self
.
training
:
raise
ValueError
(
"quantized module only support inference."
)
raise
ValueError
(
"quantized module only support inference."
)
assert
nonlinear_mode
in
[
"identity"
,
"relu"
]
inp_scale
=
dtype
.
get_scale
(
inp
.
dtype
)
inp_scale
=
dtype
.
get_scale
(
inp
.
dtype
)
w_scale
=
dtype
.
get_scale
(
self
.
weight
.
dtype
)
w_scale
=
dtype
.
get_scale
(
self
.
weight
.
dtype
)
bias_dtype
=
dtype
.
qint32
(
inp_scale
*
w_scale
)
bias_dtype
=
dtype
.
qint32
(
inp_scale
*
w_scale
)
ret
=
F
.
nn
.
linear
(
ret
=
F
.
linear
(
inp
,
inp
,
self
.
weight
,
self
.
weight
,
None
if
self
.
bias
is
None
else
self
.
bias
.
astype
(
bias_dtype
),
None
if
self
.
bias
is
None
else
self
.
bias
.
astype
(
bias_dtype
),
)
)
ret
=
ret
if
self
.
output_dtype
is
None
else
ret
.
astype
(
self
.
output_dtype
)
ret
=
ret
if
self
.
output_dtype
is
None
else
ret
.
astype
(
self
.
output_dtype
)
if
nonlinear_mode
==
"relu"
:
ret
=
F
.
relu
(
ret
)
return
ret
return
ret
def
forward
(
self
,
inp
):
return
self
.
calc_linear_quantized
(
inp
)
@
classmethod
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
Linear
):
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
Linear
):
r
"""
r
"""
...
@@ -38,8 +49,16 @@ class Linear(QuantizedModule):
...
@@ -38,8 +49,16 @@ class Linear(QuantizedModule):
"""
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
output_dtype
=
qat_module
.
get_activation_dtype
()
qmod
=
cls
(
dtype
=
output_dtype
,
name
=
qat_module
.
name
)
qmod
=
cls
(
dtype
=
output_dtype
,
name
=
qat_module
.
name
)
qmod
.
name
=
qat_module
.
name
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
qmod
.
weight
=
Parameter
(
weight
.
numpy
(),
name
=
qat_module
.
weight
.
name
)
qmod
.
weight
=
Parameter
(
weight
.
numpy
(),
name
=
qat_module
.
weight
.
name
)
if
qat_module
.
bias
is
not
None
:
if
qat_module
.
bias
is
not
None
:
qmod
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
(),
name
=
qat_module
.
bias
.
name
)
qmod
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
(),
name
=
qat_module
.
bias
.
name
)
return
qmod
return
qmod
class
LinearRelu
(
Linear
):
r
"""Quantized version of :class:`~.qat.LinearRelu`."""
def
forward
(
self
,
inp
):
return
self
.
calc_linear_quantized
(
inp
,
nonlinear_mode
=
"relu"
)
imperative/python/megengine/module/quantized/linear_bn.py
0 → 100644
浏览文件 @
bc9f9cd4
from
...tensor
import
Parameter
from
..qat
import
linear_bn
as
QAT
from
.linear
import
Linear
class
_LinearBnActivation1d
(
Linear
):
r
"""Applies a Linear over a quantized input tensor, used for inference only.
"""
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
_LinearBnActivation1d
):
r
"""
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qlinear
=
cls
(
dtype
=
output_dtype
,
name
=
qat_module
.
name
,)
w_fold
,
b_fold
=
qat_module
.
fold_weight_bias
(
qat_module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
qat_module
.
get_weight_dtype
())
qlinear
.
weight
=
Parameter
(
weight
.
numpy
(),
name
=
qat_module
.
linear
.
weight
.
name
)
qlinear
.
bias
=
Parameter
(
b_fold
.
numpy
())
if
qat_module
.
linear
.
bias
is
not
None
:
qlinear
.
bias
.
name
=
qat_module
.
linear
.
bias
.
name
return
qlinear
class
LinearBn1d
(
_LinearBnActivation1d
):
r
"""Quantized version of :class:`~.qat.LinearBn1d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_linear_quantized
(
inp
,
nonlinear_mode
=
"identity"
)
class
LinearBnRelu1d
(
_LinearBnActivation1d
):
r
"""Quantized version of :class:`~.qat.LinearBnRelu1d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_linear_quantized
(
inp
,
nonlinear_mode
=
"relu"
)
imperative/python/megengine/utils/bn_fusion.py
浏览文件 @
bc9f9cd4
...
@@ -2,6 +2,7 @@ from copy import deepcopy
...
@@ -2,6 +2,7 @@ from copy import deepcopy
from
..functional
import
ones
,
sqrt
,
zeros
from
..functional
import
ones
,
sqrt
,
zeros
from
..module
import
(
from
..module
import
(
BatchNorm1d
,
BatchNorm2d
,
BatchNorm2d
,
Conv2d
,
Conv2d
,
ConvBn2d
,
ConvBn2d
,
...
@@ -11,6 +12,10 @@ from ..module import (
...
@@ -11,6 +12,10 @@ from ..module import (
ConvTransposeBn2d
,
ConvTransposeBn2d
,
ConvTransposeBnRelu2d
,
ConvTransposeBnRelu2d
,
ConvTransposeRelu2d
,
ConvTransposeRelu2d
,
Linear
,
LinearBn1d
,
LinearBnRelu1d
,
LinearRelu
,
ReLU
,
ReLU
,
)
)
from
..tensor
import
Parameter
from
..tensor
import
Parameter
...
@@ -26,10 +31,15 @@ _MAP_TO_FUSED_MODULE = {
...
@@ -26,10 +31,15 @@ _MAP_TO_FUSED_MODULE = {
(
ConvTranspose2d
,
BatchNorm2d
,
False
):
ConvTranspose2d
,
(
ConvTranspose2d
,
BatchNorm2d
,
False
):
ConvTranspose2d
,
(
ConvTranspose2d
,
BatchNorm2d
,
True
):
ConvTransposeBn2d
,
(
ConvTranspose2d
,
BatchNorm2d
,
True
):
ConvTransposeBn2d
,
(
ConvTranspose2d
,
ReLU
):
ConvTransposeRelu2d
,
(
ConvTranspose2d
,
ReLU
):
ConvTransposeRelu2d
,
(
Linear
,
BatchNorm1d
,
ReLU
,
False
):
LinearRelu
,
(
Linear
,
BatchNorm1d
,
ReLU
,
True
):
LinearBnRelu1d
,
(
Linear
,
BatchNorm1d
,
False
):
Linear
,
(
Linear
,
BatchNorm1d
,
True
):
LinearBn1d
,
(
Linear
,
ReLU
):
LinearRelu
,
}
}
def
fold
_weight_bias
(
def
_fold_conv_bn
_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
,
transpose
=
False
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
,
transpose
=
False
):
):
shape
=
(
-
1
,
1
,
1
,
1
)
shape
=
(
-
1
,
1
,
1
,
1
)
...
@@ -76,6 +86,57 @@ def fold_weight_bias(
...
@@ -76,6 +86,57 @@ def fold_weight_bias(
return
w_fold
,
b_fold
return
w_fold
,
b_fold
def
_fold_linear_bn_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
):
bn_mean
,
bn_var
=
bn_mean
.
reshape
(
-
1
),
bn_var
.
reshape
(
-
1
)
weight_shape
=
[
1
]
*
len
(
weight
.
shape
)
weight_shape
[
0
]
=
-
1
bias_shape
=
[
1
]
*
len
(
weight
.
shape
)
bias_shape
[
1
]
=
-
1
out_features
=
weight
.
shape
[
0
]
if
gamma
is
None
:
gamma
=
ones
((
out_features
,),
dtype
=
"float32"
)
else
:
gamma
=
gamma
.
reshape
(
-
1
)
if
beta
is
None
:
beta
=
zeros
((
out_features
,),
dtype
=
"float32"
)
else
:
beta
=
beta
.
reshape
(
-
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
out_features
,),
dtype
=
"float32"
)
else
:
bn_mean
=
bn_mean
.
reshape
(
-
1
)
if
bn_var
is
None
:
bn_var
=
ones
((
out_features
,),
dtype
=
"float32"
)
else
:
bn_var
=
bn_var
.
reshape
(
-
1
)
if
bias
is
None
:
bias
=
zeros
((
beta
.
shape
),
dtype
=
"float32"
)
else
:
bias
=
bias
.
reshape
(
-
1
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
eps
)
scale_factor
=
gamma
*
bn_istd
w_fold
=
weight
*
scale_factor
.
reshape
(
*
weight_shape
)
b_fold
=
beta
+
gamma
*
(
bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
fold_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
=
1e-5
,
transpose
=
False
):
if
weight
.
ndim
!=
2
:
return
_fold_conv_bn_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
,
transpose
)
return
_fold_linear_bn_weight_bias
(
weight
,
bias
,
gamma
,
beta
,
bn_mean
,
bn_var
,
eps
)
def
fuse_conv_bn_relu_module
(
conv
:
Conv2d
,
bn
:
BatchNorm2d
,
relu
:
ReLU
):
def
fuse_conv_bn_relu_module
(
conv
:
Conv2d
,
bn
:
BatchNorm2d
,
relu
:
ReLU
):
module_key
=
tuple
([
type
(
m
)
for
m
in
[
conv
,
bn
,
relu
]
if
m
])
module_key
=
tuple
([
type
(
m
)
for
m
in
[
conv
,
bn
,
relu
]
if
m
])
if
bn
:
if
bn
:
...
@@ -137,3 +198,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
...
@@ -137,3 +198,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module
.
bn
=
deepcopy
(
bn
)
module
.
bn
=
deepcopy
(
bn
)
new_conv
.
training
=
conv
.
training
new_conv
.
training
=
conv
.
training
return
module
return
module
def
fuse_linear_bn_relu_module
(
linear
:
Linear
,
bn
:
BatchNorm1d
,
relu
:
ReLU
):
module_key
=
tuple
([
type
(
m
)
for
m
in
[
linear
,
bn
,
relu
]
if
m
])
if
bn
:
assert
(
linear
.
training
==
bn
.
training
),
"Linear and BN both must be in the same mode (train or eval)."
assert
(
bn
.
num_features
==
linear
.
out_features
),
"Output channel of Linear must match num_features of BatchNorm1d"
module_key
=
module_key
+
(
linear
.
training
,)
module
=
_MAP_TO_FUSED_MODULE
[
module_key
](
in_features
=
linear
.
in_features
,
out_features
=
linear
.
out_features
,
bias
=
linear
.
bias
is
not
None
,
compute_mode
=
linear
.
compute_mode
,
name
=
linear
.
name
,
)
new_linear
=
module
if
bn
is
None
or
not
linear
.
training
else
module
.
linear
weight
,
bias
=
linear
.
weight
,
linear
.
bias
if
not
linear
.
training
and
bn
is
not
None
:
weight
,
bias
=
fold_weight_bias
(
weight
,
bias
,
bn
.
weight
,
bn
.
bias
,
bn
.
running_mean
,
bn
.
running_var
,
bn
.
eps
,
)
new_linear
.
weight
=
Parameter
(
weight
)
if
bias
is
not
None
:
new_linear
.
bias
=
Parameter
(
bias
)
if
bn
is
not
None
and
linear
.
training
:
module
.
bn
=
deepcopy
(
bn
)
new_linear
.
training
=
linear
.
training
return
module
imperative/python/test/unit/module/test_qat.py
浏览文件 @
bc9f9cd4
...
@@ -19,6 +19,10 @@ from megengine.module import (
...
@@ -19,6 +19,10 @@ from megengine.module import (
ConvTransposeBn2d
,
ConvTransposeBn2d
,
ConvTransposeRelu2d
,
ConvTransposeRelu2d
,
DequantStub
,
DequantStub
,
Linear
,
LinearBn1d
,
LinearBnRelu1d
,
LinearRelu
,
Module
,
Module
,
QuantStub
,
QuantStub
,
)
)
...
@@ -330,3 +334,127 @@ def test_qat_conv_transpose2d():
...
@@ -330,3 +334,127 @@ def test_qat_conv_transpose2d():
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
)
)
def
test_qat_linearbn1d
():
in_features
=
10
out_features
=
5
class
TestNet
(
Module
):
def
__init__
(
self
,
bias
):
super
().
__init__
()
self
.
quant
=
QuantStub
()
self
.
dequant
=
DequantStub
()
self
.
linear_bn
=
LinearBn1d
(
in_features
,
out_features
,
bias
=
bias
,)
def
forward
(
self
,
inp
):
out
=
self
.
quant
(
inp
)
out
=
self
.
linear_bn
(
out
)
out
=
self
.
dequant
(
out
)
return
out
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_features
).
astype
(
np
.
float32
))
for
bias
in
[
True
,
False
]:
net
=
TestNet
(
bias
)
net
.
train
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
)
disable_fake_quant
(
qat_net
)
normal_outputs
=
net
(
inputs
)
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
np
.
testing
.
assert_allclose
(
net
.
linear_bn
.
bn
.
running_mean
.
numpy
(),
qat_net
.
linear_bn
.
bn
.
running_mean
.
numpy
(),
atol
=
5e-8
,
)
np
.
testing
.
assert_allclose
(
net
.
linear_bn
.
bn
.
running_var
.
numpy
(),
qat_net
.
linear_bn
.
bn
.
running_var
.
numpy
(),
atol
=
5e-7
,
)
net
.
eval
()
normal_outputs
=
net
(
inputs
)
qat_net
.
eval
()
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
def
test_qat_linear_relu
():
in_features
=
10
out_features
=
5
class
TestNet
(
Module
):
def
__init__
(
self
,
bias
):
super
().
__init__
()
self
.
quant
=
QuantStub
()
self
.
dequant
=
DequantStub
()
self
.
linear_relu
=
LinearRelu
(
in_features
,
out_features
,
bias
=
bias
,)
def
forward
(
self
,
inp
):
out
=
self
.
quant
(
inp
)
out
=
self
.
linear_relu
(
out
)
out
=
self
.
dequant
(
out
)
return
out
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_features
).
astype
(
np
.
float32
))
for
bias
in
[
True
,
False
]:
net
=
TestNet
(
bias
)
net
.
train
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
)
disable_fake_quant
(
qat_net
)
normal_outputs
=
net
(
inputs
)
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
net
.
eval
()
normal_outputs
=
net
(
inputs
)
qat_net
.
eval
()
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
def
test_qat_linear_bn_relu
():
in_features
=
10
out_features
=
5
class
TestNet
(
Module
):
def
__init__
(
self
,
bias
):
super
().
__init__
()
self
.
quant
=
QuantStub
()
self
.
dequant
=
DequantStub
()
self
.
linear_bn_relu
=
LinearBnRelu1d
(
in_features
,
out_features
,
bias
=
bias
,)
def
forward
(
self
,
inp
):
out
=
self
.
quant
(
inp
)
out
=
self
.
linear_bn_relu
(
out
)
out
=
self
.
dequant
(
out
)
return
out
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_features
).
astype
(
np
.
float32
))
for
bias
in
[
True
,
False
]:
net
=
TestNet
(
bias
)
net
.
train
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
)
disable_fake_quant
(
qat_net
)
normal_outputs
=
net
(
inputs
)
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
net
.
eval
()
normal_outputs
=
net
(
inputs
)
qat_net
.
eval
()
qat_outputs
=
qat_net
(
inputs
)
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
1e-6
,
)
imperative/python/test/unit/quantization/test_quantize.py
浏览文件 @
bc9f9cd4
...
@@ -5,11 +5,14 @@ from megengine import Parameter, Tensor
...
@@ -5,11 +5,14 @@ from megengine import Parameter, Tensor
from
megengine
import
module
as
Float
from
megengine
import
module
as
Float
from
megengine.functional
import
ones
,
zeros
from
megengine.functional
import
ones
,
zeros
from
megengine.module
import
(
from
megengine.module
import
(
BatchNorm1d
,
BatchNorm2d
,
BatchNorm2d
,
Conv2d
,
Conv2d
,
ConvBn2d
,
ConvBn2d
,
ConvTranspose2d
,
ConvTranspose2d
,
ConvTransposeBn2d
,
ConvTransposeBn2d
,
Linear
,
LinearBn1d
,
ReLU
,
ReLU
,
)
)
from
megengine.module
import
qat
as
QAT
from
megengine.module
import
qat
as
QAT
...
@@ -33,7 +36,10 @@ from megengine.quantization.quantize import (
...
@@ -33,7 +36,10 @@ from megengine.quantization.quantize import (
quantize_qat
,
quantize_qat
,
reset_qconfig
,
reset_qconfig
,
)
)
from
megengine.utils.bn_fusion
import
fuse_conv_bn_relu_module
from
megengine.utils.bn_fusion
import
(
fuse_conv_bn_relu_module
,
fuse_linear_bn_relu_module
,
)
class
FloatNet
(
Float
.
Module
):
class
FloatNet
(
Float
.
Module
):
...
@@ -383,3 +389,43 @@ def test_ConvTransposeBn2d_fold_weight_bias():
...
@@ -383,3 +389,43 @@ def test_ConvTransposeBn2d_fold_weight_bias():
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
)
)
def
test_LinearBn1d_fold_weight_bias
():
in_features
=
10
out_features
=
5
linear
=
Linear
(
in_features
,
out_features
)
bn
=
BatchNorm1d
(
out_features
)
relu
=
ReLU
()
fused_linear
=
fuse_linear_bn_relu_module
(
linear
,
bn
,
relu
)
bn
.
eval
()
fused_linear
.
eval
()
inputs
=
Tensor
(
np
.
random
.
randn
(
4
,
in_features
).
astype
(
np
.
float32
))
expected_result
=
relu
(
bn
(
linear
(
inputs
)))
actual_result
=
fused_linear
(
inputs
)
np
.
testing
.
assert_allclose
(
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
)
linear
.
eval
()
bn
.
eval
()
relu
.
eval
()
fused_linear
=
fuse_linear_bn_relu_module
(
linear
,
bn
,
relu
)
fused_linear
.
eval
()
expected_result
=
relu
(
linear
(
inputs
))
actual_result
=
fused_linear
(
inputs
)
np
.
testing
.
assert_allclose
(
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
)
linear
.
train
()
bn
.
train
()
fused_linear
=
fuse_linear_bn_relu_module
(
linear
,
bn
,
None
)
fused_linear
.
train
()
expected_result
=
bn
(
linear
(
inputs
))
actual_result
=
fused_linear
(
inputs
)
np
.
testing
.
assert_allclose
(
expected_result
.
numpy
(),
actual_result
.
numpy
(),
atol
=
1e-4
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录