Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a1ca50c9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
411
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a1ca50c9
编写于
2月 25, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/quantization): add name for quantized module
GitOrigin-RevId: edefbec7b70953144105c558bae34e3f792c02ec
上级
d0f70a44
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
172 addition
and
50 deletion
+172
-50
imperative/python/megengine/module/conv.py
imperative/python/megengine/module/conv.py
+2
-0
imperative/python/megengine/module/deformable_psroi_pooling.py
...ative/python/megengine/module/deformable_psroi_pooling.py
+2
-1
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+12
-2
imperative/python/megengine/module/qat/batch_matmul_activation.py
...ve/python/megengine/module/qat/batch_matmul_activation.py
+1
-0
imperative/python/megengine/module/qat/concat.py
imperative/python/megengine/module/qat/concat.py
+1
-1
imperative/python/megengine/module/qat/conv.py
imperative/python/megengine/module/qat/conv.py
+1
-0
imperative/python/megengine/module/qat/conv_bn.py
imperative/python/megengine/module/qat/conv_bn.py
+1
-0
imperative/python/megengine/module/qat/elemwise.py
imperative/python/megengine/module/qat/elemwise.py
+1
-1
imperative/python/megengine/module/qat/linear.py
imperative/python/megengine/module/qat/linear.py
+3
-1
imperative/python/megengine/module/qat/module.py
imperative/python/megengine/module/qat/module.py
+2
-2
imperative/python/megengine/module/qat/quant_dequant.py
imperative/python/megengine/module/qat/quant_dequant.py
+2
-2
imperative/python/megengine/module/quantized/batch_matmul_activation.py
...hon/megengine/module/quantized/batch_matmul_activation.py
+3
-2
imperative/python/megengine/module/quantized/concat.py
imperative/python/megengine/module/quantized/concat.py
+3
-3
imperative/python/megengine/module/quantized/conv.py
imperative/python/megengine/module/quantized/conv.py
+4
-2
imperative/python/megengine/module/quantized/conv_bn.py
imperative/python/megengine/module/quantized/conv_bn.py
+3
-2
imperative/python/megengine/module/quantized/elemwise.py
imperative/python/megengine/module/quantized/elemwise.py
+5
-3
imperative/python/megengine/module/quantized/linear.py
imperative/python/megengine/module/quantized/linear.py
+5
-5
imperative/python/megengine/module/quantized/quant_dequant.py
...rative/python/megengine/module/quantized/quant_dequant.py
+4
-4
imperative/python/test/unit/test_dump_naming.py
imperative/python/test/unit/test_dump_naming.py
+117
-19
未找到文件。
imperative/python/megengine/module/conv.py
浏览文件 @
a1ca50c9
...
...
@@ -641,6 +641,7 @@ class DeformableConv2d(_ConvNd):
bias
:
bool
=
True
,
conv_mode
:
str
=
"CROSS_CORRELATION"
,
compute_mode
:
str
=
"DEFAULT"
,
**
kwargs
):
kernel_size
=
_pair_nonzero
(
kernel_size
)
stride
=
_pair_nonzero
(
stride
)
...
...
@@ -657,6 +658,7 @@ class DeformableConv2d(_ConvNd):
dilation
,
groups
,
bias
,
**
kwargs
,
)
def
_get_fanin
(
self
):
...
...
imperative/python/megengine/module/deformable_psroi_pooling.py
浏览文件 @
a1ca50c9
...
...
@@ -21,8 +21,9 @@ class DeformablePSROIPooling(Module):
sample_per_part
,
spatial_scale
,
trans_std
:
float
=
0.1
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
(
**
kwargs
)
self
.
no_trans
=
no_trans
self
.
part_size
=
part_size
self
.
pooled_h
=
pooled_h
...
...
imperative/python/megengine/module/module.py
浏览文件 @
a1ca50c9
...
...
@@ -69,7 +69,17 @@ class Module(metaclass=ABCMeta):
Base Module class.
"""
def
__init__
(
self
,
name
=
""
):
def
__init__
(
self
,
name
=
None
):
"""
:param name: module's name, can be initialized by the ``kwargs`` parameter
of child class.
"""
if
name
is
not
None
:
assert
(
isinstance
(
name
,
str
)
and
name
.
strip
()
),
"Module's name must be a non-empty string"
self
.
name
=
name
# runtime attributes
...
...
@@ -109,7 +119,7 @@ class Module(metaclass=ABCMeta):
return
HookHandler
(
self
.
_forward_hooks
,
hook
)
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
auto_naming
.
push_scope
(
self
.
name
if
self
.
name
else
self
.
_name
)
auto_naming
.
push_scope
(
self
.
name
if
self
.
name
is
not
None
else
self
.
_name
)
for
hook
in
self
.
_forward_pre_hooks
.
values
():
modified_inputs
=
hook
(
self
,
inputs
)
if
modified_inputs
is
not
None
:
...
...
imperative/python/megengine/module/qat/batch_matmul_activation.py
浏览文件 @
a1ca50c9
...
...
@@ -28,6 +28,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule):
float_module
.
in_features
,
float_module
.
out_features
,
float_module
.
bias
is
not
None
,
name
=
float_module
.
name
,
)
qat_module
.
weight
=
float_module
.
weight
qat_module
.
bias
=
float_module
.
bias
...
...
imperative/python/megengine/module/qat/concat.py
浏览文件 @
a1ca50c9
...
...
@@ -27,4 +27,4 @@ class Concat(Float.Concat, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return
cls
()
return
cls
(
name
=
float_module
.
name
)
imperative/python/megengine/module/qat/conv.py
浏览文件 @
a1ca50c9
...
...
@@ -43,6 +43,7 @@ class Conv2d(Float.Conv2d, QATModule):
float_module
.
bias
is
not
None
,
float_module
.
conv_mode
,
float_module
.
compute_mode
,
name
=
float_module
.
name
,
)
qat_module
.
weight
=
float_module
.
weight
qat_module
.
bias
=
float_module
.
bias
...
...
imperative/python/megengine/module/qat/conv_bn.py
浏览文件 @
a1ca50c9
...
...
@@ -155,6 +155,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
float_module
.
conv
.
bias
is
not
None
,
float_module
.
conv
.
conv_mode
,
float_module
.
conv
.
compute_mode
,
name
=
float_module
.
name
,
)
qat_module
.
conv
.
weight
=
float_module
.
conv
.
weight
qat_module
.
conv
.
bias
=
float_module
.
conv
.
bias
...
...
imperative/python/megengine/module/qat/elemwise.py
浏览文件 @
a1ca50c9
...
...
@@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return
cls
(
float_module
.
method
)
return
cls
(
float_module
.
method
,
name
=
float_module
.
name
)
imperative/python/megengine/module/qat/linear.py
浏览文件 @
a1ca50c9
...
...
@@ -36,7 +36,9 @@ class Linear(Float.Linear, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qmod
=
cls
(
float_module
.
in_features
,
float_module
.
out_features
)
qmod
=
cls
(
float_module
.
in_features
,
float_module
.
out_features
,
name
=
float_module
.
name
)
qmod
.
weight
=
float_module
.
weight
qmod
.
bias
=
float_module
.
bias
return
qmod
imperative/python/megengine/module/qat/module.py
浏览文件 @
a1ca50c9
...
...
@@ -26,8 +26,8 @@ class QATModule(Module):
with_weight
=
True
with_act
=
True
def
__init__
(
self
):
super
().
__init__
()
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
weight_observer
=
None
# type: Observer
self
.
act_observer
=
None
# type: Observer
...
...
imperative/python/megengine/module/qat/quant_dequant.py
浏览文件 @
a1ca50c9
...
...
@@ -26,7 +26,7 @@ class QuantStub(Float.QuantStub, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return
cls
()
return
cls
(
name
=
float_module
.
name
)
class
DequantStub
(
Float
.
DequantStub
,
QATModule
):
...
...
@@ -47,4 +47,4 @@ class DequantStub(Float.DequantStub, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return
cls
()
return
cls
(
name
=
float_module
.
name
)
imperative/python/megengine/module/quantized/batch_matmul_activation.py
浏览文件 @
a1ca50c9
...
...
@@ -61,13 +61,14 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule):
qat_module
.
out_features
,
qat_module
.
bias
is
not
None
,
dtype
=
output_dtype
,
name
=
qat_module
.
name
,
)
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
weight
=
expand_dims
(
weight
,
[
-
1
,
-
2
])
qbmm
.
weight
=
Parameter
(
weight
.
numpy
())
qbmm
.
weight
=
Parameter
(
weight
.
numpy
()
,
name
=
qat_module
.
weight
.
name
)
if
qat_module
.
bias
is
not
None
:
bias
=
qat_module
.
bias
.
reshape
((
1
,
qbmm
.
out_features
,
1
,
1
))
qbmm
.
bias
=
Parameter
(
bias
.
numpy
())
qbmm
.
bias
=
Parameter
(
bias
.
numpy
()
,
name
=
qat_module
.
bias
.
name
)
else
:
qbmm
.
bias
=
Parameter
(
np
.
zeros
((
1
,
qbmm
.
out_features
,
1
,
1
),
dtype
=
np
.
float32
)
...
...
imperative/python/megengine/module/quantized/concat.py
浏览文件 @
a1ca50c9
...
...
@@ -18,8 +18,8 @@ class Concat(QuantizedModule):
A :class:`~.QuantizedModule` to do quantized :func:`~.concat`, used for inference only.
"""
def
__init__
(
self
,
dtype
=
None
):
super
().
__init__
()
def
__init__
(
self
,
dtype
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
output_dtype
=
dtype
def
forward
(
self
,
inps
:
Iterable
[
Tensor
],
axis
:
int
=
0
):
...
...
@@ -32,4 +32,4 @@ class Concat(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return
cls
(
qat_module
.
get_activation_dtype
())
return
cls
(
qat_module
.
get_activation_dtype
()
,
name
=
qat_module
.
name
)
imperative/python/megengine/module/quantized/conv.py
浏览文件 @
a1ca50c9
...
...
@@ -37,6 +37,7 @@ class Conv2d(Float.Conv2d, QuantizedModule):
conv_mode
:
str
=
"CROSS_CORRELATION"
,
compute_mode
:
str
=
"DEFAULT"
,
dtype
=
None
,
**
kwargs
):
super
().
__init__
(
in_channels
,
...
...
@@ -86,11 +87,12 @@ class Conv2d(Float.Conv2d, QuantizedModule):
qat_module
.
dilation
,
qat_module
.
groups
,
dtype
=
output_dtype
,
name
=
qat_module
.
name
,
)
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
()
,
name
=
qat_module
.
weight
.
name
)
if
qat_module
.
bias
is
not
None
:
qconv
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
())
qconv
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
()
,
name
=
qat_module
.
bias
.
name
)
else
:
qconv
.
bias
=
Parameter
(
np
.
zeros
(
qat_module
.
_infer_bias_shape
(),
dtype
=
np
.
float32
)
...
...
imperative/python/megengine/module/quantized/conv_bn.py
浏览文件 @
a1ca50c9
...
...
@@ -33,13 +33,14 @@ class _ConvBnActivation2d(Conv2d):
qat_module
.
conv
.
dilation
,
qat_module
.
conv
.
groups
,
dtype
=
output_dtype
,
name
=
qat_module
.
name
,
)
w_fold
,
b_fold
=
qat_module
.
fold_weight_bias
(
qat_module
.
bn
.
running_mean
,
qat_module
.
bn
.
running_var
)
weight
=
w_fold
.
astype
(
qat_module
.
get_weight_dtype
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
())
qconv
.
bias
=
Parameter
(
b_fold
.
numpy
())
qconv
.
weight
=
Parameter
(
weight
.
numpy
()
,
name
=
qat_module
.
conv
.
weight
.
name
)
qconv
.
bias
=
Parameter
(
b_fold
.
numpy
()
,
name
=
qat_module
.
conv
.
bias
.
name
)
return
qconv
...
...
imperative/python/megengine/module/quantized/elemwise.py
浏览文件 @
a1ca50c9
...
...
@@ -14,8 +14,8 @@ from .module import QuantizedModule
class
Elemwise
(
QuantizedModule
):
r
"""Quantized version of :class:`~.qat.Elemwise`."""
def
__init__
(
self
,
method
,
dtype
=
None
):
super
().
__init__
()
def
__init__
(
self
,
method
,
dtype
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
method
=
"Q"
+
method
self
.
output_dtype
=
dtype
...
...
@@ -30,4 +30,6 @@ class Elemwise(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return
cls
(
qat_module
.
method
,
qat_module
.
get_activation_dtype
())
return
cls
(
qat_module
.
method
,
qat_module
.
get_activation_dtype
(),
name
=
qat_module
.
name
)
imperative/python/megengine/module/quantized/linear.py
浏览文件 @
a1ca50c9
...
...
@@ -17,8 +17,8 @@ from .module import QuantizedModule
class
Linear
(
QuantizedModule
):
r
"""Quantized version of :class:`~.qat.Linear`."""
def
__init__
(
self
,
dtype
:
np
.
dtype
=
None
):
super
().
__init__
()
def
__init__
(
self
,
dtype
:
np
.
dtype
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
weight
=
None
self
.
bias
=
None
self
.
output_dtype
=
dtype
...
...
@@ -44,9 +44,9 @@ class Linear(QuantizedModule):
:class:`~.QATModule` instance.
"""
output_dtype
=
qat_module
.
get_activation_dtype
()
qmod
=
cls
(
dtype
=
output_dtype
)
qmod
=
cls
(
dtype
=
output_dtype
,
name
=
qat_module
.
name
)
weight
=
qat_module
.
weight
.
astype
(
qat_module
.
get_weight_dtype
())
qmod
.
weight
=
Parameter
(
weight
.
numpy
())
qmod
.
weight
=
Parameter
(
weight
.
numpy
()
,
name
=
qat_module
.
weight
.
name
)
if
qat_module
.
bias
is
not
None
:
qmod
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
())
qmod
.
bias
=
Parameter
(
qat_module
.
bias
.
numpy
()
,
name
=
qat_module
.
bias
.
name
)
return
qmod
imperative/python/megengine/module/quantized/quant_dequant.py
浏览文件 @
a1ca50c9
...
...
@@ -15,8 +15,8 @@ class QuantStub(QuantizedModule):
will convert input to quantized dtype.
"""
def
__init__
(
self
,
dtype
=
None
):
super
().
__init__
()
def
__init__
(
self
,
dtype
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
output_dtype
=
dtype
def
forward
(
self
,
inp
):
...
...
@@ -28,7 +28,7 @@ class QuantStub(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return
cls
(
qat_module
.
get_activation_dtype
())
return
cls
(
qat_module
.
get_activation_dtype
()
,
name
=
qat_module
.
name
)
class
DequantStub
(
QuantizedModule
):
...
...
@@ -46,4 +46,4 @@ class DequantStub(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return
cls
()
return
cls
(
name
=
qat_module
.
name
)
imperative/python/test/unit/test_dump_naming.py
浏览文件 @
a1ca50c9
...
...
@@ -17,6 +17,7 @@ import megengine.utils.comp_graph_tools as cgtools
from
megengine
import
Parameter
,
Tensor
from
megengine.core.tensor
import
megbrain_graph
as
G
from
megengine.jit.tracing
import
trace
from
megengine.quantization.quantize
import
quantize
,
quantize_qat
from
megengine.utils.naming
import
auto_naming
...
...
@@ -29,14 +30,14 @@ def _dump_and_load(func, symbolic, keep_opr_name=True):
func
.
dump
(
file
,
optimize_for_inference
=
False
,
arg_names
=
"x"
,
arg_names
=
(
"x"
,)
,
keep_opr_name
=
keep_opr_name
,
keep_var_name
=
2
,
)
file
.
seek
(
0
)
*
_
,
outputs
=
G
.
load_graph
(
file
)
op
=
cgtools
.
get_oprs_seq
(
outputs
)[
-
1
]
return
op
op
s
=
cgtools
.
get_oprs_seq
(
outputs
)
return
op
s
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
...
...
@@ -50,7 +51,7 @@ def test_auto_naming(symbolic):
return
x
+
x
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
op
=
_dump_and_load
(
m
,
symbolic
)
[
-
1
]
assert
op
.
name
==
"simple.ADD"
assert
op
.
outputs
[
0
].
name
==
"simple.ADD"
...
...
@@ -70,7 +71,7 @@ def test_user_named_tensor(symbolic):
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
op
=
_dump_and_load
(
m
,
symbolic
)
[
-
1
]
assert
op
.
name
==
"simple.ADD"
assert
op
.
outputs
[
0
].
name
==
"o_x"
...
...
@@ -88,7 +89,7 @@ def test_user_named_param(symbolic):
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
op
=
_dump_and_load
(
m
,
symbolic
)
[
-
1
]
assert
op
.
inputs
[
0
].
name
==
"x"
assert
op
.
inputs
[
1
].
name
==
"simple.k"
...
...
@@ -98,7 +99,7 @@ def test_without_module(symbolic):
def
f
(
x
):
return
2
*
x
op
=
_dump_and_load
(
f
,
symbolic
)
op
=
_dump_and_load
(
f
,
symbolic
)
[
-
1
]
assert
op
.
name
==
"MUL"
...
...
@@ -116,10 +117,10 @@ def test_with_submodule(symbolic):
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
name
==
"simple.linear.ADD"
assert
op
.
inputs
[
0
].
owner
.
name
==
"simple.linear.MatrixMul"
assert
op
.
outputs
[
0
].
name
==
"simple.linear.ADD"
op
s
=
_dump_and_load
(
m
,
symbolic
)
assert
op
s
[
-
1
]
.
name
==
"simple.linear.ADD"
assert
op
s
[
-
2
]
.
name
==
"simple.linear.MatrixMul"
assert
op
s
[
-
1
]
.
outputs
[
0
].
name
==
"simple.linear.ADD"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
...
...
@@ -136,10 +137,10 @@ def test_named_submodule(symbolic):
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
name
==
"simple.x.ADD"
assert
op
.
inputs
[
0
].
owner
.
name
==
"simple.x.MatrixMul"
assert
op
.
outputs
[
0
].
name
==
"simple.x.ADD"
op
s
=
_dump_and_load
(
m
,
symbolic
)
assert
op
s
[
-
1
]
.
name
==
"simple.x.ADD"
assert
op
s
[
-
2
]
.
name
==
"simple.x.MatrixMul"
assert
op
s
[
-
1
]
.
outputs
[
0
].
name
==
"simple.x.ADD"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
...
...
@@ -156,14 +157,111 @@ def test_with_same_operators(symbolic):
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
name
==
"simple.RELU[1]"
assert
op
.
inputs
[
0
].
owner
.
name
==
"simple.RELU[0]"
op
s
=
_dump_and_load
(
m
,
symbolic
)
assert
op
s
[
-
1
]
.
name
==
"simple.RELU[1]"
assert
op
s
[
-
2
]
.
name
==
"simple.RELU[0]"
def
test_not_keep_opr_name
():
def
f
(
x
):
return
2
*
x
op
=
_dump_and_load
(
f
,
True
,
False
)
op
=
_dump_and_load
(
f
,
True
,
False
)
[
-
1
]
assert
op
.
name
==
"MUL(x,2[2])[4]"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_quantized_module_auto_naming
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
(
name
=
name
)
self
.
quant
=
M
.
QuantStub
()
self
.
linear
=
M
.
Linear
(
3
,
3
,
bias
=
True
)
self
.
dequant
=
M
.
DequantStub
()
def
forward
(
self
,
x
):
out
=
self
.
quant
(
x
)
out
=
self
.
linear
(
out
)
out
=
self
.
dequant
(
out
)
return
out
m
=
Simple
(
"simple"
)
quantize_qat
(
m
)
quantize
(
m
)
m
.
eval
()
ops
=
_dump_and_load
(
m
,
symbolic
)
ops_name
=
(
"x"
,
"simple.quant.TypeCvt"
,
"simple.linear.MatrixMul"
,
"simple.linear.ADD"
,
"simple.linear.TypeCvt"
,
"simple.dequant.TypeCvt"
,
)
for
op
,
name
in
zip
(
ops
,
ops_name
):
assert
op
.
name
==
name
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_quantized_module_user_naming
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
(
name
=
name
)
self
.
quant
=
M
.
QuantStub
()
self
.
linear
=
M
.
Linear
(
3
,
3
,
bias
=
True
,
name
=
"user-linear"
)
self
.
dequant
=
M
.
DequantStub
()
def
forward
(
self
,
x
):
out
=
self
.
quant
(
x
)
out
=
self
.
linear
(
out
)
out
=
self
.
dequant
(
out
)
return
out
m
=
Simple
(
"simple"
)
quantize_qat
(
m
)
quantize
(
m
)
m
.
eval
()
ops
=
_dump_and_load
(
m
,
symbolic
)
ops_name
=
(
"x"
,
"simple.quant.TypeCvt"
,
"simple.user-linear.MatrixMul"
,
"simple.user-linear.ADD"
,
"simple.user-linear.TypeCvt"
,
"simple.dequant.TypeCvt"
,
)
for
op
,
name
in
zip
(
ops
,
ops_name
):
assert
op
.
name
==
name
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_quantized_module_user_naming_param
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
(
name
=
name
)
self
.
quant
=
M
.
QuantStub
()
self
.
linear
=
M
.
Linear
(
3
,
3
,
bias
=
True
)
self
.
dequant
=
M
.
DequantStub
()
self
.
linear
.
weight
.
name
=
"user-weight"
self
.
linear
.
bias
.
name
=
"user-bias"
def
forward
(
self
,
x
):
out
=
self
.
quant
(
x
)
out
=
self
.
linear
(
out
)
out
=
self
.
dequant
(
out
)
return
out
m
=
Simple
(
"simple"
)
quantize_qat
(
m
)
quantize
(
m
)
m
.
eval
()
ops
=
_dump_and_load
(
m
,
symbolic
)
(
matrix_mul_op
,)
=
[
op
for
op
in
ops
if
op
.
name
==
"simple.linear.MatrixMul"
]
for
var
in
matrix_mul_op
.
inputs
:
assert
var
.
name
in
(
"simple.quant.TypeCvt"
,
"simple.linear.user-weight"
)
# BUG bias' name does not meet expectations because of astype operator after quantization
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录