Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ab9f44f1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ab9f44f1
编写于
11月 19, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/quantization): add support for easyquant
GitOrigin-RevId: 060d908349ca6bdcee293be5a2e47a5bee98af5e
上级
fc0fcd2f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
561 addition
and
44 deletion
+561
-44
imperative/python/megengine/module/quantized/linear.py
imperative/python/megengine/module/quantized/linear.py
+1
-3
imperative/python/megengine/quantization/__init__.py
imperative/python/megengine/quantization/__init__.py
+2
-1
imperative/python/megengine/quantization/fake_quant.py
imperative/python/megengine/quantization/fake_quant.py
+19
-13
imperative/python/megengine/quantization/observer.py
imperative/python/megengine/quantization/observer.py
+41
-6
imperative/python/megengine/quantization/qconfig.py
imperative/python/megengine/quantization/qconfig.py
+30
-10
imperative/python/megengine/quantization/quantize.py
imperative/python/megengine/quantization/quantize.py
+117
-4
imperative/python/megengine/quantization/utils.py
imperative/python/megengine/quantization/utils.py
+1
-3
imperative/python/test/unit/quantization/test_observer.py
imperative/python/test/unit/quantization/test_observer.py
+70
-4
imperative/python/test/unit/quantization/test_qconfig.py
imperative/python/test/unit/quantization/test_qconfig.py
+14
-0
imperative/python/test/unit/quantization/test_quantize.py
imperative/python/test/unit/quantization/test_quantize.py
+266
-0
未找到文件。
imperative/python/megengine/module/quantized/linear.py
浏览文件 @
ab9f44f1
...
@@ -17,9 +17,7 @@ from .module import QuantizedModule
...
@@ -17,9 +17,7 @@ from .module import QuantizedModule
class
Linear
(
QuantizedModule
):
class
Linear
(
QuantizedModule
):
r
"""Quantized version of :class:`~.qat.linear.Linear`."""
r
"""Quantized version of :class:`~.qat.linear.Linear`."""
def
__init__
(
def
__init__
(
self
,
dtype
:
np
.
dtype
=
None
):
self
,
dtype
:
np
.
dtype
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
None
self
.
weight
=
None
self
.
bias
=
None
self
.
bias
=
None
...
...
imperative/python/megengine/quantization/__init__.py
浏览文件 @
ab9f44f1
...
@@ -15,7 +15,8 @@ from .qconfig import (
...
@@ -15,7 +15,8 @@ from .qconfig import (
ema_fakequant_qconfig
,
ema_fakequant_qconfig
,
ema_lowbit_fakequant_qconfig
,
ema_lowbit_fakequant_qconfig
,
min_max_fakequant_qconfig
,
min_max_fakequant_qconfig
,
passive_qconfig
,
sync_ema_fakequant_qconfig
,
sync_ema_fakequant_qconfig
,
tqt_q
uant_q
config
,
tqt_qconfig
,
)
)
from
.utils
import
QuantMode
from
.utils
import
QuantMode
imperative/python/megengine/quantization/fake_quant.py
浏览文件 @
ab9f44f1
...
@@ -28,7 +28,9 @@ class _FakeQuantize(Module):
...
@@ -28,7 +28,9 @@ class _FakeQuantize(Module):
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
"""
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
enable
:
bool
=
True
):
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
enable
:
bool
=
True
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
if
not
dtype
in
_metadata_dict
.
keys
():
if
not
dtype
in
_metadata_dict
.
keys
():
raise
ValueError
(
raise
ValueError
(
...
@@ -114,24 +116,28 @@ class TQT(_FakeQuantize):
...
@@ -114,24 +116,28 @@ class TQT(_FakeQuantize):
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
"""
"""
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
enable
:
bool
=
True
):
def
__init__
(
super
().
__init__
(
dtype
,
narrow_range
,
enable
)
self
,
self
.
scale
=
Parameter
([
0.0
],
dtype
=
np
.
float32
)
q_dict
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
enable
:
bool
=
True
,
**
kwargs
):
super
().
__init__
(
dtype
,
narrow_range
,
enable
,
**
kwargs
)
assert
(
q_dict
[
"mode"
]
==
QuantMode
.
SYMMERTIC
),
"only symmetric quantization is supported by TQT"
if
"scale"
not
in
q_dict
or
q_dict
[
"scale"
]
is
None
:
raise
AssertionError
(
"Can not get an initialized scale"
)
self
.
scale
=
F
.
log
(
q_dict
[
"scale"
])
/
math
.
log
(
2
)
def
fake_quant_forward
(
self
,
inp
,
q_dict
=
None
):
def
fake_quant_forward
(
self
,
inp
,
q_dict
=
None
):
# when enable, TQT will do fakequant forward, finetune the scale
# when enable, TQT will do fakequant forward, finetune the scale
return
TQT_Function
(
self
.
qmin
,
self
.
qmax
)(
inp
,
self
.
scale
)
return
TQT_Function
(
self
.
qmin
,
self
.
qmax
)(
inp
,
self
.
scale
)
def
normal_foward
(
self
,
inp
,
q_dict
=
None
):
if
q_dict
[
"enable_observer"
]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale
=
F
.
maximum
(
F
.
abs
(
q_dict
[
"min_val"
]),
F
.
abs
(
q_dict
[
"max_val"
]))
tmp_scale
=
F
.
log
(
tmp_scale
/
127
)
/
math
.
log
(
2
)
self
.
scale
[...]
=
tmp_scale
return
inp
def
get_qparams
(
self
):
def
get_qparams
(
self
):
q_dict
=
get_qparam_dict
(
QuantMode
.
TQT
)
q_dict
=
get_qparam_dict
(
QuantMode
.
SYMMERTIC
)
q_dict
[
"scale"
]
=
2
**
self
.
scale
q_dict
[
"scale"
]
=
2
**
self
.
scale
return
q_dict
return
q_dict
...
...
imperative/python/megengine/quantization/observer.py
浏览文件 @
ab9f44f1
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
math
import
math
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
copy
import
deepcopy
import
numpy
as
np
import
numpy
as
np
...
@@ -28,7 +29,7 @@ class Observer(Module):
...
@@ -28,7 +29,7 @@ class Observer(Module):
instead of 1 greater. Usually True for weight and False for activation.
instead of 1 greater. Usually True for weight and False for activation.
"""
"""
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
):
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
if
dtype
not
in
_metadata_dict
.
keys
():
if
dtype
not
in
_metadata_dict
.
keys
():
raise
ValueError
(
raise
ValueError
(
...
@@ -81,8 +82,9 @@ class MinMaxObserver(Observer):
...
@@ -81,8 +82,9 @@ class MinMaxObserver(Observer):
eps
=
0.00001
,
eps
=
0.00001
,
dtype
=
"qint8"
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
narrow_range
:
bool
=
False
,
**
kwargs
):
):
super
().
__init__
(
dtype
,
narrow_range
)
super
().
__init__
(
dtype
,
narrow_range
,
**
kwargs
)
self
.
mode
=
mode
self
.
mode
=
mode
self
.
min_val
=
Tensor
(
np
.
finfo
(
np
.
float32
).
max
,
dtype
=
np
.
float32
)
self
.
min_val
=
Tensor
(
np
.
finfo
(
np
.
float32
).
max
,
dtype
=
np
.
float32
)
self
.
max_val
=
Tensor
(
np
.
finfo
(
np
.
float32
).
min
,
dtype
=
np
.
float32
)
self
.
max_val
=
Tensor
(
np
.
finfo
(
np
.
float32
).
min
,
dtype
=
np
.
float32
)
...
@@ -105,7 +107,7 @@ class MinMaxObserver(Observer):
...
@@ -105,7 +107,7 @@ class MinMaxObserver(Observer):
else
:
else
:
# use maximun to avoid scale too small at the begin
# use maximun to avoid scale too small at the begin
q_dict
[
"scale"
]
=
F
.
maximum
(
q_dict
[
"scale"
]
=
F
.
maximum
(
(
max_val
-
min_val
)
/
(
self
.
qmax
-
self
.
qmin
),
self
.
scale_limit
,
(
max_val
-
min_val
)
/
(
self
.
qmax
-
self
.
qmin
),
self
.
scale_limit
)
)
# caculate zero_point
# caculate zero_point
q_dict
[
"zero_point"
]
=
self
.
qmin
-
Round
()((
min_val
/
q_dict
[
"scale"
]))
q_dict
[
"zero_point"
]
=
self
.
qmin
-
Round
()((
min_val
/
q_dict
[
"scale"
]))
...
@@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
...
@@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
eps
=
0.00001
,
eps
=
0.00001
,
dtype
=
"qint8"
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
narrow_range
:
bool
=
False
,
**
kwargs
):
):
super
().
__init__
(
mode
,
eps
,
dtype
,
narrow_range
)
super
().
__init__
(
mode
,
eps
,
dtype
,
narrow_range
,
**
kwargs
)
self
.
momentum
=
Tensor
(
momentum
)
self
.
momentum
=
Tensor
(
momentum
)
self
.
runtime_momentum
=
Tensor
(
0.0
)
self
.
runtime_momentum
=
Tensor
(
0.0
)
...
@@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver):
...
@@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver):
eps
=
0.00001
,
eps
=
0.00001
,
dtype
=
"qint8"
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
narrow_range
:
bool
=
False
,
**
kwargs
):
):
super
().
__init__
(
mode
,
eps
,
dtype
,
narrow_range
)
super
().
__init__
(
mode
,
eps
,
dtype
,
narrow_range
,
**
kwargs
)
self
.
bins
=
bins
self
.
bins
=
bins
self
.
upsample_rate
=
upsample_rate
self
.
upsample_rate
=
upsample_rate
self
.
dst_nbins
=
_metadata_dict
[
dtype
].
qmax
-
_metadata_dict
[
dtype
].
qmin
+
1
self
.
dst_nbins
=
_metadata_dict
[
dtype
].
qmax
-
_metadata_dict
[
dtype
].
qmin
+
1
...
@@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver):
...
@@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver):
# combine the existing histogram and new histogram into 1 histogram
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
# and then downsampling the histogram efficiently
(
new_min
,
new_max
,
downsample_rate
,
start_idx
,
)
=
self
.
_adjust_min_max
(
(
new_min
,
new_max
,
downsample_rate
,
start_idx
)
=
self
.
_adjust_min_max
(
new_min
,
new_max
,
self
.
upsample_rate
new_min
,
new_max
,
self
.
upsample_rate
)
)
...
@@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver):
...
@@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver):
def
forward
(
self
,
x_orig
):
def
forward
(
self
,
x_orig
):
self
.
sideeffect_forward
(
x_orig
)
self
.
sideeffect_forward
(
x_orig
)
return
x_orig
return
x_orig
class
PassiveObserver
(
Observer
):
r
"""
This class can be set :attr:`scale` derectly.
"""
def
__init__
(
self
,
q_dict
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
**
kwargs
):
super
().
__init__
(
dtype
,
narrow_range
,
**
kwargs
)
self
.
q_dict
=
deepcopy
(
q_dict
)
if
"scale"
not
in
q_dict
or
q_dict
[
"scale"
]
is
None
:
raise
AssertionError
(
"Can not get an initialized scale"
)
self
.
orig_scale
=
q_dict
[
"scale"
].
numpy
()
@
property
def
scale
(
self
):
return
self
.
q_dict
[
"scale"
]
@
scale
.
setter
def
scale
(
self
,
value
):
assert
value
>
0
self
.
q_dict
[
"scale"
].
set_value
(
value
)
def
get_qparams
(
self
):
return
self
.
q_dict
def
forward
(
self
,
x
):
r
"""
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`.
"""
return
x
imperative/python/megengine/quantization/qconfig.py
浏览文件 @
ab9f44f1
...
@@ -13,6 +13,7 @@ from .observer import (
...
@@ -13,6 +13,7 @@ from .observer import (
ExponentialMovingAverageObserver
,
ExponentialMovingAverageObserver
,
HistogramObserver
,
HistogramObserver
,
MinMaxObserver
,
MinMaxObserver
,
PassiveObserver
,
SyncExponentialMovingAverageObserver
,
SyncExponentialMovingAverageObserver
,
SyncMinMaxObserver
,
SyncMinMaxObserver
,
)
)
...
@@ -66,17 +67,22 @@ class QConfig:
...
@@ -66,17 +67,22 @@ class QConfig:
self
.
weight_fake_quant
=
weight_fake_quant
self
.
weight_fake_quant
=
weight_fake_quant
self
.
act_fake_quant
=
act_fake_quant
self
.
act_fake_quant
=
act_fake_quant
def
__eq__
(
self
,
other
):
def
eq
(
a
,
b
):
if
isinstance
(
a
,
partial
)
and
isinstance
(
b
,
partial
):
return
all
(
[
a
.
func
==
b
.
func
,
a
.
args
==
b
.
args
,
a
.
keywords
==
b
.
keywords
]
)
else
:
return
a
==
b
return
(
eq
(
self
.
weight_observer
,
other
.
weight_observer
)
and
eq
(
self
.
act_observer
,
other
.
act_observer
)
and
eq
(
self
.
weight_fake_quant
,
other
.
weight_fake_quant
)
and
eq
(
self
.
act_fake_quant
,
other
.
act_fake_quant
)
)
tqt_quant_qconfig
=
QConfig
(
weight_observer
=
partial
(
ExponentialMovingAverageObserver
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_observer
=
partial
(
ExponentialMovingAverageObserver
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
False
),
)
min_max_fakequant_qconfig
=
QConfig
(
min_max_fakequant_qconfig
=
QConfig
(
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8"
,
narrow_range
=
True
),
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8"
,
narrow_range
=
True
),
...
@@ -118,3 +124,17 @@ calibration_qconfig = QConfig(
...
@@ -118,3 +124,17 @@ calibration_qconfig = QConfig(
weight_fake_quant
=
None
,
weight_fake_quant
=
None
,
act_fake_quant
=
None
,
act_fake_quant
=
None
,
)
)
tqt_qconfig
=
QConfig
(
weight_observer
=
None
,
act_observer
=
None
,
weight_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
False
),
)
passive_qconfig
=
QConfig
(
weight_observer
=
partial
(
PassiveObserver
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_observer
=
partial
(
PassiveObserver
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
False
),
)
imperative/python/megengine/quantization/quantize.py
浏览文件 @
ab9f44f1
...
@@ -6,15 +6,18 @@
...
@@ -6,15 +6,18 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
copy
import
copy
,
deepcopy
from
copy
import
copy
,
deepcopy
from
functools
import
partial
from
typing
import
Callable
,
Dict
,
Tuple
from
typing
import
Callable
,
Dict
,
Tuple
import
numpy
as
np
from
..
import
module
as
Float
from
..
import
module
as
Float
from
..functional
import
concat
,
norm
from
..module
import
Module
from
..module
import
Module
from
..module
import
qat
as
QAT
from
..module
import
qat
as
QAT
from
..module
import
quantized
as
Quantized
from
..module
import
quantized
as
Quantized
from
..module.qat
import
QATModule
from
..module.qat
import
QATModule
from
..module.quantized
import
QuantizedModule
from
..module.quantized
import
QuantizedModule
from
.fake_quant
import
TQT
from
.qconfig
import
QConfig
,
ema_fakequant_qconfig
from
.qconfig
import
QConfig
,
ema_fakequant_qconfig
...
@@ -32,9 +35,7 @@ def _get_quantable_module_names():
...
@@ -32,9 +35,7 @@ def _get_quantable_module_names():
return
quantable_module_names
return
quantable_module_names
def
_get_convert_dict
()
->
Tuple
[
def
_get_convert_dict
():
Dict
[
Module
,
QATModule
],
Dict
[
QATModule
,
QuantizedModule
]
]:
quantable_module_names
=
_get_quantable_module_names
()
quantable_module_names
=
_get_quantable_module_names
()
quantable_modules
=
[
getattr
(
Float
,
key
)
for
key
in
quantable_module_names
]
quantable_modules
=
[
getattr
(
Float
,
key
)
for
key
in
quantable_module_names
]
...
@@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[
...
@@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[
_float2qat_dict
,
_qat2quantized_dict
=
_get_convert_dict
()
_float2qat_dict
,
_qat2quantized_dict
=
_get_convert_dict
()
qat_modules
=
tuple
(
_qat2quantized_dict
.
keys
())
def
is_qat
(
mod
:
Module
):
return
isinstance
(
mod
,
qat_modules
)
def
quantize
(
module
:
Module
,
inplace
:
bool
=
True
,
mapping
:
dict
=
None
):
def
quantize
(
module
:
Module
,
inplace
:
bool
=
True
,
mapping
:
dict
=
None
):
...
@@ -133,6 +139,34 @@ def quantize_qat(
...
@@ -133,6 +139,34 @@ def quantize_qat(
return
module
return
module
def
reset_qconfig
(
module
:
Module
,
qconfig
:
QConfig
,
inplace
:
bool
=
True
):
r
"""
Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig``
:param module: root module to reset recursively.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
:param inplace: whether to reset submodules in-place.
"""
if
not
inplace
:
module
=
deepcopy
(
module
)
def
safe_call
(
func
,
q_dict
):
return
func
(
q_dict
=
q_dict
)
if
func
is
not
None
else
None
for
m
in
list
(
module
.
_flatten
(
predicate
=
is_qat
)):
if
m
.
with_weight
:
weight_q_dict
=
m
.
get_weight_qparams
()
m
.
weight_observer
=
safe_call
(
qconfig
.
weight_observer
,
weight_q_dict
)
m
.
weight_fake_quant
=
safe_call
(
qconfig
.
weight_fake_quant
,
weight_q_dict
)
if
m
.
with_act
:
act_q_dict
=
m
.
get_activation_qparams
()
m
.
act_observer
=
safe_call
(
qconfig
.
act_observer
,
act_q_dict
)
m
.
act_fake_quant
=
safe_call
(
qconfig
.
act_fake_quant
,
act_q_dict
)
return
module
def
_propagate
(
module
:
Module
,
func_str
:
str
,
*
args
,
**
kargs
):
def
_propagate
(
module
:
Module
,
func_str
:
str
,
*
args
,
**
kargs
):
def
fn
(
mod
:
Module
):
def
fn
(
mod
:
Module
):
if
isinstance
(
mod
,
QATModule
):
if
isinstance
(
mod
,
QATModule
):
...
@@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
...
@@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
_propagate
(
module
,
"set_qconfig"
,
qconfig
)
_propagate
(
module
,
"set_qconfig"
,
qconfig
)
def
hook_qat_module
(
module
:
Module
,
func
:
Callable
):
r
"""
Add hooks for all :class:`~.QATModule` submodule
"""
hooks
=
[]
for
submodule
in
list
(
module
.
_flatten
(
predicate
=
is_qat
)):
hooks
.
append
(
submodule
.
register_forward_hook
(
func
))
return
hooks
def
apply_easy_quant
(
module
,
data
,
start
=
0.8
,
stop
=
1.2
,
num
=
40
):
r
"""
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
Search for optimal scales.
:param module: root module.
:param data: input tensor used to search optimal scale.
:param start: lower bound of the search interval.
:param stop: upper bound of the search interval.
:param num: number of samples to search.
"""
batch_size
=
data
.
shape
[
0
]
def
get_cosine
(
x
,
y
):
ndim
=
len
(
x
.
shape
)
axis
=
tuple
(
range
(
1
,
ndim
))
up
=
(
x
*
y
).
sum
(
axis
=
axis
)
down
=
norm
(
x
,
axis
=
axis
)
*
norm
(
y
,
axis
=
axis
)
sim
=
up
/
down
return
sim
.
mean
(
axis
=
0
)
def
search
(
mod
,
inputs
,
outputs
,
where
):
mod
.
_forward_hooks
.
clear
()
fp32_in
=
[
_
[:
batch_size
]
for
_
in
inputs
]
int8_in
=
[
_
[
batch_size
:]
for
_
in
inputs
]
disable_fake_quant
(
mod
)
fp32_out
=
mod
(
*
fp32_in
)
enable_fake_quant
(
mod
)
ob
=
getattr
(
mod
,
where
)
if
ob
is
None
:
return
orig_scale
=
ob
.
orig_scale
distance
=
0
best_scale
=
0
for
scale
in
np
.
linspace
(
start
*
orig_scale
,
stop
*
orig_scale
,
num
):
ob
.
scale
=
scale
int8_out
=
mod
(
*
int8_in
)
dis
=
get_cosine
(
fp32_out
,
int8_out
)
if
dis
>
distance
:
distance
=
dis
best_scale
=
scale
ob
.
scale
=
best_scale
if
where
==
"act_observer"
:
int8_out
=
mod
(
*
int8_in
)
return
concat
([
fp32_out
,
int8_out
])
else
:
int8_out
=
outputs
[
batch_size
:]
return
concat
([
fp32_out
,
int8_out
])
data
=
concat
([
data
,
data
])
hook_qat_module
(
module
,
partial
(
search
,
where
=
"weight_observer"
))
module
(
data
)
hook_qat_module
(
module
,
partial
(
search
,
where
=
"act_observer"
))
module
(
data
)
return
module
def
disable_fake_quant
(
module
:
Module
):
def
disable_fake_quant
(
module
:
Module
):
r
"""
r
"""
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
...
...
imperative/python/megengine/quantization/utils.py
浏览文件 @
ab9f44f1
...
@@ -54,17 +54,15 @@ class QuantMode(Enum):
...
@@ -54,17 +54,15 @@ class QuantMode(Enum):
SYMMERTIC
=
1
SYMMERTIC
=
1
ASYMMERTIC
=
2
ASYMMERTIC
=
2
TQT
=
3
qparam_dict
=
{
qparam_dict
=
{
QuantMode
.
SYMMERTIC
:
{
"mode"
:
QuantMode
.
SYMMERTIC
,
"scale"
:
None
,
},
QuantMode
.
SYMMERTIC
:
{
"mode"
:
QuantMode
.
SYMMERTIC
,
"scale"
:
None
},
QuantMode
.
ASYMMERTIC
:
{
QuantMode
.
ASYMMERTIC
:
{
"mode"
:
QuantMode
.
ASYMMERTIC
,
"mode"
:
QuantMode
.
ASYMMERTIC
,
"scale"
:
None
,
"scale"
:
None
,
"zero_point"
:
None
,
"zero_point"
:
None
,
},
},
QuantMode
.
TQT
:
{
"mode"
:
QuantMode
.
TQT
,
"scale"
:
None
,},
}
}
...
...
imperative/python/test/unit/quantization/test_observer.py
浏览文件 @
ab9f44f1
...
@@ -6,17 +6,53 @@ import pytest
...
@@ -6,17 +6,53 @@ import pytest
import
megengine
as
mge
import
megengine
as
mge
import
megengine.distributed
as
dist
import
megengine.distributed
as
dist
import
megengine.quantization.observer
as
ob
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.quantization.observer
import
(
ExponentialMovingAverageObserver
,
MinMaxObserver
,
Observer
,
PassiveObserver
,
SyncExponentialMovingAverageObserver
,
SyncMinMaxObserver
,
)
def
test_observer
():
with
pytest
.
raises
(
TypeError
):
Observer
(
"qint8"
)
def
test_min_max_observer
():
def
test_min_max_observer
():
x
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
"float32"
)
x
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
"float32"
)
np_min
,
np_max
=
x
.
min
(),
x
.
max
()
np_min
,
np_max
=
x
.
min
(),
x
.
max
()
x
=
mge
.
tensor
(
x
)
x
=
mge
.
tensor
(
x
)
m
=
ob
.
MinMaxObserver
()
m
=
MinMaxObserver
()
m
(
x
)
m
(
x
)
assert
m
.
min_val
==
np_min
and
m
.
max_val
==
np_max
np
.
testing
.
assert_allclose
(
m
.
min_val
.
numpy
(),
np_min
)
np
.
testing
.
assert_allclose
(
m
.
max_val
.
numpy
(),
np_max
)
def
test_exponential_moving_average_observer
():
t
=
np
.
random
.
rand
()
x1
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
"float32"
)
x2
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
"float32"
)
expected_min
=
x1
.
min
()
*
t
+
x2
.
min
()
*
(
1
-
t
)
expected_max
=
x1
.
max
()
*
t
+
x2
.
max
()
*
(
1
-
t
)
m
=
ExponentialMovingAverageObserver
(
momentum
=
t
)
m
(
mge
.
tensor
(
x1
,
dtype
=
np
.
float32
))
m
(
mge
.
tensor
(
x2
,
dtype
=
np
.
float32
))
np
.
testing
.
assert_allclose
(
m
.
min_val
.
numpy
(),
expected_min
)
np
.
testing
.
assert_allclose
(
m
.
max_val
.
numpy
(),
expected_max
)
def
test_passive_observer
():
q_dict
=
{
"scale"
:
mge
.
tensor
(
1.0
)}
m
=
PassiveObserver
(
q_dict
,
"qint8"
)
assert
m
.
orig_scale
==
1.0
assert
m
.
scale
==
1.0
m
.
scale
=
2.0
assert
m
.
scale
==
2.0
assert
m
.
get_qparams
()
==
{
"scale"
:
mge
.
tensor
(
2.0
)}
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
...
@@ -35,9 +71,39 @@ def test_sync_min_max_observer():
...
@@ -35,9 +71,39 @@ def test_sync_min_max_observer():
@
dist
.
launcher
@
dist
.
launcher
def
worker
():
def
worker
():
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
m
=
ob
.
SyncMinMaxObserver
()
m
=
SyncMinMaxObserver
()
y
=
mge
.
tensor
(
x
[
rank
*
3
:
(
rank
+
1
)
*
3
])
y
=
mge
.
tensor
(
x
[
rank
*
3
:
(
rank
+
1
)
*
3
])
m
(
y
)
m
(
y
)
assert
m
.
min_val
==
np_min
and
m
.
max_val
==
np_max
assert
m
.
min_val
==
np_min
and
m
.
max_val
==
np_max
worker
()
worker
()
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Darwin"
,
reason
=
"do not imp GPU mode at macos now"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"windows disable MGB_ENABLE_OPR_MM"
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
2
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
isolated_distributed
def
test_sync_exponential_moving_average_observer
():
word_size
=
get_device_count_by_fork
(
"gpu"
)
t
=
np
.
random
.
rand
()
x1
=
np
.
random
.
rand
(
3
*
word_size
,
3
,
3
,
3
).
astype
(
"float32"
)
x2
=
np
.
random
.
rand
(
3
*
word_size
,
3
,
3
,
3
).
astype
(
"float32"
)
expected_min
=
x1
.
min
()
*
t
+
x2
.
min
()
*
(
1
-
t
)
expected_max
=
x1
.
max
()
*
t
+
x2
.
max
()
*
(
1
-
t
)
@
dist
.
launcher
def
worker
():
rank
=
dist
.
get_rank
()
m
=
SyncExponentialMovingAverageObserver
(
momentum
=
t
)
y1
=
mge
.
tensor
(
x1
[
rank
*
3
:
(
rank
+
1
)
*
3
])
y2
=
mge
.
tensor
(
x2
[
rank
*
3
:
(
rank
+
1
)
*
3
])
m
(
y1
)
m
(
y2
)
np
.
testing
.
assert_allclose
(
m
.
min_val
.
numpy
(),
expected_min
)
np
.
testing
.
assert_allclose
(
m
.
max_val
.
numpy
(),
expected_max
)
worker
()
imperative/python/test/unit/quantization/test_qconfig.py
0 → 100644
浏览文件 @
ab9f44f1
from
functools
import
partial
from
megengine.quantization
import
QConfig
,
tqt_qconfig
from
megengine.quantization.fake_quant
import
TQT
def
test_equal
():
qconfig
=
QConfig
(
weight_observer
=
None
,
act_observer
=
None
,
weight_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
False
),
)
assert
qconfig
==
tqt_qconfig
imperative/python/test/unit/quantization/quantize.py
→
imperative/python/test/unit/quantization/
test_
quantize.py
浏览文件 @
ab9f44f1
...
@@ -8,17 +8,194 @@
...
@@ -8,17 +8,194 @@
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
megengine
import
functional
from
megengine
import
module
as
Float
from
megengine
import
module
as
Float
from
megengine
import
tensor
from
megengine
import
tensor
from
megengine.module
import
qat
as
QAT
from
megengine.module
import
qat
as
QAT
from
megengine.quantization
import
min_max_fakequant_qconfig
from
megengine.module
import
quantized
as
Q
from
megengine.quantization
import
(
min_max_fakequant_qconfig
,
passive_qconfig
,
tqt_qconfig
,
)
from
megengine.quantization.fake_quant
import
TQT
,
FakeQuantize
from
megengine.quantization.observer
import
MinMaxObserver
,
PassiveObserver
from
megengine.quantization.quantize
import
(
from
megengine.quantization.quantize
import
(
_get_quantable_module_names
,
_get_quantable_module_names
,
apply_easy_quant
,
disable_fake_quant
,
disable_fake_quant
,
disable_observer
,
enable_fake_quant
,
enable_observer
,
propagate_qconfig
,
quantize
,
quantize_qat
,
quantize_qat
,
reset_qconfig
,
)
)
class
Net
(
Float
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
quant
=
Float
.
QuantStub
()
self
.
linear
=
Float
.
Linear
(
3
,
3
)
self
.
dequant
=
Float
.
DequantStub
()
self
.
linear
.
bias
.
set_value
(
np
.
random
.
rand
(
3
))
def
forward
(
self
,
x
):
x
=
self
.
quant
(
x
)
x
=
self
.
linear
(
x
)
x
=
self
.
dequant
(
x
)
return
x
class
QATNet
(
Float
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
quant
=
QAT
.
QuantStub
()
self
.
linear
=
QAT
.
Linear
(
3
,
3
)
self
.
dequant
=
QAT
.
DequantStub
()
self
.
linear
.
bias
.
set_value
(
np
.
random
.
rand
(
3
))
def
forward
(
self
,
x
):
x
=
self
.
quant
(
x
)
x
=
self
.
linear
(
x
)
x
=
self
.
dequant
(
x
)
return
x
def
test_propagate_qconfig
():
net
=
QATNet
()
propagate_qconfig
(
net
,
min_max_fakequant_qconfig
)
assert
all
(
[
net
.
quant
.
weight_observer
is
None
,
net
.
quant
.
weight_fake_quant
is
None
,
isinstance
(
net
.
quant
.
act_observer
,
MinMaxObserver
),
isinstance
(
net
.
quant
.
act_fake_quant
,
FakeQuantize
),
isinstance
(
net
.
linear
.
weight_observer
,
MinMaxObserver
),
isinstance
(
net
.
linear
.
weight_fake_quant
,
FakeQuantize
),
isinstance
(
net
.
linear
.
act_observer
,
MinMaxObserver
),
isinstance
(
net
.
linear
.
act_fake_quant
,
FakeQuantize
),
net
.
dequant
.
weight_observer
is
None
,
net
.
dequant
.
weight_fake_quant
is
None
,
net
.
dequant
.
act_observer
is
None
,
net
.
dequant
.
act_observer
is
None
,
]
)
def
init_qat_net
():
net
=
QATNet
()
propagate_qconfig
(
net
,
min_max_fakequant_qconfig
)
min_val
=
np
.
random
.
randint
(
-
127
,
0
,
size
=
(
2
,))
max_val
=
np
.
random
.
randint
(
1
,
127
,
size
=
(
2
,))
net
.
linear
.
weight_observer
.
min_val
.
set_value
(
min_val
[
0
])
net
.
linear
.
weight_observer
.
max_val
.
set_value
(
max_val
[
0
])
net
.
linear
.
act_observer
.
min_val
.
set_value
(
min_val
[
1
])
net
.
linear
.
act_observer
.
max_val
.
set_value
(
max_val
[
1
])
return
net
def
test_reset_qconfig
():
qat_net
=
init_qat_net
()
new_qat_net
=
reset_qconfig
(
qat_net
,
passive_qconfig
)
assert
(
new_qat_net
.
linear
.
get_weight_qparams
()
==
qat_net
.
linear
.
get_weight_qparams
()
)
assert
(
new_qat_net
.
linear
.
get_activation_qparams
()
==
qat_net
.
linear
.
get_activation_qparams
()
)
def
test_enable_and_disable_observer
():
net
=
init_qat_net
()
enable_observer
(
net
)
assert
net
.
quant
.
act_observer
.
enabled
==
True
assert
net
.
linear
.
weight_observer
.
enabled
==
True
assert
net
.
linear
.
act_observer
.
enabled
==
True
disable_observer
(
net
)
assert
net
.
quant
.
act_observer
.
enabled
==
False
assert
net
.
linear
.
weight_observer
.
enabled
==
False
assert
net
.
linear
.
act_observer
.
enabled
==
False
def
test_enable_and_disable_fake_quant
():
net
=
init_qat_net
()
disable_fake_quant
(
net
)
assert
net
.
quant
.
act_fake_quant
.
enabled
==
False
assert
net
.
linear
.
weight_fake_quant
.
enabled
==
False
assert
net
.
linear
.
act_fake_quant
.
enabled
==
False
enable_fake_quant
(
net
)
assert
net
.
quant
.
act_fake_quant
.
enabled
==
True
assert
net
.
linear
.
weight_fake_quant
.
enabled
==
True
assert
net
.
linear
.
act_fake_quant
.
enabled
==
True
def
init_observer
(
module
,
data
):
enable_observer
(
module
)
disable_fake_quant
(
module
)
module
(
data
)
disable_observer
(
module
)
enable_fake_quant
(
module
)
def
test_enable_and_disable_all
():
x
=
tensor
(
np
.
random
.
randint
(
1
,
10
,
size
=
(
3
,
3
)).
astype
(
np
.
float32
))
net
=
Net
()
y1
=
net
(
x
).
numpy
()
net
=
quantize_qat
(
net
,
min_max_fakequant_qconfig
)
init_observer
(
net
,
x
)
y2
=
net
(
x
).
numpy
()
disable_fake_quant
(
net
)
y3
=
net
(
x
).
numpy
()
enable_fake_quant
(
net
)
y4
=
net
(
x
).
numpy
()
np
.
testing
.
assert_allclose
(
y1
,
y3
)
np
.
testing
.
assert_allclose
(
y2
,
y4
)
with
pytest
.
raises
(
AssertionError
):
np
.
testing
.
assert_allclose
(
y2
,
y3
)
def
test_quantize_qat
():
net
=
Net
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
,
qconfig
=
min_max_fakequant_qconfig
)
assert
isinstance
(
qat_net
.
quant
,
QAT
.
QuantStub
)
assert
isinstance
(
qat_net
.
linear
,
QAT
.
Linear
)
assert
isinstance
(
qat_net
.
dequant
,
QAT
.
DequantStub
)
def
test_quantize
():
qat_net
=
init_qat_net
()
q_net
=
quantize
(
qat_net
,
inplace
=
False
)
assert
isinstance
(
q_net
.
quant
,
Q
.
QuantStub
)
assert
isinstance
(
q_net
.
linear
,
Q
.
Linear
)
assert
isinstance
(
q_net
.
dequant
,
Q
.
DequantStub
)
def
test_apply_easy_quant
():
qat_net
=
init_qat_net
()
data
=
tensor
(
np
.
random
.
rand
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
)
eq_net
=
reset_qconfig
(
qat_net
,
passive_qconfig
,
inplace
=
False
)
apply_easy_quant
(
eq_net
,
data
,
0.9
,
1.1
,
10
)
assert
isinstance
(
eq_net
.
quant
.
act_observer
,
PassiveObserver
)
assert
isinstance
(
eq_net
.
linear
.
weight_observer
,
PassiveObserver
)
assert
isinstance
(
eq_net
.
linear
.
act_observer
,
PassiveObserver
)
assert
eq_net
.
dequant
.
act_observer
is
None
def
test_apply_tqt
():
qat_net
=
init_qat_net
()
tqt_net
=
reset_qconfig
(
qat_net
,
tqt_qconfig
,
inplace
=
False
)
assert
isinstance
(
tqt_net
.
quant
.
act_fake_quant
,
TQT
)
assert
isinstance
(
tqt_net
.
linear
.
weight_fake_quant
,
TQT
)
assert
isinstance
(
tqt_net
.
linear
.
act_fake_quant
,
TQT
)
assert
tqt_net
.
dequant
.
act_fake_quant
is
None
def
test_get_quantable_module_names
():
def
test_get_quantable_module_names
():
# need to make sure names from Quantized and QAT are the same
# need to make sure names from Quantized and QAT are the same
def
_get_qat_module_names
():
def
_get_qat_module_names
():
...
@@ -87,30 +264,3 @@ def test_convert_with_custom_mapping():
...
@@ -87,30 +264,3 @@ def test_convert_with_custom_mapping():
net
=
Net
()
net
=
Net
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
,
mapping
=
{
FloatExample
:
QATExample
})
qat_net
=
quantize_qat
(
net
,
inplace
=
False
,
mapping
=
{
FloatExample
:
QATExample
})
assert
isinstance
(
qat_net
.
example
,
QATExample
)
assert
isinstance
(
qat_net
.
example
,
QATExample
)
def
test_disable_fake_quant
():
class
Net
(
Float
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
quant
=
Float
.
QuantStub
()
self
.
linear
=
Float
.
Linear
(
3
,
3
)
self
.
dequant
=
Float
.
DequantStub
()
self
.
linear
.
bias
.
set_value
(
np
.
random
.
rand
(
3
))
def
forward
(
self
,
x
):
x
=
self
.
quant
(
x
)
x
=
self
.
linear
(
x
)
x
=
self
.
dequant
(
x
)
return
x
x
=
tensor
(
np
.
random
.
randint
(
1
,
10
,
size
=
(
3
,
3
)).
astype
(
np
.
float32
))
net
=
Net
()
y1
=
net
(
x
).
numpy
()
net
=
quantize_qat
(
net
,
min_max_fakequant_qconfig
)
y2
=
net
(
x
).
numpy
()
disable_fake_quant
(
net
)
y3
=
net
(
x
).
numpy
()
np
.
testing
.
assert_allclose
(
y1
,
y3
)
with
pytest
.
raises
(
AssertionError
):
np
.
testing
.
assert_allclose
(
y2
,
y3
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录