Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4495c0cc
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4495c0cc
编写于
7月 17, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/quantization): add get quantize parameters support
GitOrigin-RevId: 5727f6356075658691c66a135b073c914969d9c9
上级
9b097859
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
63 addition
and
28 deletion
+63
-28
python_module/megengine/module/qat/module.py
python_module/megengine/module/qat/module.py
+19
-0
python_module/megengine/quantization/__init__.py
python_module/megengine/quantization/__init__.py
+2
-1
python_module/megengine/quantization/fake_quant.py
python_module/megengine/quantization/fake_quant.py
+14
-3
python_module/megengine/quantization/observer.py
python_module/megengine/quantization/observer.py
+6
-24
python_module/megengine/quantization/utils.py
python_module/megengine/quantization/utils.py
+22
-0
未找到文件。
python_module/megengine/module/qat/module.py
浏览文件 @
4495c0cc
...
@@ -92,6 +92,25 @@ class QATModule(Module):
...
@@ -92,6 +92,25 @@ class QATModule(Module):
else
:
else
:
return
self
.
act_observer
.
get_dtype
()
return
self
.
act_observer
.
get_dtype
()
def
_get_qparams
(
self
,
fake_quant
:
FakeQuantize
,
observer
:
Observer
):
if
hasattr
(
fake_quant
,
"get_qparams"
):
return
fake_quant
.
get_qparams
()
elif
observer
is
not
None
:
return
observer
.
get_qparams
()
return
None
def
get_weight_qparams
(
self
):
r
"""
Get weight's quantization parameters.
"""
return
self
.
_get_qparams
(
self
.
weight_fake_quant
,
self
.
weight_observer
)
def
get_activation_qparams
(
self
):
r
"""
Get activation's quantization parameters.
"""
return
self
.
_get_qparams
(
self
.
act_fake_quant
,
self
.
act_observer
)
@
classmethod
@
classmethod
@
abstractmethod
@
abstractmethod
def
from_float_module
(
cls
,
float_module
:
Module
):
def
from_float_module
(
cls
,
float_module
:
Module
):
...
...
python_module/megengine/quantization/__init__.py
浏览文件 @
4495c0cc
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
from
.fake_quant
import
FakeQuantize
from
.fake_quant
import
FakeQuantize
from
.internal_fake_quant
import
*
from
.internal_fake_quant
import
*
from
.observer
import
HistogramObserver
,
Observer
,
ObserverMode
from
.observer
import
HistogramObserver
,
Observer
from
.qconfig
import
(
from
.qconfig
import
(
QConfig
,
QConfig
,
calibration_qconfig
,
calibration_qconfig
,
...
@@ -16,3 +16,4 @@ from .qconfig import (
...
@@ -16,3 +16,4 @@ from .qconfig import (
min_max_fakequant_qconfig
,
min_max_fakequant_qconfig
,
tqt_quant_qconfig
,
tqt_quant_qconfig
,
)
)
from
.utils
import
QuantMode
python_module/megengine/quantization/fake_quant.py
浏览文件 @
4495c0cc
...
@@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
...
@@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from
..core
import
Buffer
,
Function
,
Parameter
from
..core
import
Buffer
,
Function
,
Parameter
from
..jit
import
sideeffect
from
..jit
import
sideeffect
from
..module
import
Module
from
..module
import
Module
from
.observer
import
ObserverMode
,
Round
from
.observer
import
Round
from
.utils
import
QuantMode
,
get_qparam_dict
class
_FakeQuantize
(
Module
):
class
_FakeQuantize
(
Module
):
...
@@ -121,8 +122,18 @@ class TQT(_FakeQuantize):
...
@@ -121,8 +122,18 @@ class TQT(_FakeQuantize):
F
.
add_update
(
self
.
scale
,
tmp_scale
,
alpha
=
0.0
,
beta
=
1.0
,
bias
=
0.0
)
F
.
add_update
(
self
.
scale
,
tmp_scale
,
alpha
=
0.0
,
beta
=
1.0
,
bias
=
0.0
)
return
inp
return
inp
def
get_qparams
(
self
):
qdict
=
get_qparam_dict
(
QuantMode
.
TQT
)
qdict
[
"scale"
]
=
2
**
self
.
scale
return
qdict
def
get_dtype
(
self
):
def
get_dtype
(
self
):
return
get_quantized_dtype
(
self
.
dtype
,
2
**
self
.
scale
.
numpy
()[
0
],
None
)
q_dict
=
self
.
get_qparams
()
scale
=
None
if
"scale"
not
in
q_dict
else
q_dict
[
"scale"
].
numpy
()[
0
]
zero_point
=
(
None
if
"zero_point"
not
in
q_dict
else
q_dict
[
"zero_point"
].
numpy
()[
0
]
)
return
get_quantized_dtype
(
self
.
dtype
,
scale
,
zero_point
)
class
FakeQuantize
(
_FakeQuantize
):
class
FakeQuantize
(
_FakeQuantize
):
...
@@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize):
...
@@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize):
"""
"""
def
fake_quant_forward
(
self
,
inp
,
q_dict
):
def
fake_quant_forward
(
self
,
inp
,
q_dict
):
if
q_dict
[
"mode"
]
==
Observer
Mode
.
SYMMERTIC
:
if
q_dict
[
"mode"
]
==
Quant
Mode
.
SYMMERTIC
:
scale
=
q_dict
[
"scale"
]
scale
=
q_dict
[
"scale"
]
# Quant
# Quant
oup
=
Round
()(
inp
/
scale
)
oup
=
Round
()(
inp
/
scale
)
...
...
python_module/megengine/quantization/observer.py
浏览文件 @
4495c0cc
...
@@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
...
@@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from
..core
import
Buffer
,
Function
,
tensor
from
..core
import
Buffer
,
Function
,
tensor
from
..jit
import
sideeffect
from
..jit
import
sideeffect
from
..module
import
Module
from
..module
import
Module
from
.utils
import
QuantMode
,
get_qparam_dict
class
Round
(
Function
):
class
Round
(
Function
):
...
@@ -81,29 +82,10 @@ class Observer(Module):
...
@@ -81,29 +82,10 @@ class Observer(Module):
pass
pass
class
ObserverMode
(
Enum
):
SYMMERTIC
=
1
ASYMMERTIC
=
2
def
create_observer_dict
(
mode
):
if
mode
==
ObserverMode
.
SYMMERTIC
:
return
{
"mode"
:
ObserverMode
.
SYMMERTIC
,
"scale"
:
None
,
}
else
:
return
{
"mode"
:
ObserverMode
.
ASYMMERTIC
,
"scale"
:
None
,
"zero_point"
:
None
,
}
class
MinMaxObserver
(
Observer
):
class
MinMaxObserver
(
Observer
):
def
__init__
(
def
__init__
(
self
,
self
,
mode
=
Observer
Mode
.
SYMMERTIC
,
mode
=
Quant
Mode
.
SYMMERTIC
,
eps
=
0.00001
,
eps
=
0.00001
,
dtype
=
"qint8"
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
narrow_range
:
bool
=
False
,
...
@@ -117,10 +99,10 @@ class MinMaxObserver(Observer):
...
@@ -117,10 +99,10 @@ class MinMaxObserver(Observer):
def
_calculate_qparams
(
self
,
inp_min_val
,
inp_max_val
):
def
_calculate_qparams
(
self
,
inp_min_val
,
inp_max_val
):
min_val
=
F
.
minimum
(
0.0
,
inp_min_val
)
min_val
=
F
.
minimum
(
0.0
,
inp_min_val
)
max_val
=
F
.
maximum
(
0.0
,
inp_max_val
)
max_val
=
F
.
maximum
(
0.0
,
inp_max_val
)
q_dict
=
create_observer
_dict
(
self
.
mode
)
q_dict
=
get_qparam
_dict
(
self
.
mode
)
q_dict
[
"min_val"
]
=
inp_min_val
q_dict
[
"min_val"
]
=
inp_min_val
q_dict
[
"max_val"
]
=
inp_max_val
q_dict
[
"max_val"
]
=
inp_max_val
if
self
.
mode
==
Observer
Mode
.
SYMMERTIC
:
if
self
.
mode
==
Quant
Mode
.
SYMMERTIC
:
symmetric_max_vals
=
F
.
maximum
(
-
min_val
,
max_val
)
symmetric_max_vals
=
F
.
maximum
(
-
min_val
,
max_val
)
# use maximun to avoid scale too small at the begin
# use maximun to avoid scale too small at the begin
q_dict
[
"scale"
]
=
F
.
maximum
(
q_dict
[
"scale"
]
=
F
.
maximum
(
...
@@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
...
@@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
def
__init__
(
def
__init__
(
self
,
self
,
momentum
=
0.9
,
momentum
=
0.9
,
mode
=
Observer
Mode
.
SYMMERTIC
,
mode
=
Quant
Mode
.
SYMMERTIC
,
eps
=
0.00001
,
eps
=
0.00001
,
dtype
=
"qint8"
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
narrow_range
:
bool
=
False
,
...
@@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver):
...
@@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver):
self
,
self
,
bins
=
2048
,
bins
=
2048
,
upsample_rate
=
128
,
upsample_rate
=
128
,
mode
=
Observer
Mode
.
SYMMERTIC
,
mode
=
Quant
Mode
.
SYMMERTIC
,
eps
=
0.00001
,
eps
=
0.00001
,
dtype
=
"qint8"
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
narrow_range
:
bool
=
False
,
...
...
python_module/megengine/quantization/utils.py
浏览文件 @
4495c0cc
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
enum
import
Enum
from
functools
import
partial
,
update_wrapper
,
wraps
from
functools
import
partial
,
update_wrapper
,
wraps
...
@@ -21,3 +22,24 @@ def register_method_to_class(cls):
...
@@ -21,3 +22,24 @@ def register_method_to_class(cls):
return
func
return
func
return
decorator
return
decorator
class
QuantMode
(
Enum
):
SYMMERTIC
=
1
ASYMMERTIC
=
2
TQT
=
3
qparam_dict
=
{
QuantMode
.
SYMMERTIC
:
{
"mode"
:
QuantMode
.
SYMMERTIC
,
"scale"
:
None
,},
QuantMode
.
ASYMMERTIC
:
{
"mode"
:
QuantMode
.
ASYMMERTIC
,
"scale"
:
None
,
"zero_point"
:
None
,
},
QuantMode
.
TQT
:
{
"mode"
:
QuantMode
.
TQT
,
"scale"
:
None
,},
}
def
get_qparam_dict
(
mode
):
return
qparam_dict
.
get
(
mode
,
None
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录