Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0b4918b2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0b4918b2
编写于
12月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(mge/quantization): classmethod `from_float_module` of qat module
GitOrigin-RevId: 95c3d45f83349825b7913556899002efdacdc971
上级
bb369383
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
80 addition
and
29 deletion
+80
-29
imperative/python/test/unit/quantization/test_module.py
imperative/python/test/unit/quantization/test_module.py
+80
-29
未找到文件。
imperative/python/test/unit/quantization/test_module.py
浏览文件 @
0b4918b2
...
...
@@ -8,7 +8,11 @@ import megengine.module.qat as QAT
import
megengine.module.quantized
as
Q
from
megengine.core.tensor
import
dtype
from
megengine.quantization
import
min_max_fakequant_qconfig
from
megengine.quantization.quantize
import
disable_observer
,
propagate_qconfig
from
megengine.quantization.quantize
import
(
disable_fake_quant
,
disable_observer
,
propagate_qconfig
,
)
"""
Calculate testing scales based on ``min_max_fakequant_qconfig``
...
...
@@ -47,6 +51,12 @@ def init_qat_net(net):
def
test_quant_stub
():
normal_net
=
Float
.
QuantStub
()
normal_net
.
eval
()
qat_from_float
=
QAT
.
QuantStub
.
from_float_module
(
normal_net
)
qat_from_float
.
eval
()
disable_observer
(
qat_from_float
)
disable_fake_quant
(
qat_from_float
)
qat_net
=
QAT
.
QuantStub
()
qat_net
.
eval
()
disable_observer
(
qat_net
)
...
...
@@ -59,16 +69,25 @@ def test_quant_stub():
x
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
3
,
3
)).
astype
(
"float32"
))
normal_out
=
fake_quant
(
normal_net
(
x
),
act_scale
)
qat_out
=
qat_net
(
x
)
q_out
=
q_net
(
x
).
numpy
()
*
act_scale
np
.
testing
.
assert_allclose
(
qat_out
,
normal_out
)
np
.
testing
.
assert_allclose
(
q_out
,
normal_out
.
numpy
())
normal
=
normal_net
(
x
)
qat_without_fakequant
=
qat_from_float
(
x
)
fake_quant_normal
=
fake_quant
(
normal_net
(
x
),
act_scale
)
qat
=
qat_net
(
x
)
q
=
q_net
(
x
).
numpy
()
*
act_scale
np
.
testing
.
assert_allclose
(
qat_without_fakequant
,
normal
)
np
.
testing
.
assert_allclose
(
qat
,
fake_quant_normal
)
np
.
testing
.
assert_allclose
(
q
,
fake_quant_normal
.
numpy
())
def
test_dequant_stub
():
normal_net
=
Float
.
DequantStub
()
normal_net
.
eval
()
qat_from_float
=
QAT
.
DequantStub
.
from_float_module
(
normal_net
)
qat_from_float
.
eval
()
disable_fake_quant
(
qat_from_float
)
disable_observer
(
qat_from_float
)
qat_net
=
QAT
.
DequantStub
()
qat_net
.
eval
()
disable_observer
(
qat_net
)
...
...
@@ -83,17 +102,26 @@ def test_dequant_stub():
x
=
fake_quant
(
x
,
inp_scale
)
x
.
q_dict
[
"scale"
]
=
inp_scale
normal_out
=
normal_net
(
x
)
qat_out
=
qat_net
(
x
)
q_out
=
q_net
(
quant
(
x
,
inp_scale
)).
numpy
()
np
.
testing
.
assert_allclose
(
qat_out
,
normal_out
)
np
.
testing
.
assert_allclose
(
q_out
,
normal_out
.
numpy
())
normal
=
normal_net
(
x
)
qat_without_fakequant
=
qat_from_float
(
x
)
fake_quant_normal
=
normal_net
(
x
)
qat
=
qat_net
(
x
)
q
=
q_net
(
quant
(
x
,
inp_scale
)).
numpy
()
np
.
testing
.
assert_allclose
(
qat_without_fakequant
,
normal
)
np
.
testing
.
assert_allclose
(
qat
,
fake_quant_normal
)
np
.
testing
.
assert_allclose
(
q
,
fake_quant_normal
.
numpy
())
@
pytest
.
mark
.
parametrize
(
"kind"
,
[
"COS"
,
"RELU"
,
"ADD"
,
"MUL"
,
"FUSE_ADD_RELU"
])
def
test_elemwise
(
kind
):
normal_net
=
Float
.
Elemwise
(
kind
)
normal_net
.
eval
()
qat_from_float
=
QAT
.
Elemwise
.
from_float_module
(
normal_net
)
qat_from_float
.
eval
()
disable_observer
(
qat_from_float
)
disable_fake_quant
(
qat_from_float
)
qat_net
=
QAT
.
Elemwise
(
kind
)
qat_net
.
eval
()
disable_observer
(
qat_net
)
...
...
@@ -117,16 +145,22 @@ def test_elemwise(kind):
x1_int8
=
quant
(
x1
,
x1_scale
)
x2_int8
=
quant
(
x2
,
x2_scale
)
# test correctness of `Float`, `QAT` and `Quantized`
if
kind
in
(
"ADD"
,
"MUL"
,
"FUSE_ADD_RELU"
):
normal_out
=
fake_quant
(
normal_net
(
x1
,
x2
),
act_scale
)
qat_out
=
qat_net
(
x1
,
x2
)
q_out
=
q_net
(
x1_int8
,
x2_int8
).
numpy
()
*
act_scale
normal
=
normal_net
(
x1
,
x2
)
qat_without_fakequant
=
qat_from_float
(
x1
,
x2
)
fake_quant_normal
=
fake_quant
(
normal_net
(
x1
,
x2
),
act_scale
)
qat
=
qat_net
(
x1
,
x2
)
q
=
q_net
(
x1_int8
,
x2_int8
).
numpy
()
*
act_scale
else
:
normal_out
=
fake_quant
(
normal_net
(
x1
),
act_scale
)
qat_out
=
qat_net
(
x1
)
q_out
=
q_net
(
x1_int8
).
numpy
()
*
act_scale
np
.
testing
.
assert_allclose
(
qat_out
,
normal_out
)
np
.
testing
.
assert_allclose
(
q_out
,
normal_out
.
numpy
())
normal
=
normal_net
(
x1
)
qat_without_fakequant
=
qat_from_float
(
x1
)
fake_quant_normal
=
fake_quant
(
normal_net
(
x1
),
act_scale
)
qat
=
qat_net
(
x1
)
q
=
q_net
(
x1_int8
).
numpy
()
*
act_scale
np
.
testing
.
assert_allclose
(
qat_without_fakequant
,
normal
)
np
.
testing
.
assert_allclose
(
qat
,
fake_quant_normal
)
np
.
testing
.
assert_allclose
(
q
,
fake_quant_normal
.
numpy
())
def
test_linear
():
...
...
@@ -153,20 +187,29 @@ def test_linear():
qat_net
.
weight
.
set_value
(
weight
)
qat_net
.
bias
.
set_value
(
bias
)
qat_from_float
=
QAT
.
Linear
.
from_float_module
(
normal_net
)
qat_from_float
.
eval
()
disable_fake_quant
(
qat_from_float
)
disable_observer
(
qat_from_float
)
q_net
=
Q
.
Linear
.
from_qat_module
(
qat_net
)
q_net
.
eval
()
normal_out
=
fake_quant
(
normal_net
(
x
),
act_scale
)
qat_out
=
qat_net
(
x
)
q_out
=
q_net
(
x_int8
).
numpy
()
*
act_scale
np
.
testing
.
assert_allclose
(
qat_out
,
normal_out
)
np
.
testing
.
assert_allclose
(
q_out
,
normal_out
.
numpy
())
normal
=
normal_net
(
x
)
qat_without_fakequant
=
qat_from_float
(
x
)
fake_quant_normal
=
fake_quant
(
normal_net
(
x
),
act_scale
)
qat
=
qat_net
(
x
)
q
=
q_net
(
x_int8
).
numpy
()
*
act_scale
np
.
testing
.
assert_allclose
(
qat_without_fakequant
,
normal
)
np
.
testing
.
assert_allclose
(
qat
,
fake_quant_normal
)
np
.
testing
.
assert_allclose
(
q
,
fake_quant_normal
.
numpy
())
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"Conv2d"
,
"ConvBn2d"
,
"ConvBnRelu2d"
])
def
test_conv
(
module
):
normal_net
=
getattr
(
Float
,
module
)(
3
,
3
,
3
,
1
,
1
,
1
,
bias
=
True
)
normal_net
.
eval
()
qat_net
=
getattr
(
QAT
,
module
)(
3
,
3
,
3
,
1
,
1
,
1
,
bias
=
True
)
qat_net
.
eval
()
disable_observer
(
qat_net
)
...
...
@@ -193,11 +236,19 @@ def test_conv(module):
qat_net
.
weight
.
set_value
(
weight
)
qat_net
.
bias
.
set_value
(
bias
)
qat_from_float
=
getattr
(
QAT
,
module
).
from_float_module
(
normal_net
)
qat_from_float
.
eval
()
disable_observer
(
qat_from_float
)
disable_fake_quant
(
qat_from_float
)
q_net
=
getattr
(
Q
,
module
).
from_qat_module
(
qat_net
)
q_net
.
eval
()
normal_out
=
fake_quant
(
normal_net
(
x
),
act_scale
)
qat_out
=
qat_net
(
x
)
q_out
=
q_net
(
x_int8
).
numpy
()
*
act_scale
np
.
testing
.
assert_allclose
(
qat_out
,
normal_out
)
np
.
testing
.
assert_allclose
(
q_out
,
normal_out
.
numpy
())
normal
=
normal_net
(
x
)
qat_without_fakequant
=
qat_from_float
(
x
)
fake_quant_normal
=
fake_quant
(
normal_net
(
x
),
act_scale
)
qat
=
qat_net
(
x
)
q
=
q_net
(
x_int8
).
numpy
()
*
act_scale
np
.
testing
.
assert_allclose
(
qat_without_fakequant
,
normal
,
atol
=
1e-6
)
np
.
testing
.
assert_allclose
(
qat
,
fake_quant_normal
)
np
.
testing
.
assert_allclose
(
q
,
fake_quant_normal
.
numpy
())
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录