提交 a6f45641 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(mge/dtype): modify some interface name and enrich comments

GitOrigin-RevId: f9217f6d27b2235aa1b541d0b7953f503dfd7d33
上级 d3730036
...@@ -6,36 +6,25 @@ ...@@ -6,36 +6,25 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Union
import numpy as np import numpy as np
from .mgb import intb1, intb2, intb4 from .mgb import intb1, intb2, intb4
_QuantDtypeMetadata = collections.namedtuple(
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",]
)
_metadata_dict = { _metadata_dict = {
"quint8": { "quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255),
"is_unsigned": True, "qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127),
"np_dtype_str": "uint8", "quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15),
"mgb_dtype": {"name": "Quantized8Asymm", "qmin": 0, "qmax": 255,}, "qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7),
}, "qint32": _QuantDtypeMetadata(
"qint8": { "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1,
"is_unsigned": False, ),
"np_dtype_str": "int8",
"mgb_dtype": {"name": "QuantizedS8", "qmin": -128, "qmax": 127,},
},
"quint4": {
"is_unsigned": True,
"np_dtype_str": "uint8",
"mgb_dtype": {"name": "Quantized4Asymm", "qmin": 0, "qmax": 15,},
},
"qint4": {
"is_unsigned": False,
"np_dtype_str": "int8",
"mgb_dtype": {"name": "QuantizedS4", "qmin": -8, "qmax": 7,},
},
"qint32": {
"is_unsigned": False,
"np_dtype_str": "int32",
"mgb_dtype": {"name": "QuantizedS32", "qmin": -(2 ** 31), "qmax": 2 ** 31 - 1,},
},
} }
...@@ -64,25 +53,48 @@ def get_zero_point(dtype): ...@@ -64,25 +53,48 @@ def get_zero_point(dtype):
def _check_zero_point(zp: int, dtype_str: str): def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str]["mgb_dtype"]["qmin"] qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str]["mgb_dtype"]["qmax"] qmax = _metadata_dict[dtype_str].qmax
if zp < qmin or zp > qmax: if zp < qmin or zp > qmax:
raise ValueError( raise ValueError(
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) "zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str)
) )
def _get_dtype(dtype_str: str, scale, zp): def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
if zp is not None: r"""
if int(zp) != zp: Get quantized dtype with metadata attribute according to _metadata_dict.
Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`.
:param dtype: a string indicating which dtype to return
:param scale: a number for scale to store in dtype's metadata
:param zp: a number for zero_point to store in dtype's metadata
"""
metadata = _metadata_dict[dtype_str]
np_dtype_str = metadata.np_dtype_str
is_unsigned = metadata.is_unsigned
if is_unsigned:
if zp is None or int(zp) != zp:
raise ValueError("zero_point should be an integer") raise ValueError("zero_point should be an integer")
zp = int(zp) zp = int(zp)
_check_zero_point(zp, dtype_str) _check_zero_point(zp, dtype_str)
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
np_dtype_str = _metadata_dict[dtype_str]["np_dtype_str"]
return np.dtype( return np.dtype(
np_dtype_str, np_dtype_str,
metadata={"mgb_dtype": {**metadata, "scale": float(scale), "zero_point": zp,}}, metadata={
"mgb_dtype": {
"name": metadata.name,
"scale": float(scale),
"zero_point": zp,
}
},
)
else:
return np.dtype(
np_dtype_str,
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}},
) )
...@@ -92,7 +104,7 @@ def quint8(scale, zero_point): ...@@ -92,7 +104,7 @@ def quint8(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint8 data type is ``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point) float_val = scale * (uint8_val - zero_point)
""" """
return _get_dtype("quint8", scale, zero_point) return get_quantized_dtype("quint8", scale, zero_point)
def qint8(scale): def qint8(scale):
...@@ -100,7 +112,7 @@ def qint8(scale): ...@@ -100,7 +112,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val represented by a qint8 data type is float_val = scale * int8_val
""" """
return _get_dtype("qint8", scale, None) return get_quantized_dtype("qint8", scale, None)
def qint32(scale): def qint32(scale):
...@@ -108,7 +120,7 @@ def qint32(scale): ...@@ -108,7 +120,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val represented by a qint32 data type is float_val = scale * int32_val
""" """
return _get_dtype("qint32", scale, None) return get_quantized_dtype("qint32", scale, None)
def quint4(scale, zero_point): def quint4(scale, zero_point):
...@@ -117,7 +129,7 @@ def quint4(scale, zero_point): ...@@ -117,7 +129,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is ``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point) float_val = scale * (uint4_val - zero_point)
""" """
return _get_dtype("quint4", scale, zero_point) return get_quantized_dtype("quint4", scale, zero_point)
def qint4(scale): def qint4(scale):
...@@ -125,17 +137,17 @@ def qint4(scale): ...@@ -125,17 +137,17 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val represented by a qint4 data type is float_val = scale * int4_val
""" """
return _get_dtype("qint4", scale, None) return get_quantized_dtype("qint4", scale, None)
def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
metadata = _metadata_dict[dtype_str]["mgb_dtype"] metadata = _metadata_dict[dtype_str]
arr_metadata = dtype.metadata["mgb_dtype"] arr_metadata = dtype.metadata["mgb_dtype"]
if not isinstance(arr, np.ndarray): if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray") raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(dtype) or arr_metadata["name"] != metadata["name"]: if not is_quantize(dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) raise ValueError("dtype parameter should be a {} dtype".format(dtype_str))
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] is_unsigned = metadata.is_unsigned
if is_unsigned: if is_unsigned:
scale, zp = ( scale, zp = (
arr_metadata["scale"], arr_metadata["scale"],
...@@ -143,25 +155,23 @@ def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): ...@@ -143,25 +155,23 @@ def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
) )
return ( return (
(np.round(arr / scale) + zp) (np.round(arr / scale) + zp)
.clip(metadata["qmin"], metadata["qmax"]) .clip(metadata.qmin, metadata.qmax)
.astype(dtype) .astype(dtype)
) )
else: else:
# don't trick to combine with is_unsigned for consistency with cpp interface # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale = arr_metadata["scale"] scale = arr_metadata["scale"]
return ( return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype)
np.round(arr / scale).clip(metadata["qmin"], metadata["qmax"]).astype(dtype)
)
def _convert_from_dtype(arr: np.ndarray, dtype_str: str): def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str):
metadata = _metadata_dict[dtype_str]["mgb_dtype"] metadata = _metadata_dict[dtype_str]
arr_metadata = arr.dtype.metadata["mgb_dtype"] arr_metadata = arr.dtype.metadata["mgb_dtype"]
if not isinstance(arr, np.ndarray): if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray") raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata["name"]: if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) raise ValueError("arr's dtype should be a {} dtype".format(dtype_str))
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] is_unsigned = metadata.is_unsigned
if is_unsigned: if is_unsigned:
scale, zp = ( scale, zp = (
arr_metadata["scale"], arr_metadata["scale"],
...@@ -169,7 +179,7 @@ def _convert_from_dtype(arr: np.ndarray, dtype_str: str): ...@@ -169,7 +179,7 @@ def _convert_from_dtype(arr: np.ndarray, dtype_str: str):
) )
return (arr.astype(np.float32) - zp) * scale return (arr.astype(np.float32) - zp) * scale
else: else:
# don't trick to combine with is_unsigned for consistency with cpp interface # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale = arr_metadata["scale"] scale = arr_metadata["scale"]
return (arr.astype(np.float32)) * scale return (arr.astype(np.float32)) * scale
...@@ -181,7 +191,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype): ...@@ -181,7 +191,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a quint8. :param q: Target data type, should be a quint8.
""" """
return _convert_to_dtype(arr, q, "quint8") return _convert_to_quantized_dtype(arr, q, "quint8")
def convert_from_quint8(arr: np.ndarray): def convert_from_quint8(arr: np.ndarray):
...@@ -190,7 +200,7 @@ def convert_from_quint8(arr: np.ndarray): ...@@ -190,7 +200,7 @@ def convert_from_quint8(arr: np.ndarray):
:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_dtype(arr, "quint8") return _convert_from_quantized_dtype(arr, "quint8")
def convert_to_qint8(arr: np.ndarray, q: np.dtype): def convert_to_qint8(arr: np.ndarray, q: np.dtype):
...@@ -200,7 +210,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype): ...@@ -200,7 +210,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a qint8. :param q: Target data type, should be a qint8.
""" """
return _convert_to_dtype(arr, q, "qint8") return _convert_to_quantized_dtype(arr, q, "qint8")
def convert_from_qint8(arr: np.ndarray): def convert_from_qint8(arr: np.ndarray):
...@@ -209,7 +219,7 @@ def convert_from_qint8(arr: np.ndarray): ...@@ -209,7 +219,7 @@ def convert_from_qint8(arr: np.ndarray):
:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_dtype(arr, "qint8") return _convert_from_quantized_dtype(arr, "qint8")
def convert_to_qint32(arr: np.ndarray, q: np.dtype): def convert_to_qint32(arr: np.ndarray, q: np.dtype):
...@@ -219,7 +229,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype): ...@@ -219,7 +229,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a qint8. :param q: Target data type, should be a qint8.
""" """
return _convert_to_dtype(arr, q, "qint32") return _convert_to_quantized_dtype(arr, q, "qint32")
def convert_from_qint32(arr): def convert_from_qint32(arr):
...@@ -228,7 +238,7 @@ def convert_from_qint32(arr): ...@@ -228,7 +238,7 @@ def convert_from_qint32(arr):
:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_dtype(arr, "qint32") return _convert_from_quantized_dtype(arr, "qint32")
def convert_to_quint4(arr: np.ndarray, q: np.dtype): def convert_to_quint4(arr: np.ndarray, q: np.dtype):
...@@ -238,7 +248,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype): ...@@ -238,7 +248,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a quint4. :param q: Target data type, should be a quint4.
""" """
return _convert_to_dtype(arr, q, "quint4") return _convert_to_quantized_dtype(arr, q, "quint4")
def convert_from_quint4(arr: np.ndarray): def convert_from_quint4(arr: np.ndarray):
...@@ -247,7 +257,7 @@ def convert_from_quint4(arr: np.ndarray): ...@@ -247,7 +257,7 @@ def convert_from_quint4(arr: np.ndarray):
:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_dtype(arr, "quint4") return _convert_from_quantized_dtype(arr, "quint4")
def convert_to_qint4(arr: np.ndarray, q: np.dtype): def convert_to_qint4(arr: np.ndarray, q: np.dtype):
...@@ -257,7 +267,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype): ...@@ -257,7 +267,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a qint4. :param q: Target data type, should be a qint4.
""" """
return _convert_to_dtype(arr, q, "qint4") return _convert_to_quantized_dtype(arr, q, "qint4")
def convert_from_qint4(arr: np.ndarray): def convert_from_qint4(arr: np.ndarray):
...@@ -266,4 +276,4 @@ def convert_from_qint4(arr: np.ndarray): ...@@ -266,4 +276,4 @@ def convert_from_qint4(arr: np.ndarray):
:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_dtype(arr, "qint4") return _convert_from_quantized_dtype(arr, "qint4")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册