提交 a6f45641 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(mge/dtype): modify some interface name and enrich comments

GitOrigin-RevId: f9217f6d27b2235aa1b541d0b7953f503dfd7d33
上级 d3730036
......@@ -6,36 +6,25 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Union
import numpy as np
from .mgb import intb1, intb2, intb4
_QuantDtypeMetadata = collections.namedtuple(
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",]
)
_metadata_dict = {
"quint8": {
"is_unsigned": True,
"np_dtype_str": "uint8",
"mgb_dtype": {"name": "Quantized8Asymm", "qmin": 0, "qmax": 255,},
},
"qint8": {
"is_unsigned": False,
"np_dtype_str": "int8",
"mgb_dtype": {"name": "QuantizedS8", "qmin": -128, "qmax": 127,},
},
"quint4": {
"is_unsigned": True,
"np_dtype_str": "uint8",
"mgb_dtype": {"name": "Quantized4Asymm", "qmin": 0, "qmax": 15,},
},
"qint4": {
"is_unsigned": False,
"np_dtype_str": "int8",
"mgb_dtype": {"name": "QuantizedS4", "qmin": -8, "qmax": 7,},
},
"qint32": {
"is_unsigned": False,
"np_dtype_str": "int32",
"mgb_dtype": {"name": "QuantizedS32", "qmin": -(2 ** 31), "qmax": 2 ** 31 - 1,},
},
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255),
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127),
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15),
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7),
"qint32": _QuantDtypeMetadata(
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1,
),
}
......@@ -64,26 +53,49 @@ def get_zero_point(dtype):
def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str]["mgb_dtype"]["qmin"]
qmax = _metadata_dict[dtype_str]["mgb_dtype"]["qmax"]
qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str].qmax
if zp < qmin or zp > qmax:
raise ValueError(
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str)
)
def _get_dtype(dtype_str: str, scale, zp):
if zp is not None:
if int(zp) != zp:
def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
r"""
Get quantized dtype with metadata attribute according to _metadata_dict.
Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`.
:param dtype: a string indicating which dtype to return
:param scale: a number for scale to store in dtype's metadata
:param zp: a number for zero_point to store in dtype's metadata
"""
metadata = _metadata_dict[dtype_str]
np_dtype_str = metadata.np_dtype_str
is_unsigned = metadata.is_unsigned
if is_unsigned:
if zp is None or int(zp) != zp:
raise ValueError("zero_point should be an integer")
zp = int(zp)
_check_zero_point(zp, dtype_str)
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
np_dtype_str = _metadata_dict[dtype_str]["np_dtype_str"]
return np.dtype(
np_dtype_str,
metadata={"mgb_dtype": {**metadata, "scale": float(scale), "zero_point": zp,}},
)
return np.dtype(
np_dtype_str,
metadata={
"mgb_dtype": {
"name": metadata.name,
"scale": float(scale),
"zero_point": zp,
}
},
)
else:
return np.dtype(
np_dtype_str,
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}},
)
def quint8(scale, zero_point):
......@@ -92,7 +104,7 @@ def quint8(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point)
"""
return _get_dtype("quint8", scale, zero_point)
return get_quantized_dtype("quint8", scale, zero_point)
def qint8(scale):
......@@ -100,7 +112,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val
"""
return _get_dtype("qint8", scale, None)
return get_quantized_dtype("qint8", scale, None)
def qint32(scale):
......@@ -108,7 +120,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val
"""
return _get_dtype("qint32", scale, None)
return get_quantized_dtype("qint32", scale, None)
def quint4(scale, zero_point):
......@@ -117,7 +129,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
return _get_dtype("quint4", scale, zero_point)
return get_quantized_dtype("quint4", scale, zero_point)
def qint4(scale):
......@@ -125,17 +137,17 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return _get_dtype("qint4", scale, None)
return get_quantized_dtype("qint4", scale, None)
def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
metadata = _metadata_dict[dtype_str]
arr_metadata = dtype.metadata["mgb_dtype"]
if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(dtype) or arr_metadata["name"] != metadata["name"]:
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str))
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"]
is_unsigned = metadata.is_unsigned
if is_unsigned:
scale, zp = (
arr_metadata["scale"],
......@@ -143,25 +155,23 @@ def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
)
return (
(np.round(arr / scale) + zp)
.clip(metadata["qmin"], metadata["qmax"])
.clip(metadata.qmin, metadata.qmax)
.astype(dtype)
)
else:
# don't trick to combine with is_unsigned for consistency with cpp interface
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale = arr_metadata["scale"]
return (
np.round(arr / scale).clip(metadata["qmin"], metadata["qmax"]).astype(dtype)
)
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype)
def _convert_from_dtype(arr: np.ndarray, dtype_str: str):
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str):
metadata = _metadata_dict[dtype_str]
arr_metadata = arr.dtype.metadata["mgb_dtype"]
if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata["name"]:
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str))
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"]
is_unsigned = metadata.is_unsigned
if is_unsigned:
scale, zp = (
arr_metadata["scale"],
......@@ -169,7 +179,7 @@ def _convert_from_dtype(arr: np.ndarray, dtype_str: str):
)
return (arr.astype(np.float32) - zp) * scale
else:
# don't trick to combine with is_unsigned for consistency with cpp interface
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale = arr_metadata["scale"]
return (arr.astype(np.float32)) * scale
......@@ -181,7 +191,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a quint8.
"""
return _convert_to_dtype(arr, q, "quint8")
return _convert_to_quantized_dtype(arr, q, "quint8")
def convert_from_quint8(arr: np.ndarray):
......@@ -190,7 +200,7 @@ def convert_from_quint8(arr: np.ndarray):
:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "quint8")
return _convert_from_quantized_dtype(arr, "quint8")
def convert_to_qint8(arr: np.ndarray, q: np.dtype):
......@@ -200,7 +210,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
"""
return _convert_to_dtype(arr, q, "qint8")
return _convert_to_quantized_dtype(arr, q, "qint8")
def convert_from_qint8(arr: np.ndarray):
......@@ -209,7 +219,7 @@ def convert_from_qint8(arr: np.ndarray):
:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "qint8")
return _convert_from_quantized_dtype(arr, "qint8")
def convert_to_qint32(arr: np.ndarray, q: np.dtype):
......@@ -219,7 +229,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
"""
return _convert_to_dtype(arr, q, "qint32")
return _convert_to_quantized_dtype(arr, q, "qint32")
def convert_from_qint32(arr):
......@@ -228,7 +238,7 @@ def convert_from_qint32(arr):
:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "qint32")
return _convert_from_quantized_dtype(arr, "qint32")
def convert_to_quint4(arr: np.ndarray, q: np.dtype):
......@@ -238,7 +248,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a quint4.
"""
return _convert_to_dtype(arr, q, "quint4")
return _convert_to_quantized_dtype(arr, q, "quint4")
def convert_from_quint4(arr: np.ndarray):
......@@ -247,7 +257,7 @@ def convert_from_quint4(arr: np.ndarray):
:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "quint4")
return _convert_from_quantized_dtype(arr, "quint4")
def convert_to_qint4(arr: np.ndarray, q: np.dtype):
......@@ -257,7 +267,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint4.
"""
return _convert_to_dtype(arr, q, "qint4")
return _convert_to_quantized_dtype(arr, q, "qint4")
def convert_from_qint4(arr: np.ndarray):
......@@ -266,4 +276,4 @@ def convert_from_qint4(arr: np.ndarray):
:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "qint4")
return _convert_from_quantized_dtype(arr, "qint4")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册