Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7c4f1a38
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7c4f1a38
编写于
5月 09, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/quantization): add calibration support
GitOrigin-RevId: f16fbba2b7cbc6138c4382fcb96b70f5eb71074c
上级
5eca4da3
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
45 addition
and
43 deletion
+45
-43
python_module/megengine/module/module.py
python_module/megengine/module/module.py
+8
-9
python_module/megengine/module/quantized/concat.py
python_module/megengine/module/quantized/concat.py
+3
-8
python_module/megengine/module/quantized/conv_bn_relu.py
python_module/megengine/module/quantized/conv_bn_relu.py
+4
-7
python_module/megengine/module/quantized/elemwise.py
python_module/megengine/module/quantized/elemwise.py
+3
-8
python_module/megengine/module/quantized/quant_dequant.py
python_module/megengine/module/quantized/quant_dequant.py
+4
-10
python_module/megengine/quantization/__init__.py
python_module/megengine/quantization/__init__.py
+1
-0
python_module/megengine/quantization/observer.py
python_module/megengine/quantization/observer.py
+1
-1
python_module/megengine/quantization/quantize.py
python_module/megengine/quantization/quantize.py
+21
-0
未找到文件。
python_module/megengine/module/module.py
浏览文件 @
7c4f1a38
...
...
@@ -496,6 +496,9 @@ class QATModule(Module):
self
,
target
:
Tensor
,
fq
:
"FakeQuantize"
,
obs
:
"Observer"
):
oup
=
self
.
apply_observer
(
target
,
obs
)
if
self
.
quantizing
==
self
.
QATMode
.
CALIBRATION
:
return
oup
else
:
scale
,
zero_point
=
obs
.
get_qparams
()
return
fq
(
oup
,
scale
,
zero_point
)
...
...
@@ -524,11 +527,7 @@ class QATModule(Module):
"""
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
self
.
quantizing
==
self
.
QATMode
.
QAT
:
return
self
.
forward_qat
(
*
args
,
**
kwargs
)
elif
self
.
quantizing
==
self
.
QATMode
.
CALIBRATION
:
# TODO implement the CALIBRATION
assert
False
return
None
else
:
if
self
.
quantizing
==
self
.
QATMode
.
DISABLED
:
return
self
.
forward
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_qat
(
*
args
,
**
kwargs
)
python_module/megengine/module/quantized/concat.py
浏览文件 @
7c4f1a38
...
...
@@ -20,11 +20,9 @@ class Concat(Module):
A :class:`~.Module` to do quantized concat, inference only.
"""
def
__init__
(
self
):
def
__init__
(
self
,
dtype
=
None
):
super
().
__init__
()
self
.
scale
=
1.0
self
.
zero_point
=
0.0
self
.
output_dtype
=
mgb
.
dtype
.
qint8
(
self
.
scale
)
self
.
output_dtype
=
dtype
def
forward
(
self
,
inps
:
Iterable
[
Tensor
],
axis
:
int
=
0
):
if
self
.
training
:
...
...
@@ -39,7 +37,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod
=
Concat
()
qmod
.
output_dtype
=
float_module
.
act_observer
.
get_dtype
()
qmod
.
scale
,
qmod
.
zero_point
=
float_module
.
act_observer
.
get_qparams
()
return
qmod
return
Concat
(
float_module
.
act_observer
.
get_dtype
())
python_module/megengine/module/quantized/conv_bn_relu.py
浏览文件 @
7c4f1a38
...
...
@@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d):
groups
:
int
=
1
,
conv_mode
:
str
=
"CROSS_CORRELATION"
,
compute_mode
:
str
=
"DEFAULT"
,
dtype
=
None
,
):
super
().
__init__
(
in_channels
,
...
...
@@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d):
conv_mode
,
compute_mode
,
)
self
.
scale
=
1.0
self
.
zero_point
=
0.0
self
.
output_dtype
=
mgb
.
dtype
.
qint8
(
self
.
scale
)
self
.
weight
=
self
.
weight
.
astype
(
self
.
output_dtype
)
self
.
bias
=
self
.
bias
.
astype
(
mgb
.
dtype
.
qint32
(
self
.
scale
))
self
.
output_dtype
=
dtype
def
calc_conv_quantized
(
self
,
inp
,
nonlinear_mode
=
"IDENTITY"
):
inp_scale
=
mgb
.
dtype
.
get_scale
(
inp
.
dtype
)
...
...
@@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d):
def
to_quantized
(
quantized_class
,
float_module
):
output_dtype
=
float_module
.
act_observer
.
get_dtype
()
qconv
=
quantized_class
(
float_module
.
conv
.
in_channels
,
float_module
.
conv
.
out_channels
,
...
...
@@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module):
float_module
.
conv
.
padding
,
float_module
.
conv
.
dilation
,
float_module
.
conv
.
groups
,
dtype
=
output_dtype
,
)
w_fold
,
b_fold
=
float_module
.
fold_weight_bias
(
float_module
.
bn
.
running_mean
,
float_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
float_module
.
weight_observer
.
get_dtype
())
qconv
.
output_dtype
=
float_module
.
act_observer
.
get_dtype
()
qconv
.
weight
=
Parameter
(
weight
.
numpy
())
qconv
.
bias
=
Parameter
(
b_fold
.
numpy
())
qconv
.
scale
,
qconv
.
zero_point
=
float_module
.
act_observer
.
get_qparams
()
return
qconv
...
...
python_module/megengine/module/quantized/elemwise.py
浏览文件 @
7c4f1a38
...
...
@@ -34,12 +34,10 @@ class Elemwise(Module):
_elemwise_multi_type_mode
=
mgb
.
opr_param_defs
.
ElemwiseMultiType
.
Mode
def
__init__
(
self
,
method
):
def
__init__
(
self
,
method
,
dtype
=
None
):
super
().
__init__
()
self
.
method
=
self
.
_elemwise_multi_type_mode
.
convert
(
"Q"
+
method
)
self
.
scale
=
1.0
self
.
zero_point
=
0.0
self
.
output_dtype
=
mgb
.
dtype
.
qint8
(
self
.
scale
)
self
.
output_dtype
=
dtype
def
forward
(
self
,
*
inps
):
if
self
.
training
:
...
...
@@ -53,7 +51,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod
=
Elemwise
(
float_module
.
method
.
name
)
qmod
.
output_dtype
=
float_module
.
act_observer
.
get_dtype
()
qmod
.
scale
,
qmod
.
zero_point
=
float_module
.
act_observer
.
get_qparams
()
return
qmod
return
Elemwise
(
float_module
.
method
.
name
,
float_module
.
act_observer
.
get_dtype
())
python_module/megengine/module/quantized/quant_dequant.py
浏览文件 @
7c4f1a38
...
...
@@ -16,11 +16,9 @@ class QuantStub(Module):
A helper quantize operation on input and inference only.
"""
def
__init__
(
self
):
def
__init__
(
self
,
dtype
=
None
):
super
().
__init__
()
self
.
scale
=
1.0
self
.
zero_point
=
0.0
self
.
output_dtype
=
mgb
.
dtype
.
qint8
(
self
.
scale
)
self
.
output_dtype
=
dtype
def
forward
(
self
,
inp
):
if
self
.
training
:
...
...
@@ -45,10 +43,7 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod
=
QuantStub
()
qmod
.
output_dtype
=
float_module
.
act_observer
.
get_dtype
()
qmod
.
scale
,
qmod
.
zero_point
=
float_module
.
act_observer
.
get_qparams
()
return
qmod
return
QuantStub
(
float_module
.
act_observer
.
get_dtype
())
@
register_method_to_class
(
Float
.
DequantStub
)
...
...
@@ -57,5 +52,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod
=
DequantStub
()
return
qmod
return
DequantStub
()
python_module/megengine/quantization/__init__.py
浏览文件 @
7c4f1a38
...
...
@@ -14,5 +14,6 @@ from .quantize import (
enable_fake_quant
,
enable_observer
,
quantize
,
quantize_calibration
,
quantize_qat
,
)
python_module/megengine/quantization/observer.py
浏览文件 @
7c4f1a38
...
...
@@ -11,7 +11,7 @@ import numpy as np
from
..
import
functional
as
F
from
.._internal.dtype
import
_metadata_dict
,
get_quantized_dtype
from
..core
import
Buffer
,
Function
,
ones
,
tensor
,
zeros
from
..core
import
Buffer
,
Function
,
tensor
from
..module
import
Module
...
...
python_module/megengine/quantization/quantize.py
浏览文件 @
7c4f1a38
...
...
@@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True):
else
:
setattr
(
parent
,
key
.
split
(
"."
)[
-
1
],
submodule
.
to_quantized
())
return
module
def
quantize_qat
(
module
:
Module
,
qconfig
:
QConfig
=
ema_fakequant_qconfig
):
r
"""
...
...
@@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
module
.
apply
(
fn
)
def
quantize_calibration
(
module
:
Module
,
qconfig
:
QConfig
=
ema_fakequant_qconfig
):
r
"""
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply`
and set qconfig relatively.
:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""
def
fn
(
mod
:
Module
):
if
isinstance
(
mod
,
QATModule
):
mod
.
set_qat_mode
(
QATModule
.
QATMode
.
CALIBRATION
)
mod
.
set_qconfig
(
qconfig
)
module
.
apply
(
fn
)
enable_observer
(
module
)
def
disable_fake_quant
(
module
:
Module
):
r
"""
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录