Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
555ecea9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
555ecea9
编写于
8月 12, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
8月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/quantization): add bias fakequant support
GitOrigin-RevId: a5e953b3fa3e0cf91b03708c26dca4561243504a
上级
9440842e
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
99 addition
and
43 deletion
+99
-43
python_module/megengine/core/tensor.py
python_module/megengine/core/tensor.py
+4
-3
python_module/megengine/module/qat/conv.py
python_module/megengine/module/qat/conv.py
+3
-1
python_module/megengine/module/qat/conv_bn.py
python_module/megengine/module/qat/conv_bn.py
+3
-1
python_module/megengine/module/qat/linear.py
python_module/megengine/module/qat/linear.py
+3
-1
python_module/megengine/module/qat/module.py
python_module/megengine/module/qat/module.py
+10
-5
python_module/megengine/quantization/fake_quant.py
python_module/megengine/quantization/fake_quant.py
+2
-21
python_module/megengine/quantization/observer.py
python_module/megengine/quantization/observer.py
+2
-10
python_module/megengine/quantization/utils.py
python_module/megengine/quantization/utils.py
+72
-1
未找到文件。
python_module/megengine/core/tensor.py
浏览文件 @
555ecea9
...
@@ -138,6 +138,7 @@ class Tensor:
...
@@ -138,6 +138,7 @@ class Tensor:
def
__init__
(
self
,
val
=
None
,
*
,
requires_grad
=
None
):
def
__init__
(
self
,
val
=
None
,
*
,
requires_grad
=
None
):
self
.
_reset
(
val
,
requires_grad
=
requires_grad
)
self
.
_reset
(
val
,
requires_grad
=
requires_grad
)
self
.
q_dict
=
{
"mode"
:
None
,
"scale"
:
None
,
"zero_point"
:
None
}
def
_reset
(
self
,
val
=
None
,
*
,
requires_grad
=
None
):
def
_reset
(
self
,
val
=
None
,
*
,
requires_grad
=
None
):
self
.
__sym_override
=
None
self
.
__sym_override
=
None
...
@@ -677,9 +678,9 @@ class Tensor:
...
@@ -677,9 +678,9 @@ class Tensor:
def
__deepcopy__
(
self
,
memo
):
def
__deepcopy__
(
self
,
memo
):
"""
"""
Since Tensor have __getstate__ and __setstate__ method,
The default deepcopy will ignore other attributes except those defined at
deepcopy only process the that and ignore the attribute of Parameter
.
__getstate__ and __setstate__ method
.
So we need to add __deepcopy__ method to deepcopy correct attribute.
So we need to add __deepcopy__ method to deepcopy correct attribute
s
.
"""
"""
assert
(
self
.
__val
is
not
None
)
and
(
assert
(
self
.
__val
is
not
None
)
and
(
self
.
__sym
is
None
self
.
__sym
is
None
...
...
python_module/megengine/module/qat/conv.py
浏览文件 @
555ecea9
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...
import
functional
as
F
from
...
import
functional
as
F
from
...quantization.utils
import
fake_quant_bias
from
..
import
conv
as
Float
from
..
import
conv
as
Float
from
.module
import
QATModule
from
.module
import
QATModule
...
@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule):
...
@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule):
def
calc_conv_qat
(
self
,
inp
):
def
calc_conv_qat
(
self
,
inp
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
conv
=
self
.
calc_conv
(
inp
,
w_qat
,
self
.
bias
)
b_qat
=
fake_quant_bias
(
self
.
bias
,
inp
,
w_qat
)
conv
=
self
.
calc_conv
(
inp
,
w_qat
,
b_qat
)
return
conv
return
conv
@
classmethod
@
classmethod
...
...
python_module/megengine/module/qat/conv_bn.py
浏览文件 @
555ecea9
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...core
import
ones
,
zeros
from
...core
import
ones
,
zeros
from
...functional
import
add_update
,
relu
,
sqrt
,
sum
,
zero_grad
from
...functional
import
add_update
,
relu
,
sqrt
,
sum
,
zero_grad
from
...quantization.utils
import
fake_quant_bias
from
..
import
conv_bn
as
Float
from
..
import
conv_bn
as
Float
from
.module
import
QATModule
from
.module
import
QATModule
...
@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
...
@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold
=
beta
+
gamma
*
(
conv_bias
-
bn_mean
)
*
bn_istd
b_fold
=
beta
+
gamma
*
(
conv_bias
-
bn_mean
)
*
bn_istd
w_qat
=
self
.
apply_quant_weight
(
w_fold
)
w_qat
=
self
.
apply_quant_weight
(
w_fold
)
conv
=
self
.
conv
.
calc_conv
(
inp
,
w_qat
,
b_fold
)
b_qat
=
fake_quant_bias
(
b_fold
,
inp
,
w_qat
)
conv
=
self
.
conv
.
calc_conv
(
inp
,
w_qat
,
b_qat
)
if
not
(
self
.
training
and
approx
):
if
not
(
self
.
training
and
approx
):
return
conv
return
conv
...
...
python_module/megengine/module/qat/linear.py
浏览文件 @
555ecea9
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...quantization.utils
import
fake_quant_bias
from
..
import
linear
as
Float
from
..
import
linear
as
Float
from
.module
import
QATModule
from
.module
import
QATModule
...
@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule):
...
@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
return
self
.
apply_quant_activation
(
self
.
_calc_linear
(
x
,
w_qat
,
self
.
bias
),)
b_qat
=
fake_quant_bias
(
self
.
bias
,
x
,
w_qat
)
return
self
.
apply_quant_activation
(
self
.
_calc_linear
(
x
,
w_qat
,
b_qat
))
@
classmethod
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
Linear
):
def
from_float_module
(
cls
,
float_module
:
Float
.
Linear
):
...
...
python_module/megengine/module/qat/module.py
浏览文件 @
555ecea9
...
@@ -73,11 +73,16 @@ class QATModule(Module):
...
@@ -73,11 +73,16 @@ class QATModule(Module):
if
observer
is
None
:
if
observer
is
None
:
return
target
return
target
oup
=
observer
(
target
)
oup
=
observer
(
target
)
if
fake_quant
is
None
:
return
oup
else
:
q_dict
=
observer
.
get_qparams
()
q_dict
=
observer
.
get_qparams
()
return
fake_quant
(
oup
,
q_dict
)
# do fake quant
if
fake_quant
is
not
None
:
oup
=
fake_quant
(
oup
,
q_dict
)
# use qparams of fake_quant if have.
if
hasattr
(
fake_quant
,
"get_qparams"
):
q_dict
=
fake_quant
.
get_qparams
()
# set to tensor qparams.
oup
.
q_dict
.
update
(
q_dict
)
return
oup
def
apply_quant_weight
(
self
,
target
:
Tensor
):
def
apply_quant_weight
(
self
,
target
:
Tensor
):
r
"""
r
"""
...
...
python_module/megengine/quantization/fake_quant.py
浏览文件 @
555ecea9
...
@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
...
@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from
..core
import
Buffer
,
Function
,
Parameter
from
..core
import
Buffer
,
Function
,
Parameter
from
..jit
import
sideeffect
from
..jit
import
sideeffect
from
..module
import
Module
from
..module
import
Module
from
.observer
import
Round
from
.utils
import
QuantMode
,
Round
,
fake_quant_tensor
,
get_qparam_dict
from
.utils
import
QuantMode
,
get_qparam_dict
class
_FakeQuantize
(
Module
):
class
_FakeQuantize
(
Module
):
...
@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize):
...
@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize):
"""
"""
def
fake_quant_forward
(
self
,
inp
,
q_dict
):
def
fake_quant_forward
(
self
,
inp
,
q_dict
):
if
q_dict
[
"mode"
]
==
QuantMode
.
SYMMERTIC
:
return
fake_quant_tensor
(
inp
,
self
.
qmin
,
self
.
qmax
,
q_dict
)
scale
=
q_dict
[
"scale"
]
# Quant
oup
=
Round
()(
inp
/
scale
)
# clip
oup
=
F
.
minimum
(
F
.
maximum
(
oup
,
self
.
qmin
),
self
.
qmax
)
# DeQuant
oup
=
(
oup
)
*
scale
return
oup
else
:
scale
=
q_dict
[
"scale"
]
zero_point
=
q_dict
[
"zero_point"
]
# Quant
oup
=
Round
()(
inp
/
scale
)
+
zero_point
# clip
oup
=
F
.
minimum
(
F
.
maximum
(
oup
,
self
.
qmin
),
self
.
qmax
)
# DeQuant
oup
=
(
oup
-
zero_point
)
*
scale
return
oup
python_module/megengine/quantization/observer.py
浏览文件 @
555ecea9
...
@@ -13,18 +13,10 @@ import numpy as np
...
@@ -13,18 +13,10 @@ import numpy as np
from
..
import
functional
as
F
from
..
import
functional
as
F
from
.._internal.dtype
import
_metadata_dict
,
get_quantized_dtype
from
.._internal.dtype
import
_metadata_dict
,
get_quantized_dtype
from
..core
import
Buffer
,
Function
,
tensor
from
..core
import
Buffer
from
..jit
import
sideeffect
from
..jit
import
sideeffect
from
..module
import
Module
from
..module
import
Module
from
.utils
import
QuantMode
,
get_qparam_dict
from
.utils
import
QuantMode
,
Round
,
get_qparam_dict
class
Round
(
Function
):
def
forward
(
self
,
x
):
return
x
.
round
()
def
backward
(
self
,
output_grads
):
return
output_grads
class
Observer
(
Module
):
class
Observer
(
Module
):
...
...
python_module/megengine/quantization/utils.py
浏览文件 @
555ecea9
...
@@ -8,6 +8,24 @@
...
@@ -8,6 +8,24 @@
from
enum
import
Enum
from
enum
import
Enum
from
functools
import
partial
,
update_wrapper
,
wraps
from
functools
import
partial
,
update_wrapper
,
wraps
from
typing
import
Dict
from
..
import
functional
as
F
from
.._internal.dtype
import
_metadata_dict
from
..core
import
Function
,
Tensor
class
Round
(
Function
):
"""
The functional round have no grad and can not use for quantization-aware-training.
We use Function and STE(Straight-Through Estimator) to implement backward propagation.
"""
def
forward
(
self
,
x
):
return
x
.
round
()
def
backward
(
self
,
output_grads
):
return
output_grads
def
register_method_to_class
(
cls
):
def
register_method_to_class
(
cls
):
...
@@ -25,6 +43,9 @@ def register_method_to_class(cls):
...
@@ -25,6 +43,9 @@ def register_method_to_class(cls):
class
QuantMode
(
Enum
):
class
QuantMode
(
Enum
):
"""Quantization mode enumerate class.
"""
SYMMERTIC
=
1
SYMMERTIC
=
1
ASYMMERTIC
=
2
ASYMMERTIC
=
2
TQT
=
3
TQT
=
3
...
@@ -41,5 +62,55 @@ qparam_dict = {
...
@@ -41,5 +62,55 @@ qparam_dict = {
}
}
def
get_qparam_dict
(
mode
):
def
get_qparam_dict
(
mode
:
QuantMode
):
"""Return the quantization parameters dictory according to the mode.
"""
return
qparam_dict
.
get
(
mode
,
None
)
return
qparam_dict
.
get
(
mode
,
None
)
def
fake_quant_tensor
(
inp
:
Tensor
,
qmin
:
int
,
qmax
:
int
,
q_dict
:
Dict
)
->
Tensor
:
"""Apply fake quantization to the inp tensor.
:param inp: the input tensor which need to be faked.
:param qmin: the minimum value which the integer limit to.
:param qmax: the maximum value which the integer limit to.
:param q_dict: the quantization parameter dict.
"""
scale
=
q_dict
[
"scale"
]
zero_point
=
0
if
q_dict
[
"mode"
]
==
QuantMode
.
ASYMMERTIC
:
zero_point
=
q_dict
[
"zero_point"
]
# Quant
oup
=
Round
()(
inp
/
scale
)
+
zero_point
# Clip
oup
=
F
.
minimum
(
F
.
maximum
(
oup
,
qmin
),
qmax
)
# Dequant
oup
=
(
oup
-
zero_point
)
*
scale
return
oup
def
fake_quant_bias
(
bias
:
Tensor
,
inp
:
Tensor
,
w_qat
:
Tensor
)
->
Tensor
:
"""Apply fake quantization to bias, the special scale from input tensor
and weight tensor, the quantized type set to qint32 also.
:param bias: the bias tensor which need to be faked.
:param inp: the input tensor which contain the quantization parameters.
:param qmax: the weight tensor which contain the quantization parameters.
.. warning::
Only work for symmetric quantization method now.
"""
b_qat
=
bias
if
hasattr
(
inp
,
"q_dict"
)
and
b_qat
is
not
None
:
if
inp
.
q_dict
[
"scale"
]
is
not
None
and
w_qat
.
q_dict
[
"scale"
]
is
not
None
:
# use the same mode with weight.
b_dict
=
get_qparam_dict
(
w_qat
.
q_dict
[
"mode"
])
b_dict
[
"scale"
]
=
inp
.
q_dict
[
"scale"
]
*
w_qat
.
q_dict
[
"scale"
]
# TODO: add zero_point for ASYMMERTIC mode.
qmax
=
_metadata_dict
[
"qint32"
].
qmax
qmin
=
_metadata_dict
[
"qint32"
].
qmin
b_qat
=
fake_quant_tensor
(
b_qat
,
qmin
,
qmax
,
b_dict
)
return
b_qat
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录