Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
caf1fac2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
caf1fac2
编写于
5月 25, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/quantization): split `QATModule` and refactor convert api
GitOrigin-RevId: 80cfb12d10590bbc88fd98370f5e3cf5d196d586
上级
ad3c9315
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
735 addition
and
480 deletion
+735
-480
python_module/megengine/functional/nn.py
python_module/megengine/functional/nn.py
+4
-4
python_module/megengine/module/__init__.py
python_module/megengine/module/__init__.py
+1
-1
python_module/megengine/module/concat.py
python_module/megengine/module/concat.py
+4
-9
python_module/megengine/module/conv_bn_relu.py
python_module/megengine/module/conv_bn_relu.py
+11
-159
python_module/megengine/module/elemwise.py
python_module/megengine/module/elemwise.py
+4
-9
python_module/megengine/module/linear.py
python_module/megengine/module/linear.py
+2
-13
python_module/megengine/module/module.py
python_module/megengine/module/module.py
+0
-96
python_module/megengine/module/qat/__init__.py
python_module/megengine/module/qat/__init__.py
+13
-0
python_module/megengine/module/qat/concat.py
python_module/megengine/module/qat/concat.py
+30
-0
python_module/megengine/module/qat/conv_bn_relu.py
python_module/megengine/module/qat/conv_bn_relu.py
+193
-0
python_module/megengine/module/qat/elemwise.py
python_module/megengine/module/qat/elemwise.py
+29
-0
python_module/megengine/module/qat/linear.py
python_module/megengine/module/qat/linear.py
+37
-0
python_module/megengine/module/qat/module.py
python_module/megengine/module/qat/module.py
+96
-0
python_module/megengine/module/qat/quant_dequant.py
python_module/megengine/module/qat/quant_dequant.py
+45
-0
python_module/megengine/module/quant_dequant.py
python_module/megengine/module/quant_dequant.py
+7
-13
python_module/megengine/module/quantized/__init__.py
python_module/megengine/module/quantized/__init__.py
+1
-0
python_module/megengine/module/quantized/concat.py
python_module/megengine/module/quantized/concat.py
+11
-16
python_module/megengine/module/quantized/conv_bn_relu.py
python_module/megengine/module/quantized/conv_bn_relu.py
+32
-36
python_module/megengine/module/quantized/elemwise.py
python_module/megengine/module/quantized/elemwise.py
+11
-18
python_module/megengine/module/quantized/linear.py
python_module/megengine/module/quantized/linear.py
+17
-24
python_module/megengine/module/quantized/module.py
python_module/megengine/module/quantized/module.py
+31
-0
python_module/megengine/module/quantized/quant_dequant.py
python_module/megengine/module/quantized/quant_dequant.py
+23
-29
python_module/megengine/quantization/__init__.py
python_module/megengine/quantization/__init__.py
+0
-9
python_module/megengine/quantization/qconfig.py
python_module/megengine/quantization/qconfig.py
+3
-7
python_module/megengine/quantization/quantize.py
python_module/megengine/quantization/quantize.py
+82
-25
python_module/test/unit/module/test_conv_bn_relu.py
python_module/test/unit/module/test_conv_bn_relu.py
+10
-12
python_module/test/unit/quantization/quantize.py
python_module/test/unit/quantization/quantize.py
+38
-0
未找到文件。
python_module/megengine/functional/nn.py
浏览文件 @
caf1fac2
...
...
@@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph
def
linear
(
inp
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
"""Applies a linear transformation to the input.
Refer to :class:`~.Linear` for more information.
Refer to :class:`~.
module.linear.
Linear` for more information.
:param inp: the input tensor with shape `(N, in_features)`.
:param weight: the weight with shape `(out_features, in_features)`.
:param weight: the weight with shape `(out_features, in_features)`.
:param bias: the bias with shape `(out_features,)`.
Default: ``None``
"""
...
...
@@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
def
softplus
(
inp
:
Tensor
,
beta
:
float
=
1
,
threshold
:
float
=
20
)
->
Tensor
:
r
"""
Performs the elementwise function:
.. math::
\mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta.
For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`.
...
...
python_module/megengine/module/__init__.py
浏览文件 @
caf1fac2
...
...
@@ -16,7 +16,7 @@ from .elemwise import Elemwise
from
.embedding
import
Embedding
from
.identity
import
Identity
from
.linear
import
Linear
from
.module
import
Module
,
QATModule
from
.module
import
Module
from
.parampack
import
ParamPack
from
.pooling
import
AvgPool2d
,
MaxPool2d
from
.quant_dequant
import
DequantStub
,
QuantStub
...
...
python_module/megengine/module/concat.py
浏览文件 @
caf1fac2
...
...
@@ -9,19 +9,14 @@ from typing import Iterable
from
..
import
functional
as
F
from
..core.tensor
import
Tensor
from
.module
import
QAT
Module
from
.module
import
Module
class
Concat
(
QAT
Module
):
class
Concat
(
Module
):
r
"""
A :class:`~.
QATModule` to do functional concat, should replace concat with this module,
supporting ``qat`` mode and ``quantized`` mode
.
A :class:`~.
Module` to do functional concat. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`
.
"""
def
forward
(
self
,
inps
:
Iterable
[
Tensor
],
axis
:
int
=
0
):
return
F
.
concat
(
inps
,
axis
)
def
forward_qat
(
self
,
inps
:
Iterable
[
Tensor
],
axis
:
int
=
0
):
return
self
.
apply_fakequant_with_observer
(
self
.
forward
(
inps
,
axis
),
self
.
act_fake_quant
,
self
.
act_observer
)
python_module/megengine/module/conv_bn_relu.py
浏览文件 @
caf1fac2
...
...
@@ -7,14 +7,13 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Tuple
,
Union
from
..core
import
ones
,
zeros
from
..functional
import
add_update
,
flatten
,
relu
,
sqrt
,
sum
,
zero_grad
from
..functional
import
relu
from
.batchnorm
import
BatchNorm2d
from
.conv
import
Conv2d
from
.module
import
QAT
Module
from
.module
import
Module
class
_ConvBn
2d
(
QAT
Module
):
class
_ConvBn
Activation2d
(
Module
):
def
__init__
(
self
,
in_channels
:
int
,
...
...
@@ -47,171 +46,24 @@ class _ConvBn2d(QATModule):
)
self
.
bn
=
BatchNorm2d
(
out_channels
,
eps
,
momentum
,
affine
,
track_running_stats
)
def
get_batch_mean_var
(
self
,
inp
):
def
_sum_channel
(
inp
,
axis
=
0
,
keepdims
=
True
):
if
isinstance
(
axis
,
int
):
out
=
sum
(
inp
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
isinstance
(
axis
,
tuple
):
for
idx
,
elem
in
enumerate
(
axis
):
out
=
sum
(
inp
if
idx
==
0
else
out
,
axis
=
elem
,
keepdims
=
keepdims
)
return
out
sum1
=
_sum_channel
(
inp
,
(
0
,
2
,
3
))
sum2
=
_sum_channel
(
inp
**
2
,
(
0
,
2
,
3
))
reduce_size
=
inp
.
shapeof
().
prod
()
/
inp
.
shapeof
(
1
)
batch_mean
=
sum1
/
reduce_size
batch_var
=
(
sum2
-
sum1
**
2
/
reduce_size
)
/
reduce_size
return
batch_mean
,
batch_var
def
fold_weight_bias
(
self
,
bn_mean
,
bn_var
):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
if
bn_var
is
None
:
bn_var
=
ones
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
conv_bias
=
self
.
conv
.
bias
if
conv_bias
is
None
:
conv_bias
=
zeros
(
self
.
conv
.
_infer_bias_shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor
=
gamma
*
bn_istd
if
self
.
conv
.
groups
==
1
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
else
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
self
.
conv
.
groups
,
-
1
,
1
,
1
,
1
)
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold
=
beta
+
gamma
*
(
conv_bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
update_running_mean_and_running_var
(
self
,
bn_mean
,
bn_var
,
num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean
=
zero_grad
(
bn_mean
)
bn_var
=
(
zero_grad
(
bn_var
)
*
num_elements_per_channel
/
(
num_elements_per_channel
-
1
)
)
exponential_average_factor
=
1
-
self
.
bn
.
momentum
add_update
(
self
.
bn
.
running_mean
,
delta
=
bn_mean
,
alpha
=
1
-
exponential_average_factor
,
beta
=
exponential_average_factor
,
)
add_update
(
self
.
bn
.
running_var
,
delta
=
bn_var
,
alpha
=
1
-
exponential_average_factor
,
beta
=
exponential_average_factor
,
)
def
calc_conv_bn_qat
(
self
,
inp
,
approx
=
True
):
if
self
.
training
and
not
approx
:
conv
=
self
.
conv
(
inp
)
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
conv
)
num_elements_per_channel
=
conv
.
shapeof
().
prod
()
/
conv
.
shapeof
(
1
)
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
else
:
bn_mean
,
bn_var
=
self
.
bn
.
running_mean
,
self
.
bn
.
running_var
# get gamma and beta in BatchNorm
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
# conv_bias
conv_bias
=
self
.
conv
.
bias
if
conv_bias
is
None
:
conv_bias
=
zeros
(
self
.
conv
.
_infer_bias_shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor
=
gamma
*
bn_istd
if
self
.
conv
.
groups
==
1
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
else
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
self
.
conv
.
groups
,
-
1
,
1
,
1
,
1
)
b_fold
=
None
if
not
(
self
.
training
and
approx
):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold
=
beta
+
gamma
*
(
conv_bias
-
bn_mean
)
*
bn_istd
w_qat
=
self
.
apply_fakequant_with_observer
(
w_fold
,
self
.
weight_fake_quant
,
self
.
weight_observer
)
conv
=
self
.
conv
.
calc_conv
(
inp
,
w_qat
,
b_fold
)
if
not
(
self
.
training
and
approx
):
return
conv
# rescale conv to get original conv output
orig_conv
=
conv
/
scale_factor
.
reshape
(
1
,
-
1
,
1
,
1
)
if
self
.
conv
.
bias
is
not
None
:
orig_conv
=
orig_conv
+
self
.
conv
.
bias
# calculate batch norm
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
orig_conv
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
conv
=
gamma
*
bn_istd
*
(
orig_conv
-
bn_mean
)
+
beta
num_elements_per_channel
=
conv
.
shapeof
().
prod
()
/
conv
.
shapeof
(
1
)
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
return
conv
class
ConvBn2d
(
_ConvBn2d
):
class
ConvBn2d
(
_ConvBnActivation2d
):
r
"""
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode
and ``normal`` mode.
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using
:func:`~.quantize.quantize_qat`.
"""
def
forward_qat
(
self
,
inp
):
return
self
.
apply_fakequant_with_observer
(
self
.
calc_conv_bn_qat
(
inp
),
self
.
act_fake_quant
,
self
.
act_observer
)
def
forward
(
self
,
inp
):
return
self
.
bn
(
self
.
conv
(
inp
))
class
ConvBnRelu2d
(
_ConvBn2d
):
class
ConvBnRelu2d
(
_ConvBn
Activation
2d
):
r
"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat``
mode and ``normal`` mode.
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using
:func:`~.quantize.quantize_qat`.
"""
def
forward_qat
(
self
,
inp
):
return
self
.
apply_fakequant_with_observer
(
relu
(
self
.
calc_conv_bn_qat
(
inp
)),
self
.
act_fake_quant
,
self
.
act_observer
)
def
forward
(
self
,
inp
):
return
relu
(
self
.
bn
(
self
.
conv
(
inp
)))
python_module/megengine/module/elemwise.py
浏览文件 @
caf1fac2
...
...
@@ -8,7 +8,7 @@
from
..
import
_internal
as
mgb
from
..core
import
Tensor
,
wrap_io_tensor
from
..core.graph
import
_use_default_if_none
from
.module
import
QAT
Module
from
.module
import
Module
@
wrap_io_tensor
...
...
@@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor:
return
mgb
.
opr
.
elemwise
(
*
inputs
,
mode
=
mode
,
**
kwargs
)
class
Elemwise
(
QAT
Module
):
class
Elemwise
(
Module
):
r
"""
A :class:`~.
QATModule` to do elemwise operator, should functional operator with this module,
supporting ``qat`` mode and ``normal`` mode
.
A :class:`~.
Module` to do elemwise operator. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`
.
:param method: the elemwise method, support the following string.
It will do the normal elemwise operator for float.
...
...
@@ -88,8 +88,3 @@ class Elemwise(QATModule):
def
forward
(
self
,
*
inps
):
return
_elemwise_func
(
self
.
method
,
*
inps
)
def
forward_qat
(
self
,
*
inps
):
return
self
.
apply_fakequant_with_observer
(
self
.
forward
(
*
inps
),
self
.
act_fake_quant
,
self
.
act_observer
,
)
python_module/megengine/module/linear.py
浏览文件 @
caf1fac2
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -11,10 +10,10 @@ import numpy as np
from
..
import
functional
as
F
from
..core
import
Parameter
from
.
import
init
from
.module
import
QAT
Module
from
.module
import
Module
class
Linear
(
QAT
Module
):
class
Linear
(
Module
):
r
"""Applies a linear transformation to the input. For instance, if input
is x, then output y is:
...
...
@@ -60,13 +59,3 @@ class Linear(QATModule):
def
forward
(
self
,
x
):
return
self
.
_calc_linear
(
x
,
self
.
weight
,
self
.
bias
)
def
forward_qat
(
self
,
x
):
w_qat
=
self
.
apply_fakequant_with_observer
(
self
.
weight
,
self
.
weight_fake_quant
,
self
.
weight_observer
)
return
self
.
apply_fakequant_with_observer
(
self
.
_calc_linear
(
x
,
w_qat
,
self
.
bias
),
self
.
act_fake_quant
,
self
.
act_observer
,
)
python_module/megengine/module/module.py
浏览文件 @
caf1fac2
...
...
@@ -7,7 +7,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
OrderedDict
from
enum
import
Enum
from
typing
import
Any
,
Callable
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta):
loaded
.
append
(
k
)
return
set
(
loaded
),
set
(
skipped
)
class
QATModule
(
Module
):
r
"""
Base class of quantization related Module. Add extra forward methods
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for
``qat``(quantization aware training) mode and ``quantized`` mode respectively.
Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode,
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode,
which is irreversible.
If you want to recursively switch mode for all QATModule in network, use
functions in :mod:`~.quantization.quantize`.
"""
class
QATMode
(
Enum
):
DISABLED
=
1
QAT
=
2
CALIBRATION
=
3
def
__init__
(
self
):
from
..quantization
import
(
QConfig
,
FakeQuantize
,
Observer
,
)
# pylint: disable=all
super
().
__init__
()
self
.
quantizing
=
self
.
QATMode
.
DISABLED
self
.
scale
=
None
self
.
weight_observer
=
None
# type: Observer
self
.
act_observer
=
None
# type: Observer
self
.
weight_fake_quant
=
None
# type: FakeQuantize
self
.
act_fake_quant
=
None
# type: FakeQuantize
def
set_qconfig
(
self
,
qconfig
:
"QConfig"
):
self
.
weight_observer
=
qconfig
.
weight_observer
()
self
.
act_observer
=
qconfig
.
act_observer
()
self
.
weight_fake_quant
=
(
None
if
qconfig
.
fake_quant
is
None
else
qconfig
.
fake_quant
(
self
.
weight_observer
.
dtype
)
)
self
.
act_fake_quant
=
(
None
if
qconfig
.
fake_quant
is
None
else
qconfig
.
fake_quant
(
self
.
act_observer
.
dtype
)
)
def
apply_observer
(
self
,
target
:
Tensor
,
obs
:
"Observer"
):
return
obs
(
target
)
def
apply_fakequant_with_observer
(
self
,
target
:
Tensor
,
fq
:
"FakeQuantize"
,
obs
:
"Observer"
):
oup
=
self
.
apply_observer
(
target
,
obs
)
if
fq
is
not
None
:
q_dict
=
obs
.
get_qparams
()
oup
=
fq
(
oup
,
q_dict
)
return
oup
def
set_qat_mode
(
self
,
mode
:
QATMode
):
r
"""
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``,
``QAT``,``CALIBRATION``.
"""
if
not
isinstance
(
mode
,
self
.
QATMode
):
raise
TypeError
(
"mode must be QATMode Enum type"
)
self
.
quantizing
=
mode
def
to_quantized
(
self
):
r
"""
Return a new :class:`~.Module` with quantized parameters of ``self``
according to scale and zero_point in ``self.xxx_observer``.
"""
raise
NotImplementedError
(
"Use megengine.quantization.quantize to register the method."
)
@
abstractmethod
def
forward_qat
(
self
,
*
args
,
**
kwargs
):
r
"""
Forward method for ``qat`` mode.
"""
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
self
.
quantizing
==
self
.
QATMode
.
DISABLED
:
return
self
.
forward
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_qat
(
*
args
,
**
kwargs
)
python_module/megengine/module/qat/__init__.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.concat
import
Concat
from
.conv_bn_relu
import
ConvBn2d
,
ConvBnRelu2d
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.module
import
QATModule
from
.quant_dequant
import
DequantStub
,
QuantStub
python_module/megengine/module/qat/concat.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Iterable
from
...core.tensor
import
Tensor
from
..
import
concat
as
Float
from
.module
import
QATModule
class
Concat
(
Float
.
Concat
,
QATModule
):
r
"""
A :class:`~.QATModule` to do functional concat with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def
forward
(
self
,
inps
:
Iterable
[
Tensor
],
axis
:
int
=
0
):
return
self
.
apply_quant_activation
(
super
().
forward
(
inps
,
axis
))
@
classmethod
def
from_float_module
(
cls
,
float_module
):
r
"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return
cls
()
python_module/megengine/module/qat/conv_bn_relu.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...core
import
ones
,
zeros
from
...functional
import
add_update
,
relu
,
sqrt
,
sum
,
zero_grad
from
..
import
conv_bn_relu
as
Float
from
.module
import
QATModule
class
_ConvBnActivation2d
(
Float
.
_ConvBnActivation2d
,
QATModule
):
def
get_batch_mean_var
(
self
,
inp
):
def
_sum_channel
(
inp
,
axis
=
0
,
keepdims
=
True
):
if
isinstance
(
axis
,
int
):
out
=
sum
(
inp
,
axis
=
axis
,
keepdims
=
keepdims
)
elif
isinstance
(
axis
,
tuple
):
for
idx
,
elem
in
enumerate
(
axis
):
out
=
sum
(
inp
if
idx
==
0
else
out
,
axis
=
elem
,
keepdims
=
keepdims
)
return
out
sum1
=
_sum_channel
(
inp
,
(
0
,
2
,
3
))
sum2
=
_sum_channel
(
inp
**
2
,
(
0
,
2
,
3
))
reduce_size
=
inp
.
shapeof
().
prod
()
/
inp
.
shapeof
(
1
)
batch_mean
=
sum1
/
reduce_size
batch_var
=
(
sum2
-
sum1
**
2
/
reduce_size
)
/
reduce_size
return
batch_mean
,
batch_var
def
fold_weight_bias
(
self
,
bn_mean
,
bn_var
):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
if
bn_mean
is
None
:
bn_mean
=
zeros
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
if
bn_var
is
None
:
bn_var
=
ones
((
1
,
self
.
bn
.
num_features
,
1
,
1
),
dtype
=
"float32"
)
conv_bias
=
self
.
conv
.
bias
if
conv_bias
is
None
:
conv_bias
=
zeros
(
self
.
conv
.
_infer_bias_shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor
=
gamma
*
bn_istd
if
self
.
conv
.
groups
==
1
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
else
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
self
.
conv
.
groups
,
-
1
,
1
,
1
,
1
)
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold
=
beta
+
gamma
*
(
conv_bias
-
bn_mean
)
*
bn_istd
return
w_fold
,
b_fold
def
update_running_mean_and_running_var
(
self
,
bn_mean
,
bn_var
,
num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean
=
zero_grad
(
bn_mean
)
bn_var
=
(
zero_grad
(
bn_var
)
*
num_elements_per_channel
/
(
num_elements_per_channel
-
1
)
)
exponential_average_factor
=
1
-
self
.
bn
.
momentum
add_update
(
self
.
bn
.
running_mean
,
delta
=
bn_mean
,
alpha
=
1
-
exponential_average_factor
,
beta
=
exponential_average_factor
,
)
add_update
(
self
.
bn
.
running_var
,
delta
=
bn_var
,
alpha
=
1
-
exponential_average_factor
,
beta
=
exponential_average_factor
,
)
def
calc_conv_bn_qat
(
self
,
inp
,
approx
=
True
):
if
self
.
training
and
not
approx
:
conv
=
self
.
conv
(
inp
)
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
conv
)
num_elements_per_channel
=
conv
.
shapeof
().
prod
()
/
conv
.
shapeof
(
1
)
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
else
:
bn_mean
,
bn_var
=
self
.
bn
.
running_mean
,
self
.
bn
.
running_var
# get gamma and beta in BatchNorm
gamma
=
self
.
bn
.
weight
if
gamma
is
None
:
gamma
=
ones
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
gamma
=
gamma
.
reshape
(
1
,
-
1
,
1
,
1
)
beta
=
self
.
bn
.
bias
if
beta
is
None
:
beta
=
zeros
((
self
.
bn
.
num_features
),
dtype
=
"float32"
)
beta
=
beta
.
reshape
(
1
,
-
1
,
1
,
1
)
# conv_bias
conv_bias
=
self
.
conv
.
bias
if
conv_bias
is
None
:
conv_bias
=
zeros
(
self
.
conv
.
_infer_bias_shape
(),
dtype
=
"float32"
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor
=
gamma
*
bn_istd
if
self
.
conv
.
groups
==
1
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
-
1
,
1
,
1
,
1
)
else
:
w_fold
=
self
.
conv
.
weight
*
scale_factor
.
reshape
(
self
.
conv
.
groups
,
-
1
,
1
,
1
,
1
)
b_fold
=
None
if
not
(
self
.
training
and
approx
):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold
=
beta
+
gamma
*
(
conv_bias
-
bn_mean
)
*
bn_istd
w_qat
=
self
.
apply_quant_weight
(
w_fold
)
conv
=
self
.
conv
.
calc_conv
(
inp
,
w_qat
,
b_fold
)
if
not
(
self
.
training
and
approx
):
return
conv
# rescale conv to get original conv output
orig_conv
=
conv
/
scale_factor
.
reshape
(
1
,
-
1
,
1
,
1
)
if
self
.
conv
.
bias
is
not
None
:
orig_conv
=
orig_conv
+
self
.
conv
.
bias
# calculate batch norm
bn_mean
,
bn_var
=
self
.
get_batch_mean_var
(
orig_conv
)
bn_istd
=
1.0
/
sqrt
(
bn_var
+
self
.
bn
.
eps
)
conv
=
gamma
*
bn_istd
*
(
orig_conv
-
bn_mean
)
+
beta
num_elements_per_channel
=
conv
.
shapeof
().
prod
()
/
conv
.
shapeof
(
1
)
self
.
update_running_mean_and_running_var
(
bn_mean
,
bn_var
,
num_elements_per_channel
)
return
conv
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
_ConvBnActivation2d
):
r
"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qat_module
=
cls
(
float_module
.
conv
.
in_channels
,
float_module
.
conv
.
out_channels
,
float_module
.
conv
.
kernel_size
,
float_module
.
conv
.
stride
,
float_module
.
conv
.
padding
,
float_module
.
conv
.
dilation
,
float_module
.
conv
.
groups
,
bool
(
float_module
.
conv
.
bias
),
float_module
.
conv
.
conv_mode
.
name
,
float_module
.
conv
.
compute_mode
.
name
,
)
qat_module
.
conv
.
weight
=
float_module
.
conv
.
weight
qat_module
.
conv
.
bias
=
float_module
.
conv
.
bias
qat_module
.
bn
=
float_module
.
bn
return
qat_module
class
ConvBn2d
(
_ConvBnActivation2d
):
r
"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
self
.
calc_conv_bn_qat
(
inp
))
class
ConvBnRelu2d
(
_ConvBnActivation2d
):
r
"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
relu
(
self
.
calc_conv_bn_qat
(
inp
)))
python_module/megengine/module/qat/elemwise.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..
import
elemwise
as
Float
from
.module
import
QATModule
class
Elemwise
(
Float
.
Elemwise
,
QATModule
):
r
"""
A :class:`~.QATModule` to do elemwise operator with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
"""
def
forward
(
self
,
*
inps
):
return
self
.
apply_quant_activation
(
super
().
forward
(
*
inps
))
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
Elemwise
):
r
"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return
cls
(
float_module
.
method
.
name
)
python_module/megengine/module/qat/linear.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..
import
linear
as
Float
from
.module
import
QATModule
class
Linear
(
Float
.
Linear
,
QATModule
):
r
"""
A :class:`~.QATModule` version of :class:`~.module.linear.Linear`.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
:param in_features: size of each input sample.
:param out_features: size of each output sample.
:param bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
"""
def
forward
(
self
,
x
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
return
self
.
apply_quant_activation
(
self
.
_calc_linear
(
x
,
w_qat
,
self
.
bias
),)
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
Linear
):
r
"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qmod
=
cls
(
float_module
.
in_features
,
float_module
.
out_features
)
qmod
.
weight
=
float_module
.
weight
qmod
.
bias
=
float_module
.
bias
return
qmod
python_module/megengine/module/qat/module.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
abc
import
abstractmethod
from
...core
import
Tensor
from
...quantization
import
FakeQuantize
,
Observer
,
QConfig
from
..module
import
Module
class
QATModule
(
Module
):
r
"""
Base class of quantized-float related Module, basically for QAT and Calibration.
Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`.
Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.
Can also be converted to :class:`~.QuantizedModule` for deployment using
:func:`~.quantize.quantize` further.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
scale
=
None
self
.
weight_observer
=
None
# type: Observer
self
.
act_observer
=
None
# type: Observer
self
.
weight_fake_quant
=
None
# type: FakeQuantize
self
.
act_fake_quant
=
None
# type: FakeQuantize
def
set_qconfig
(
self
,
qconfig
:
QConfig
):
r
"""
Set quantization related configs with ``qconfig``, including
observer and fake_quant for weight and activation.
"""
self
.
weight_observer
=
qconfig
.
weight_observer
()
self
.
act_observer
=
qconfig
.
act_observer
()
if
qconfig
.
fake_quant
is
None
:
self
.
weight_fake_quant
=
None
self
.
act_fake_quant
=
None
else
:
self
.
weight_fake_quant
=
qconfig
.
fake_quant
(
self
.
weight_observer
.
dtype
)
self
.
act_fake_quant
=
qconfig
.
fake_quant
(
self
.
act_observer
.
dtype
)
def
_apply_fakequant_with_observer
(
self
,
target
:
Tensor
,
fake_quant
:
FakeQuantize
,
observer
:
Observer
):
oup
=
observer
(
target
)
if
fake_quant
is
None
:
return
oup
else
:
q_dict
=
observer
.
get_qparams
()
return
fake_quant
(
oup
,
q_dict
)
def
apply_quant_weight
(
self
,
target
:
Tensor
):
r
"""
Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
"""
return
self
.
_apply_fakequant_with_observer
(
target
,
self
.
weight_fake_quant
,
self
.
weight_observer
)
def
apply_quant_activation
(
self
,
target
:
Tensor
):
r
"""
Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
"""
return
self
.
_apply_fakequant_with_observer
(
target
,
self
.
act_fake_quant
,
self
.
act_observer
)
def
get_weight_dtype
(
self
):
r
"""
Get weight's quantization dtype as the method from ``qconfig``.
"""
return
self
.
weight_observer
.
get_dtype
()
def
get_activation_dtype
(
self
):
r
"""
Get activation's quantization dtype as the method from ``qconfig``.
"""
return
self
.
act_observer
.
get_dtype
()
@
classmethod
@
abstractmethod
def
from_float_module
(
cls
,
float_module
:
Module
):
r
"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
python_module/megengine/module/qat/quant_dequant.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..
import
quant_dequant
as
Float
from
.module
import
QATModule
class
QuantStub
(
Float
.
QuantStub
,
QATModule
):
r
"""
A helper QATModule simply return input, but will quantize
input after converted to :class:`~.QuantizedModule`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
inp
)
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
QuantStub
):
r
"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return
cls
()
class
DequantStub
(
Float
.
DequantStub
,
QATModule
):
r
"""
A helper QATModule simply return input, but will de-quantize
input after converted to :class:`~.QuantizedModule`.
"""
def
forward
(
self
,
inp
):
return
inp
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
DequantStub
):
r
"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return
cls
()
python_module/megengine/module/quant_dequant.py
浏览文件 @
caf1fac2
...
...
@@ -5,30 +5,24 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.module
import
QAT
Module
from
.module
import
Module
class
QuantStub
(
QAT
Module
):
class
QuantStub
(
Module
):
r
"""
A helper QATModule doing quantize operation on input.
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.QuantStub` using :func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
inp
def
forward_qat
(
self
,
inp
):
return
self
.
apply_fakequant_with_observer
(
inp
,
self
.
act_fake_quant
,
self
.
act_observer
)
class
DequantStub
(
QATModule
):
class
DequantStub
(
Module
):
r
"""
A helper QATModule doing de-quantize operation on input.
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.DequantStub` using :func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
inp
def
forward_qat
(
self
,
inp
):
return
inp
python_module/megengine/module/quantized/__init__.py
浏览文件 @
caf1fac2
...
...
@@ -9,4 +9,5 @@ from .concat import Concat
from
.conv_bn_relu
import
ConvBn2d
,
ConvBnRelu2d
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.module
import
QuantizedModule
from
.quant_dequant
import
DequantStub
,
QuantStub
python_module/megengine/module/quantized/concat.py
浏览文件 @
caf1fac2
...
...
@@ -7,17 +7,15 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Iterable
from
...
import
_internal
as
mgb
from
...
import
functional
as
F
from
...
import
module
as
Float
from
...core.tensor
import
Tensor
from
..
.quantization.utils
import
register_method_to_class
from
.
.module
import
Module
from
..
qat
import
concat
as
QAT
from
.
module
import
Quantized
Module
class
Concat
(
Module
):
class
Concat
(
Quantized
Module
):
r
"""
A :class:`~.Module` to do quantized concat, inference only.
A :class:`~.
Quantized
Module` to do quantized concat, inference only.
"""
def
__init__
(
self
,
dtype
=
None
):
...
...
@@ -25,16 +23,13 @@ class Concat(Module):
self
.
output_dtype
=
dtype
def
forward
(
self
,
inps
:
Iterable
[
Tensor
],
axis
:
int
=
0
):
if
self
.
training
:
raise
ValueError
(
"quantized module only support inference."
)
new_inps
=
(
x
.
astype
(
self
.
output_dtype
)
for
x
in
inps
)
return
F
.
concat
(
new_inps
,
axis
)
@
register_method_to_class
(
Float
.
Concat
)
def
to_quantized
(
float_module
):
r
"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return
Concat
(
float_module
.
act_observer
.
get_dtype
())
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
Concat
):
r
"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return
cls
(
qat_module
.
get_activation_dtype
())
python_module/megengine/module/quantized/conv_bn_relu.py
浏览文件 @
caf1fac2
...
...
@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
functools
import
partial
from
typing
import
Tuple
,
Union
import
megengine._internal
as
mgb
...
...
@@ -13,11 +12,11 @@ import megengine._internal as mgb
from
...
import
module
as
Float
from
...core
import
Parameter
from
...functional
import
conv_bias_activation
from
..
.module
import
Conv2d
from
.
..quantization.utils
import
register_method_to_class
from
..
qat
import
conv_bn_relu
as
QAT
from
.
module
import
QuantizedModule
class
_ConvBnActivation2d
(
Conv2d
):
class
_ConvBnActivation2d
(
Float
.
Conv2d
,
QuantizedModule
):
r
"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d`
...
...
@@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d):
nonlinear_mode
=
nonlinear_mode
,
)
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
_ConvBnActivation2d
):
r
"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qconv
=
cls
(
qat_module
.
conv
.
in_channels
,
qat_module
.
conv
.
out_channels
,
qat_module
.
conv
.
kernel_size
,
qat_module
.
conv
.
stride
,
qat_module
.
conv
.
padding
,
qat_module
.
conv
.
dilation
,
qat_module
.
conv
.
groups
,
dtype
=
output_dtype
,
)
w_fold
,
b_fold
=
qat_module
.
fold_weight_bias
(
qat_module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
qat_module
.
get_weight_dtype
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
())
qconv
.
bias
=
Parameter
(
b_fold
.
numpy
())
return
qconv
class
ConvBn2d
(
_ConvBnActivation2d
):
r
"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`."""
def
forward
(
self
,
inp
):
if
self
.
training
:
raise
ValueError
(
"quantized module only support inference."
)
return
self
.
calc_conv_quantized
(
inp
,
nonlinear_mode
=
"IDENTITY"
)
class
ConvBnRelu2d
(
_ConvBnActivation2d
):
r
"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`."""
def
forward
(
self
,
inp
):
if
self
.
training
:
raise
ValueError
(
"quantized module only support inference."
)
return
self
.
calc_conv_quantized
(
inp
,
nonlinear_mode
=
"RELU"
)
def
to_quantized
(
quantized_class
,
float_module
):
output_dtype
=
float_module
.
act_observer
.
get_dtype
()
qconv
=
quantized_class
(
float_module
.
conv
.
in_channels
,
float_module
.
conv
.
out_channels
,
float_module
.
conv
.
kernel_size
,
float_module
.
conv
.
stride
,
float_module
.
conv
.
padding
,
float_module
.
conv
.
dilation
,
float_module
.
conv
.
groups
,
dtype
=
output_dtype
,
)
w_fold
,
b_fold
=
float_module
.
fold_weight_bias
(
float_module
.
bn
.
running_mean
,
float_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
float_module
.
weight_observer
.
get_dtype
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
())
qconv
.
bias
=
Parameter
(
b_fold
.
numpy
())
return
qconv
# replace :class:`~.module.QATModule`'s ``to_quantized`` method.
# implemented here to avoid circular import.
register_method_to_class
(
Float
.
ConvBn2d
)(
partial
(
to_quantized
,
ConvBn2d
))
register_method_to_class
(
Float
.
ConvBnRelu2d
)(
partial
(
to_quantized
,
ConvBnRelu2d
))
python_module/megengine/module/quantized/elemwise.py
浏览文件 @
caf1fac2
...
...
@@ -6,11 +6,10 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...
import
_internal
as
mgb
from
...
import
module
as
Float
from
...core
import
Tensor
,
wrap_io_tensor
from
...core.graph
import
_use_default_if_none
from
..
.quantization.utils
import
register_method_to_class
from
.
.module
import
Module
from
..
qat
import
elemwise
as
QAT
from
.
module
import
Quantized
Module
@
wrap_io_tensor
...
...
@@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor:
return
mgb
.
opr
.
elemwise_multi_type
(
*
inputs
,
mode
=
mode
,
**
kwargs
)
class
Elemwise
(
Module
):
r
"""
quantized module for elemwise operator, inference only.
:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`.
it will do quantized operator with specified output quantized dtype.
"""
class
Elemwise
(
QuantizedModule
):
r
"""quantized version of :class:`~.qat.elemwise.Elemwise`."""
_elemwise_multi_type_mode
=
mgb
.
opr_param_defs
.
ElemwiseMultiType
.
Mode
...
...
@@ -44,11 +38,10 @@ class Elemwise(Module):
raise
ValueError
(
"quantized module only support inference."
)
return
_elemwise_multi_type
(
self
.
method
,
*
inps
,
dtype
=
self
.
output_dtype
)
@
register_method_to_class
(
Float
.
Elemwise
)
def
to_quantized
(
float_module
):
r
"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return
Elemwise
(
float_module
.
method
.
name
,
float_module
.
act_observer
.
get_dtype
())
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
Elemwise
):
r
"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return
cls
(
qat_module
.
method
.
name
,
qat_module
.
get_activation_dtype
())
python_module/megengine/module/quantized/linear.py
浏览文件 @
caf1fac2
...
...
@@ -10,19 +10,13 @@ import numpy as np
import
megengine._internal
as
mgb
from
...
import
functional
as
F
from
...
import
module
as
Float
from
...core
import
Parameter
from
..
.quantization.utils
import
register_method_to_class
from
.
.module
import
Module
from
..
qat
import
linear
as
QAT
from
.
module
import
Quantized
Module
class
Linear
(
Module
):
r
"""Applies a quantized linear transformation to the input. The module
usually convert from QAT module by to_quantized method.
:param dtype: output data type.
"""
class
Linear
(
QuantizedModule
):
r
"""quantized version of :class:`~.qat.linear.Linear`."""
def
__init__
(
self
,
dtype
:
np
.
dtype
=
None
,
...
...
@@ -44,17 +38,16 @@ class Linear(Module):
None
if
self
.
bias
is
None
else
self
.
bias
.
astype
(
bias_dtype
),
).
astype
(
self
.
output_dtype
)
@
register_method_to_class
(
Float
.
Linear
)
def
to_quantized
(
float_module
):
r
"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
output_dtype
=
float_module
.
act_observer
.
get_dtype
()
qmod
=
Linear
(
dtype
=
output_dtype
,)
weight
=
float_module
.
weight
.
astype
(
float_module
.
weight_observer
.
get_dtype
())
qmod
.
weight
=
Parameter
(
weight
.
numpy
())
if
float_module
.
bias
is
not
None
:
qmod
.
bias
=
Parameter
(
float_module
.
bias
.
numpy
())
return
qmod
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
Linear
):
r
"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qmod
=
cls
(
dtype
=
output_dtype
)
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
qmod
.
weight
=
Parameter
(
weight
.
numpy
())
if
qat_module
.
bias
is
not
None
:
qmod
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
())
return
qmod
python_module/megengine/module/quantized/module.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
abc
import
abstractmethod
from
..module
import
Module
from
..qat
import
QATModule
class
QuantizedModule
(
Module
):
r
"""
Base class of quantized Module, which should be converted from QATModule
and not support traning.
"""
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
if
self
.
training
:
raise
ValueError
(
"quantized module only support inference."
)
return
super
().
__call__
(
*
inputs
,
**
kwargs
)
@
classmethod
@
abstractmethod
def
from_qat_module
(
cls
,
qat_module
:
QATModule
):
r
"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
python_module/megengine/module/quantized/quant_dequant.py
浏览文件 @
caf1fac2
...
...
@@ -5,15 +5,14 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...
import
_internal
as
mgb
from
...
import
module
as
Float
from
...quantization.utils
import
register_method_to_class
from
..module
import
Module
from
..qat
import
quant_dequant
as
QAT
from
.module
import
QuantizedModule
class
QuantStub
(
Module
):
class
QuantStub
(
Quantized
Module
):
r
"""
A helper quantize operation on input and inference only.
quantized version of :class:`~.qat.quant_dequant.QuantStub`,
will convert input to quantized dtype.
"""
def
__init__
(
self
,
dtype
=
None
):
...
...
@@ -21,35 +20,30 @@ class QuantStub(Module):
self
.
output_dtype
=
dtype
def
forward
(
self
,
inp
):
if
self
.
training
:
raise
ValueError
(
"quantized module only support inference."
)
return
inp
.
astype
(
self
.
output_dtype
)
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
QuantStub
):
r
"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return
cls
(
qat_module
.
get_activation_dtype
())
class
DequantStub
(
Module
):
class
DequantStub
(
QuantizedModule
):
r
"""
A helper de-quantize operation and inference only.
quantized version of :class:`~.qat.quant_dequant.DequantStub`,
will restore quantized input to float32 dtype.
"""
def
forward
(
self
,
inp
):
if
self
.
training
:
raise
ValueError
(
"quantized module only support inference."
)
return
inp
.
astype
(
"float32"
)
@
register_method_to_class
(
Float
.
QuantStub
)
def
to_quantized
(
float_module
):
r
"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return
QuantStub
(
float_module
.
act_observer
.
get_dtype
())
@
register_method_to_class
(
Float
.
DequantStub
)
def
to_quantized
(
float_module
):
r
"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return
DequantStub
()
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
DequantStub
):
r
"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return
cls
()
python_module/megengine/quantization/__init__.py
浏览文件 @
caf1fac2
...
...
@@ -13,12 +13,3 @@ from .qconfig import (
ema_fakequant_qconfig
,
min_max_fakequant_qconfig
,
)
from
.quantize
import
(
disable_fake_quant
,
disable_observer
,
enable_fake_quant
,
enable_observer
,
quantize
,
quantize_calibration
,
quantize_qat
,
)
python_module/megengine/quantization/qconfig.py
浏览文件 @
caf1fac2
...
...
@@ -15,16 +15,12 @@ from .observer import (
class
QConfig
:
"""
r
"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation`` and ``weight``.
And ``fake_quant`` parameter to indicate
See :meth:`~.QATModule.set_qconfig` for detail usage.
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating
-
how to collect scales and zero_point of wegiht.
how to collect scales and zero_point of wegiht.
:param act_observer: similar to ``weight_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different
...
...
python_module/megengine/quantization/quantize.py
浏览文件 @
caf1fac2
...
...
@@ -6,68 +6,125 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
copy
import
deepcopy
from
..module
import
Module
,
QATModule
,
Sequential
,
quantized
from
typing
import
Dict
,
Tuple
from
..
import
module
as
Float
from
..module
import
Module
from
..module
import
qat
as
QAT
from
..module
import
quantized
as
Quantized
from
..module.qat
import
QATModule
from
..module.quantized
import
QuantizedModule
from
.qconfig
import
QConfig
,
ema_fakequant_qconfig
def
_get_quantable_module_names
():
def
is_quantable
(
key
:
str
):
value
=
getattr
(
Quantized
,
key
)
return
(
isinstance
(
value
,
type
)
and
issubclass
(
value
,
QuantizedModule
)
and
value
!=
QuantizedModule
)
# source should have all quantable modules' names
quantable_module_names
=
[
key
for
key
in
dir
(
Quantized
)
if
is_quantable
(
key
)]
return
quantable_module_names
def
_get_convert_dict
()
->
Tuple
[
Dict
[
Module
,
QATModule
],
Dict
[
QATModule
,
QuantizedModule
]
]:
quantable_module_names
=
_get_quantable_module_names
()
quantable_modules
=
[
getattr
(
Float
,
key
)
for
key
in
quantable_module_names
]
qat_modules
=
[
getattr
(
QAT
,
key
)
for
key
in
quantable_module_names
]
quantized_modules
=
[
getattr
(
Quantized
,
key
)
for
key
in
quantable_module_names
]
float2qat_dict
=
dict
(
zip
(
quantable_modules
,
qat_modules
))
qat2quantized_dict
=
dict
(
zip
(
qat_modules
,
quantized_modules
))
return
float2qat_dict
,
qat2quantized_dict
_float2qat_dict
,
_qat2quantized_dict
=
_get_convert_dict
()
def
quantize
(
module
:
Module
,
inplace
=
True
):
r
"""
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`.
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`.
:param module: root module to do convert recursively.
:param inplace: whether to convert submodules in-place.
"""
if
not
inplace
:
module
=
deepcopy
(
module
)
def
is_qat_module
(
obj
):
return
isinstance
(
obj
,
QATModule
)
qat_modules
=
tuple
(
_qat2quantized_dict
.
keys
())
def
is_qat
(
mod
:
Module
):
return
isinstance
(
mod
,
qat_modules
)
# no need to pass prefix and get pure key of parent Module.
for
key
,
submodule
,
parent
in
module
.
_flatten
(
with_key
=
True
,
with_parent
=
True
,
predicate
=
is_qat
_module
with_key
=
True
,
with_parent
=
True
,
predicate
=
is_qat
):
if
isinstance
(
parent
,
Sequential
):
new_mod
=
_qat2quantized_dict
[
type
(
submodule
)].
from_qat_module
(
submodule
)
if
isinstance
(
parent
,
Float
.
Sequential
):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent
[
int
(
key
.
split
(
"."
)[
-
1
])]
=
submodule
.
to_quantized
()
parent
[
int
(
key
.
split
(
"."
)[
-
1
])]
=
new_mod
else
:
setattr
(
parent
,
key
.
split
(
"."
)[
-
1
],
submodule
.
to_quantized
()
)
setattr
(
parent
,
key
.
split
(
"."
)[
-
1
],
new_mod
)
return
module
def
quantize_qat
(
module
:
Module
,
qconfig
:
QConfig
=
ema_fakequant_qconfig
):
def
quantize_qat
(
module
:
Module
,
inplace
=
True
,
qconfig
:
QConfig
=
ema_fakequant_qconfig
):
r
"""
Recursively convert
`module` to `qat` mode through :meth:`~.Module.apply
`
and set qconfig relatively.
Recursively convert
float :class:`~.Module` to :class:`~.QATModule
`
through :meth:`~.Module.apply`
and set qconfig relatively.
:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
:param inplace: whether to convert submodules in-place.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is ``ema_fakequant_qconfig``.
"""
def
fn
(
mod
:
Module
):
if
isinstance
(
mod
,
QATModule
):
mod
.
set_qat_mode
(
QATModule
.
QATMode
.
QAT
)
mod
.
set_qconfig
(
qconfig
)
if
not
inplace
:
module
=
deepcopy
(
module
)
module
.
apply
(
fn
)
quantable_modules
=
tuple
(
_float2qat_dict
.
keys
())
def
is_quantable
(
mod
:
Module
):
return
isinstance
(
mod
,
quantable_modules
)
# no need to pass prefix and get pure key of parent Module.
for
key
,
submodule
,
parent
in
module
.
_flatten
(
with_key
=
True
,
with_parent
=
True
,
predicate
=
is_quantable
):
new_mod
=
_float2qat_dict
[
type
(
submodule
)].
from_float_module
(
submodule
)
if
isinstance
(
parent
,
Float
.
Sequential
):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent
[
int
(
key
.
split
(
"."
)[
-
1
])]
=
new_mod
else
:
setattr
(
parent
,
key
.
split
(
"."
)[
-
1
],
new_mod
)
propagate_qconfig
(
module
,
qconfig
)
return
module
def
quantize_calibration
(
module
:
Module
,
qconfig
:
QConfig
=
ema_fakequant_qc
onfig
):
def
propagate_qconfig
(
module
:
QATModule
,
qconfig
:
QC
onfig
):
r
"""
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply`
and set qconfig relatively.
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
:param module: root module to
do convert
recursively.
:param module: root module to
traverse
recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""
def
fn
(
mod
:
Module
):
if
isinstance
(
mod
,
QATModule
):
mod
.
set_qat_mode
(
QATModule
.
QATMode
.
CALIBRATION
)
mod
.
set_qconfig
(
qconfig
)
module
.
apply
(
fn
)
...
...
python_module/test/unit/module/test_conv_bn_relu.py
浏览文件 @
caf1fac2
...
...
@@ -5,8 +5,7 @@ import numpy as np
from
megengine
import
tensor
from
megengine.module
import
ConvBn2d
from
megengine.quantization
import
quantize_qat
from
megengine.quantization.quantize
import
disable_fake_quant
from
megengine.quantization.quantize
import
disable_fake_quant
,
quantize_qat
from
megengine.test
import
assertTensorClose
...
...
@@ -14,18 +13,17 @@ def test_convbn2d():
in_channels
=
32
out_channels
=
64
kernel_size
=
3
module
=
ConvBn2d
(
in_channels
,
out_channels
,
kernel_size
)
quantize_qat
(
module
)
for
groups
,
bias
in
product
([
1
,
4
],
[
True
,
False
]):
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_channels
,
32
,
32
).
astype
(
np
.
float32
))
module
=
ConvBn2d
(
in_channels
,
out_channels
,
kernel_size
,
groups
=
groups
,
bias
=
bias
)
module
.
train
()
qat_module
=
copy
.
deepcopy
(
modul
e
)
qat_module
=
quantize_qat
(
module
,
inplace
=
Fals
e
)
disable_fake_quant
(
qat_module
)
normal_outputs
=
module
.
forward
(
inputs
)
qat_outputs
=
qat_module
.
forward_qat
(
inputs
)
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_channels
,
32
,
32
).
astype
(
np
.
float32
))
normal_outputs
=
module
(
inputs
)
qat_outputs
=
qat_module
(
inputs
)
assertTensorClose
(
normal_outputs
,
qat_outputs
,
max_err
=
5e-6
)
a
=
module
.
bn
.
running_mean
.
numpy
()
b
=
qat_module
.
bn
.
running_mean
.
numpy
()
assertTensorClose
(
module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_mean
,
max_err
=
5e-8
)
...
...
@@ -33,7 +31,7 @@ def test_convbn2d():
module
.
bn
.
running_var
,
qat_module
.
bn
.
running_var
,
max_err
=
5e-7
)
module
.
eval
()
normal_outputs
=
module
.
forward
(
inputs
)
normal_outputs
=
module
(
inputs
)
qat_module
.
eval
()
qat_outputs
=
qat_module
.
forward_qat
(
inputs
)
qat_outputs
=
qat_module
(
inputs
)
assertTensorClose
(
normal_outputs
,
qat_outputs
,
max_err
=
5e-6
)
python_module/test/unit/quantization/quantize.py
0 → 100644
浏览文件 @
caf1fac2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
megengine
import
module
as
Float
from
megengine.module
import
qat
as
QAT
from
megengine.quantization.quantize
import
_get_quantable_module_names
def
test_get_quantable_module_names
():
# need to make sure names from Quantized and QAT are the same
def
_get_qat_module_names
():
def
is_qat
(
key
:
str
):
value
=
getattr
(
QAT
,
key
)
return
(
isinstance
(
value
,
type
)
and
issubclass
(
value
,
QAT
.
QATModule
)
and
value
!=
QAT
.
QATModule
)
# source should have all quantable modules' names
quantable_module_names
=
[
key
for
key
in
dir
(
QAT
)
if
is_qat
(
key
)]
return
quantable_module_names
qat_module_names
=
_get_qat_module_names
()
quantized_module_names
=
_get_quantable_module_names
()
assert
set
(
qat_module_names
)
==
set
(
quantized_module_names
)
for
key
in
qat_module_names
:
value
=
getattr
(
Float
,
key
)
assert
(
isinstance
(
value
,
type
)
and
issubclass
(
value
,
Float
.
Module
)
and
value
!=
Float
.
Module
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录