Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fda9599a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fda9599a
编写于
6月 05, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/quant): add TQT quant method
GitOrigin-RevId: 00b1616e73ed34c8c09e2407b8fc7d90230f8cec
上级
285d70cb
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
203 addition
and
30 deletion
+203
-30
python_module/megengine/core/function.py
python_module/megengine/core/function.py
+1
-0
python_module/megengine/module/qat/module.py
python_module/megengine/module/qat/module.py
+8
-2
python_module/megengine/quantization/__init__.py
python_module/megengine/quantization/__init__.py
+1
-0
python_module/megengine/quantization/fake_quant.py
python_module/megengine/quantization/fake_quant.py
+106
-25
python_module/megengine/quantization/observer.py
python_module/megengine/quantization/observer.py
+2
-0
python_module/megengine/quantization/qconfig.py
python_module/megengine/quantization/qconfig.py
+8
-2
python_module/test/unit/core/test_function.py
python_module/test/unit/core/test_function.py
+0
-1
python_module/test/unit/quantization/test_TQT.py
python_module/test/unit/quantization/test_TQT.py
+77
-0
未找到文件。
python_module/megengine/core/function.py
浏览文件 @
fda9599a
...
@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta):
...
@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta):
memo
[
id
(
self
)]
=
result
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
for
k
,
v
in
self
.
__dict__
.
items
():
setattr
(
result
,
k
,
copy
.
deepcopy
(
v
,
memo
))
setattr
(
result
,
k
,
copy
.
deepcopy
(
v
,
memo
))
setattr
(
result
,
"saved_tensors"
,
tmp
)
self
.
saved_tensors
=
tmp
self
.
saved_tensors
=
tmp
return
result
return
result
...
...
python_module/megengine/module/qat/module.py
浏览文件 @
fda9599a
...
@@ -77,13 +77,19 @@ class QATModule(Module):
...
@@ -77,13 +77,19 @@ class QATModule(Module):
r
"""
r
"""
Get weight's quantization dtype as the method from ``qconfig``.
Get weight's quantization dtype as the method from ``qconfig``.
"""
"""
return
self
.
weight_observer
.
get_dtype
()
if
hasattr
(
self
.
act_fake_quant
,
"get_dtype"
):
return
self
.
weight_fake_quant
.
get_dtype
()
else
:
return
self
.
weight_observer
.
get_dtype
()
def
get_activation_dtype
(
self
):
def
get_activation_dtype
(
self
):
r
"""
r
"""
Get activation's quantization dtype as the method from ``qconfig``.
Get activation's quantization dtype as the method from ``qconfig``.
"""
"""
return
self
.
act_observer
.
get_dtype
()
if
hasattr
(
self
.
act_fake_quant
,
"get_dtype"
):
return
self
.
act_fake_quant
.
get_dtype
()
else
:
return
self
.
act_observer
.
get_dtype
()
@
classmethod
@
classmethod
@
abstractmethod
@
abstractmethod
...
...
python_module/megengine/quantization/__init__.py
浏览文件 @
fda9599a
...
@@ -12,4 +12,5 @@ from .qconfig import (
...
@@ -12,4 +12,5 @@ from .qconfig import (
calibration_qconfig
,
calibration_qconfig
,
ema_fakequant_qconfig
,
ema_fakequant_qconfig
,
min_max_fakequant_qconfig
,
min_max_fakequant_qconfig
,
tqt_quant_qconfig
,
)
)
python_module/megengine/quantization/fake_quant.py
浏览文件 @
fda9599a
...
@@ -5,17 +5,20 @@
...
@@ -5,17 +5,20 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
copy
import
math
import
numpy
as
np
from
..
import
functional
as
F
from
..
import
functional
as
F
from
.._internal.dtype
import
_metadata_dict
from
.._internal.dtype
import
_metadata_dict
,
get_quantized_dtype
from
..core
import
Buffer
,
Function
,
Parameter
from
..jit
import
sideeffect
from
..module
import
Module
from
..module
import
Module
from
.observer
import
ObserverMode
,
Round
from
.observer
import
ObserverMode
,
Round
class
FakeQuantize
(
Module
):
class
_FakeQuantize
(
Module
):
r
"""
A module to do quant and dequant according to observer's scale and zero_point.
"""
def
__init__
(
self
,
dtype
:
str
,
enable
:
bool
=
True
):
def
__init__
(
self
,
dtype
:
str
,
enable
:
bool
=
True
):
super
().
__init__
()
super
().
__init__
()
if
not
dtype
in
_metadata_dict
.
keys
():
if
not
dtype
in
_metadata_dict
.
keys
():
...
@@ -35,25 +38,103 @@ class FakeQuantize(Module):
...
@@ -35,25 +38,103 @@ class FakeQuantize(Module):
def
disable
(
self
):
def
disable
(
self
):
self
.
enabled
=
False
self
.
enabled
=
False
def
fake_quant_forward
(
self
,
inp
,
q_dict
):
return
inp
def
normal_foward
(
self
,
inp
,
q_dict
):
return
inp
def
forward
(
self
,
inp
,
q_dict
):
def
forward
(
self
,
inp
,
q_dict
):
if
self
.
enabled
:
if
self
.
enabled
:
if
q_dict
[
"mode"
]
==
ObserverMode
.
SYMMERTIC
:
return
self
.
fake_quant_forward
(
inp
,
q_dict
)
scale
=
q_dict
[
"scale"
]
else
:
# Quant
return
self
.
normal_foward
(
inp
,
q_dict
)
oup
=
Round
()(
inp
/
scale
)
# clip
oup
=
F
.
minimum
(
F
.
maximum
(
oup
,
self
.
qmin
),
self
.
qmax
)
class
TQT_Function
(
Function
):
# DeQuant
def
__init__
(
self
,
lowerbound
,
upperbound
):
oup
=
(
oup
)
*
scale
super
().
__init__
()
return
oup
self
.
lowerbound
=
lowerbound
else
:
self
.
upperbound
=
upperbound
scale
=
q_dict
[
"scale"
]
zero_point
=
q_dict
[
"zero_point"
]
def
forward
(
self
,
inp
,
scale
):
# Quant
t
=
2
**
scale
oup
=
Round
()(
inp
/
scale
)
+
zero_point
# t = F.maximum(t, 1e-4)
# clip
inp_scaled
=
inp
/
t
oup
=
F
.
minimum
(
F
.
maximum
(
oup
,
self
.
qmin
),
self
.
qmax
)
inp_clipped
=
F
.
maximum
(
F
.
minimum
(
inp_scaled
,
self
.
upperbound
),
self
.
lowerbound
)
# DeQuant
inp_rounded
=
F
.
round
(
inp_clipped
)
oup
=
(
oup
-
zero_point
)
*
scale
inp_flq
=
inp_rounded
*
t
return
oup
self
.
save_for_backward
(
inp_scaled
,
inp_rounded
,
t
)
return
inp_flq
def
backward
(
self
,
grad_inp_flq
):
(
inp_scaled
,
inp_rounded
,
t
)
=
self
.
saved_tensors
mask_clip
=
(
inp_scaled
<
-
0.5
+
self
.
lowerbound
)
+
(
inp_scaled
>
self
.
upperbound
+
0.5
)
# mask for accumulating the gradients of |data_scaled|>L
mask_quant
=
F
.
abs
(
mask_clip
-
1
)
# mask for accumulating the gradients with |data_scaled|<=L
grad_quant
=
(
grad_inp_flq
*
mask_quant
*
(
inp_rounded
-
inp_scaled
)
)
# gradient within |data_scaled|<=L
grad_clip
=
(
grad_inp_flq
*
mask_clip
*
inp_rounded
)
# gradient with | data_scaled|>L
grad_s
=
grad_clip
.
sum
()
+
grad_quant
.
sum
()
# dL/ds = dL/dt * t * ln(2)
grad_s
=
grad_s
*
t
*
math
.
log
(
2
)
grad_inp
=
grad_inp_flq
*
mask_quant
return
grad_inp
,
grad_s
class
TQT
(
_FakeQuantize
):
"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks
"""
def
__init__
(
self
,
dtype
:
str
,
enable
:
bool
=
True
):
super
().
__init__
(
dtype
,
enable
)
self
.
scale
=
Parameter
(
0.0
,
dtype
=
np
.
float32
)
def
fake_quant_forward
(
self
,
inp
,
q_dict
):
# when enable, TQT will do fakequant forward, finetune the scale
return
TQT_Function
(
self
.
qmin
,
self
.
qmax
)(
inp
,
self
.
scale
)
def
normal_foward
(
self
,
inp
,
q_dict
):
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale
=
F
.
maximum
(
F
.
abs
(
q_dict
[
"min_val"
]),
F
.
abs
(
q_dict
[
"max_val"
]))
tmp_scale
=
F
.
log
(
tmp_scale
/
127
)
/
F
.
log
(
2
)
F
.
add_update
(
self
.
scale
,
tmp_scale
,
alpha
=
0.0
,
beta
=
1.0
,
bias
=
0.0
)
return
inp
return
inp
def
get_dtype
(
self
):
return
get_quantized_dtype
(
self
.
dtype
,
2
**
self
.
scale
.
numpy
()[
0
],
None
)
class
FakeQuantize
(
_FakeQuantize
):
r
"""
A module to do quant and dequant according to observer's scale and zero_point.
"""
def
fake_quant_forward
(
self
,
inp
,
q_dict
):
if
q_dict
[
"mode"
]
==
ObserverMode
.
SYMMERTIC
:
scale
=
q_dict
[
"scale"
]
# Quant
oup
=
Round
()(
inp
/
scale
)
# clip
oup
=
F
.
minimum
(
F
.
maximum
(
oup
,
self
.
qmin
),
self
.
qmax
)
# DeQuant
oup
=
(
oup
)
*
scale
return
oup
else
:
scale
=
q_dict
[
"scale"
]
zero_point
=
q_dict
[
"zero_point"
]
# Quant
oup
=
Round
()(
inp
/
scale
)
+
zero_point
# clip
oup
=
F
.
minimum
(
F
.
maximum
(
oup
,
self
.
qmin
),
self
.
qmax
)
# DeQuant
oup
=
(
oup
-
zero_point
)
*
scale
return
oup
python_module/megengine/quantization/observer.py
浏览文件 @
fda9599a
...
@@ -107,6 +107,8 @@ class MinMaxObserver(Observer):
...
@@ -107,6 +107,8 @@ class MinMaxObserver(Observer):
min_val
=
F
.
minimum
(
0.0
,
inp_min_val
)
min_val
=
F
.
minimum
(
0.0
,
inp_min_val
)
max_val
=
F
.
maximum
(
0.0
,
inp_max_val
)
max_val
=
F
.
maximum
(
0.0
,
inp_max_val
)
q_dict
=
create_observer_dict
(
self
.
mode
)
q_dict
=
create_observer_dict
(
self
.
mode
)
q_dict
[
"min_val"
]
=
inp_min_val
q_dict
[
"max_val"
]
=
inp_max_val
if
self
.
mode
==
ObserverMode
.
SYMMERTIC
:
if
self
.
mode
==
ObserverMode
.
SYMMERTIC
:
symmetric_max_vals
=
F
.
maximum
(
-
min_val
,
max_val
)
symmetric_max_vals
=
F
.
maximum
(
-
min_val
,
max_val
)
# use maximun to avoid scale too small at the begin
# use maximun to avoid scale too small at the begin
...
...
python_module/megengine/quantization/qconfig.py
浏览文件 @
fda9599a
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
#
'
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..module
import
Module
from
..module
import
Module
from
.fake_quant
import
FakeQuantize
from
.fake_quant
import
TQT
,
FakeQuantize
from
.observer
import
(
from
.observer
import
(
ExponentialMovingAverageObserver
,
ExponentialMovingAverageObserver
,
HistogramObserver
,
HistogramObserver
,
...
@@ -52,6 +52,12 @@ class QConfig:
...
@@ -52,6 +52,12 @@ class QConfig:
self
.
fake_quant
=
fake_quant
self
.
fake_quant
=
fake_quant
tqt_quant_qconfig
=
QConfig
(
weight_observer
=
ExponentialMovingAverageObserver
,
act_observer
=
ExponentialMovingAverageObserver
,
fake_quant
=
TQT
,
)
# Default QAT QConfigs
# Default QAT QConfigs
min_max_fakequant_qconfig
=
QConfig
(
min_max_fakequant_qconfig
=
QConfig
(
weight_observer
=
MinMaxObserver
,
weight_observer
=
MinMaxObserver
,
...
...
python_module/test/unit/core/test_function.py
浏览文件 @
fda9599a
...
@@ -96,7 +96,6 @@ def test_deepcopy():
...
@@ -96,7 +96,6 @@ def test_deepcopy():
origin
=
Sigmoid
(
0
)
origin
=
Sigmoid
(
0
)
new
=
copy
.
deepcopy
(
Sigmoid
(
0
))
new
=
copy
.
deepcopy
(
Sigmoid
(
0
))
assert
new
.
param
==
origin
.
param
assert
new
.
param
==
origin
.
param
assert
new
.
saved_tensors
==
None
def
test_save_context
():
def
test_save_context
():
...
...
python_module/test/unit/quantization/test_TQT.py
0 → 100644
浏览文件 @
fda9599a
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine._internal
as
mgb
from
megengine.core
import
tensor
from
megengine.quantization.fake_quant
import
TQT_Function
from
megengine.test
import
assertTensorClose
class
numpy_TQT_Function
:
def
__init__
(
self
,
lowerbound
,
upperbound
):
super
().
__init__
()
self
.
lowerbound
=
lowerbound
self
.
upperbound
=
upperbound
def
forward
(
self
,
inp
,
scale
):
t
=
2
**
scale
# t = F.maximum(t, 1e-4)
inp_scaled
=
inp
/
t
inp_clipped
=
np
.
maximum
(
np
.
minimum
(
inp_scaled
,
self
.
upperbound
),
self
.
lowerbound
)
inp_rounded
=
np
.
round
(
inp_clipped
)
inp_flq
=
inp_rounded
*
t
self
.
saved_tensors
=
(
inp_scaled
,
inp_rounded
,
t
)
return
inp_flq
def
backward
(
self
,
grad_inp_flq
):
(
inp_scaled
,
inp_rounded
,
t
)
=
self
.
saved_tensors
mask_clip
=
(
inp_scaled
<
-
0.5
+
self
.
lowerbound
)
+
(
inp_scaled
>
self
.
upperbound
+
0.5
)
# mask for accumulating the gradients of |data_scaled|>L
mask_quant
=
np
.
abs
(
mask_clip
-
1
)
# mask for accumulating the gradients with |data_scaled|<=L
grad_quant
=
(
grad_inp_flq
*
mask_quant
*
(
inp_rounded
-
inp_scaled
)
)
# gradient within |data_scaled|<=L
grad_clip
=
(
grad_inp_flq
*
mask_clip
*
inp_rounded
)
# gradient with | data_scaled|>L
grad_s
=
grad_clip
.
sum
()
+
grad_quant
.
sum
()
# dL/ds = dL/dt * t * ln(2)
grad_s
=
grad_s
*
t
*
np
.
log
(
2
)
grad_inp
=
grad_inp_flq
*
mask_quant
return
grad_inp
,
grad_s
def
test_TQT
():
f
=
TQT_Function
(
-
127
,
127
)
nf
=
numpy_TQT_Function
(
-
127
,
127
)
def
check_inp
(
a
,
b
,
c
,
a_np
,
b_np
,
c_np
):
assertTensorClose
(
f
.
forward
(
a
,
b
).
numpy
(),
nf
.
forward
(
a_np
,
b_np
).
astype
(
"float32"
)
)
c1
,
c2
=
f
.
backward
(
c
)
c1_np
,
c2_np
=
nf
.
backward
(
c_np
)
assertTensorClose
(
c1
.
numpy
(),
c1_np
.
astype
(
"float32"
))
assertTensorClose
(
c2
.
numpy
(),
c2_np
.
astype
(
"float32"
))
a
=
tensor
()
b
=
tensor
()
a_np
=
np
.
random
.
random
((
4
,
3
)).
astype
(
"float32"
)
b_np
=
np
.
random
.
random
((
1
)).
astype
(
"float32"
)
a
.
set_value
(
a_np
)
b
.
set_value
(
b_np
)
check_inp
(
a
,
b
,
b
,
a_np
,
b_np
,
b_np
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录