Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a6f45641
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a6f45641
编写于
4月 27, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/dtype): modify some interface name and enrich comments
GitOrigin-RevId: f9217f6d27b2235aa1b541d0b7953f503dfd7d33
上级
d3730036
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
75 addition
and
65 deletion
+75
-65
python_module/megengine/_internal/dtype.py
python_module/megengine/_internal/dtype.py
+75
-65
未找到文件。
python_module/megengine/_internal/dtype.py
浏览文件 @
a6f45641
...
@@ -6,36 +6,25 @@
...
@@ -6,36 +6,25 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
from
.mgb
import
intb1
,
intb2
,
intb4
from
.mgb
import
intb1
,
intb2
,
intb4
_QuantDtypeMetadata
=
collections
.
namedtuple
(
"QuantDtypeMetadata"
,
[
"name"
,
"np_dtype_str"
,
"is_unsigned"
,
"qmin"
,
"qmax"
,]
)
_metadata_dict
=
{
_metadata_dict
=
{
"quint8"
:
{
"quint8"
:
_QuantDtypeMetadata
(
"Quantized8Asymm"
,
"uint8"
,
True
,
0
,
255
),
"is_unsigned"
:
True
,
"qint8"
:
_QuantDtypeMetadata
(
"QuantizedS8"
,
"int8"
,
False
,
-
128
,
127
),
"np_dtype_str"
:
"uint8"
,
"quint4"
:
_QuantDtypeMetadata
(
"Quantized4Asymm"
,
"uint8"
,
True
,
0
,
15
),
"mgb_dtype"
:
{
"name"
:
"Quantized8Asymm"
,
"qmin"
:
0
,
"qmax"
:
255
,},
"qint4"
:
_QuantDtypeMetadata
(
"QuantizedS4"
,
"int8"
,
False
,
-
8
,
7
),
},
"qint32"
:
_QuantDtypeMetadata
(
"qint8"
:
{
"QuantizedS32"
,
"int32"
,
False
,
-
(
2
**
31
),
2
**
31
-
1
,
"is_unsigned"
:
False
,
),
"np_dtype_str"
:
"int8"
,
"mgb_dtype"
:
{
"name"
:
"QuantizedS8"
,
"qmin"
:
-
128
,
"qmax"
:
127
,},
},
"quint4"
:
{
"is_unsigned"
:
True
,
"np_dtype_str"
:
"uint8"
,
"mgb_dtype"
:
{
"name"
:
"Quantized4Asymm"
,
"qmin"
:
0
,
"qmax"
:
15
,},
},
"qint4"
:
{
"is_unsigned"
:
False
,
"np_dtype_str"
:
"int8"
,
"mgb_dtype"
:
{
"name"
:
"QuantizedS4"
,
"qmin"
:
-
8
,
"qmax"
:
7
,},
},
"qint32"
:
{
"is_unsigned"
:
False
,
"np_dtype_str"
:
"int32"
,
"mgb_dtype"
:
{
"name"
:
"QuantizedS32"
,
"qmin"
:
-
(
2
**
31
),
"qmax"
:
2
**
31
-
1
,},
},
}
}
...
@@ -64,25 +53,48 @@ def get_zero_point(dtype):
...
@@ -64,25 +53,48 @@ def get_zero_point(dtype):
def
_check_zero_point
(
zp
:
int
,
dtype_str
:
str
):
def
_check_zero_point
(
zp
:
int
,
dtype_str
:
str
):
qmin
=
_metadata_dict
[
dtype_str
]
[
"mgb_dtype"
][
"qmin"
]
qmin
=
_metadata_dict
[
dtype_str
]
.
qmin
qmax
=
_metadata_dict
[
dtype_str
]
[
"mgb_dtype"
][
"qmax"
]
qmax
=
_metadata_dict
[
dtype_str
]
.
qmax
if
zp
<
qmin
or
zp
>
qmax
:
if
zp
<
qmin
or
zp
>
qmax
:
raise
ValueError
(
raise
ValueError
(
"zero_point should be within [{}, {}] for {}"
.
format
(
qmin
,
qmax
,
dtype_str
)
"zero_point should be within [{}, {}] for {}"
.
format
(
qmin
,
qmax
,
dtype_str
)
)
)
def
_get_dtype
(
dtype_str
:
str
,
scale
,
zp
):
def
get_quantized_dtype
(
dtype_str
:
str
,
scale
:
float
,
zp
:
Union
[
int
,
None
]):
if
zp
is
not
None
:
r
"""
if
int
(
zp
)
!=
zp
:
Get quantized dtype with metadata attribute according to _metadata_dict.
Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`.
:param dtype: a string indicating which dtype to return
:param scale: a number for scale to store in dtype's metadata
:param zp: a number for zero_point to store in dtype's metadata
"""
metadata
=
_metadata_dict
[
dtype_str
]
np_dtype_str
=
metadata
.
np_dtype_str
is_unsigned
=
metadata
.
is_unsigned
if
is_unsigned
:
if
zp
is
None
or
int
(
zp
)
!=
zp
:
raise
ValueError
(
"zero_point should be an integer"
)
raise
ValueError
(
"zero_point should be an integer"
)
zp
=
int
(
zp
)
zp
=
int
(
zp
)
_check_zero_point
(
zp
,
dtype_str
)
_check_zero_point
(
zp
,
dtype_str
)
metadata
=
_metadata_dict
[
dtype_str
][
"mgb_dtype"
]
np_dtype_str
=
_metadata_dict
[
dtype_str
][
"np_dtype_str"
]
return
np
.
dtype
(
return
np
.
dtype
(
np_dtype_str
,
np_dtype_str
,
metadata
=
{
"mgb_dtype"
:
{
**
metadata
,
"scale"
:
float
(
scale
),
"zero_point"
:
zp
,}},
metadata
=
{
"mgb_dtype"
:
{
"name"
:
metadata
.
name
,
"scale"
:
float
(
scale
),
"zero_point"
:
zp
,
}
},
)
else
:
return
np
.
dtype
(
np_dtype_str
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
metadata
.
name
,
"scale"
:
float
(
scale
)}},
)
)
...
@@ -92,7 +104,7 @@ def quint8(scale, zero_point):
...
@@ -92,7 +104,7 @@ def quint8(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint8 data type is
``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point)
float_val = scale * (uint8_val - zero_point)
"""
"""
return
_get
_dtype
(
"quint8"
,
scale
,
zero_point
)
return
get_quantized
_dtype
(
"quint8"
,
scale
,
zero_point
)
def
qint8
(
scale
):
def
qint8
(
scale
):
...
@@ -100,7 +112,7 @@ def qint8(scale):
...
@@ -100,7 +112,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value
Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val
represented by a qint8 data type is float_val = scale * int8_val
"""
"""
return
_get
_dtype
(
"qint8"
,
scale
,
None
)
return
get_quantized
_dtype
(
"qint8"
,
scale
,
None
)
def
qint32
(
scale
):
def
qint32
(
scale
):
...
@@ -108,7 +120,7 @@ def qint32(scale):
...
@@ -108,7 +120,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value
Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val
represented by a qint32 data type is float_val = scale * int32_val
"""
"""
return
_get
_dtype
(
"qint32"
,
scale
,
None
)
return
get_quantized
_dtype
(
"qint32"
,
scale
,
None
)
def
quint4
(
scale
,
zero_point
):
def
quint4
(
scale
,
zero_point
):
...
@@ -117,7 +129,7 @@ def quint4(scale, zero_point):
...
@@ -117,7 +129,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
float_val = scale * (uint4_val - zero_point)
"""
"""
return
_get
_dtype
(
"quint4"
,
scale
,
zero_point
)
return
get_quantized
_dtype
(
"quint4"
,
scale
,
zero_point
)
def
qint4
(
scale
):
def
qint4
(
scale
):
...
@@ -125,17 +137,17 @@ def qint4(scale):
...
@@ -125,17 +137,17 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
represented by a qint4 data type is float_val = scale * int4_val
"""
"""
return
_get
_dtype
(
"qint4"
,
scale
,
None
)
return
get_quantized
_dtype
(
"qint4"
,
scale
,
None
)
def
_convert_to_dtype
(
arr
:
np
.
ndarray
,
dtype
:
np
.
dtype
,
dtype_str
:
str
):
def
_convert_to_
quantized_
dtype
(
arr
:
np
.
ndarray
,
dtype
:
np
.
dtype
,
dtype_str
:
str
):
metadata
=
_metadata_dict
[
dtype_str
]
[
"mgb_dtype"
]
metadata
=
_metadata_dict
[
dtype_str
]
arr_metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
arr_metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
if
not
isinstance
(
arr
,
np
.
ndarray
):
if
not
isinstance
(
arr
,
np
.
ndarray
):
raise
ValueError
(
"arr parameter should be instance of np.ndarray"
)
raise
ValueError
(
"arr parameter should be instance of np.ndarray"
)
if
not
is_quantize
(
dtype
)
or
arr_metadata
[
"name"
]
!=
metadata
[
"name"
]
:
if
not
is_quantize
(
dtype
)
or
arr_metadata
[
"name"
]
!=
metadata
.
name
:
raise
ValueError
(
"dtype parameter should be a {} dtype"
.
format
(
dtype_str
))
raise
ValueError
(
"dtype parameter should be a {} dtype"
.
format
(
dtype_str
))
is_unsigned
=
_metadata_dict
[
dtype_str
][
"is_unsigned"
]
is_unsigned
=
metadata
.
is_unsigned
if
is_unsigned
:
if
is_unsigned
:
scale
,
zp
=
(
scale
,
zp
=
(
arr_metadata
[
"scale"
],
arr_metadata
[
"scale"
],
...
@@ -143,25 +155,23 @@ def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
...
@@ -143,25 +155,23 @@ def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
)
)
return
(
return
(
(
np
.
round
(
arr
/
scale
)
+
zp
)
(
np
.
round
(
arr
/
scale
)
+
zp
)
.
clip
(
metadata
[
"qmin"
],
metadata
[
"qmax"
]
)
.
clip
(
metadata
.
qmin
,
metadata
.
qmax
)
.
astype
(
dtype
)
.
astype
(
dtype
)
)
)
else
:
else
:
# don't trick to combine with is_unsigned
for consistency with cpp interface
# don't trick to combine with is_unsigned
, seeing ``get_quantized_dtype``
scale
=
arr_metadata
[
"scale"
]
scale
=
arr_metadata
[
"scale"
]
return
(
return
np
.
round
(
arr
/
scale
).
clip
(
metadata
.
qmin
,
metadata
.
qmax
).
astype
(
dtype
)
np
.
round
(
arr
/
scale
).
clip
(
metadata
[
"qmin"
],
metadata
[
"qmax"
]).
astype
(
dtype
)
)
def
_convert_from_dtype
(
arr
:
np
.
ndarray
,
dtype_str
:
str
):
def
_convert_from_
quantized_
dtype
(
arr
:
np
.
ndarray
,
dtype_str
:
str
):
metadata
=
_metadata_dict
[
dtype_str
]
[
"mgb_dtype"
]
metadata
=
_metadata_dict
[
dtype_str
]
arr_metadata
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
]
arr_metadata
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
]
if
not
isinstance
(
arr
,
np
.
ndarray
):
if
not
isinstance
(
arr
,
np
.
ndarray
):
raise
ValueError
(
"arr parameter should be instance of np.ndarray"
)
raise
ValueError
(
"arr parameter should be instance of np.ndarray"
)
if
not
is_quantize
(
arr
.
dtype
)
or
arr_metadata
[
"name"
]
!=
metadata
[
"name"
]
:
if
not
is_quantize
(
arr
.
dtype
)
or
arr_metadata
[
"name"
]
!=
metadata
.
name
:
raise
ValueError
(
"arr's dtype should be a {} dtype"
.
format
(
dtype_str
))
raise
ValueError
(
"arr's dtype should be a {} dtype"
.
format
(
dtype_str
))
is_unsigned
=
_metadata_dict
[
dtype_str
][
"is_unsigned"
]
is_unsigned
=
metadata
.
is_unsigned
if
is_unsigned
:
if
is_unsigned
:
scale
,
zp
=
(
scale
,
zp
=
(
arr_metadata
[
"scale"
],
arr_metadata
[
"scale"
],
...
@@ -169,7 +179,7 @@ def _convert_from_dtype(arr: np.ndarray, dtype_str: str):
...
@@ -169,7 +179,7 @@ def _convert_from_dtype(arr: np.ndarray, dtype_str: str):
)
)
return
(
arr
.
astype
(
np
.
float32
)
-
zp
)
*
scale
return
(
arr
.
astype
(
np
.
float32
)
-
zp
)
*
scale
else
:
else
:
# don't trick to combine with is_unsigned
for consistency with cpp interface
# don't trick to combine with is_unsigned
, seeing ``get_quantized_dtype``
scale
=
arr_metadata
[
"scale"
]
scale
=
arr_metadata
[
"scale"
]
return
(
arr
.
astype
(
np
.
float32
))
*
scale
return
(
arr
.
astype
(
np
.
float32
))
*
scale
...
@@ -181,7 +191,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype):
...
@@ -181,7 +191,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param arr: Input ndarray.
:param q: Target data type, should be a quint8.
:param q: Target data type, should be a quint8.
"""
"""
return
_convert_to_dtype
(
arr
,
q
,
"quint8"
)
return
_convert_to_
quantized_
dtype
(
arr
,
q
,
"quint8"
)
def
convert_from_quint8
(
arr
:
np
.
ndarray
):
def
convert_from_quint8
(
arr
:
np
.
ndarray
):
...
@@ -190,7 +200,7 @@ def convert_from_quint8(arr: np.ndarray):
...
@@ -190,7 +200,7 @@ def convert_from_quint8(arr: np.ndarray):
:param arr: Input ndarray.
:param arr: Input ndarray.
"""
"""
return
_convert_from_dtype
(
arr
,
"quint8"
)
return
_convert_from_
quantized_
dtype
(
arr
,
"quint8"
)
def
convert_to_qint8
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
def
convert_to_qint8
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
...
@@ -200,7 +210,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype):
...
@@ -200,7 +210,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
:param q: Target data type, should be a qint8.
"""
"""
return
_convert_to_dtype
(
arr
,
q
,
"qint8"
)
return
_convert_to_
quantized_
dtype
(
arr
,
q
,
"qint8"
)
def
convert_from_qint8
(
arr
:
np
.
ndarray
):
def
convert_from_qint8
(
arr
:
np
.
ndarray
):
...
@@ -209,7 +219,7 @@ def convert_from_qint8(arr: np.ndarray):
...
@@ -209,7 +219,7 @@ def convert_from_qint8(arr: np.ndarray):
:param arr: Input ndarray.
:param arr: Input ndarray.
"""
"""
return
_convert_from_dtype
(
arr
,
"qint8"
)
return
_convert_from_
quantized_
dtype
(
arr
,
"qint8"
)
def
convert_to_qint32
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
def
convert_to_qint32
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
...
@@ -219,7 +229,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype):
...
@@ -219,7 +229,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
:param q: Target data type, should be a qint8.
"""
"""
return
_convert_to_dtype
(
arr
,
q
,
"qint32"
)
return
_convert_to_
quantized_
dtype
(
arr
,
q
,
"qint32"
)
def
convert_from_qint32
(
arr
):
def
convert_from_qint32
(
arr
):
...
@@ -228,7 +238,7 @@ def convert_from_qint32(arr):
...
@@ -228,7 +238,7 @@ def convert_from_qint32(arr):
:param arr: Input ndarray.
:param arr: Input ndarray.
"""
"""
return
_convert_from_dtype
(
arr
,
"qint32"
)
return
_convert_from_
quantized_
dtype
(
arr
,
"qint32"
)
def
convert_to_quint4
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
def
convert_to_quint4
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
...
@@ -238,7 +248,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype):
...
@@ -238,7 +248,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param arr: Input ndarray.
:param q: Target data type, should be a quint4.
:param q: Target data type, should be a quint4.
"""
"""
return
_convert_to_dtype
(
arr
,
q
,
"quint4"
)
return
_convert_to_
quantized_
dtype
(
arr
,
q
,
"quint4"
)
def
convert_from_quint4
(
arr
:
np
.
ndarray
):
def
convert_from_quint4
(
arr
:
np
.
ndarray
):
...
@@ -247,7 +257,7 @@ def convert_from_quint4(arr: np.ndarray):
...
@@ -247,7 +257,7 @@ def convert_from_quint4(arr: np.ndarray):
:param arr: Input ndarray.
:param arr: Input ndarray.
"""
"""
return
_convert_from_dtype
(
arr
,
"quint4"
)
return
_convert_from_
quantized_
dtype
(
arr
,
"quint4"
)
def
convert_to_qint4
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
def
convert_to_qint4
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
...
@@ -257,7 +267,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype):
...
@@ -257,7 +267,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param arr: Input ndarray.
:param q: Target data type, should be a qint4.
:param q: Target data type, should be a qint4.
"""
"""
return
_convert_to_dtype
(
arr
,
q
,
"qint4"
)
return
_convert_to_
quantized_
dtype
(
arr
,
q
,
"qint4"
)
def
convert_from_qint4
(
arr
:
np
.
ndarray
):
def
convert_from_qint4
(
arr
:
np
.
ndarray
):
...
@@ -266,4 +276,4 @@ def convert_from_qint4(arr: np.ndarray):
...
@@ -266,4 +276,4 @@ def convert_from_qint4(arr: np.ndarray):
:param arr: Input ndarray.
:param arr: Input ndarray.
"""
"""
return
_convert_from_dtype
(
arr
,
"qint4"
)
return
_convert_from_
quantized_
dtype
(
arr
,
"qint4"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录