提交 abf82cfb 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/dtype): export `_metadata_dict` for consistent dtype property

GitOrigin-RevId: 6840c0b6b49df1deb8c6acd1895cf96fe96f771a
上级 f582c192
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -11,6 +10,34 @@ import numpy as np
from .mgb import intb1, intb2, intb4
_metadata_dict = {
"quint8": {
"is_unsigned": True,
"np_dtype_str": "uint8",
"mgb_dtype": {"name": "Quantized8Asymm", "qmin": 0, "qmax": 255,},
},
"qint8": {
"is_unsigned": False,
"np_dtype_str": "int8",
"mgb_dtype": {"name": "QuantizedS8", "qmin": -128, "qmax": 127,},
},
"quint4": {
"is_unsigned": True,
"np_dtype_str": "uint8",
"mgb_dtype": {"name": "Quantized4Asymm", "qmin": 0, "qmax": 15,},
},
"qint4": {
"is_unsigned": False,
"np_dtype_str": "int8",
"mgb_dtype": {"name": "QuantizedS4", "qmin": -8, "qmax": 7,},
},
"qint32": {
"is_unsigned": False,
"np_dtype_str": "int32",
"mgb_dtype": {"name": "QuantizedS32", "qmin": -(2 ** 31), "qmax": 2 ** 31 - 1,},
},
}
def is_quantize(dtype):
return (
......@@ -36,26 +63,36 @@ def get_zero_point(dtype):
return metadata["zero_point"]
def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str]["mgb_dtype"]["qmin"]
qmax = _metadata_dict[dtype_str]["mgb_dtype"]["qmax"]
if zp < qmin or zp > qmax:
raise ValueError(
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str)
)
def _get_dtype(dtype_str: str, scale, zp):
if zp is not None:
if int(zp) != zp:
raise ValueError("zero_point should be an integer")
zp = int(zp)
_check_zero_point(zp, dtype_str)
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
np_dtype_str = _metadata_dict[dtype_str]["np_dtype_str"]
return np.dtype(
np_dtype_str,
metadata={"mgb_dtype": {**metadata, "scale": float(scale), "zero_point": zp,}},
)
def quint8(scale, zero_point):
"""
Consturct a quantized unsigned int8 data type with ``scale`` (float) and
``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point)
"""
int_zp = int(zero_point)
assert int_zp == zero_point, "zero_point should be an integer"
if int_zp < 0 or int_zp > 255:
raise ValueError("zero_point should be within [0, 255] for quint8")
return np.dtype(
np.uint8,
metadata={
"mgb_dtype": {
"name": "Quantized8Asymm",
"scale": float(scale),
"zero_point": int(zero_point),
}
},
)
return _get_dtype("quint8", scale, zero_point)
def qint8(scale):
......@@ -63,9 +100,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val
"""
return np.dtype(
np.int8, metadata={"mgb_dtype": {"name": "QuantizedS8", "scale": float(scale)}}
)
return _get_dtype("qint8", scale, None)
def qint32(scale):
......@@ -73,10 +108,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val
"""
return np.dtype(
np.int32,
metadata={"mgb_dtype": {"name": "QuantizedS32", "scale": float(scale)}},
)
return _get_dtype("qint32", scale, None)
def quint4(scale, zero_point):
......@@ -85,20 +117,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
int_zp = int(zero_point)
assert int_zp == zero_point, "zero_point should be an integer"
if int_zp < 0 or int_zp > 15:
raise ValueError("zero_point should be within [0, 15] for quint4")
return np.dtype(
np.uint8,
metadata={
"mgb_dtype": {
"name": "Quantized4Asymm",
"scale": float(scale),
"zero_point": int(zero_point),
}
},
)
return _get_dtype("quint4", scale, zero_point)
def qint4(scale):
......@@ -106,94 +125,101 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return np.dtype(
np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}}
return _get_dtype("qint4", scale, None)
def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
arr_metadata = dtype.metadata["mgb_dtype"]
if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(dtype) or arr_metadata["name"] != metadata["name"]:
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str))
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"]
if is_unsigned:
scale, zp = (
arr_metadata["scale"],
arr_metadata["zero_point"],
)
return (
(np.round(arr / scale) + zp)
.clip(metadata["qmin"], metadata["qmax"])
.astype(dtype)
)
else:
# don't trick to combine with is_unsigned for consistency with cpp interface
scale = arr_metadata["scale"]
return (
np.round(arr / scale).clip(metadata["qmin"], metadata["qmax"]).astype(dtype)
)
def _convert_from_dtype(arr: np.ndarray, dtype_str: str):
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
arr_metadata = arr.dtype.metadata["mgb_dtype"]
if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata["name"]:
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str))
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"]
if is_unsigned:
scale, zp = (
arr_metadata["scale"],
arr_metadata["zero_point"],
)
return (arr.astype(np.float32) - zp) * scale
else:
# don't trick to combine with is_unsigned for consistency with cpp interface
scale = arr_metadata["scale"]
return (arr.astype(np.float32)) * scale
def convert_to_quint8(arr, q):
def convert_to_quint8(arr: np.ndarray, q: np.dtype):
"""
Quantize a float NumPy ndarray into a quint8 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a quint8.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata
and q.metadata["mgb_dtype"]["name"] == "Quantized8Asymm"
), "q should be a quint8 dtype"
scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"]
return (np.round(arr / scale) + zp).clip(0, 255).astype(q)
return _convert_to_dtype(arr, q, "quint8")
def convert_from_quint8(arr):
def convert_from_quint8(arr: np.ndarray):
"""
Dequantize a quint8 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized8Asymm"
), "arr should be a ndarray with quint8 dtype"
scale, zp = (
arr.dtype.metadata["mgb_dtype"]["scale"],
arr.dtype.metadata["mgb_dtype"]["zero_point"],
)
return (arr.astype(np.float32) - zp) * scale
return _convert_from_dtype(arr, "quint8")
def convert_to_qint8(arr, q):
def convert_to_qint8(arr: np.ndarray, q: np.dtype):
"""
Quantize a float NumPy ndarray into a qint8 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint8.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS8"
), "q should be a qint8 dtype"
scale = q.metadata["mgb_dtype"]["scale"]
return (np.round(arr / scale)).clip(-128, 127).astype(q)
return _convert_to_dtype(arr, q, "qint8")
def convert_from_qint8(arr):
def convert_from_qint8(arr: np.ndarray):
"""
Dequantize a qint8 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS8"
), "arr should be a ndarray with qint8 dtype"
scale = arr.dtype.metadata["mgb_dtype"]["scale"]
return arr.astype(np.float32) * scale
return _convert_from_dtype(arr, "qint8")
def convert_to_qint32(arr, q):
def convert_to_qint32(arr: np.ndarray, q: np.dtype):
"""
Quantize a float NumPy ndarray into a qint32 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint8.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS32"
), "q should be a qint32 dtype"
scale = q.metadata["mgb_dtype"]["scale"]
return (np.round(arr / scale)).clip(-(2 ** 31), 2 ** 31).astype(q)
return _convert_to_dtype(arr, q, "qint32")
def convert_from_qint32(arr):
......@@ -202,78 +228,42 @@ def convert_from_qint32(arr):
:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS32"
), "arr should be a ndarray with qint8 dtype"
scale = arr.dtype.metadata["mgb_dtype"]["scale"]
return arr.astype(np.float32) * scale
return _convert_from_dtype(arr, "qint32")
def convert_to_quint4(arr, q):
def convert_to_quint4(arr: np.ndarray, q: np.dtype):
"""
Quantize a float NumPy ndarray into a quint4 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a quint4.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata
and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm"
), "q should be a quint4 dtype"
scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"]
return (np.round(arr / scale) + zp).clip(0, 15).astype(q)
return _convert_to_dtype(arr, q, "quint4")
def convert_from_quint4(arr):
def convert_from_quint4(arr: np.ndarray):
"""
Dequantize a quint4 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm"
), "arr should be a ndarray with quint4 dtype"
scale, zp = (
arr.dtype.metadata["mgb_dtype"]["scale"],
arr.dtype.metadata["mgb_dtype"]["zero_point"],
)
return (arr.astype(np.float32) - zp) * scale
return _convert_from_dtype(arr, "quint4")
def convert_to_qint4(arr, q):
def convert_to_qint4(arr: np.ndarray, q: np.dtype):
"""
Quantize a float NumPy ndarray into a qint4 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint4.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4"
), "q should be a qint4 dtype"
scale = q.metadata["mgb_dtype"]["scale"]
return (np.round(arr / scale)).clip(-8, 7).astype(q)
return _convert_to_dtype(arr, q, "qint4")
def convert_from_qint4(arr):
def convert_from_qint4(arr: np.ndarray):
"""
Dequantize a qint4 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4"
), "arr should be a ndarray with qint4 dtype"
scale = arr.dtype.metadata["mgb_dtype"]["scale"]
return arr.astype(np.float32) * scale
return _convert_from_dtype(arr, "qint4")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册