Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
abf82cfb
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
abf82cfb
编写于
4月 26, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/dtype): export `_metadata_dict` for consistent dtype property
GitOrigin-RevId: 6840c0b6b49df1deb8c6acd1895cf96fe96f771a
上级
f582c192
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
123 addition
and
133 deletion
+123
-133
python_module/megengine/_internal/dtype.py
python_module/megengine/_internal/dtype.py
+123
-133
未找到文件。
python_module/megengine/_internal/dtype.py
浏览文件 @
abf82cfb
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -11,6 +10,34 @@ import numpy as np
from
.mgb
import
intb1
,
intb2
,
intb4
_metadata_dict
=
{
"quint8"
:
{
"is_unsigned"
:
True
,
"np_dtype_str"
:
"uint8"
,
"mgb_dtype"
:
{
"name"
:
"Quantized8Asymm"
,
"qmin"
:
0
,
"qmax"
:
255
,},
},
"qint8"
:
{
"is_unsigned"
:
False
,
"np_dtype_str"
:
"int8"
,
"mgb_dtype"
:
{
"name"
:
"QuantizedS8"
,
"qmin"
:
-
128
,
"qmax"
:
127
,},
},
"quint4"
:
{
"is_unsigned"
:
True
,
"np_dtype_str"
:
"uint8"
,
"mgb_dtype"
:
{
"name"
:
"Quantized4Asymm"
,
"qmin"
:
0
,
"qmax"
:
15
,},
},
"qint4"
:
{
"is_unsigned"
:
False
,
"np_dtype_str"
:
"int8"
,
"mgb_dtype"
:
{
"name"
:
"QuantizedS4"
,
"qmin"
:
-
8
,
"qmax"
:
7
,},
},
"qint32"
:
{
"is_unsigned"
:
False
,
"np_dtype_str"
:
"int32"
,
"mgb_dtype"
:
{
"name"
:
"QuantizedS32"
,
"qmin"
:
-
(
2
**
31
),
"qmax"
:
2
**
31
-
1
,},
},
}
def
is_quantize
(
dtype
):
return
(
...
...
@@ -36,26 +63,36 @@ def get_zero_point(dtype):
return
metadata
[
"zero_point"
]
def
_check_zero_point
(
zp
:
int
,
dtype_str
:
str
):
qmin
=
_metadata_dict
[
dtype_str
][
"mgb_dtype"
][
"qmin"
]
qmax
=
_metadata_dict
[
dtype_str
][
"mgb_dtype"
][
"qmax"
]
if
zp
<
qmin
or
zp
>
qmax
:
raise
ValueError
(
"zero_point should be within [{}, {}] for {}"
.
format
(
qmin
,
qmax
,
dtype_str
)
)
def
_get_dtype
(
dtype_str
:
str
,
scale
,
zp
):
if
zp
is
not
None
:
if
int
(
zp
)
!=
zp
:
raise
ValueError
(
"zero_point should be an integer"
)
zp
=
int
(
zp
)
_check_zero_point
(
zp
,
dtype_str
)
metadata
=
_metadata_dict
[
dtype_str
][
"mgb_dtype"
]
np_dtype_str
=
_metadata_dict
[
dtype_str
][
"np_dtype_str"
]
return
np
.
dtype
(
np_dtype_str
,
metadata
=
{
"mgb_dtype"
:
{
**
metadata
,
"scale"
:
float
(
scale
),
"zero_point"
:
zp
,}},
)
def
quint8
(
scale
,
zero_point
):
"""
Consturct a quantized unsigned int8 data type with ``scale`` (float) and
``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point)
"""
int_zp
=
int
(
zero_point
)
assert
int_zp
==
zero_point
,
"zero_point should be an integer"
if
int_zp
<
0
or
int_zp
>
255
:
raise
ValueError
(
"zero_point should be within [0, 255] for quint8"
)
return
np
.
dtype
(
np
.
uint8
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
"Quantized8Asymm"
,
"scale"
:
float
(
scale
),
"zero_point"
:
int
(
zero_point
),
}
},
)
return
_get_dtype
(
"quint8"
,
scale
,
zero_point
)
def
qint8
(
scale
):
...
...
@@ -63,9 +100,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val
"""
return
np
.
dtype
(
np
.
int8
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
"QuantizedS8"
,
"scale"
:
float
(
scale
)}}
)
return
_get_dtype
(
"qint8"
,
scale
,
None
)
def
qint32
(
scale
):
...
...
@@ -73,10 +108,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val
"""
return
np
.
dtype
(
np
.
int32
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
"QuantizedS32"
,
"scale"
:
float
(
scale
)}},
)
return
_get_dtype
(
"qint32"
,
scale
,
None
)
def
quint4
(
scale
,
zero_point
):
...
...
@@ -85,20 +117,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
int_zp
=
int
(
zero_point
)
assert
int_zp
==
zero_point
,
"zero_point should be an integer"
if
int_zp
<
0
or
int_zp
>
15
:
raise
ValueError
(
"zero_point should be within [0, 15] for quint4"
)
return
np
.
dtype
(
np
.
uint8
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
"Quantized4Asymm"
,
"scale"
:
float
(
scale
),
"zero_point"
:
int
(
zero_point
),
}
},
)
return
_get_dtype
(
"quint4"
,
scale
,
zero_point
)
def
qint4
(
scale
):
...
...
@@ -106,94 +125,101 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return
np
.
dtype
(
np
.
int8
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
"QuantizedS4"
,
"scale"
:
float
(
scale
)}}
)
def
convert_to_quint8
(
arr
,
q
):
return
_get_dtype
(
"qint4"
,
scale
,
None
)
def
_convert_to_dtype
(
arr
:
np
.
ndarray
,
dtype
:
np
.
dtype
,
dtype_str
:
str
):
metadata
=
_metadata_dict
[
dtype_str
][
"mgb_dtype"
]
arr_metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
if
not
isinstance
(
arr
,
np
.
ndarray
):
raise
ValueError
(
"arr parameter should be instance of np.ndarray"
)
if
not
is_quantize
(
dtype
)
or
arr_metadata
[
"name"
]
!=
metadata
[
"name"
]:
raise
ValueError
(
"dtype parameter should be a {} dtype"
.
format
(
dtype_str
))
is_unsigned
=
_metadata_dict
[
dtype_str
][
"is_unsigned"
]
if
is_unsigned
:
scale
,
zp
=
(
arr_metadata
[
"scale"
],
arr_metadata
[
"zero_point"
],
)
return
(
(
np
.
round
(
arr
/
scale
)
+
zp
)
.
clip
(
metadata
[
"qmin"
],
metadata
[
"qmax"
])
.
astype
(
dtype
)
)
else
:
# don't trick to combine with is_unsigned for consistency with cpp interface
scale
=
arr_metadata
[
"scale"
]
return
(
np
.
round
(
arr
/
scale
).
clip
(
metadata
[
"qmin"
],
metadata
[
"qmax"
]).
astype
(
dtype
)
)
def
_convert_from_dtype
(
arr
:
np
.
ndarray
,
dtype_str
:
str
):
metadata
=
_metadata_dict
[
dtype_str
][
"mgb_dtype"
]
arr_metadata
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
]
if
not
isinstance
(
arr
,
np
.
ndarray
):
raise
ValueError
(
"arr parameter should be instance of np.ndarray"
)
if
not
is_quantize
(
arr
.
dtype
)
or
arr_metadata
[
"name"
]
!=
metadata
[
"name"
]:
raise
ValueError
(
"arr's dtype should be a {} dtype"
.
format
(
dtype_str
))
is_unsigned
=
_metadata_dict
[
dtype_str
][
"is_unsigned"
]
if
is_unsigned
:
scale
,
zp
=
(
arr_metadata
[
"scale"
],
arr_metadata
[
"zero_point"
],
)
return
(
arr
.
astype
(
np
.
float32
)
-
zp
)
*
scale
else
:
# don't trick to combine with is_unsigned for consistency with cpp interface
scale
=
arr_metadata
[
"scale"
]
return
(
arr
.
astype
(
np
.
float32
))
*
scale
def
convert_to_quint8
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
"""
Quantize a float NumPy ndarray into a quint8 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a quint8.
:type q: :class:`np.dtype`
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
q
.
metadata
and
q
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"Quantized8Asymm"
),
"q should be a quint8 dtype"
scale
,
zp
=
q
.
metadata
[
"mgb_dtype"
][
"scale"
],
q
.
metadata
[
"mgb_dtype"
][
"zero_point"
]
return
(
np
.
round
(
arr
/
scale
)
+
zp
).
clip
(
0
,
255
).
astype
(
q
)
return
_convert_to_dtype
(
arr
,
q
,
"quint8"
)
def
convert_from_quint8
(
arr
):
def
convert_from_quint8
(
arr
:
np
.
ndarray
):
"""
Dequantize a quint8 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
arr
.
dtype
.
metadata
and
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"Quantized8Asymm"
),
"arr should be a ndarray with quint8 dtype"
scale
,
zp
=
(
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
],
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"zero_point"
],
)
return
(
arr
.
astype
(
np
.
float32
)
-
zp
)
*
scale
return
_convert_from_dtype
(
arr
,
"quint8"
)
def
convert_to_qint8
(
arr
,
q
):
def
convert_to_qint8
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
"""
Quantize a float NumPy ndarray into a qint8 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint8.
:type q: :class:`np.dtype`
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
q
.
metadata
and
q
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"QuantizedS8"
),
"q should be a qint8 dtype"
scale
=
q
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
(
np
.
round
(
arr
/
scale
)).
clip
(
-
128
,
127
).
astype
(
q
)
return
_convert_to_dtype
(
arr
,
q
,
"qint8"
)
def
convert_from_qint8
(
arr
):
def
convert_from_qint8
(
arr
:
np
.
ndarray
):
"""
Dequantize a qint8 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
arr
.
dtype
.
metadata
and
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"QuantizedS8"
),
"arr should be a ndarray with qint8 dtype"
scale
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
arr
.
astype
(
np
.
float32
)
*
scale
return
_convert_from_dtype
(
arr
,
"qint8"
)
def
convert_to_qint32
(
arr
,
q
):
def
convert_to_qint32
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
"""
Quantize a float NumPy ndarray into a qint32 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint8.
:type q: :class:`np.dtype`
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
q
.
metadata
and
q
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"QuantizedS32"
),
"q should be a qint32 dtype"
scale
=
q
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
(
np
.
round
(
arr
/
scale
)).
clip
(
-
(
2
**
31
),
2
**
31
).
astype
(
q
)
return
_convert_to_dtype
(
arr
,
q
,
"qint32"
)
def
convert_from_qint32
(
arr
):
...
...
@@ -202,78 +228,42 @@ def convert_from_qint32(arr):
:param arr: Input ndarray.
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
arr
.
dtype
.
metadata
and
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"QuantizedS32"
),
"arr should be a ndarray with qint8 dtype"
scale
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
arr
.
astype
(
np
.
float32
)
*
scale
return
_convert_from_dtype
(
arr
,
"qint32"
)
def
convert_to_quint4
(
arr
,
q
):
def
convert_to_quint4
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
"""
Quantize a float NumPy ndarray into a quint4 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a quint4.
:type q: :class:`np.dtype`
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
q
.
metadata
and
q
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"Quantized4Asymm"
),
"q should be a quint4 dtype"
scale
,
zp
=
q
.
metadata
[
"mgb_dtype"
][
"scale"
],
q
.
metadata
[
"mgb_dtype"
][
"zero_point"
]
return
(
np
.
round
(
arr
/
scale
)
+
zp
).
clip
(
0
,
15
).
astype
(
q
)
return
_convert_to_dtype
(
arr
,
q
,
"quint4"
)
def
convert_from_quint4
(
arr
):
def
convert_from_quint4
(
arr
:
np
.
ndarray
):
"""
Dequantize a quint4 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
arr
.
dtype
.
metadata
and
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"Quantized4Asymm"
),
"arr should be a ndarray with quint4 dtype"
scale
,
zp
=
(
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
],
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"zero_point"
],
)
return
(
arr
.
astype
(
np
.
float32
)
-
zp
)
*
scale
return
_convert_from_dtype
(
arr
,
"quint4"
)
def
convert_to_qint4
(
arr
,
q
):
def
convert_to_qint4
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
"""
Quantize a float NumPy ndarray into a qint4 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint4.
:type q: :class:`np.dtype`
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
q
.
metadata
and
q
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"QuantizedS4"
),
"q should be a qint4 dtype"
scale
=
q
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
(
np
.
round
(
arr
/
scale
)).
clip
(
-
8
,
7
).
astype
(
q
)
return
_convert_to_dtype
(
arr
,
q
,
"qint4"
)
def
convert_from_qint4
(
arr
):
def
convert_from_qint4
(
arr
:
np
.
ndarray
):
"""
Dequantize a qint4 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
arr
.
dtype
.
metadata
and
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"QuantizedS4"
),
"arr should be a ndarray with qint4 dtype"
scale
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
arr
.
astype
(
np
.
float32
)
*
scale
return
_convert_from_dtype
(
arr
,
"qint4"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录