Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e6820b91
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e6820b91
编写于
5月 28, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/module): add conv and conv_relu quantization module
GitOrigin-RevId: 9cd668d97b4ccae8adfa801fd43856cd3fdca813
上级
a1f8ecc7
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
250 addition
and
35 deletion
+250
-35
python_module/megengine/module/__init__.py
python_module/megengine/module/__init__.py
+2
-2
python_module/megengine/module/conv.py
python_module/megengine/module/conv.py
+16
-6
python_module/megengine/module/conv_bn.py
python_module/megengine/module/conv_bn.py
+2
-2
python_module/megengine/module/qat/__init__.py
python_module/megengine/module/qat/__init__.py
+2
-1
python_module/megengine/module/qat/conv.py
python_module/megengine/module/qat/conv.py
+57
-0
python_module/megengine/module/qat/conv_bn.py
python_module/megengine/module/qat/conv_bn.py
+2
-2
python_module/megengine/module/quantized/__init__.py
python_module/megengine/module/quantized/__init__.py
+2
-1
python_module/megengine/module/quantized/conv.py
python_module/megengine/module/quantized/conv.py
+22
-21
python_module/megengine/module/quantized/conv_bn.py
python_module/megengine/module/quantized/conv_bn.py
+56
-0
python_module/megengine/quantization/quantize.py
python_module/megengine/quantization/quantize.py
+4
-0
python_module/test/unit/module/test_qat.py
python_module/test/unit/module/test_qat.py
+85
-0
未找到文件。
python_module/megengine/module/__init__.py
浏览文件 @
e6820b91
...
@@ -9,8 +9,8 @@
...
@@ -9,8 +9,8 @@
from
.activation
import
LeakyReLU
,
PReLU
,
ReLU
,
Sigmoid
,
Softmax
from
.activation
import
LeakyReLU
,
PReLU
,
ReLU
,
Sigmoid
,
Softmax
from
.batchnorm
import
BatchNorm1d
,
BatchNorm2d
,
SyncBatchNorm
from
.batchnorm
import
BatchNorm1d
,
BatchNorm2d
,
SyncBatchNorm
from
.concat
import
Concat
from
.concat
import
Concat
from
.conv
import
Conv2d
,
ConvTranspose2d
,
LocalConv2d
from
.conv
import
Conv2d
,
Conv
Relu2d
,
Conv
Transpose2d
,
LocalConv2d
from
.conv_bn
_relu
import
ConvBn2d
,
ConvBnRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.dropout
import
Dropout
from
.dropout
import
Dropout
from
.elemwise
import
Elemwise
from
.elemwise
import
Elemwise
from
.embedding
import
Embedding
from
.embedding
import
Embedding
...
...
python_module/megengine/module/conv.py
浏览文件 @
e6820b91
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
@@ -13,8 +12,8 @@ import numpy as np
...
@@ -13,8 +12,8 @@ import numpy as np
import
megengine._internal
as
mgb
import
megengine._internal
as
mgb
from
..
import
functional
as
F
from
..core
import
Parameter
from
..core
import
Parameter
from
..functional
import
conv2d
,
conv_transpose2d
,
local_conv2d
from
..utils.types
import
_pair
,
_pair_nonzero
from
..utils.types
import
_pair
,
_pair_nonzero
from
.
import
init
from
.
import
init
from
.module
import
Module
from
.module
import
Module
...
@@ -183,7 +182,7 @@ class Conv2d(_ConvNd):
...
@@ -183,7 +182,7 @@ class Conv2d(_ConvNd):
return
(
1
,
self
.
out_channels
,
1
,
1
)
return
(
1
,
self
.
out_channels
,
1
,
1
)
def
calc_conv
(
self
,
inp
,
weight
,
bias
):
def
calc_conv
(
self
,
inp
,
weight
,
bias
):
return
conv2d
(
return
F
.
conv2d
(
inp
,
inp
,
weight
,
weight
,
bias
,
bias
,
...
@@ -295,7 +294,7 @@ class ConvTranspose2d(_ConvNd):
...
@@ -295,7 +294,7 @@ class ConvTranspose2d(_ConvNd):
return
(
1
,
self
.
out_channels
,
1
,
1
)
return
(
1
,
self
.
out_channels
,
1
,
1
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
return
conv_transpose2d
(
return
F
.
conv_transpose2d
(
inp
,
inp
,
self
.
weight
,
self
.
weight
,
self
.
bias
,
self
.
bias
,
...
@@ -324,7 +323,7 @@ class LocalConv2d(Conv2d):
...
@@ -324,7 +323,7 @@ class LocalConv2d(Conv2d):
spatial dimensions. Only zero-padding is supported. Default: 0
spatial dimensions. Only zero-padding is supported. Default: 0
:param groups: number of groups to divide input and output channels into,
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``.
``in_channels`` and ``out_channels`` must be divisible by ``groups``.
The shape of weight is ``(groups, output_height, output_width,
The shape of weight is ``(groups, output_height, output_width,
in_channels // groups, *kernel_size, out_channels // groups)``.
in_channels // groups, *kernel_size, out_channels // groups)``.
"""
"""
...
@@ -377,6 +376,17 @@ class LocalConv2d(Conv2d):
...
@@ -377,6 +376,17 @@ class LocalConv2d(Conv2d):
)
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
return
local_conv2d
(
return
F
.
local_conv2d
(
inp
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
conv_mode
inp
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
conv_mode
)
)
class
ConvRelu2d
(
Conv2d
):
r
"""
A fused :class:`~.Module` including Conv2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using
:func:`~.quantize.quantize_qat`.
"""
def
forward
(
self
,
inp
):
return
F
.
relu
(
self
.
calc_conv
(
inp
,
self
.
weight
,
self
.
bias
))
python_module/megengine/module/conv_bn
_relu
.py
→
python_module/megengine/module/conv_bn.py
浏览文件 @
e6820b91
...
@@ -50,7 +50,7 @@ class _ConvBnActivation2d(Module):
...
@@ -50,7 +50,7 @@ class _ConvBnActivation2d(Module):
class
ConvBn2d
(
_ConvBnActivation2d
):
class
ConvBn2d
(
_ConvBnActivation2d
):
r
"""
r
"""
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn
_relu
.ConvBn2d` using
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using
:func:`~.quantize.quantize_qat`.
:func:`~.quantize.quantize_qat`.
"""
"""
...
@@ -61,7 +61,7 @@ class ConvBn2d(_ConvBnActivation2d):
...
@@ -61,7 +61,7 @@ class ConvBn2d(_ConvBnActivation2d):
class
ConvBnRelu2d
(
_ConvBnActivation2d
):
class
ConvBnRelu2d
(
_ConvBnActivation2d
):
r
"""
r
"""
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn
_relu
.ConvBnRelu2d` using
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using
:func:`~.quantize.quantize_qat`.
:func:`~.quantize.quantize_qat`.
"""
"""
...
...
python_module/megengine/module/qat/__init__.py
浏览文件 @
e6820b91
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.concat
import
Concat
from
.concat
import
Concat
from
.conv_bn_relu
import
ConvBn2d
,
ConvBnRelu2d
from
.conv
import
Conv2d
,
ConvRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.elemwise
import
Elemwise
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.linear
import
Linear
from
.module
import
QATModule
from
.module
import
QATModule
...
...
python_module/megengine/module/qat/conv.py
0 → 100644
浏览文件 @
e6820b91
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...
import
functional
as
F
from
..
import
conv
as
Float
from
.module
import
QATModule
class
Conv2d
(
Float
.
Conv2d
,
QATModule
):
r
"""
A :class:`~.QATModule` Conv2d with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def
calc_conv_qat
(
self
,
inp
):
w_qat
=
self
.
apply_quant_weight
(
self
.
weight
)
conv
=
self
.
calc_conv
(
inp
,
w_qat
,
self
.
bias
)
return
conv
@
classmethod
def
from_float_module
(
cls
,
float_module
:
Float
.
Conv2d
):
r
"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qat_module
=
cls
(
float_module
.
in_channels
,
float_module
.
out_channels
,
float_module
.
kernel_size
,
float_module
.
stride
,
float_module
.
padding
,
float_module
.
dilation
,
float_module
.
groups
,
float_module
.
bias
is
not
None
,
float_module
.
conv_mode
.
name
,
float_module
.
compute_mode
.
name
,
)
qat_module
.
weight
=
float_module
.
weight
qat_module
.
bias
=
float_module
.
bias
return
qat_module
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
self
.
calc_conv_qat
(
inp
))
class
ConvRelu2d
(
Conv2d
):
r
"""
A :class:`~.QATModule` include Conv2d and Relu with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def
forward
(
self
,
inp
):
return
self
.
apply_quant_activation
(
F
.
relu
(
self
.
calc_conv_qat
(
inp
)))
python_module/megengine/module/qat/conv_bn
_relu
.py
→
python_module/megengine/module/qat/conv_bn.py
浏览文件 @
e6820b91
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...core
import
ones
,
zeros
from
...core
import
ones
,
zeros
from
...functional
import
add_update
,
relu
,
sqrt
,
sum
,
zero_grad
from
...functional
import
add_update
,
relu
,
sqrt
,
sum
,
zero_grad
from
..
import
conv_bn
_relu
as
Float
from
..
import
conv_bn
as
Float
from
.module
import
QATModule
from
.module
import
QATModule
...
@@ -163,7 +163,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
...
@@ -163,7 +163,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
float_module
.
conv
.
padding
,
float_module
.
conv
.
padding
,
float_module
.
conv
.
dilation
,
float_module
.
conv
.
dilation
,
float_module
.
conv
.
groups
,
float_module
.
conv
.
groups
,
bool
(
float_module
.
conv
.
bias
)
,
float_module
.
conv
.
bias
is
not
None
,
float_module
.
conv
.
conv_mode
.
name
,
float_module
.
conv
.
conv_mode
.
name
,
float_module
.
conv
.
compute_mode
.
name
,
float_module
.
conv
.
compute_mode
.
name
,
)
)
...
...
python_module/megengine/module/quantized/__init__.py
浏览文件 @
e6820b91
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.concat
import
Concat
from
.concat
import
Concat
from
.conv_bn_relu
import
ConvBn2d
,
ConvBnRelu2d
from
.conv
import
Conv2d
,
ConvRelu2d
from
.conv_bn
import
ConvBn2d
,
ConvBnRelu2d
from
.elemwise
import
Elemwise
from
.elemwise
import
Elemwise
from
.linear
import
Linear
from
.linear
import
Linear
from
.module
import
QuantizedModule
from
.module
import
QuantizedModule
...
...
python_module/megengine/module/quantized/conv
_bn_relu
.py
→
python_module/megengine/module/quantized/conv.py
浏览文件 @
e6820b91
...
@@ -7,16 +7,19 @@
...
@@ -7,16 +7,19 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Tuple
,
Union
from
typing
import
Tuple
,
Union
import
numpy
as
np
import
megengine._internal
as
mgb
import
megengine._internal
as
mgb
from
...
import
module
as
Float
from
...
import
module
as
Float
from
...core
import
Parameter
from
...core
import
Parameter
from
...functional
import
conv_bias_activation
from
...functional
import
conv_bias_activation
from
..qat
import
conv
_bn_relu
as
QAT
from
..qat
import
conv
as
QAT
from
.module
import
QuantizedModule
from
.module
import
QuantizedModule
class
_ConvBnActivation2d
(
Float
.
Conv2d
,
QuantizedModule
):
class
Conv2d
(
Float
.
Conv2d
,
QuantizedModule
):
r
"""quantized version of :class:`~.qat.conv.Conv2d`."""
r
"""Applies a 2D convolution over an quantized input tensor, inference only.
r
"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d`
The parameter is same with :class: `~.Conv2d`
...
@@ -68,40 +71,38 @@ class _ConvBnActivation2d(Float.Conv2d, QuantizedModule):
...
@@ -68,40 +71,38 @@ class _ConvBnActivation2d(Float.Conv2d, QuantizedModule):
)
)
@
classmethod
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
_ConvBnActivation
2d
):
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
Conv
2d
):
r
"""
r
"""
return a :class:`~.QuantizedModule` instance converted from a
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
:class:`~.QATModule` instance.
"""
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
output_dtype
=
qat_module
.
get_activation_dtype
()
qconv
=
cls
(
qconv
=
cls
(
qat_module
.
conv
.
in_channels
,
qat_module
.
in_channels
,
qat_module
.
conv
.
out_channels
,
qat_module
.
out_channels
,
qat_module
.
conv
.
kernel_size
,
qat_module
.
kernel_size
,
qat_module
.
conv
.
stride
,
qat_module
.
stride
,
qat_module
.
conv
.
padding
,
qat_module
.
padding
,
qat_module
.
conv
.
dilation
,
qat_module
.
dilation
,
qat_module
.
conv
.
groups
,
qat_module
.
groups
,
dtype
=
output_dtype
,
dtype
=
output_dtype
,
)
)
w_fold
,
b_fold
=
qat_module
.
fold_weight_bias
(
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
qat_module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
qat_module
.
get_weight_dtype
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
())
qconv
.
bias
=
Parameter
(
b_fold
.
numpy
())
if
qat_module
.
bias
is
not
None
:
qconv
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
())
else
:
qconv
.
bias
=
Parameter
(
np
.
zeros
(
qat_module
.
_infer_bias_shape
(),
dtype
=
np
.
float32
)
)
return
qconv
return
qconv
class
ConvBn2d
(
_ConvBnActivation2d
):
r
"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`."""
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
return
self
.
calc_conv_quantized
(
inp
,
nonlinear_mode
=
"IDENTITY"
)
return
self
.
calc_conv_quantized
(
inp
,
nonlinear_mode
=
"IDENTITY"
)
class
Conv
BnRelu2d
(
_ConvBnActivation
2d
):
class
Conv
Relu2d
(
Conv
2d
):
r
"""quantized version of :class:`~.qat.conv
_bn_relu.ConvBn
Relu2d`."""
r
"""quantized version of :class:`~.qat.conv
.Conv
Relu2d`."""
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
return
self
.
calc_conv_quantized
(
inp
,
nonlinear_mode
=
"RELU"
)
return
self
.
calc_conv_quantized
(
inp
,
nonlinear_mode
=
"RELU"
)
python_module/megengine/module/quantized/conv_bn.py
0 → 100644
浏览文件 @
e6820b91
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
...core
import
Parameter
from
..qat
import
conv_bn
as
QAT
from
.conv
import
Conv2d
class
_ConvBnActivation2d
(
Conv2d
):
r
"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d`
"""
@
classmethod
def
from_qat_module
(
cls
,
qat_module
:
QAT
.
_ConvBnActivation2d
):
r
"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qconv
=
cls
(
qat_module
.
conv
.
in_channels
,
qat_module
.
conv
.
out_channels
,
qat_module
.
conv
.
kernel_size
,
qat_module
.
conv
.
stride
,
qat_module
.
conv
.
padding
,
qat_module
.
conv
.
dilation
,
qat_module
.
conv
.
groups
,
dtype
=
output_dtype
,
)
w_fold
,
b_fold
=
qat_module
.
fold_weight_bias
(
qat_module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
qat_module
.
get_weight_dtype
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
())
qconv
.
bias
=
Parameter
(
b_fold
.
numpy
())
return
qconv
class
ConvBn2d
(
_ConvBnActivation2d
):
r
"""quantized version of :class:`~.qat.conv_bn.ConvBn2d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_conv_quantized
(
inp
,
nonlinear_mode
=
"IDENTITY"
)
class
ConvBnRelu2d
(
_ConvBnActivation2d
):
r
"""quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`."""
def
forward
(
self
,
inp
):
return
self
.
calc_conv_quantized
(
inp
,
nonlinear_mode
=
"RELU"
)
python_module/megengine/quantization/quantize.py
浏览文件 @
e6820b91
...
@@ -104,6 +104,10 @@ def quantize_qat(
...
@@ -104,6 +104,10 @@ def quantize_qat(
for
key
,
submodule
,
parent
in
module
.
_flatten
(
for
key
,
submodule
,
parent
in
module
.
_flatten
(
with_key
=
True
,
with_parent
=
True
,
predicate
=
is_quantable
with_key
=
True
,
with_parent
=
True
,
predicate
=
is_quantable
):
):
# only convert top quantable module.
if
is_quantable
(
parent
):
continue
new_mod
=
_float2qat_dict
[
type
(
submodule
)].
from_float_module
(
submodule
)
new_mod
=
_float2qat_dict
[
type
(
submodule
)].
from_float_module
(
submodule
)
if
isinstance
(
parent
,
Float
.
Sequential
):
if
isinstance
(
parent
,
Float
.
Sequential
):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
...
...
python_module/test/unit/module/test_
conv_bn_relu
.py
→
python_module/test/unit/module/test_
qat
.py
浏览文件 @
e6820b91
import
copy
from
itertools
import
product
from
itertools
import
product
import
numpy
as
np
import
numpy
as
np
from
megengine
import
tensor
from
megengine
import
tensor
from
megengine.module
import
ConvBn2d
from
megengine.module
import
(
Conv2d
,
ConvBn2d
,
ConvRelu2d
,
DequantStub
,
Module
,
QuantStub
,
)
from
megengine.quantization.quantize
import
disable_fake_quant
,
quantize_qat
from
megengine.quantization.quantize
import
disable_fake_quant
,
quantize_qat
from
megengine.test
import
assertTensorClose
from
megengine.test
import
assertTensorClose
def
test_convbn2d
():
def
test_
qat_
convbn2d
():
in_channels
=
32
in_channels
=
32
out_channels
=
64
out_channels
=
64
kernel_size
=
3
kernel_size
=
3
...
@@ -35,3 +41,45 @@ def test_convbn2d():
...
@@ -35,3 +41,45 @@ def test_convbn2d():
qat_module
.
eval
()
qat_module
.
eval
()
qat_outputs
=
qat_module
(
inputs
)
qat_outputs
=
qat_module
(
inputs
)
assertTensorClose
(
normal_outputs
,
qat_outputs
,
max_err
=
5e-6
)
assertTensorClose
(
normal_outputs
,
qat_outputs
,
max_err
=
5e-6
)
def
test_qat_conv
():
in_channels
=
32
out_channels
=
64
kernel_size
=
3
class
TestNet
(
Module
):
def
__init__
(
self
,
groups
,
bias
):
super
().
__init__
()
self
.
quant
=
QuantStub
()
self
.
dequant
=
DequantStub
()
self
.
conv
=
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
groups
=
groups
,
bias
=
bias
)
self
.
conv_relu
=
ConvRelu2d
(
out_channels
,
in_channels
,
kernel_size
,
groups
=
groups
,
bias
=
bias
)
def
forward
(
self
,
inp
):
out
=
self
.
quant
(
inp
)
out
=
self
.
conv
(
out
)
out
=
self
.
conv_relu
(
out
)
out
=
self
.
dequant
(
out
)
return
out
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_channels
,
32
,
32
).
astype
(
np
.
float32
))
for
groups
,
bias
in
product
([
1
,
4
],
[
True
,
False
]):
net
=
TestNet
(
groups
,
bias
)
net
.
train
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
)
disable_fake_quant
(
qat_net
)
normal_outputs
=
net
(
inputs
)
qat_outputs
=
qat_net
(
inputs
)
assertTensorClose
(
normal_outputs
,
qat_outputs
)
net
.
eval
()
normal_outputs
=
net
(
inputs
)
qat_net
.
eval
()
qat_outputs
=
qat_net
(
inputs
)
assertTensorClose
(
normal_outputs
,
qat_outputs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录