Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1d7dd001
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1d7dd001
编写于
2月 09, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/quantization): add QParams and QuantDtypeMeta for quantization data structure
GitOrigin-RevId: df3416fe13fbff1cc6dd8f88f0a937aa1b6b58a9
上级
4130dcd3
变更
24
显示空白变更内容
内联
并排
Showing
24 changed file
with
570 addition
and
411 deletion
+570
-411
imperative/python/megengine/core/ops/builtin/__init__.py
imperative/python/megengine/core/ops/builtin/__init__.py
+0
-3
imperative/python/megengine/core/tensor/dtype.py
imperative/python/megengine/core/tensor/dtype.py
+118
-62
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+1
-1
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+0
-1
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+4
-7
imperative/python/megengine/module/qat/module.py
imperative/python/megengine/module/qat/module.py
+12
-9
imperative/python/megengine/quantization/__init__.py
imperative/python/megengine/quantization/__init__.py
+13
-3
imperative/python/megengine/quantization/fake_quant.py
imperative/python/megengine/quantization/fake_quant.py
+61
-50
imperative/python/megengine/quantization/internal_fake_quant.py
...tive/python/megengine/quantization/internal_fake_quant.py
+2
-0
imperative/python/megengine/quantization/observer.py
imperative/python/megengine/quantization/observer.py
+84
-80
imperative/python/megengine/quantization/qconfig.py
imperative/python/megengine/quantization/qconfig.py
+44
-61
imperative/python/megengine/quantization/quantize.py
imperative/python/megengine/quantization/quantize.py
+12
-9
imperative/python/megengine/quantization/utils.py
imperative/python/megengine/quantization/utils.py
+111
-31
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+28
-9
imperative/python/test/unit/core/test_dtype_quant.py
imperative/python/test/unit/core/test_dtype_quant.py
+3
-3
imperative/python/test/unit/core/test_serialization.py
imperative/python/test/unit/core/test_serialization.py
+3
-3
imperative/python/test/unit/core/test_tensor_wrapper.py
imperative/python/test/unit/core/test_tensor_wrapper.py
+22
-0
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+0
-25
imperative/python/test/unit/quantization/test_fake_quant.py
imperative/python/test/unit/quantization/test_fake_quant.py
+11
-7
imperative/python/test/unit/quantization/test_module.py
imperative/python/test/unit/quantization/test_module.py
+16
-10
imperative/python/test/unit/quantization/test_observer.py
imperative/python/test/unit/quantization/test_observer.py
+7
-6
imperative/python/test/unit/quantization/test_op.py
imperative/python/test/unit/quantization/test_op.py
+3
-2
imperative/python/test/unit/quantization/test_qconfig.py
imperative/python/test/unit/quantization/test_qconfig.py
+0
-14
imperative/python/test/unit/quantization/test_quantize.py
imperative/python/test/unit/quantization/test_quantize.py
+15
-15
未找到文件。
imperative/python/megengine/core/ops/builtin/__init__.py
浏览文件 @
1d7dd001
...
...
@@ -6,9 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
warnings
from
typing
import
Union
from
..._imperative_rt
import
OpDef
,
ops
__all__
=
[
"OpDef"
]
...
...
imperative/python/megengine/core/tensor/dtype.py
浏览文件 @
1d7dd001
...
...
@@ -5,22 +5,24 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
from
collections
import
namedtuple
from
typing
import
Union
import
numpy
as
np
# normal dtype related
from
.._imperative_rt
import
bfloat16
,
intb1
,
intb2
,
intb4
from
.._imperative_rt.common
import
(
bfloat16
,
get_scale
,
get_zero_point
,
intb1
,
intb2
,
intb4
,
is_dtype_equal
,
is_quantize
,
)
# normal dtype related
def
is_lowbit
(
dtype
):
return
(
dtype
is
intb1
)
or
(
dtype
is
intb2
)
or
(
dtype
is
intb4
)
...
...
@@ -30,34 +32,80 @@ def is_bfloat16(dtype):
# quantization dtype related
_QuantDtypeMetadata
=
collections
.
namedtuple
(
"QuantDtypeMetadata"
,
[
"name"
,
"np_dtype_str"
,
"is_unsigned"
,
"qmin"
,
"qmax"
,]
)
_metadata_dict
=
{
"quint8"
:
_QuantDtypeMetadata
(
"Quantized8Asymm"
,
"uint8"
,
True
,
0
,
255
),
"qint8"
:
_QuantDtypeMetadata
(
"QuantizedS8"
,
"int8"
,
False
,
-
128
,
127
),
"quint4"
:
_QuantDtypeMetadata
(
"Quantized4Asymm"
,
"uint8"
,
True
,
0
,
15
),
"qint4"
:
_QuantDtypeMetadata
(
"QuantizedS4"
,
"int8"
,
False
,
-
8
,
7
),
"qint32"
:
_QuantDtypeMetadata
(
"QuantizedS32"
,
"int32"
,
False
,
-
(
2
**
31
),
2
**
31
-
1
,
# use namedtuple to make class immutable, comparable and easy to print
class
QuantDtypeMeta
(
namedtuple
(
"QuantDtypeMeta"
,
[
"name"
,
"cname"
,
"np_dtype_str"
,
"qmin"
,
"qmax"
,
"is_unsigned"
],
)
):
r
"""
Store metadata for quantize dtype. Could be used to create custom quant dtype
for QAT when the network don't need to be converted for inference, but only
to export network metadata for third-party platform inference.
:param name: a unique name string.
:param cname: used in :func:`~.create_quantized_dtype` for model dump and inference.
:param np_dtype_str: used in :func:`~.create_quantized_dtype` to generate ``np.dtype``.
:param qmin: a int number indicating quant dtype's lowerbound.
:param qmax: a int number indicating quant dtype's upperbound.
:param is_unsigned: a helper value that could be inference from np_dtype_str.
"""
def
__new__
(
cls
,
name
:
str
,
cname
:
str
,
np_dtype_str
:
str
,
qmin
:
int
,
qmax
:
int
,
is_unsigned
:
bool
=
None
,
):
assert
isinstance
(
np_dtype_str
,
str
)
is_unsigned
=
np_dtype_str
[
0
]
==
"u"
if
is_unsigned
is
None
else
is_unsigned
return
super
().
__new__
(
cls
,
name
,
cname
,
np_dtype_str
,
qmin
,
qmax
,
is_unsigned
)
def
__copy__
(
self
):
return
self
def
__deepcopy__
(
self
,
_
):
"""
Ignore deepcopy so that a dtype meta can be treated as singleton, for more
strict check in :meth:`~.FakeQuantize.fake_quant_forward`.
"""
return
self
_builtin_quant_dtypes
=
{
"quint8"
:
QuantDtypeMeta
(
"quint8"
,
"Quantized8Asymm"
,
"uint8"
,
0
,
255
),
"qint8"
:
QuantDtypeMeta
(
"qint8"
,
"QuantizedS8"
,
"int8"
,
-
128
,
127
),
"qint8_narrow"
:
QuantDtypeMeta
(
"qint8_narrow"
,
"QuantizedS8"
,
"int8"
,
-
127
,
127
),
"quint4"
:
QuantDtypeMeta
(
"quint4"
,
"Quantized4Asymm"
,
"uint8"
,
0
,
15
),
"qint4"
:
QuantDtypeMeta
(
"qint4"
,
"QuantizedS4"
,
"int8"
,
-
8
,
7
),
"qint32"
:
QuantDtypeMeta
(
"qint32"
,
"QuantizedS32"
,
"int32"
,
-
(
2
**
31
),
2
**
31
-
1
,
),
# NOTE: int2 is not supported for model dump yet
"quint2"
:
_QuantDtypeMetadata
(
None
,
"uint8"
,
True
,
0
,
3
),
"qint2"
:
_QuantDtypeMetadata
(
None
,
"int8"
,
False
,
-
2
,
1
),
"quint2"
:
QuantDtypeMeta
(
"quint2"
,
None
,
"uint8"
,
0
,
3
),
"qint2"
:
QuantDtypeMeta
(
"qint2"
,
None
,
"int8"
,
-
2
,
1
),
}
def
_check_zero_point
(
zp
:
int
,
dtype_
str
:
str
):
qmin
=
_metadata_dict
[
dtype_str
]
.
qmin
qmax
=
_metadata_dict
[
dtype_str
]
.
qmax
def
_check_zero_point
(
zp
:
int
,
dtype_
meta
:
QuantDtypeMeta
):
qmin
=
dtype_meta
.
qmin
qmax
=
dtype_meta
.
qmax
if
zp
<
qmin
or
zp
>
qmax
:
raise
ValueError
(
"zero_point should be within [{}, {}] for {}"
.
format
(
qmin
,
qmax
,
dtype_str
)
"zero_point should be within [{}, {}] for {}"
.
format
(
qmin
,
qmax
,
dtype_meta
.
name
)
)
def
get_quantized_dtype
(
dtype_str
:
str
,
scale
:
float
,
zp
:
Union
[
int
,
None
]):
def
create_quantized_dtype
(
dtype_meta
:
QuantDtypeMeta
,
scale
:
float
,
zp
:
Union
[
int
,
None
]
):
r
"""
Get quantized dtype with metadata attribute according to _metadata_dict.
...
...
@@ -65,32 +113,34 @@ def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`.
:param dtype: a string indicating which dtype to return
:param dtype_meta: a QuantDtypeMeta indicating which dtype to return. the
``cname`` attribute cannot be ``None``.
:param scale: a number for scale to store in dtype's metadata
:param zp: a number for zero_point to store in dtype's metadata
"""
metadata
=
_metadata_dict
[
dtype_str
]
np_dtype_str
=
metadata
.
np_dtype_str
is_unsigned
=
metadata
.
is_unsigned
if
is_unsigned
:
if
dtype_meta
.
cname
is
None
:
raise
ValueError
(
"dtype {} without cname attr is not supported."
)
if
dtype_meta
.
is_unsigned
:
if
zp
is
None
or
int
(
zp
)
!=
zp
:
raise
ValueError
(
"zero_point should be an integer"
)
zp
=
int
(
zp
)
_check_zero_point
(
zp
,
dtype_
str
)
_check_zero_point
(
zp
,
dtype_
meta
)
return
np
.
dtype
(
np_dtype_str
,
dtype_meta
.
np_dtype_str
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
metadata
.
name
,
"name"
:
dtype_meta
.
c
name
,
"scale"
:
float
(
scale
),
"zero_point"
:
zp
,
}
},
)
else
:
# Don't trick to combine with is_unsigned. Metadata should not contain
# invalid field to keep consistent with c dtype.
return
np
.
dtype
(
np_dtype_str
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
metadata
.
name
,
"scale"
:
float
(
scale
)}},
dtype_meta
.
np_dtype_str
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
dtype_meta
.
c
name
,
"scale"
:
float
(
scale
)}},
)
...
...
@@ -100,7 +150,7 @@ def quint8(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point)
"""
return
get_quantized_dtype
(
"quint8"
,
scale
,
zero_point
)
return
create_quantized_dtype
(
_builtin_quant_dtypes
[
"quint8"
]
,
scale
,
zero_point
)
def
qint8
(
scale
):
...
...
@@ -108,7 +158,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val
"""
return
get_quantized_dtype
(
"qint8"
,
scale
,
None
)
return
create_quantized_dtype
(
_builtin_quant_dtypes
[
"qint8"
]
,
scale
,
None
)
def
qint32
(
scale
):
...
...
@@ -116,7 +166,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val
"""
return
get_quantized_dtype
(
"qint32"
,
scale
,
None
)
return
create_quantized_dtype
(
_builtin_quant_dtypes
[
"qint32"
]
,
scale
,
None
)
def
quint4
(
scale
,
zero_point
):
...
...
@@ -125,7 +175,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
return
get_quantized_dtype
(
"quint4"
,
scale
,
zero_point
)
return
create_quantized_dtype
(
_builtin_quant_dtypes
[
"quint4"
]
,
scale
,
zero_point
)
def
qint4
(
scale
):
...
...
@@ -133,42 +183,48 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return
get_quantized_dtype
(
"qint4"
,
scale
,
None
)
return
create_quantized_dtype
(
_builtin_quant_dtypes
[
"qint4"
]
,
scale
,
None
)
def
_convert_to_quantized_dtype
(
arr
:
np
.
ndarray
,
dtype
:
np
.
dtype
,
dtype_str
:
str
):
metadata
=
_metadata_dict
[
dtype_str
]
arr_metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
def
_convert_to_quantized_dtype
(
arr
:
np
.
ndarray
,
dtype
:
np
.
dtype
,
dtype_meta
:
QuantDtypeMeta
):
if
not
isinstance
(
arr
,
np
.
ndarray
):
raise
ValueError
(
"arr parameter should be instance of np.ndarray"
)
if
not
is_quantize
(
dtype
)
or
arr_metadata
[
"name"
]
!=
metadata
.
name
:
raise
ValueError
(
"dtype parameter should be a {} dtype"
.
format
(
dtype_str
))
is_unsigned
=
metadata
.
is_unsigned
if
is_unsigned
:
if
(
not
is_quantize
(
dtype
)
or
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
!=
dtype_meta
.
cname
):
raise
ValueError
(
"dtype parameter should be a {} dtype"
.
format
(
dtype_meta
))
arr_metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
if
dtype_meta
.
is_unsigned
:
scale
,
zp
=
(
arr_metadata
[
"scale"
],
arr_metadata
[
"zero_point"
],
)
return
(
(
np
.
round
(
arr
/
scale
)
+
zp
)
.
clip
(
metadata
.
qmin
,
metada
ta
.
qmax
)
.
clip
(
dtype_meta
.
qmin
,
dtype_me
ta
.
qmax
)
.
astype
(
dtype
)
)
else
:
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale
=
arr_metadata
[
"scale"
]
return
np
.
round
(
arr
/
scale
).
clip
(
metadata
.
qmin
,
metadata
.
qmax
).
astype
(
dtype
)
return
(
np
.
round
(
arr
/
scale
).
clip
(
dtype_meta
.
qmin
,
dtype_meta
.
qmax
).
astype
(
dtype
)
)
def
_convert_from_quantized_dtype
(
arr
:
np
.
ndarray
,
dtype_str
:
str
):
metadata
=
_metadata_dict
[
dtype_str
]
arr_metadata
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
]
def
_convert_from_quantized_dtype
(
arr
:
np
.
ndarray
,
dtype_meta
:
QuantDtypeMeta
):
if
not
isinstance
(
arr
,
np
.
ndarray
):
raise
ValueError
(
"arr parameter should be instance of np.ndarray"
)
if
not
is_quantize
(
arr
.
dtype
)
or
arr_metadata
[
"name"
]
!=
metadata
.
name
:
raise
ValueError
(
"arr's dtype should be a {} dtype"
.
format
(
dtype_str
))
is_unsigned
=
metadata
.
is_unsigned
if
is_unsigned
:
if
(
not
is_quantize
(
arr
.
dtype
)
or
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
!=
dtype_meta
.
cname
):
raise
ValueError
(
"arr's dtype should be a {} dtype"
.
format
(
dtype_meta
))
arr_metadata
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
]
if
dtype_meta
.
is_unsigned
:
scale
,
zp
=
(
arr_metadata
[
"scale"
],
arr_metadata
[
"zero_point"
],
...
...
@@ -187,7 +243,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a quint8.
"""
return
_convert_to_quantized_dtype
(
arr
,
q
,
"quint8"
)
return
_convert_to_quantized_dtype
(
arr
,
q
,
_builtin_quant_dtypes
[
"quint8"
]
)
def
convert_from_quint8
(
arr
:
np
.
ndarray
):
...
...
@@ -196,7 +252,7 @@ def convert_from_quint8(arr: np.ndarray):
:param arr: Input ndarray.
"""
return
_convert_from_quantized_dtype
(
arr
,
"quint8"
)
return
_convert_from_quantized_dtype
(
arr
,
_builtin_quant_dtypes
[
"quint8"
]
)
def
convert_to_qint8
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
...
...
@@ -206,7 +262,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
"""
return
_convert_to_quantized_dtype
(
arr
,
q
,
"qint8"
)
return
_convert_to_quantized_dtype
(
arr
,
q
,
_builtin_quant_dtypes
[
"qint8"
]
)
def
convert_from_qint8
(
arr
:
np
.
ndarray
):
...
...
@@ -215,7 +271,7 @@ def convert_from_qint8(arr: np.ndarray):
:param arr: Input ndarray.
"""
return
_convert_from_quantized_dtype
(
arr
,
"qint8"
)
return
_convert_from_quantized_dtype
(
arr
,
_builtin_quant_dtypes
[
"qint8"
]
)
def
convert_to_qint32
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
...
...
@@ -225,7 +281,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
"""
return
_convert_to_quantized_dtype
(
arr
,
q
,
"qint32"
)
return
_convert_to_quantized_dtype
(
arr
,
q
,
_builtin_quant_dtypes
[
"qint32"
]
)
def
convert_from_qint32
(
arr
):
...
...
@@ -234,7 +290,7 @@ def convert_from_qint32(arr):
:param arr: Input ndarray.
"""
return
_convert_from_quantized_dtype
(
arr
,
"qint32"
)
return
_convert_from_quantized_dtype
(
arr
,
_builtin_quant_dtypes
[
"qint32"
]
)
def
convert_to_quint4
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
...
...
@@ -244,7 +300,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a quint4.
"""
return
_convert_to_quantized_dtype
(
arr
,
q
,
"quint4"
)
return
_convert_to_quantized_dtype
(
arr
,
q
,
_builtin_quant_dtypes
[
"quint4"
]
)
def
convert_from_quint4
(
arr
:
np
.
ndarray
):
...
...
@@ -253,7 +309,7 @@ def convert_from_quint4(arr: np.ndarray):
:param arr: Input ndarray.
"""
return
_convert_from_quantized_dtype
(
arr
,
"quint4"
)
return
_convert_from_quantized_dtype
(
arr
,
_builtin_quant_dtypes
[
"quint4"
]
)
def
convert_to_qint4
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
...
...
@@ -263,7 +319,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint4.
"""
return
_convert_to_quantized_dtype
(
arr
,
q
,
"qint4"
)
return
_convert_to_quantized_dtype
(
arr
,
q
,
_builtin_quant_dtypes
[
"qint4"
]
)
def
convert_from_qint4
(
arr
:
np
.
ndarray
):
...
...
@@ -272,4 +328,4 @@ def convert_from_qint4(arr: np.ndarray):
:param arr: Input ndarray.
"""
return
_convert_from_quantized_dtype
(
arr
,
"qint4"
)
return
_convert_from_quantized_dtype
(
arr
,
_builtin_quant_dtypes
[
"qint4"
]
)
imperative/python/megengine/functional/nn.py
浏览文件 @
1d7dd001
...
...
@@ -203,7 +203,7 @@ def conv_transpose2d(
assert
compute_mode
==
"DEFAULT"
or
compute_mode
.
name
==
"DEFAULT"
if
groups
!=
1
:
raise
NotImplementedError
(
"
TODO
"
)
raise
NotImplementedError
(
"
group transposed conv2d is not supported yet.
"
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
pad_h
,
pad_w
=
expand_hw
(
padding
)
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
1d7dd001
...
...
@@ -13,7 +13,6 @@ import itertools
import
json
import
os
import
typing
import
warnings
import
weakref
import
numpy
as
np
...
...
imperative/python/megengine/module/module.py
浏览文件 @
1d7dd001
...
...
@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
warnings
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
OrderedDict
from
typing
import
Any
,
Callable
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
...
...
@@ -204,10 +203,9 @@ class Module(metaclass=ABCMeta):
if
"requires_grad"
in
kwargs
:
del
kwargs
[
"requires_grad"
]
warnings
.
warn
(
logger
.
warning
(
"Tensor currently has no requires_grad attribute "
"so requires_grad argument is ignored here"
,
DeprecationWarning
,
"so requires_grad argument is ignored here"
)
def
predicate
(
obj
)
->
bool
:
...
...
@@ -232,10 +230,9 @@ class Module(metaclass=ABCMeta):
if
"requires_grad"
in
kwargs
:
del
kwargs
[
"requires_grad"
]
warnings
.
warn
(
logger
.
warning
(
"Tensor currently has no requires_grad attribute "
"so requires_grad argument is ignored here"
,
DeprecationWarning
,
"so requires_grad argument is ignored here"
)
def
predicate
(
obj
)
->
bool
:
...
...
imperative/python/megengine/module/qat/module.py
浏览文件 @
1d7dd001
...
...
@@ -7,7 +7,10 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
abc
import
abstractmethod
from
...quantization
import
FakeQuantize
,
Observer
,
QConfig
# avoid circular reference
from
...quantization.fake_quant
import
FakeQuantize
from
...quantization.observer
import
Observer
from
...quantization.qconfig
import
QConfig
from
...tensor
import
Tensor
from
..module
import
Module
...
...
@@ -73,19 +76,19 @@ class QATModule(Module):
# do observer
if
observer
is
None
:
oup
=
target
q
_dict
=
None
q
params
=
None
else
:
oup
=
observer
(
target
)
q
_dict
=
observer
.
get_qparams
()
q
params
=
observer
.
get_qparams
()
# do fake quant
if
fake_quant
is
not
None
:
oup
=
fake_quant
(
oup
,
q
_dict
)
oup
=
fake_quant
(
oup
,
q
params
)
# use qparams of fake_quant if have.
if
hasattr
(
fake_quant
,
"get_qparams"
):
q
_dict
=
fake_quant
.
get_qparams
()
q
params
=
fake_quant
.
get_qparams
()
# set to tensor qparams.
if
q
_dict
is
not
None
:
oup
.
q
_dict
.
update
(
q_dict
)
if
q
params
is
not
None
:
oup
.
q
params
.
update
(
qparams
)
return
oup
def
apply_quant_weight
(
self
,
target
:
Tensor
):
...
...
@@ -118,7 +121,7 @@ class QATModule(Module):
Get weight's quantization dtype as the method from ``qconfig``.
"""
return
self
.
_get_method_result
(
"get_dtype"
,
self
.
weight_fake_quant
,
self
.
weight_observer
"get_
quantized_
dtype"
,
self
.
weight_fake_quant
,
self
.
weight_observer
)
def
get_activation_dtype
(
self
):
...
...
@@ -126,7 +129,7 @@ class QATModule(Module):
Get activation's quantization dtype as the method from ``qconfig``.
"""
return
self
.
_get_method_result
(
"get_dtype"
,
self
.
act_fake_quant
,
self
.
act_observer
"get_
quantized_
dtype"
,
self
.
act_fake_quant
,
self
.
act_observer
)
def
get_weight_qparams
(
self
):
...
...
imperative/python/megengine/quantization/__init__.py
浏览文件 @
1d7dd001
...
...
@@ -7,8 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.fake_quant
import
FakeQuantize
from
.internal_fake_quant
import
*
from
.observer
import
HistogramObserver
,
Observer
from
.observer
import
Observer
from
.qconfig
import
(
QConfig
,
calibration_qconfig
,
...
...
@@ -20,4 +19,15 @@ from .qconfig import (
sync_ema_fakequant_qconfig
,
tqt_qconfig
,
)
from
.utils
import
QuantMode
from
.quantize
import
(
apply_easy_quant
,
disable_fake_quant
,
disable_observer
,
enable_fake_quant
,
enable_observer
,
propagate_qconfig
,
quantize
,
quantize_qat
,
reset_qconfig
,
)
from
.utils
import
QParams
,
QuantMode
,
create_qparams
imperative/python/megengine/quantization/fake_quant.py
浏览文件 @
1d7dd001
...
...
@@ -6,40 +6,48 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
math
from
typing
import
Union
from
..
import
functional
as
F
from
..core.tensor.dtype
import
_metadata_dict
,
get_quantized_dtype
from
..core.tensor.dtype
import
QuantDtypeMeta
,
_builtin_quant_dtypes
from
..logger
import
get_logger
from
..module
import
Module
from
..tensor
import
Parameter
,
Tensor
from
.utils
import
QuantMode
,
fake_quant_tensor
,
get_qparam_dict
,
tqt_forward
from
..tensor
import
Parameter
from
.utils
import
(
QParams
,
QParamsModuleMixin
,
QuantMode
,
create_qparams
,
fake_quant_tensor
,
tqt_forward
,
)
logger
=
get_logger
(
__name__
)
class
_FakeQuantize
(
Module
):
r
"""
A Basic Fake Quant module.
:param dtype: a string indicating the target quantization type of input.
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
class
_FakeQuantize
(
Module
):
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
enable
:
bool
=
True
,
**
kwargs
self
,
dtype
:
Union
[
str
,
QuantDtypeMeta
]
,
enable
:
bool
=
True
,
**
kwargs
):
super
().
__init__
()
if
not
dtype
in
_metadata_dict
.
keys
():
if
isinstance
(
dtype
,
str
):
if
not
dtype
in
_builtin_quant_dtypes
:
raise
ValueError
(
"unknown dtype: {}, only support {}"
.
format
(
dtype
,
_metadata_dict
.
keys
()
dtype
,
_builtin_quant_dtypes
.
keys
()
)
)
self
.
dtype
=
dtype
self
.
narrow_range
=
narrow_range
self
.
qmin
=
(
-
_metadata_dict
[
dtype
].
qmax
if
narrow_range
else
_metadata_dict
[
dtype
].
qmin
dtype
=
_builtin_quant_dtypes
[
dtype
]
if
"narrow_range"
in
kwargs
:
del
kwargs
[
"narrow_range"
]
logger
.
warning
(
"FakeQuantize currently has no narrow_range param "
"so it is ignored here"
,
exc_info
=
DeprecationWarning
,
)
self
.
qmax
=
_metadata_dict
[
dtype
].
qmax
self
.
dtype
=
dtype
self
.
qmin
=
dtype
.
qmin
self
.
qmax
=
dtype
.
qmax
self
.
enabled
=
enable
def
enable
(
self
):
...
...
@@ -48,61 +56,64 @@ class _FakeQuantize(Module):
def
disable
(
self
):
self
.
enabled
=
False
def
fake_quant_forward
(
self
,
inp
,
q
_dict
=
None
):
r
eturn
inp
def
fake_quant_forward
(
self
,
inp
,
q
params
:
QParams
=
None
):
r
aise
NotImplementedError
def
normal_foward
(
self
,
inp
,
q
_dict
=
None
):
def
normal_foward
(
self
,
inp
,
q
params
:
QParams
=
None
):
return
inp
def
forward
(
self
,
inp
,
q
_dict
=
None
):
def
forward
(
self
,
inp
,
q
params
:
QParams
=
None
):
if
self
.
enabled
:
return
self
.
fake_quant_forward
(
inp
,
q
_dict
=
q_dict
)
return
self
.
fake_quant_forward
(
inp
,
q
params
=
qparams
)
else
:
return
self
.
normal_foward
(
inp
,
q
_dict
=
q_dict
)
return
self
.
normal_foward
(
inp
,
q
params
=
qparams
)
class
TQT
(
_FakeQuantize
):
class
TQT
(
_FakeQuantize
,
QParamsModuleMixin
):
r
"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
enable
:
bool
=
True
,
**
kwargs
self
,
dtype
:
Union
[
str
,
QuantDtypeMeta
]
,
enable
:
bool
=
True
,
**
kwargs
):
super
().
__init__
(
dtype
,
narrow_range
,
enable
,
**
kwargs
)
super
().
__init__
(
dtype
,
enable
,
**
kwargs
)
self
.
scale
=
Parameter
(
0.0
,
dtype
=
"float32"
)
def
fake_quant_forward
(
self
,
inp
,
q
_dict
=
None
):
def
fake_quant_forward
(
self
,
inp
,
q
params
:
QParams
=
None
):
# when enable, TQT will do fakequant forward, finetune the scale
return
tqt_forward
(
self
.
qmin
,
self
.
qmax
,
inp
,
self
.
scale
)
def
get_qparams
(
self
):
q_dict
=
get_qparam_dict
(
QuantMode
.
SYMMERTIC
)
q_dict
[
"scale"
]
=
2
**
self
.
scale
.
detach
()
return
q_dict
def
set_qparams
(
self
,
q_dict
):
def
set_qparams
(
self
,
qparams
:
QParams
):
assert
(
q
_dict
[
"mode"
]
==
QuantMode
.
SYMMERTIC
q
params
.
mode
==
QuantMode
.
SYMMERTIC
),
"only symmetric quantization is supported by TQT"
if
"scale"
not
in
q_dict
or
q_dict
[
"scale"
]
is
None
:
if
qparams
.
scale
is
None
:
raise
AssertionError
(
"Can not get an initialized scale"
)
self
.
scale
.
_reset
(
F
.
log
(
q_dict
[
"scale"
])
/
math
.
log
(
2
)
)
self
.
scale
[...]
=
F
.
log
(
qparams
.
scale
)
/
math
.
log
(
2
)
def
get_dtype
(
self
):
q_dict
=
self
.
get_qparams
()
scale
=
None
if
"scale"
not
in
q_dict
else
q_dict
[
"scale"
].
numpy
()
zero_point
=
(
None
if
"zero_point"
not
in
q_dict
else
q_dict
[
"zero_point"
].
numpy
()
)
return
get_quantized_dtype
(
self
.
dtype
,
scale
,
zero_point
)
def
get_qparams
(
self
):
return
create_qparams
(
QuantMode
.
SYMMERTIC
,
self
.
dtype
,
scale
=
2
**
self
.
scale
)
class
FakeQuantize
(
_FakeQuantize
):
r
"""
A module to do quant and dequant according to observer's scale and zero_point.
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def
fake_quant_forward
(
self
,
inp
,
q_dict
=
None
):
return
fake_quant_tensor
(
inp
,
self
.
qmin
,
self
.
qmax
,
q_dict
)
def
fake_quant_forward
(
self
,
inp
,
qparams
:
QParams
=
None
):
assert
(
qparams
.
dtype_meta
is
self
.
dtype
),
"input qparams' dtype is not equal to self.dtype.
\n
qparams.dtype_meta={}
\n
self.dtype={}"
.
format
(
qparams
.
dtype_meta
,
self
.
dtype
)
return
fake_quant_tensor
(
inp
,
qparams
)
imperative/python/megengine/quantization/internal_fake_quant.py
浏览文件 @
1d7dd001
...
...
@@ -16,4 +16,6 @@ from ..autodiff import Function
from
.fake_quant
import
_FakeQuantize
from
.observer
import
MinMaxObserver
from
.qconfig
import
QConfig
from
.utils
import
QParams
imperative/python/megengine/quantization/observer.py
浏览文件 @
1d7dd001
...
...
@@ -8,51 +8,51 @@
import
math
from
abc
import
abstractmethod
from
copy
import
deepcopy
from
typing
import
Union
import
numpy
as
np
from
..
import
functional
as
F
from
..core.tensor.dtype
import
_metadata_dict
,
get_quantized_dtype
from
..core.tensor.dtype
import
QuantDtypeMeta
,
_builtin_quant_dtypes
from
..distributed
import
WORLD
,
get_rank
,
is_distributed
from
..functional.distributed
import
all_reduce_max
,
all_reduce_min
from
..logger
import
get_logger
from
..module
import
Module
from
..tensor
import
Tensor
from
.utils
import
Q
uantMode
,
get_qparam_dict
from
.utils
import
Q
Params
,
QParamsModuleMixin
,
QuantMode
,
create_qparams
logger
=
get_logger
(
__name__
)
class
Observer
(
Module
):
class
Observer
(
Module
,
QParamsModuleMixin
):
r
"""
A base class for Observer Module.
:param dtype: a string indicating to collect scale and zero_point of which dtype.
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
"""
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
**
kwargs
):
def
__init__
(
self
,
dtype
:
Union
[
str
,
QuantDtypeMeta
]
,
**
kwargs
):
super
().
__init__
()
if
dtype
not
in
_metadata_dict
.
keys
():
if
isinstance
(
dtype
,
str
):
if
not
dtype
in
_builtin_quant_dtypes
:
raise
ValueError
(
"unknown dtype: {}, only support {}"
.
format
(
dtype
,
_metadata_dict
.
keys
()
dtype
,
_builtin_quant_dtypes
.
keys
()
)
)
self
.
dtype
=
dtype
self
.
narrow_range
=
narrow_range
self
.
qmin
=
(
-
_metadata_dict
[
dtype
].
qmax
if
narrow_range
else
_metadata_dict
[
dtype
].
qmin
dtype
=
_builtin_quant_dtypes
[
dtype
]
if
"narrow_range"
in
kwargs
:
del
kwargs
[
"narrow_range"
]
logger
.
warning
(
"FakeQuantize currently has no narrow_range param "
"so it is ignored here"
,
exc_info
=
DeprecationWarning
,
)
self
.
qmax
=
_metadata_dict
[
dtype
].
qmax
self
.
dtype
=
dtype
self
.
qmin
=
dtype
.
qmin
self
.
qmax
=
dtype
.
qmax
self
.
enabled
=
True
def
get_dtype
(
self
):
q_dict
=
self
.
get_qparams
()
numpy_scale
=
None
if
"scale"
not
in
q_dict
else
q_dict
[
"scale"
].
numpy
()
numpy_zero_point
=
(
None
if
"zero_point"
not
in
q_dict
else
q_dict
[
"zero_point"
].
numpy
()
)
return
get_quantized_dtype
(
self
.
dtype
,
numpy_scale
,
numpy_zero_point
)
def
enable
(
self
):
self
.
enabled
=
True
...
...
@@ -70,21 +70,16 @@ class Observer(Module):
def
forward
(
self
,
x
):
pass
@
abstractmethod
def
get_qparams
(
self
,
**
kwargs
):
pass
class
MinMaxObserver
(
Observer
):
def
__init__
(
self
,
mode
=
QuantMode
.
SYMMERTIC
,
eps
=
0.00001
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
mode
:
QuantMode
=
QuantMode
.
SYMMERTIC
,
eps
:
float
=
0.00001
,
dtype
:
Union
[
str
,
QuantDtypeMeta
]
=
"qint8"
,
**
kwargs
):
super
().
__init__
(
dtype
,
narrow_range
,
**
kwargs
)
super
().
__init__
(
dtype
,
**
kwargs
)
self
.
mode
=
mode
self
.
min_val
=
Tensor
(
np
.
finfo
(
np
.
float32
).
max
,
dtype
=
np
.
float32
)
self
.
max_val
=
Tensor
(
np
.
finfo
(
np
.
float32
).
min
,
dtype
=
np
.
float32
)
...
...
@@ -93,26 +88,22 @@ class MinMaxObserver(Observer):
def
_calculate_qparams
(
self
,
inp_min_val
,
inp_max_val
):
min_val
=
F
.
minimum
(
0.0
,
inp_min_val
)
max_val
=
F
.
maximum
(
0.0
,
inp_max_val
)
q_dict
=
get_qparam_dict
(
self
.
mode
)
q_dict
[
"min_val"
]
=
inp_min_val
q_dict
[
"max_val"
]
=
inp_max_val
q_dict
[
"enable_observer"
]
=
self
.
enable
if
self
.
mode
==
QuantMode
.
SYMMERTIC
:
symmetric_max_vals
=
F
.
maximum
(
-
min_val
,
max_val
)
# use maximun to avoid scale too small at the begin
q_dict
[
"scale"
]
=
F
.
maximum
(
scale
=
F
.
maximum
(
symmetric_max_vals
/
((
self
.
qmax
-
self
.
qmin
)
/
2
),
self
.
scale_limit
)
# zero_point = self.zero_point
zero_point
=
None
else
:
# use maximun to avoid scale too small at the begin
q_dict
[
"scale"
]
=
F
.
maximum
(
scale
=
F
.
maximum
(
(
max_val
-
min_val
)
/
(
self
.
qmax
-
self
.
qmin
),
self
.
scale_limit
)
# caculate zero_point
q_dict
[
"zero_point"
]
=
self
.
qmin
-
F
.
round
(
min_val
/
q_dict
[
"scale"
]
)
zero_point
=
self
.
qmin
-
F
.
round
((
min_val
/
scale
)
)
return
q_dict
return
create_qparams
(
self
.
mode
,
self
.
dtype
,
scale
=
scale
,
zero_point
=
zero_point
)
def
get_qparams
(
self
):
return
self
.
_calculate_qparams
(
self
.
min_val
,
self
.
max_val
)
...
...
@@ -122,8 +113,8 @@ class MinMaxObserver(Observer):
# stop gradient
x
=
x_orig
.
detach
()
# find max and min
self
.
min_val
.
_reset
(
F
.
minimum
(
self
.
min_val
,
x
.
min
()
))
self
.
max_val
.
_reset
(
F
.
maximum
(
self
.
max_val
,
x
.
max
()
))
self
.
min_val
[...]
=
F
.
minimum
(
self
.
min_val
,
x
.
min
(
))
self
.
max_val
[...]
=
F
.
maximum
(
self
.
max_val
,
x
.
max
(
))
return
x_orig
...
...
@@ -137,42 +128,43 @@ class SyncMinMaxObserver(MinMaxObserver):
else
:
min_x
=
x
.
min
()
max_x
=
x
.
max
()
self
.
min_val
.
_reset
(
F
.
minimum
(
self
.
min_val
,
min_x
)
)
self
.
max_val
.
_reset
(
F
.
maximum
(
self
.
max_val
,
max_x
)
)
self
.
min_val
[...]
=
F
.
minimum
(
self
.
min_val
,
min_x
)
self
.
max_val
[...]
=
F
.
maximum
(
self
.
max_val
,
max_x
)
return
x_orig
class
ExponentialMovingAverageObserver
(
MinMaxObserver
):
def
__init__
(
self
,
momentum
=
0.9
,
mode
=
QuantMode
.
SYMMERTIC
,
eps
=
0.00001
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
momentum
:
float
=
0.9
,
mode
:
QuantMode
=
QuantMode
.
SYMMERTIC
,
eps
:
float
=
0.00001
,
dtype
:
Union
[
str
,
QuantDtypeMeta
]
=
"qint8"
,
**
kwargs
):
super
().
__init__
(
mode
,
eps
,
dtype
,
narrow_range
,
**
kwargs
)
super
().
__init__
(
mode
,
eps
,
dtype
,
**
kwargs
)
self
.
momentum
=
Tensor
(
momentum
,
dtype
=
"float32"
)
# used to avoid if-clauses in the first forward which is not supported
# in trace mode.
self
.
runtime_momentum
=
Tensor
(
0.0
)
def
set_momentum
(
self
,
momentum
):
self
.
momentum
=
Ten
os
r
(
momentum
,
dtype
=
"float32"
)
self
.
momentum
=
Ten
so
r
(
momentum
,
dtype
=
"float32"
)
def
forward
(
self
,
x_orig
):
if
self
.
enabled
:
# stop gradient
x
=
x_orig
.
detach
()
# Exponential Moving Average
self
.
min_val
.
_reset
(
self
.
min_val
[...]
=
(
self
.
min_val
*
self
.
runtime_momentum
+
(
1
-
self
.
runtime_momentum
)
*
x
.
min
()
)
self
.
max_val
.
_reset
(
self
.
max_val
[...]
=
(
self
.
max_val
*
self
.
runtime_momentum
+
(
1
-
self
.
runtime_momentum
)
*
x
.
max
()
)
self
.
runtime_momentum
=
self
.
momentum
self
.
runtime_momentum
[...]
=
self
.
momentum
return
x_orig
...
...
@@ -187,33 +179,34 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
else
:
min_x
=
x
.
min
()
max_x
=
x
.
max
()
self
.
min_val
.
_reset
(
self
.
min_val
[...]
=
(
self
.
min_val
*
self
.
runtime_momentum
+
(
1
-
self
.
runtime_momentum
)
*
min_x
)
self
.
max_val
.
_reset
(
self
.
max_val
[...]
=
(
self
.
max_val
*
self
.
runtime_momentum
+
(
1
-
self
.
runtime_momentum
)
*
max_x
)
self
.
runtime_momentum
=
self
.
momentum
self
.
runtime_momentum
[...]
=
self
.
momentum
return
x_orig
class
HistogramObserver
(
MinMaxObserver
):
def
__init__
(
self
,
bins
=
2048
,
upsample_rate
=
128
,
mode
=
QuantMode
.
SYMMERTIC
,
eps
=
0.00001
,
dtype
=
"qint8"
,
narrow_range
:
bool
=
False
,
bins
:
int
=
2048
,
upsample_rate
:
int
=
128
,
mode
:
QuantMode
=
QuantMode
.
SYMMERTIC
,
eps
:
float
=
0.00001
,
dtype
:
Union
[
str
,
QuantDtypeMeta
]
=
"qint8"
,
**
kwargs
):
super
().
__init__
(
mode
,
eps
,
dtype
,
narrow_range
,
**
kwargs
)
super
().
__init__
(
mode
,
eps
,
dtype
,
**
kwargs
)
self
.
bins
=
bins
self
.
upsample_rate
=
upsample_rate
self
.
dst_nbins
=
_metadata_dict
[
dtype
].
qmax
-
_metadata_dict
[
dtype
].
qmin
+
1
self
.
dst_nbins
=
(
_builtin_quant_dtypes
[
dtype
].
qmax
-
_builtin_quant_dtypes
[
dtype
].
qmin
+
1
)
self
.
histogram
=
Tensor
([
-
1
]
+
[
0.0
]
*
(
bins
-
1
),
dtype
=
"float32"
)
def
_non_linear_param_search
(
self
):
...
...
@@ -450,34 +443,45 @@ class HistogramObserver(MinMaxObserver):
class
PassiveObserver
(
Observer
):
r
"""
This class can be set :attr:`scale` de
rectly.
An Observer that supports setting :attr:`scale` di
rectly.
"""
def
__init__
(
self
,
dtype
:
str
,
narrow_range
:
bool
=
False
,
**
kwargs
):
super
().
__init__
(
dtype
,
narrow_range
,
**
kwargs
)
self
.
q
_dict
=
None
def
__init__
(
self
,
dtype
:
Union
[
str
,
QuantDtypeMeta
]
,
**
kwargs
):
super
().
__init__
(
dtype
,
**
kwargs
)
self
.
q
params
=
None
self
.
orig_scale
=
None
@
property
def
scale
(
self
):
return
self
.
q
_dict
[
"scale"
]
return
self
.
q
params
.
scale
@
scale
.
setter
def
scale
(
self
,
value
):
assert
value
>
0
self
.
q
_dict
[
"scale"
]
[...]
=
Tensor
(
value
)
def
scale
(
self
,
value
:
np
.
ndarray
):
assert
np
.
all
(
value
>
0
)
self
.
q
params
.
scale
[...]
=
Tensor
(
value
)
def
get_qparams
(
self
):
return
self
.
q
_dict
return
self
.
q
params
def
set_qparams
(
self
,
q_dict
):
self
.
q_dict
=
deepcopy
(
q_dict
)
if
"scale"
not
in
q_dict
or
q_dict
[
"scale"
]
is
None
:
def
set_qparams
(
self
,
qparams
:
QParams
):
"""
:param qparams: used to set initial scale.
"""
self
.
qparams
=
deepcopy
(
qparams
)
if
qparams
.
scale
is
None
:
raise
AssertionError
(
"Can not get an initialized scale"
)
self
.
orig_scale
=
q_dict
[
"scale"
].
numpy
()
if
qparams
.
dtype_meta
is
None
:
qparams
.
dtype_meta
=
self
.
dtype
else
:
assert
(
qparams
.
dtype_meta
is
self
.
dtype
),
"input qparams' dtype is not equal to self.dtype.
\n
qparams.dtype_meta={}
\n
self.dtype={}"
.
format
(
qparams
.
dtype_meta
,
self
.
dtype
)
self
.
orig_scale
=
qparams
.
scale
.
numpy
()
def
forward
(
self
,
x
):
r
"""
Just return input because :attr:`q
_dict
` is set by :func:`~.apply_easy_quant`.
Just return input because :attr:`q
params
` is set by :func:`~.apply_easy_quant`.
"""
return
x
imperative/python/megengine/quantization/qconfig.py
浏览文件 @
1d7dd001
...
...
@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
collections
import
namedtuple
from
functools
import
partial
from
..module
import
Module
...
...
@@ -19,7 +20,13 @@ from .observer import (
)
class
QConfig
:
# use namedtuple to make class immutable, comparable and easy to print
class
QConfig
(
namedtuple
(
"QConfig"
,
[
"weight_observer"
,
"act_observer"
,
"weight_fake_quant"
,
"act_fake_quant"
],
)
):
r
"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
...
...
@@ -37,90 +44,66 @@ class QConfig:
# Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8
", narrow_range=True
),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"
, narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint8
", narrow_range=True
),
act_fake_quant=partial(FakeQuantize, dtype="qint8"
, narrow_range=False
),
weight_observer=partial(MinMaxObserver, dtype="qint8
_narrow"
),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8
_narrow"
),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)
Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial``
to add initialization parameters of the ``class``, so that don't need to provide parameters in
:meth:`~.QATModule.set_qconfig`.
Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related
parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if
four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
Weights are commonly calculated in this way, so needed to narrow the range.
Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related
paramters and normal version for activation related ones. For the result of
multiplication and addition as ``a * b + c * d``, if four variables are all -128 of
dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
Weights are commonly calculated in this way, so need to narrow qmin to -127.
"""
def
__init__
(
self
,
weight_observer
,
act_observer
,
weight_fake_quant
,
act_fake_quant
):
def
__new__
(
cls
,
weight_observer
,
act_observer
,
weight_fake_quant
,
act_fake_quant
):
if
isinstance
(
act_observer
,
Module
)
or
isinstance
(
weight_observer
,
Module
):
raise
ValueError
(
"QConfig must not receive observer instance, please pass observer"
" class generator using `partial(Observer, ...)` instead. Use"
" partial(MyObserver, x=1) to override arguments to constructor if needed"
)
self
.
weight_observer
=
weight_observer
self
.
act_observer
=
act_observer
self
.
weight_fake_quant
=
weight_fake_quant
self
.
act_fake_quant
=
act_fake_quant
def
__eq__
(
self
,
other
):
def
eq
(
a
,
b
):
if
isinstance
(
a
,
partial
)
and
isinstance
(
b
,
partial
):
return
all
(
[
a
.
func
==
b
.
func
,
a
.
args
==
b
.
args
,
a
.
keywords
==
b
.
keywords
]
)
else
:
return
a
==
b
return
(
eq
(
self
.
weight_observer
,
other
.
weight_observer
)
and
eq
(
self
.
act_observer
,
other
.
act_observer
)
and
eq
(
self
.
weight_fake_quant
,
other
.
weight_fake_quant
)
and
eq
(
self
.
act_fake_quant
,
other
.
act_fake_quant
)
return
super
().
__new__
(
cls
,
weight_observer
,
act_observer
,
weight_fake_quant
,
act_fake_quant
)
min_max_fakequant_qconfig
=
QConfig
(
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8
"
,
narrow_range
=
True
),
act_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8
"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8
_narrow"
),
act_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8"
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8
_narrow"
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
),
)
ema_fakequant_qconfig
=
QConfig
(
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_observer
=
partial
(
ExponentialMovingAverageObserver
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8_narrow"
),
act_observer
=
partial
(
ExponentialMovingAverageObserver
,
dtype
=
"qint8"
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8_narrow"
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
),
)
sync_ema_fakequant_qconfig
=
QConfig
(
weight_observer
=
partial
(
SyncMinMaxObserver
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_observer
=
partial
(
SyncExponentialMovingAverageObserver
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_observer
=
partial
(
SyncMinMaxObserver
,
dtype
=
"qint8_narrow"
),
act_observer
=
partial
(
SyncExponentialMovingAverageObserver
,
dtype
=
"qint8"
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8_narrow"
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
),
)
ema_lowbit_fakequant_qconfig
=
QConfig
(
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint4"
,
narrow_range
=
False
),
act_observer
=
partial
(
ExponentialMovingAverageObserver
,
dtype
=
"qint4"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint4"
,
narrow_range
=
False
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint4"
,
narrow_range
=
False
),
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint4"
),
act_observer
=
partial
(
ExponentialMovingAverageObserver
,
dtype
=
"qint4"
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint4"
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint4"
),
)
calibration_qconfig
=
QConfig
(
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8
"
,
narrow_range
=
True
),
act_observer
=
partial
(
HistogramObserver
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8
_narrow"
),
act_observer
=
partial
(
HistogramObserver
,
dtype
=
"qint8"
),
weight_fake_quant
=
None
,
act_fake_quant
=
None
,
)
...
...
@@ -128,15 +111,15 @@ calibration_qconfig = QConfig(
tqt_qconfig
=
QConfig
(
weight_observer
=
None
,
act_observer
=
None
,
weight_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8
"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8
_narrow"
),
act_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
),
)
passive_qconfig
=
QConfig
(
weight_observer
=
partial
(
PassiveObserver
,
dtype
=
"qint8
"
,
narrow_range
=
True
),
act_observer
=
partial
(
PassiveObserver
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8
"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_observer
=
partial
(
PassiveObserver
,
dtype
=
"qint8
_narrow"
),
act_observer
=
partial
(
PassiveObserver
,
dtype
=
"qint8"
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8
_narrow"
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
),
)
easyquant_qconfig
=
passive_qconfig
imperative/python/megengine/quantization/quantize.py
浏览文件 @
1d7dd001
...
...
@@ -18,6 +18,7 @@ from ..module import qat as QAT
from
..module
import
quantized
as
Quantized
from
..module.qat
import
QATModule
from
..module.quantized
import
QuantizedModule
from
..tensor
import
Tensor
from
.qconfig
import
QConfig
,
ema_fakequant_qconfig
...
...
@@ -147,10 +148,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
if
not
inplace
:
module
=
deepcopy
(
module
)
def
safe_call
(
func
,
q
_dict
):
def
safe_call
(
func
,
q
params
):
inst
=
func
()
if
func
is
not
None
else
None
if
inst
is
not
None
and
getattr
(
inst
,
"set_qparams"
,
None
)
is
not
None
:
inst
.
set_qparams
(
q
_dict
)
inst
.
set_qparams
(
q
params
)
return
inst
def
is_qat
(
mod
:
Module
):
...
...
@@ -158,13 +159,13 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
for
m
in
list
(
module
.
_flatten
(
predicate
=
is_qat
)):
if
m
.
with_weight
:
weight_
q_dict
=
m
.
get_weight_qparams
()
m
.
weight_observer
=
safe_call
(
qconfig
.
weight_observer
,
weight_
q_dict
)
m
.
weight_fake_quant
=
safe_call
(
qconfig
.
weight_fake_quant
,
weight_
q_dict
)
weight_
params
=
m
.
get_weight_qparams
()
m
.
weight_observer
=
safe_call
(
qconfig
.
weight_observer
,
weight_
params
)
m
.
weight_fake_quant
=
safe_call
(
qconfig
.
weight_fake_quant
,
weight_
params
)
if
m
.
with_act
:
act_
q_dict
=
m
.
get_activation_qparams
()
m
.
act_observer
=
safe_call
(
qconfig
.
act_observer
,
act_
q_dict
)
m
.
act_fake_quant
=
safe_call
(
qconfig
.
act_fake_quant
,
act_
q_dict
)
act_
params
=
m
.
get_activation_qparams
()
m
.
act_observer
=
safe_call
(
qconfig
.
act_observer
,
act_
params
)
m
.
act_fake_quant
=
safe_call
(
qconfig
.
act_fake_quant
,
act_
params
)
return
module
...
...
@@ -202,7 +203,9 @@ def hook_qat_module(module: Module, func: Callable):
return
hooks
def
apply_easy_quant
(
module
,
data
,
start
=
0.8
,
stop
=
1.2
,
num
=
40
):
def
apply_easy_quant
(
module
:
Module
,
data
:
Tensor
,
start
:
float
=
0.8
,
stop
:
float
=
1.2
,
num
:
int
=
40
):
r
"""
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
Search for optimal scales.
...
...
imperative/python/megengine/quantization/utils.py
浏览文件 @
1d7dd001
...
...
@@ -5,9 +5,10 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
abc
from
enum
import
Enum
from
functools
import
partial
,
update_wrapper
,
wraps
from
typing
import
Dict
from
typing
import
Union
import
numpy
as
np
...
...
@@ -15,7 +16,11 @@ from .. import functional as F
from
..autodiff
import
Function
from
..core._imperative_rt.core2
import
apply
from
..core.ops
import
builtin
from
..core.tensor.dtype
import
_metadata_dict
from
..core.tensor.dtype
import
(
QuantDtypeMeta
,
_builtin_quant_dtypes
,
create_quantized_dtype
,
)
from
..tensor
import
Tensor
...
...
@@ -61,37 +66,100 @@ class QuantMode(Enum):
ASYMMERTIC
=
2
qparam_dict
=
{
QuantMode
.
SYMMERTIC
:
{
"mode"
:
QuantMode
.
SYMMERTIC
,
"scale"
:
None
},
QuantMode
.
ASYMMERTIC
:
{
"mode"
:
QuantMode
.
ASYMMERTIC
,
"scale"
:
None
,
"zero_point"
:
None
,
},
class
QParams
:
"""
To standardize FakeQuant, Observer and Tensor's qparams format. If custom
qparams is needed, inherit this class and add custom ``__slots__``.
"""
__slots__
=
"mode"
,
"dtype_meta"
,
"scale"
,
"zero_point"
def
__init__
(
self
,
mode
:
QuantMode
,
dtype_meta
:
QuantDtypeMeta
,
scale
:
Tensor
,
zero_point
:
Tensor
,
):
self
.
mode
=
mode
self
.
dtype_meta
=
dtype_meta
self
.
scale
=
scale
self
.
zero_point
=
zero_point
def
update
(
self
,
qparams
:
"QParams"
):
for
key
in
self
.
__slots__
:
setattr
(
self
,
key
,
getattr
(
qparams
,
key
))
def
__eq__
(
self
,
other
):
if
len
(
self
.
__slots__
)
!=
len
(
other
.
__slots__
):
return
False
for
key
in
self
.
__slots__
:
if
not
hasattr
(
other
,
key
)
or
getattr
(
self
,
key
)
!=
getattr
(
other
,
key
):
return
False
return
True
def
__repr__
(
self
):
content
=
", "
.
join
(
[
"{}={}"
.
format
(
key
,
getattr
(
self
,
key
))
for
key
in
self
.
__slots__
]
)
return
"QParams({})"
.
format
(
content
)
class
QParamsModuleMixin
(
abc
.
ABC
):
def
get_quantized_dtype
(
self
):
qparams
=
self
.
get_qparams
()
dtype
=
qparams
.
dtype_meta
scale
=
float
(
qparams
.
scale
.
numpy
())
if
qparams
.
scale
is
not
None
else
None
zero_point
=
(
int
(
qparams
.
zero_point
.
numpy
())
if
qparams
.
zero_point
is
not
None
else
None
)
return
create_quantized_dtype
(
dtype
,
scale
,
zero_point
)
@
abc
.
abstractmethod
def
get_qparams
(
self
)
->
QParams
:
pass
_builtin_qparams
=
{
QuantMode
.
SYMMERTIC
:
partial
(
QParams
,
mode
=
QuantMode
.
SYMMERTIC
),
QuantMode
.
ASYMMERTIC
:
partial
(
QParams
,
mode
=
QuantMode
.
ASYMMERTIC
),
}
def
get_qparam_dict
(
mode
:
QuantMode
):
def
create_qparams
(
mode
:
QuantMode
=
QuantMode
.
SYMMERTIC
,
dtype_meta
:
Union
[
str
,
QuantDtypeMeta
]
=
None
,
scale
:
Tensor
=
None
,
zero_point
:
Tensor
=
None
,
):
"""
Return
the quantization parameters dictionary
according to the mode.
Return
:class:`~.QParams`
according to the mode.
"""
return
qparam_dict
.
get
(
mode
,
None
)
if
isinstance
(
dtype_meta
,
str
):
dtype_meta
=
_builtin_quant_dtypes
[
dtype_meta
]
if
mode
is
None
:
return
QParams
(
mode
,
dtype_meta
,
scale
,
zero_point
)
assert
isinstance
(
mode
,
QuantMode
)
return
_builtin_qparams
[
mode
](
dtype_meta
=
dtype_meta
,
scale
=
scale
,
zero_point
=
zero_point
)
def
fake_quant_tensor
(
inp
:
Tensor
,
q
min
:
int
,
qmax
:
int
,
q_dict
:
Dict
)
->
Tensor
:
def
fake_quant_tensor
(
inp
:
Tensor
,
q
params
:
QParams
)
->
Tensor
:
"""
Apply fake quantization to the inp tensor.
:param inp: the input tensor which need to be faked.
:param qmin: the minimum value which the integer limit to.
:param qmax: the maximum value which the integer limit to.
:param q_dict: the quantization parameter dict.
:param qparams: to get mode, qmin, qmax, scale and zero_point from.
"""
scale
=
q_dict
[
"scale"
]
scale
=
qparams
.
scale
if
qparams
.
mode
==
QuantMode
.
ASYMMERTIC
:
zero_point
=
qparams
.
zero_point
else
:
zero_point
=
Tensor
([
0.0
],
dtype
=
np
.
float32
)
if
q_dict
[
"mode"
]
==
QuantMode
.
ASYMMERTIC
:
zero_point
=
q_dict
[
"zero_point"
]
qmin
=
qparams
.
dtype_meta
.
qmin
qmax
=
qparams
.
dtype_meta
.
qmax
op
=
builtin
.
FakeQuant
(
qmin
=
qmin
,
qmax
=
qmax
)
return
apply
(
op
,
inp
,
scale
,
zero_point
)[
0
]
...
...
@@ -104,22 +172,34 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
:param bias: the bias tensor which need to be faked.
:param inp: the input tensor which contain the quantization parameters.
:param
qmax
: the weight tensor which contain the quantization parameters.
:param
w_qat
: the weight tensor which contain the quantization parameters.
.. warning::
Only work for symmetric quantization method now.
"""
b_qat
=
bias
if
hasattr
(
inp
,
"q_dict"
)
and
b_qat
is
not
None
:
if
inp
.
q_dict
[
"scale"
]
is
not
None
and
w_qat
.
q_dict
[
"scale"
]
is
not
None
:
# use the same mode with weight.
b_dict
=
get_qparam_dict
(
w_qat
.
q_dict
[
"mode"
])
b_dict
[
"scale"
]
=
inp
.
q_dict
[
"scale"
]
*
w_qat
.
q_dict
[
"scale"
]
# TODO: add zero_point for ASYMMERTIC mode.
qmax
=
_metadata_dict
[
"qint32"
].
qmax
qmin
=
_metadata_dict
[
"qint32"
].
qmin
b_qat
=
fake_quant_tensor
(
b_qat
,
qmin
,
qmax
,
b_dict
)
b_qat
.
q_dict
.
update
(
b_dict
)
if
(
getattr
(
inp
,
"qparams"
,
None
)
is
not
None
and
getattr
(
w_qat
,
"qparams"
,
None
)
is
not
None
and
bias
is
not
None
):
inp_params
=
inp
.
qparams
w_params
=
w_qat
.
qparams
if
inp_params
.
scale
is
not
None
and
w_params
.
scale
is
not
None
:
assert
inp_params
.
mode
==
w_params
.
mode
,
"incompatible QuantMode"
# TODO: support quint8 dtype.
assert
(
inp_params
.
dtype_meta
.
np_dtype_str
==
"int8"
and
w_params
.
dtype_meta
.
np_dtype_str
==
"int8"
),
"fake_quant_bias only support int8 like dtype now"
# use the same mode with weight.
# TODO: avoid hardcode
b_dtype
=
_builtin_quant_dtypes
[
"qint32"
]
b_param
=
create_qparams
(
w_params
.
mode
,
b_dtype
,
scale
=
inp_params
.
scale
*
w_params
.
scale
)
b_qat
=
fake_quant_tensor
(
bias
,
b_param
)
b_qat
.
qparams
.
update
(
b_param
)
return
b_qat
imperative/python/megengine/tensor.py
浏览文件 @
1d7dd001
...
...
@@ -22,6 +22,8 @@ from .logger import get_logger
from
.utils.deprecation
import
deprecated
from
.utils.naming
import
auto_naming
logger
=
get_logger
(
__name__
)
class
Tensor
(
_Tensor
,
ArrayMethodMixin
):
r
"""
...
...
@@ -30,7 +32,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
grad
=
None
dmap_callback
=
None
_q
_dict
=
None
_q
params
=
None
def
__new__
(
cls
,
data
,
dtype
=
None
,
device
=
None
,
is_const
=
False
,
no_cache
=
False
,
name
=
None
...
...
@@ -50,7 +52,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
if
isinstance
(
data
,
_Tensor
):
if
dtype
is
not
None
:
get_logger
()
.
warning
(
logger
.
warning
(
"dtype does not work when creating a new Tensor with another Tensor"
)
obj
=
_Tensor
.
__new__
(
cls
,
data
)
...
...
@@ -101,10 +103,12 @@ class Tensor(_Tensor, ArrayMethodMixin):
return
super
().
dtype
@
property
def
q_dict
(
self
):
if
self
.
_q_dict
is
None
:
self
.
_q_dict
=
{
"mode"
:
None
,
"scale"
:
None
,
"zero_point"
:
None
}
return
self
.
_q_dict
def
qparams
(
self
):
from
.quantization.utils
import
create_qparams
# pylint: disable=all
if
self
.
_qparams
is
None
:
self
.
_qparams
=
create_qparams
()
return
self
.
_qparams
def
numpy
(
self
)
->
np
.
ndarray
:
r
"""
...
...
@@ -185,14 +189,29 @@ class Tensor(_Tensor, ArrayMethodMixin):
def
__getstate__
(
self
):
r
""" __getstate__ will be called for pickle serialization or deep copy
"""
state
=
{
"qdict"
:
self
.
q_dict
,
"numpy"
:
self
.
numpy
(),
"dtype"
:
self
.
dtype
,
"device"
:
self
.
device
.
logical_name
,
}
if
self
.
_qparams
is
not
None
:
state
[
"qparams"
]
=
self
.
_qparams
return
state
def
__setstate__
(
self
,
state
):
self
.
_q_dict
=
state
.
pop
(
"qdict"
)
from
.quantization.utils
import
create_qparams
# pylint: disable=all
if
"qdict"
in
state
:
qparams
=
state
.
pop
(
"qdict"
)
logger
.
warning
(
"Tensor's 'qdict' state is depreciated. Use 'qparams' instead"
)
elif
"qparams"
in
state
:
qparams
=
state
.
pop
(
"qparams"
)
else
:
qparams
=
None
self
.
_reset
(
Tensor
(
state
.
pop
(
"numpy"
),
state
.
pop
(
"dtype"
),
state
.
pop
(
"device"
)))
self
.
_qparams
=
qparams
tensor
=
Tensor
...
...
imperative/python/test/unit/core/test_dtype_quant.py
浏览文件 @
1d7dd001
...
...
@@ -14,7 +14,7 @@ import pytest
import
megengine.core.tensor.megbrain_graph
as
G
from
megengine.core.ops
import
builtin
as
ops
from
megengine.core.tensor.dtype
import
(
_
metadata_dict
,
_
builtin_quant_dtypes
,
convert_from_qint4
,
convert_from_qint8
,
convert_from_quint4
,
...
...
@@ -76,10 +76,10 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None):
def
_check_result_attr
(
oup
,
dtype
,
dtype_str
,
is_unsigned
=
True
):
metadata
=
_
metadata_dict
[
dtype_str
]
metadata
=
_
builtin_quant_dtypes
[
dtype_str
]
assert
"mgb_dtype"
in
oup
.
dtype
.
metadata
assert
is_quantize
(
oup
.
dtype
)
np
.
testing
.
assert_equal
(
oup
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
],
metadata
.
name
)
np
.
testing
.
assert_equal
(
oup
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
],
metadata
.
c
name
)
np
.
testing
.
assert_allclose
(
get_scale
(
oup
.
dtype
),
get_scale
(
dtype
))
if
is_unsigned
:
np
.
testing
.
assert_equal
(
get_zero_point
(
oup
.
dtype
),
get_zero_point
(
dtype
))
...
...
imperative/python/test/unit/core/test_serialization.py
浏览文件 @
1d7dd001
...
...
@@ -65,9 +65,9 @@ def test_tensor_serialization():
with
TemporaryFile
()
as
f
:
a
=
Tensor
(
0
)
a
.
q
_dict
[
"scale"
]
=
Tensor
(
1.0
)
a
.
q
params
.
scale
=
Tensor
(
1.0
)
pickle
.
dump
(
a
,
f
)
f
.
seek
(
0
)
b
=
pickle
.
load
(
f
)
assert
isinstance
(
b
.
q
_dict
[
"scale"
]
,
Tensor
)
np
.
testing
.
assert_equal
(
b
.
q
_dict
[
"scale"
]
.
numpy
(),
1.0
)
assert
isinstance
(
b
.
q
params
.
scale
,
Tensor
)
np
.
testing
.
assert_equal
(
b
.
q
params
.
scale
.
numpy
(),
1.0
)
imperative/python/test/unit/core/test_tensor_wrapper.py
浏览文件 @
1d7dd001
...
...
@@ -6,6 +6,8 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
copy
import
numpy
as
np
from
megengine.core.tensor.dtype
import
get_scale
,
get_zero_point
,
qint8
,
quint8
...
...
@@ -86,3 +88,23 @@ def test_as_type():
b
=
a
.
astype
(
quint8
(
0.3
,
128
))
np
.
testing
.
assert_almost_equal
(
get_scale
(
b
.
dtype
),
0.3
)
np
.
testing
.
assert_equal
(
get_zero_point
(
b
.
dtype
),
128
)
def
test_qparams
():
x
=
Tensor
(
1
)
assert
x
.
qparams
.
scale
is
None
x
.
qparams
.
scale
=
Tensor
(
1.0
)
assert
x
.
qparams
.
scale
.
numpy
()
==
1.0
x2
=
copy
.
copy
(
x
)
assert
x
.
qparams
is
x2
.
qparams
and
x2
.
qparams
.
scale
.
numpy
()
==
1.0
x3
=
copy
.
deepcopy
(
x
)
assert
x
.
qparams
is
not
x3
.
qparams
and
x3
.
qparams
.
scale
.
numpy
()
==
1.0
def
test_name
():
x
=
Tensor
(
0
)
assert
x
.
name
==
""
x
.
name
=
"x"
assert
x
.
name
==
"x"
x
=
Tensor
(
0
,
name
=
"x"
)
assert
x
.
name
==
"x"
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
1d7dd001
...
...
@@ -406,28 +406,3 @@ def test_copy_d2h():
def
test_copy_d2d
():
copy_test
(
"gpu0"
,
"gpu1"
)
copy_test
(
"gpu0:0"
,
"gpu0:1"
)
def
test_name
():
x
=
tensor
(
0
)
assert
x
.
name
==
""
x
.
name
=
"x"
assert
x
.
name
==
"x"
x
=
tensor
(
0
,
name
=
"x"
)
assert
x
.
name
==
"x"
def
test_q_dict
():
x
=
tensor
(
1
)
assert
x
.
q_dict
[
"scale"
]
is
None
x
.
q_dict
[
"scale"
]
=
tensor
(
1.0
)
y
=
tensor
(
1
)
assert
y
.
q_dict
[
"scale"
]
is
None
y
.
q_dict
[
"scale"
]
=
tensor
(
2.0
)
assert
x
.
q_dict
[
"scale"
].
numpy
()
==
1.0
assert
y
.
q_dict
[
"scale"
].
numpy
()
==
2.0
z
=
x
+
y
assert
z
.
q_dict
[
"scale"
]
is
None
imperative/python/test/unit/quantization/test_fake_quant.py
浏览文件 @
1d7dd001
...
...
@@ -12,9 +12,15 @@ import pytest
import
megengine
as
mge
from
megengine
import
tensor
from
megengine.core.autodiff.grad
import
Function
,
Grad
from
megengine.core.tensor.dtype
import
QuantDtypeMeta
from
megengine.core.tensor.utils
import
make_shape_tuple
from
megengine.quantization.internal_fake_quant
import
*
from
megengine.quantization.utils
import
QuantMode
,
fake_quant_tensor
,
tqt_forward
from
megengine.quantization.utils
import
(
QuantMode
,
create_qparams
,
fake_quant_tensor
,
tqt_forward
,
)
class
TQT_numpy
:
...
...
@@ -111,16 +117,14 @@ def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax):
def
test_fakequant
():
qmin
=
-
126
qmax
=
129
test_dtype
=
QuantDtypeMeta
(
"test_qint8"
,
None
,
"int8"
,
qmin
,
qmax
)
def
run
(
zero_point
,
scale
):
q_dict
=
{}
q_dict
[
"mode"
]
=
QuantMode
.
ASYMMERTIC
q_dict
[
"scale"
]
=
scale
q_dict
[
"zero_point"
]
=
zero_point
qparams
=
create_qparams
(
QuantMode
.
ASYMMERTIC
,
test_dtype
,
scale
,
zero_point
)
inp_data
=
np
.
random
.
uniform
(
low
=-
512.0
,
high
=
512.0
,
size
=
(
1
,
32
,
32
,
32
))
inp
=
tensor
(
inp_data
,
dtype
=
np
.
float32
)
# test forward
oup
=
fake_quant_tensor
(
inp
,
q
min
,
qmax
,
q_dict
).
numpy
()
oup
=
fake_quant_tensor
(
inp
,
q
params
).
numpy
()
oup_gt
=
fake_quant_tensor_gt
(
inp
,
scale
,
zero_point
,
qmin
,
qmax
).
numpy
()
assert
np
.
allclose
(
oup
,
oup_gt
)
assert
oup
.
shape
==
oup_gt
.
shape
...
...
@@ -128,7 +132,7 @@ def test_fakequant():
# test backward
x
=
tensor
(
inp_data
,
dtype
=
np
.
float32
)
grad
=
Grad
().
wrt
(
x
,
callback
=
_save_to
(
x
))
y
=
fake_quant_tensor
(
x
,
q
min
,
qmax
,
q_dict
)
y
=
fake_quant_tensor
(
x
,
q
params
)
grad
(
y
,
tensor
(
F
.
ones_like
(
x
)))
x1
=
tensor
(
inp_data
,
dtype
=
np
.
float32
)
...
...
imperative/python/test/unit/quantization/test_module.py
浏览文件 @
1d7dd001
...
...
@@ -10,7 +10,13 @@ import megengine.module.qat as QAT
import
megengine.module.quantized
as
Q
from
megengine
import
Parameter
,
Tensor
from
megengine.core.tensor
import
dtype
from
megengine.quantization
import
FakeQuantize
,
MinMaxObserver
,
QConfig
from
megengine.quantization
import
(
FakeQuantize
,
MinMaxObserver
,
QConfig
,
QuantMode
,
create_qparams
,
)
from
megengine.quantization.quantize
import
(
disable_fake_quant
,
disable_observer
,
...
...
@@ -18,10 +24,10 @@ from megengine.quantization.quantize import (
)
min_max_fakequant_qconfig
=
QConfig
(
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8
"
,
narrow_range
=
True
),
act_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8
"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
,
narrow_range
=
False
),
weight_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8
_narrow"
),
act_observer
=
partial
(
MinMaxObserver
,
dtype
=
"qint8"
),
weight_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8
_narrow"
),
act_fake_quant
=
partial
(
FakeQuantize
,
dtype
=
"qint8"
),
)
inp_scale
=
np
.
float32
(
np
.
random
.
rand
()
+
1
)
...
...
@@ -111,7 +117,7 @@ def test_dequant_stub():
x
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
3
,
3
)).
astype
(
"float32"
))
x
=
fake_quant_act
(
x
,
inp_scale
)
x
.
q
_dict
[
"scale"
]
=
inp_scale
x
.
q
params
.
scale
=
inp_scale
normal
=
normal_net
(
x
)
qat_without_fakequant
=
qat_from_float
(
x
)
...
...
@@ -146,12 +152,12 @@ def test_elemwise(kind):
x1_scale
=
np
.
float32
(
np
.
random
.
rand
()
+
1
)
x1
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
3
,
3
)).
astype
(
"float32"
))
x1
=
fake_quant_act
(
x1
,
x1_scale
)
x1
.
q
_dict
[
"scale"
]
=
x1_scale
x1
.
q
params
.
scale
=
x1_scale
x2_scale
=
np
.
float32
(
np
.
random
.
rand
()
+
1
)
x2
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
3
,
3
)).
astype
(
"float32"
))
x2
=
fake_quant_act
(
x2
,
x2_scale
)
x2
.
q
_dict
[
"scale"
]
=
x2_scale
x2
.
q
params
.
scale
=
x2_scale
x1_int8
=
quant
(
x1
,
x1_scale
)
x2_int8
=
quant
(
x2
,
x2_scale
)
...
...
@@ -187,7 +193,7 @@ def test_linear():
x
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
3
,
3
)).
astype
(
"float32"
))
x
=
fake_quant_act
(
x
,
inp_scale
)
x
.
q
_dict
[
"scale"
]
=
inp_scale
x
.
q
params
.
update
(
create_qparams
(
QuantMode
.
SYMMERTIC
,
"qint8"
,
inp_scale
))
x_int8
=
quant
(
x
,
inp_scale
)
...
...
@@ -230,7 +236,7 @@ def test_conv(module):
x
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
1
,
3
,
3
,
3
)).
astype
(
"float32"
))
x
=
fake_quant_act
(
x
,
inp_scale
)
x
.
q
_dict
[
"scale"
]
=
inp_scale
x
.
q
params
.
update
(
create_qparams
(
QuantMode
.
SYMMERTIC
,
"qint8"
,
inp_scale
))
x_int8
=
quant
(
x
,
inp_scale
)
...
...
imperative/python/test/unit/quantization/test_observer.py
浏览文件 @
1d7dd001
...
...
@@ -6,6 +6,7 @@ import pytest
import
megengine
as
mge
import
megengine.distributed
as
dist
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.quantization
import
QuantMode
,
create_qparams
from
megengine.quantization.observer
import
(
ExponentialMovingAverageObserver
,
HistogramObserver
,
...
...
@@ -56,14 +57,14 @@ def test_histogram_observer():
def
test_passive_observer
():
q
_dict
=
{
"scale"
:
mge
.
tensor
(
1.0
)}
q
params
=
create_qparams
(
QuantMode
.
SYMMERTIC
,
"qint8"
,
mge
.
tensor
(
1.0
))
m
=
PassiveObserver
(
"qint8"
)
m
.
set_qparams
(
q
_dict
)
m
.
set_qparams
(
q
params
)
assert
m
.
orig_scale
==
1.0
assert
m
.
scale
==
1.0
m
.
scale
=
2.0
assert
m
.
scale
==
2.0
assert
m
.
get_qparams
()
==
{
"scale"
:
mge
.
tensor
(
2.0
)}
assert
m
.
scale
.
numpy
()
==
1.0
assert
m
.
get_qparams
().
dtype_meta
==
qparams
.
dtype_meta
assert
m
.
get_qparams
().
scale
==
qparams
.
scale
assert
m
.
get_qparams
()
==
qparams
@
pytest
.
mark
.
require_ngpu
(
2
)
...
...
imperative/python/test/unit/quantization/test_op.py
浏览文件 @
1d7dd001
...
...
@@ -6,6 +6,7 @@ import megengine.functional as F
from
megengine.core.tensor
import
dtype
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.functional.elemwise
import
_elemwise_multi_type
,
_elwise
from
megengine.quantization
import
QuantMode
,
create_qparams
def
quant
(
x
,
scale
):
...
...
@@ -26,13 +27,13 @@ def test_elemwise(kind):
x1
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
3
,
3
)).
astype
(
"float32"
))
x1_scale
=
np
.
float32
(
np
.
random
.
rand
()
+
1
)
x1
=
fake_quant
(
x1
,
x1_scale
)
x1
.
q
_dict
[
"scale"
]
=
x1_scale
x1
.
q
params
.
update
(
create_qparams
(
QuantMode
.
SYMMERTIC
,
"qint8"
,
x1_scale
))
x1_int8
=
quant
(
x1
,
x1_scale
)
x2
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
3
,
3
)).
astype
(
"float32"
))
x2_scale
=
np
.
float32
(
np
.
random
.
rand
()
+
1
)
x2
=
fake_quant
(
x2
,
x2_scale
)
x2
.
q
_dict
[
"scale"
]
=
x2_scale
x2
.
q
params
.
update
(
create_qparams
(
QuantMode
.
SYMMERTIC
,
"qint8"
,
x2_scale
))
x2_int8
=
quant
(
x2
,
x2_scale
)
output_scale
=
np
.
float32
(
np
.
random
.
rand
()
+
1
)
...
...
imperative/python/test/unit/quantization/test_qconfig.py
已删除
100644 → 0
浏览文件 @
4130dcd3
from
functools
import
partial
from
megengine.quantization
import
QConfig
,
tqt_qconfig
from
megengine.quantization.fake_quant
import
TQT
def
test_equal
():
qconfig
=
QConfig
(
weight_observer
=
None
,
act_observer
=
None
,
weight_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
True
),
act_fake_quant
=
partial
(
TQT
,
dtype
=
"qint8"
,
narrow_range
=
False
),
)
assert
qconfig
==
tqt_qconfig
imperative/python/test/unit/quantization/test_quantize.py
浏览文件 @
1d7dd001
...
...
@@ -33,7 +33,7 @@ from megengine.quantization.quantize import (
)
class
Net
(
Float
.
Module
):
class
Float
Net
(
Float
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
quant
=
Float
.
QuantStub
()
...
...
@@ -113,25 +113,25 @@ def test_reset_qconfig():
def
test_enable_and_disable_observer
():
net
=
init_qat_net
()
enable_observer
(
net
)
assert
net
.
quant
.
act_observer
.
enabled
==
True
assert
net
.
linear
.
weight_observer
.
enabled
==
True
assert
net
.
linear
.
act_observer
.
enabled
==
True
assert
net
.
quant
.
act_observer
.
enabled
is
True
assert
net
.
linear
.
weight_observer
.
enabled
is
True
assert
net
.
linear
.
act_observer
.
enabled
is
True
disable_observer
(
net
)
assert
net
.
quant
.
act_observer
.
enabled
==
False
assert
net
.
linear
.
weight_observer
.
enabled
==
False
assert
net
.
linear
.
act_observer
.
enabled
==
False
assert
net
.
quant
.
act_observer
.
enabled
is
False
assert
net
.
linear
.
weight_observer
.
enabled
is
False
assert
net
.
linear
.
act_observer
.
enabled
is
False
def
test_enable_and_disable_fake_quant
():
net
=
init_qat_net
()
disable_fake_quant
(
net
)
assert
net
.
quant
.
act_fake_quant
.
enabled
==
False
assert
net
.
linear
.
weight_fake_quant
.
enabled
==
False
assert
net
.
linear
.
act_fake_quant
.
enabled
==
False
assert
net
.
quant
.
act_fake_quant
.
enabled
is
False
assert
net
.
linear
.
weight_fake_quant
.
enabled
is
False
assert
net
.
linear
.
act_fake_quant
.
enabled
is
False
enable_fake_quant
(
net
)
assert
net
.
quant
.
act_fake_quant
.
enabled
==
True
assert
net
.
linear
.
weight_fake_quant
.
enabled
==
True
assert
net
.
linear
.
act_fake_quant
.
enabled
==
True
assert
net
.
quant
.
act_fake_quant
.
enabled
is
True
assert
net
.
linear
.
weight_fake_quant
.
enabled
is
True
assert
net
.
linear
.
act_fake_quant
.
enabled
is
True
def
init_observer
(
module
,
data
):
...
...
@@ -144,7 +144,7 @@ def init_observer(module, data):
def
test_enable_and_disable_all
():
x
=
Tensor
(
np
.
random
.
randint
(
1
,
10
,
size
=
(
3
,
3
)).
astype
(
np
.
float32
))
net
=
Net
()
net
=
Float
Net
()
y1
=
net
(
x
).
numpy
()
net
=
quantize_qat
(
net
,
min_max_fakequant_qconfig
)
...
...
@@ -162,7 +162,7 @@ def test_enable_and_disable_all():
def
test_quantize_qat
():
net
=
Net
()
net
=
Float
Net
()
qat_net
=
quantize_qat
(
net
,
inplace
=
False
,
qconfig
=
min_max_fakequant_qconfig
)
assert
isinstance
(
qat_net
.
quant
,
QAT
.
QuantStub
)
assert
isinstance
(
qat_net
.
linear
,
QAT
.
Linear
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录