diff --git a/python_module/megengine/_internal/dtype.py b/python_module/megengine/_internal/dtype.py index 2bb41a19a855ffc23ce183484a6666cd74c4743c..4efefd3a2d146c6a32ec8b4c1a2a506d87185e72 100644 --- a/python_module/megengine/_internal/dtype.py +++ b/python_module/megengine/_internal/dtype.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -11,6 +10,34 @@ import numpy as np from .mgb import intb1, intb2, intb4 +_metadata_dict = { + "quint8": { + "is_unsigned": True, + "np_dtype_str": "uint8", + "mgb_dtype": {"name": "Quantized8Asymm", "qmin": 0, "qmax": 255,}, + }, + "qint8": { + "is_unsigned": False, + "np_dtype_str": "int8", + "mgb_dtype": {"name": "QuantizedS8", "qmin": -128, "qmax": 127,}, + }, + "quint4": { + "is_unsigned": True, + "np_dtype_str": "uint8", + "mgb_dtype": {"name": "Quantized4Asymm", "qmin": 0, "qmax": 15,}, + }, + "qint4": { + "is_unsigned": False, + "np_dtype_str": "int8", + "mgb_dtype": {"name": "QuantizedS4", "qmin": -8, "qmax": 7,}, + }, + "qint32": { + "is_unsigned": False, + "np_dtype_str": "int32", + "mgb_dtype": {"name": "QuantizedS32", "qmin": -(2 ** 31), "qmax": 2 ** 31 - 1,}, + }, +} + def is_quantize(dtype): return ( @@ -36,26 +63,36 @@ def get_zero_point(dtype): return metadata["zero_point"] +def _check_zero_point(zp: int, dtype_str: str): + qmin = _metadata_dict[dtype_str]["mgb_dtype"]["qmin"] + qmax = _metadata_dict[dtype_str]["mgb_dtype"]["qmax"] + if zp < qmin or zp > qmax: + raise ValueError( + "zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) + ) + + +def _get_dtype(dtype_str: str, scale, zp): + if zp is not None: + if int(zp) != zp: + raise ValueError("zero_point should be an integer") + zp = int(zp) + _check_zero_point(zp, dtype_str) + metadata = _metadata_dict[dtype_str]["mgb_dtype"] + np_dtype_str = _metadata_dict[dtype_str]["np_dtype_str"] + return np.dtype( + np_dtype_str, + metadata={"mgb_dtype": {**metadata, "scale": float(scale), "zero_point": zp,}}, + ) + + def quint8(scale, zero_point): """ Consturct a quantized unsigned int8 data type with ``scale`` (float) and ``zero_point`` (uint8). The real value represented by a quint8 data type is float_val = scale * (uint8_val - zero_point) """ - int_zp = int(zero_point) - assert int_zp == zero_point, "zero_point should be an integer" - if int_zp < 0 or int_zp > 255: - raise ValueError("zero_point should be within [0, 255] for quint8") - return np.dtype( - np.uint8, - metadata={ - "mgb_dtype": { - "name": "Quantized8Asymm", - "scale": float(scale), - "zero_point": int(zero_point), - } - }, - ) + return _get_dtype("quint8", scale, zero_point) def qint8(scale): @@ -63,9 +100,7 @@ def qint8(scale): Construct a quantized int8 data type with ``scale`` (float). The real value represented by a qint8 data type is float_val = scale * int8_val """ - return np.dtype( - np.int8, metadata={"mgb_dtype": {"name": "QuantizedS8", "scale": float(scale)}} - ) + return _get_dtype("qint8", scale, None) def qint32(scale): @@ -73,10 +108,7 @@ def qint32(scale): Construct a quantized int32 data type with ``scale`` (float). The real value represented by a qint32 data type is float_val = scale * int32_val """ - return np.dtype( - np.int32, - metadata={"mgb_dtype": {"name": "QuantizedS32", "scale": float(scale)}}, - ) + return _get_dtype("qint32", scale, None) def quint4(scale, zero_point): @@ -85,20 +117,7 @@ def quint4(scale, zero_point): ``zero_point`` (uint8). The real value represented by a quint4 data type is float_val = scale * (uint4_val - zero_point) """ - int_zp = int(zero_point) - assert int_zp == zero_point, "zero_point should be an integer" - if int_zp < 0 or int_zp > 15: - raise ValueError("zero_point should be within [0, 15] for quint4") - return np.dtype( - np.uint8, - metadata={ - "mgb_dtype": { - "name": "Quantized4Asymm", - "scale": float(scale), - "zero_point": int(zero_point), - } - }, - ) + return _get_dtype("quint4", scale, zero_point) def qint4(scale): @@ -106,94 +125,101 @@ def qint4(scale): Construct a quantized int4 data type with ``scale`` (float). The real value represented by a qint4 data type is float_val = scale * int4_val """ - return np.dtype( - np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}} - ) - - -def convert_to_quint8(arr, q): + return _get_dtype("qint4", scale, None) + + +def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): + metadata = _metadata_dict[dtype_str]["mgb_dtype"] + arr_metadata = dtype.metadata["mgb_dtype"] + if not isinstance(arr, np.ndarray): + raise ValueError("arr parameter should be instance of np.ndarray") + if not is_quantize(dtype) or arr_metadata["name"] != metadata["name"]: + raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) + is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] + if is_unsigned: + scale, zp = ( + arr_metadata["scale"], + arr_metadata["zero_point"], + ) + return ( + (np.round(arr / scale) + zp) + .clip(metadata["qmin"], metadata["qmax"]) + .astype(dtype) + ) + else: + # don't trick to combine with is_unsigned for consistency with cpp interface + scale = arr_metadata["scale"] + return ( + np.round(arr / scale).clip(metadata["qmin"], metadata["qmax"]).astype(dtype) + ) + + +def _convert_from_dtype(arr: np.ndarray, dtype_str: str): + metadata = _metadata_dict[dtype_str]["mgb_dtype"] + arr_metadata = arr.dtype.metadata["mgb_dtype"] + if not isinstance(arr, np.ndarray): + raise ValueError("arr parameter should be instance of np.ndarray") + if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata["name"]: + raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) + is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] + if is_unsigned: + scale, zp = ( + arr_metadata["scale"], + arr_metadata["zero_point"], + ) + return (arr.astype(np.float32) - zp) * scale + else: + # don't trick to combine with is_unsigned for consistency with cpp interface + scale = arr_metadata["scale"] + return (arr.astype(np.float32)) * scale + + +def convert_to_quint8(arr: np.ndarray, q: np.dtype): """ Quantize a float NumPy ndarray into a quint8 one with specified params. :param arr: Input ndarray. - :type arr: :class:`np.ndarray` :param q: Target data type, should be a quint8. - :type q: :class:`np.dtype` """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in q.metadata - and q.metadata["mgb_dtype"]["name"] == "Quantized8Asymm" - ), "q should be a quint8 dtype" - scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] - return (np.round(arr / scale) + zp).clip(0, 255).astype(q) + return _convert_to_dtype(arr, q, "quint8") -def convert_from_quint8(arr): +def convert_from_quint8(arr: np.ndarray): """ Dequantize a quint8 NumPy ndarray into a float one. :param arr: Input ndarray. """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in arr.dtype.metadata - and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized8Asymm" - ), "arr should be a ndarray with quint8 dtype" - scale, zp = ( - arr.dtype.metadata["mgb_dtype"]["scale"], - arr.dtype.metadata["mgb_dtype"]["zero_point"], - ) - return (arr.astype(np.float32) - zp) * scale + return _convert_from_dtype(arr, "quint8") -def convert_to_qint8(arr, q): +def convert_to_qint8(arr: np.ndarray, q: np.dtype): """ Quantize a float NumPy ndarray into a qint8 one with specified params. :param arr: Input ndarray. - :type arr: :class:`np.ndarray` :param q: Target data type, should be a qint8. - :type q: :class:`np.dtype` """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS8" - ), "q should be a qint8 dtype" - scale = q.metadata["mgb_dtype"]["scale"] - return (np.round(arr / scale)).clip(-128, 127).astype(q) + return _convert_to_dtype(arr, q, "qint8") -def convert_from_qint8(arr): +def convert_from_qint8(arr: np.ndarray): """ Dequantize a qint8 NumPy ndarray into a float one. :param arr: Input ndarray. """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in arr.dtype.metadata - and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS8" - ), "arr should be a ndarray with qint8 dtype" - scale = arr.dtype.metadata["mgb_dtype"]["scale"] - return arr.astype(np.float32) * scale + return _convert_from_dtype(arr, "qint8") -def convert_to_qint32(arr, q): +def convert_to_qint32(arr: np.ndarray, q: np.dtype): """ Quantize a float NumPy ndarray into a qint32 one with specified params. :param arr: Input ndarray. - :type arr: :class:`np.ndarray` :param q: Target data type, should be a qint8. - :type q: :class:`np.dtype` """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS32" - ), "q should be a qint32 dtype" - scale = q.metadata["mgb_dtype"]["scale"] - return (np.round(arr / scale)).clip(-(2 ** 31), 2 ** 31).astype(q) + return _convert_to_dtype(arr, q, "qint32") def convert_from_qint32(arr): @@ -202,78 +228,42 @@ def convert_from_qint32(arr): :param arr: Input ndarray. """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in arr.dtype.metadata - and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS32" - ), "arr should be a ndarray with qint8 dtype" - scale = arr.dtype.metadata["mgb_dtype"]["scale"] - return arr.astype(np.float32) * scale + return _convert_from_dtype(arr, "qint32") -def convert_to_quint4(arr, q): +def convert_to_quint4(arr: np.ndarray, q: np.dtype): """ Quantize a float NumPy ndarray into a quint4 one with specified params. :param arr: Input ndarray. - :type arr: :class:`np.ndarray` :param q: Target data type, should be a quint4. - :type q: :class:`np.dtype` """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in q.metadata - and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" - ), "q should be a quint4 dtype" - scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] - return (np.round(arr / scale) + zp).clip(0, 15).astype(q) + return _convert_to_dtype(arr, q, "quint4") -def convert_from_quint4(arr): +def convert_from_quint4(arr: np.ndarray): """ Dequantize a quint4 NumPy ndarray into a float one. :param arr: Input ndarray. """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in arr.dtype.metadata - and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" - ), "arr should be a ndarray with quint4 dtype" - scale, zp = ( - arr.dtype.metadata["mgb_dtype"]["scale"], - arr.dtype.metadata["mgb_dtype"]["zero_point"], - ) - return (arr.astype(np.float32) - zp) * scale + return _convert_from_dtype(arr, "quint4") -def convert_to_qint4(arr, q): +def convert_to_qint4(arr: np.ndarray, q: np.dtype): """ Quantize a float NumPy ndarray into a qint4 one with specified params. :param arr: Input ndarray. - :type arr: :class:`np.ndarray` :param q: Target data type, should be a qint4. - :type q: :class:`np.dtype` """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4" - ), "q should be a qint4 dtype" - scale = q.metadata["mgb_dtype"]["scale"] - return (np.round(arr / scale)).clip(-8, 7).astype(q) + return _convert_to_dtype(arr, q, "qint4") -def convert_from_qint4(arr): +def convert_from_qint4(arr: np.ndarray): """ Dequantize a qint4 NumPy ndarray into a float one. :param arr: Input ndarray. """ - assert isinstance(arr, np.ndarray) - assert ( - "mgb_dtype" in arr.dtype.metadata - and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4" - ), "arr should be a ndarray with qint4 dtype" - scale = arr.dtype.metadata["mgb_dtype"]["scale"] - return arr.astype(np.float32) * scale + return _convert_from_dtype(arr, "qint4")