From a6f456415f38885b3554da2f3531e24dddc40d43 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 27 Apr 2020 20:57:45 +0800 Subject: [PATCH] refactor(mge/dtype): modify some interface name and enrich comments GitOrigin-RevId: f9217f6d27b2235aa1b541d0b7953f503dfd7d33 --- python_module/megengine/_internal/dtype.py | 140 +++++++++++---------- 1 file changed, 75 insertions(+), 65 deletions(-) diff --git a/python_module/megengine/_internal/dtype.py b/python_module/megengine/_internal/dtype.py index 4efefd3a2..dc5a4220f 100644 --- a/python_module/megengine/_internal/dtype.py +++ b/python_module/megengine/_internal/dtype.py @@ -6,36 +6,25 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections +from typing import Union + import numpy as np from .mgb import intb1, intb2, intb4 +_QuantDtypeMetadata = collections.namedtuple( + "QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] +) + _metadata_dict = { - "quint8": { - "is_unsigned": True, - "np_dtype_str": "uint8", - "mgb_dtype": {"name": "Quantized8Asymm", "qmin": 0, "qmax": 255,}, - }, - "qint8": { - "is_unsigned": False, - "np_dtype_str": "int8", - "mgb_dtype": {"name": "QuantizedS8", "qmin": -128, "qmax": 127,}, - }, - "quint4": { - "is_unsigned": True, - "np_dtype_str": "uint8", - "mgb_dtype": {"name": "Quantized4Asymm", "qmin": 0, "qmax": 15,}, - }, - "qint4": { - "is_unsigned": False, - "np_dtype_str": "int8", - "mgb_dtype": {"name": "QuantizedS4", "qmin": -8, "qmax": 7,}, - }, - "qint32": { - "is_unsigned": False, - "np_dtype_str": "int32", - "mgb_dtype": {"name": "QuantizedS32", "qmin": -(2 ** 31), "qmax": 2 ** 31 - 1,}, - }, + "quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), + "qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), + "quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), + "qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), + "qint32": _QuantDtypeMetadata( + "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, + ), } @@ -64,26 +53,49 @@ def get_zero_point(dtype): def _check_zero_point(zp: int, dtype_str: str): - qmin = _metadata_dict[dtype_str]["mgb_dtype"]["qmin"] - qmax = _metadata_dict[dtype_str]["mgb_dtype"]["qmax"] + qmin = _metadata_dict[dtype_str].qmin + qmax = _metadata_dict[dtype_str].qmax if zp < qmin or zp > qmax: raise ValueError( "zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) ) -def _get_dtype(dtype_str: str, scale, zp): - if zp is not None: - if int(zp) != zp: +def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): + r""" + Get quantized dtype with metadata attribute according to _metadata_dict. + + Note that unsigned dtype must have ``zero_point`` and signed dtype must + not have ``zero_point``, to be consitent with tensor generated by calling + compiled function from `CompGraph.compile(inputs, outspec)`. + + :param dtype: a string indicating which dtype to return + :param scale: a number for scale to store in dtype's metadata + :param zp: a number for zero_point to store in dtype's metadata + """ + metadata = _metadata_dict[dtype_str] + np_dtype_str = metadata.np_dtype_str + is_unsigned = metadata.is_unsigned + if is_unsigned: + if zp is None or int(zp) != zp: raise ValueError("zero_point should be an integer") zp = int(zp) _check_zero_point(zp, dtype_str) - metadata = _metadata_dict[dtype_str]["mgb_dtype"] - np_dtype_str = _metadata_dict[dtype_str]["np_dtype_str"] - return np.dtype( - np_dtype_str, - metadata={"mgb_dtype": {**metadata, "scale": float(scale), "zero_point": zp,}}, - ) + return np.dtype( + np_dtype_str, + metadata={ + "mgb_dtype": { + "name": metadata.name, + "scale": float(scale), + "zero_point": zp, + } + }, + ) + else: + return np.dtype( + np_dtype_str, + metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, + ) def quint8(scale, zero_point): @@ -92,7 +104,7 @@ def quint8(scale, zero_point): ``zero_point`` (uint8). The real value represented by a quint8 data type is float_val = scale * (uint8_val - zero_point) """ - return _get_dtype("quint8", scale, zero_point) + return get_quantized_dtype("quint8", scale, zero_point) def qint8(scale): @@ -100,7 +112,7 @@ def qint8(scale): Construct a quantized int8 data type with ``scale`` (float). The real value represented by a qint8 data type is float_val = scale * int8_val """ - return _get_dtype("qint8", scale, None) + return get_quantized_dtype("qint8", scale, None) def qint32(scale): @@ -108,7 +120,7 @@ def qint32(scale): Construct a quantized int32 data type with ``scale`` (float). The real value represented by a qint32 data type is float_val = scale * int32_val """ - return _get_dtype("qint32", scale, None) + return get_quantized_dtype("qint32", scale, None) def quint4(scale, zero_point): @@ -117,7 +129,7 @@ def quint4(scale, zero_point): ``zero_point`` (uint8). The real value represented by a quint4 data type is float_val = scale * (uint4_val - zero_point) """ - return _get_dtype("quint4", scale, zero_point) + return get_quantized_dtype("quint4", scale, zero_point) def qint4(scale): @@ -125,17 +137,17 @@ def qint4(scale): Construct a quantized int4 data type with ``scale`` (float). The real value represented by a qint4 data type is float_val = scale * int4_val """ - return _get_dtype("qint4", scale, None) + return get_quantized_dtype("qint4", scale, None) -def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): - metadata = _metadata_dict[dtype_str]["mgb_dtype"] +def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): + metadata = _metadata_dict[dtype_str] arr_metadata = dtype.metadata["mgb_dtype"] if not isinstance(arr, np.ndarray): raise ValueError("arr parameter should be instance of np.ndarray") - if not is_quantize(dtype) or arr_metadata["name"] != metadata["name"]: + if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) - is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] + is_unsigned = metadata.is_unsigned if is_unsigned: scale, zp = ( arr_metadata["scale"], @@ -143,25 +155,23 @@ def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): ) return ( (np.round(arr / scale) + zp) - .clip(metadata["qmin"], metadata["qmax"]) + .clip(metadata.qmin, metadata.qmax) .astype(dtype) ) else: - # don't trick to combine with is_unsigned for consistency with cpp interface + # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` scale = arr_metadata["scale"] - return ( - np.round(arr / scale).clip(metadata["qmin"], metadata["qmax"]).astype(dtype) - ) + return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) -def _convert_from_dtype(arr: np.ndarray, dtype_str: str): - metadata = _metadata_dict[dtype_str]["mgb_dtype"] +def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): + metadata = _metadata_dict[dtype_str] arr_metadata = arr.dtype.metadata["mgb_dtype"] if not isinstance(arr, np.ndarray): raise ValueError("arr parameter should be instance of np.ndarray") - if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata["name"]: + if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) - is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] + is_unsigned = metadata.is_unsigned if is_unsigned: scale, zp = ( arr_metadata["scale"], @@ -169,7 +179,7 @@ def _convert_from_dtype(arr: np.ndarray, dtype_str: str): ) return (arr.astype(np.float32) - zp) * scale else: - # don't trick to combine with is_unsigned for consistency with cpp interface + # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` scale = arr_metadata["scale"] return (arr.astype(np.float32)) * scale @@ -181,7 +191,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a quint8. """ - return _convert_to_dtype(arr, q, "quint8") + return _convert_to_quantized_dtype(arr, q, "quint8") def convert_from_quint8(arr: np.ndarray): @@ -190,7 +200,7 @@ def convert_from_quint8(arr: np.ndarray): :param arr: Input ndarray. """ - return _convert_from_dtype(arr, "quint8") + return _convert_from_quantized_dtype(arr, "quint8") def convert_to_qint8(arr: np.ndarray, q: np.dtype): @@ -200,7 +210,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a qint8. """ - return _convert_to_dtype(arr, q, "qint8") + return _convert_to_quantized_dtype(arr, q, "qint8") def convert_from_qint8(arr: np.ndarray): @@ -209,7 +219,7 @@ def convert_from_qint8(arr: np.ndarray): :param arr: Input ndarray. """ - return _convert_from_dtype(arr, "qint8") + return _convert_from_quantized_dtype(arr, "qint8") def convert_to_qint32(arr: np.ndarray, q: np.dtype): @@ -219,7 +229,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a qint8. """ - return _convert_to_dtype(arr, q, "qint32") + return _convert_to_quantized_dtype(arr, q, "qint32") def convert_from_qint32(arr): @@ -228,7 +238,7 @@ def convert_from_qint32(arr): :param arr: Input ndarray. """ - return _convert_from_dtype(arr, "qint32") + return _convert_from_quantized_dtype(arr, "qint32") def convert_to_quint4(arr: np.ndarray, q: np.dtype): @@ -238,7 +248,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a quint4. """ - return _convert_to_dtype(arr, q, "quint4") + return _convert_to_quantized_dtype(arr, q, "quint4") def convert_from_quint4(arr: np.ndarray): @@ -247,7 +257,7 @@ def convert_from_quint4(arr: np.ndarray): :param arr: Input ndarray. """ - return _convert_from_dtype(arr, "quint4") + return _convert_from_quantized_dtype(arr, "quint4") def convert_to_qint4(arr: np.ndarray, q: np.dtype): @@ -257,7 +267,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a qint4. """ - return _convert_to_dtype(arr, q, "qint4") + return _convert_to_quantized_dtype(arr, q, "qint4") def convert_from_qint4(arr: np.ndarray): @@ -266,4 +276,4 @@ def convert_from_qint4(arr: np.ndarray): :param arr: Input ndarray. """ - return _convert_from_dtype(arr, "qint4") + return _convert_from_quantized_dtype(arr, "qint4") -- GitLab