diff --git a/imperative/python/megengine/core/ops/builtin/__init__.py b/imperative/python/megengine/core/ops/builtin/__init__.py index 2029f277f082571a590419a5ff8661045a34543b..13598687effbd6f542e2c7cc86951174b9675458 100644 --- a/imperative/python/megengine/core/ops/builtin/__init__.py +++ b/imperative/python/megengine/core/ops/builtin/__init__.py @@ -6,9 +6,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import warnings -from typing import Union - from ..._imperative_rt import OpDef, ops __all__ = ["OpDef"] diff --git a/imperative/python/megengine/core/tensor/dtype.py b/imperative/python/megengine/core/tensor/dtype.py index ccbdd90163748c86e4011f5546a87cc18a47280b..f68fde1215674aac35af500573ee2b258f62b508 100644 --- a/imperative/python/megengine/core/tensor/dtype.py +++ b/imperative/python/megengine/core/tensor/dtype.py @@ -5,22 +5,24 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - -import collections +from collections import namedtuple from typing import Union import numpy as np -# normal dtype related -from .._imperative_rt import bfloat16, intb1, intb2, intb4 from .._imperative_rt.common import ( + bfloat16, get_scale, get_zero_point, + intb1, + intb2, + intb4, is_dtype_equal, is_quantize, ) +# normal dtype related def is_lowbit(dtype): return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) @@ -30,34 +32,80 @@ def is_bfloat16(dtype): # quantization dtype related -_QuantDtypeMetadata = collections.namedtuple( - "QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] -) -_metadata_dict = { - "quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), - "qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), - "quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), - "qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), - "qint32": _QuantDtypeMetadata( - "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, +# use namedtuple to make class immutable, comparable and easy to print +class QuantDtypeMeta( + namedtuple( + "QuantDtypeMeta", + ["name", "cname", "np_dtype_str", "qmin", "qmax", "is_unsigned"], + ) +): + r""" + Store metadata for quantize dtype. Could be used to create custom quant dtype + for QAT when the network don't need to be converted for inference, but only + to export network metadata for third-party platform inference. + + :param name: a unique name string. + :param cname: used in :func:`~.create_quantized_dtype` for model dump and inference. + :param np_dtype_str: used in :func:`~.create_quantized_dtype` to generate ``np.dtype``. + :param qmin: a int number indicating quant dtype's lowerbound. + :param qmax: a int number indicating quant dtype's upperbound. + :param is_unsigned: a helper value that could be inference from np_dtype_str. + """ + + def __new__( + cls, + name: str, + cname: str, + np_dtype_str: str, + qmin: int, + qmax: int, + is_unsigned: bool = None, + ): + assert isinstance(np_dtype_str, str) + is_unsigned = np_dtype_str[0] == "u" if is_unsigned is None else is_unsigned + return super().__new__(cls, name, cname, np_dtype_str, qmin, qmax, is_unsigned) + + def __copy__(self): + return self + + def __deepcopy__(self, _): + """ + Ignore deepcopy so that a dtype meta can be treated as singleton, for more + strict check in :meth:`~.FakeQuantize.fake_quant_forward`. + """ + return self + + +_builtin_quant_dtypes = { + "quint8": QuantDtypeMeta("quint8", "Quantized8Asymm", "uint8", 0, 255), + "qint8": QuantDtypeMeta("qint8", "QuantizedS8", "int8", -128, 127), + "qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127), + "quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15), + "qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7), + "qint32": QuantDtypeMeta( + "qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1, ), # NOTE: int2 is not supported for model dump yet - "quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), - "qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), + "quint2": QuantDtypeMeta("quint2", None, "uint8", 0, 3), + "qint2": QuantDtypeMeta("qint2", None, "int8", -2, 1), } -def _check_zero_point(zp: int, dtype_str: str): - qmin = _metadata_dict[dtype_str].qmin - qmax = _metadata_dict[dtype_str].qmax +def _check_zero_point(zp: int, dtype_meta: QuantDtypeMeta): + qmin = dtype_meta.qmin + qmax = dtype_meta.qmax if zp < qmin or zp > qmax: raise ValueError( - "zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) + "zero_point should be within [{}, {}] for {}".format( + qmin, qmax, dtype_meta.name + ) ) -def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): +def create_quantized_dtype( + dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] +): r""" Get quantized dtype with metadata attribute according to _metadata_dict. @@ -65,32 +113,34 @@ def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): not have ``zero_point``, to be consitent with tensor generated by calling compiled function from `CompGraph.compile(inputs, outspec)`. - :param dtype: a string indicating which dtype to return + :param dtype_meta: a QuantDtypeMeta indicating which dtype to return. the + ``cname`` attribute cannot be ``None``. :param scale: a number for scale to store in dtype's metadata :param zp: a number for zero_point to store in dtype's metadata """ - metadata = _metadata_dict[dtype_str] - np_dtype_str = metadata.np_dtype_str - is_unsigned = metadata.is_unsigned - if is_unsigned: + if dtype_meta.cname is None: + raise ValueError("dtype {} without cname attr is not supported.") + if dtype_meta.is_unsigned: if zp is None or int(zp) != zp: raise ValueError("zero_point should be an integer") zp = int(zp) - _check_zero_point(zp, dtype_str) + _check_zero_point(zp, dtype_meta) return np.dtype( - np_dtype_str, + dtype_meta.np_dtype_str, metadata={ "mgb_dtype": { - "name": metadata.name, + "name": dtype_meta.cname, "scale": float(scale), "zero_point": zp, } }, ) else: + # Don't trick to combine with is_unsigned. Metadata should not contain + # invalid field to keep consistent with c dtype. return np.dtype( - np_dtype_str, - metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, + dtype_meta.np_dtype_str, + metadata={"mgb_dtype": {"name": dtype_meta.cname, "scale": float(scale)}}, ) @@ -100,7 +150,7 @@ def quint8(scale, zero_point): ``zero_point`` (uint8). The real value represented by a quint8 data type is float_val = scale * (uint8_val - zero_point) """ - return get_quantized_dtype("quint8", scale, zero_point) + return create_quantized_dtype(_builtin_quant_dtypes["quint8"], scale, zero_point) def qint8(scale): @@ -108,7 +158,7 @@ def qint8(scale): Construct a quantized int8 data type with ``scale`` (float). The real value represented by a qint8 data type is float_val = scale * int8_val """ - return get_quantized_dtype("qint8", scale, None) + return create_quantized_dtype(_builtin_quant_dtypes["qint8"], scale, None) def qint32(scale): @@ -116,7 +166,7 @@ def qint32(scale): Construct a quantized int32 data type with ``scale`` (float). The real value represented by a qint32 data type is float_val = scale * int32_val """ - return get_quantized_dtype("qint32", scale, None) + return create_quantized_dtype(_builtin_quant_dtypes["qint32"], scale, None) def quint4(scale, zero_point): @@ -125,7 +175,7 @@ def quint4(scale, zero_point): ``zero_point`` (uint8). The real value represented by a quint4 data type is float_val = scale * (uint4_val - zero_point) """ - return get_quantized_dtype("quint4", scale, zero_point) + return create_quantized_dtype(_builtin_quant_dtypes["quint4"], scale, zero_point) def qint4(scale): @@ -133,42 +183,48 @@ def qint4(scale): Construct a quantized int4 data type with ``scale`` (float). The real value represented by a qint4 data type is float_val = scale * int4_val """ - return get_quantized_dtype("qint4", scale, None) + return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) -def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): - metadata = _metadata_dict[dtype_str] - arr_metadata = dtype.metadata["mgb_dtype"] +def _convert_to_quantized_dtype( + arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta +): if not isinstance(arr, np.ndarray): raise ValueError("arr parameter should be instance of np.ndarray") - if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: - raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) - is_unsigned = metadata.is_unsigned - if is_unsigned: + if ( + not is_quantize(dtype) + or dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname + ): + raise ValueError("dtype parameter should be a {} dtype".format(dtype_meta)) + arr_metadata = dtype.metadata["mgb_dtype"] + if dtype_meta.is_unsigned: scale, zp = ( arr_metadata["scale"], arr_metadata["zero_point"], ) return ( (np.round(arr / scale) + zp) - .clip(metadata.qmin, metadata.qmax) + .clip(dtype_meta.qmin, dtype_meta.qmax) .astype(dtype) ) else: # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` scale = arr_metadata["scale"] - return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) + return ( + np.round(arr / scale).clip(dtype_meta.qmin, dtype_meta.qmax).astype(dtype) + ) -def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): - metadata = _metadata_dict[dtype_str] - arr_metadata = arr.dtype.metadata["mgb_dtype"] +def _convert_from_quantized_dtype(arr: np.ndarray, dtype_meta: QuantDtypeMeta): if not isinstance(arr, np.ndarray): raise ValueError("arr parameter should be instance of np.ndarray") - if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: - raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) - is_unsigned = metadata.is_unsigned - if is_unsigned: + if ( + not is_quantize(arr.dtype) + or arr.dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname + ): + raise ValueError("arr's dtype should be a {} dtype".format(dtype_meta)) + arr_metadata = arr.dtype.metadata["mgb_dtype"] + if dtype_meta.is_unsigned: scale, zp = ( arr_metadata["scale"], arr_metadata["zero_point"], @@ -187,7 +243,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a quint8. """ - return _convert_to_quantized_dtype(arr, q, "quint8") + return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint8"]) def convert_from_quint8(arr: np.ndarray): @@ -196,7 +252,7 @@ def convert_from_quint8(arr: np.ndarray): :param arr: Input ndarray. """ - return _convert_from_quantized_dtype(arr, "quint8") + return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint8"]) def convert_to_qint8(arr: np.ndarray, q: np.dtype): @@ -206,7 +262,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a qint8. """ - return _convert_to_quantized_dtype(arr, q, "qint8") + return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint8"]) def convert_from_qint8(arr: np.ndarray): @@ -215,7 +271,7 @@ def convert_from_qint8(arr: np.ndarray): :param arr: Input ndarray. """ - return _convert_from_quantized_dtype(arr, "qint8") + return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint8"]) def convert_to_qint32(arr: np.ndarray, q: np.dtype): @@ -225,7 +281,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a qint8. """ - return _convert_to_quantized_dtype(arr, q, "qint32") + return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint32"]) def convert_from_qint32(arr): @@ -234,7 +290,7 @@ def convert_from_qint32(arr): :param arr: Input ndarray. """ - return _convert_from_quantized_dtype(arr, "qint32") + return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint32"]) def convert_to_quint4(arr: np.ndarray, q: np.dtype): @@ -244,7 +300,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a quint4. """ - return _convert_to_quantized_dtype(arr, q, "quint4") + return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint4"]) def convert_from_quint4(arr: np.ndarray): @@ -253,7 +309,7 @@ def convert_from_quint4(arr: np.ndarray): :param arr: Input ndarray. """ - return _convert_from_quantized_dtype(arr, "quint4") + return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint4"]) def convert_to_qint4(arr: np.ndarray, q: np.dtype): @@ -263,7 +319,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype): :param arr: Input ndarray. :param q: Target data type, should be a qint4. """ - return _convert_to_quantized_dtype(arr, q, "qint4") + return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint4"]) def convert_from_qint4(arr: np.ndarray): @@ -272,4 +328,4 @@ def convert_from_qint4(arr: np.ndarray): :param arr: Input ndarray. """ - return _convert_from_quantized_dtype(arr, "qint4") + return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 33a8484cf08c50957468972bfc8f8c01a5f2210f..22b2a5f9355fa340053b2a0c135f8c7c671478a4 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -203,7 +203,7 @@ def conv_transpose2d( assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" if groups != 1: - raise NotImplementedError("TODO") + raise NotImplementedError("group transposed conv2d is not supported yet.") stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 925cb8eea463984fdd042a8b90609553a812a1df..91c6e7d347aeaad2f8ca96755a43764a4ca1bb6a 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -13,7 +13,6 @@ import itertools import json import os import typing -import warnings import weakref import numpy as np diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index df857626f5477c1fc41230b1732cfef16d22e4ce..5293ad29738c3c17749d8619c80971733177e96b 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -5,7 +5,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union @@ -204,10 +203,9 @@ class Module(metaclass=ABCMeta): if "requires_grad" in kwargs: del kwargs["requires_grad"] - warnings.warn( + logger.warning( "Tensor currently has no requires_grad attribute " - "so requires_grad argument is ignored here", - DeprecationWarning, + "so requires_grad argument is ignored here" ) def predicate(obj) -> bool: @@ -232,10 +230,9 @@ class Module(metaclass=ABCMeta): if "requires_grad" in kwargs: del kwargs["requires_grad"] - warnings.warn( + logger.warning( "Tensor currently has no requires_grad attribute " - "so requires_grad argument is ignored here", - DeprecationWarning, + "so requires_grad argument is ignored here" ) def predicate(obj) -> bool: diff --git a/imperative/python/megengine/module/qat/module.py b/imperative/python/megengine/module/qat/module.py index 82b2911cce169937b8377912b23ca4267db3429d..466ba960ddd9af4e50763478cdf2ae81c308bedc 100644 --- a/imperative/python/megengine/module/qat/module.py +++ b/imperative/python/megengine/module/qat/module.py @@ -7,7 +7,10 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import abstractmethod -from ...quantization import FakeQuantize, Observer, QConfig +# avoid circular reference +from ...quantization.fake_quant import FakeQuantize +from ...quantization.observer import Observer +from ...quantization.qconfig import QConfig from ...tensor import Tensor from ..module import Module @@ -73,19 +76,19 @@ class QATModule(Module): # do observer if observer is None: oup = target - q_dict = None + qparams = None else: oup = observer(target) - q_dict = observer.get_qparams() + qparams = observer.get_qparams() # do fake quant if fake_quant is not None: - oup = fake_quant(oup, q_dict) + oup = fake_quant(oup, qparams) # use qparams of fake_quant if have. if hasattr(fake_quant, "get_qparams"): - q_dict = fake_quant.get_qparams() + qparams = fake_quant.get_qparams() # set to tensor qparams. - if q_dict is not None: - oup.q_dict.update(q_dict) + if qparams is not None: + oup.qparams.update(qparams) return oup def apply_quant_weight(self, target: Tensor): @@ -118,7 +121,7 @@ class QATModule(Module): Get weight's quantization dtype as the method from ``qconfig``. """ return self._get_method_result( - "get_dtype", self.weight_fake_quant, self.weight_observer + "get_quantized_dtype", self.weight_fake_quant, self.weight_observer ) def get_activation_dtype(self): @@ -126,7 +129,7 @@ class QATModule(Module): Get activation's quantization dtype as the method from ``qconfig``. """ return self._get_method_result( - "get_dtype", self.act_fake_quant, self.act_observer + "get_quantized_dtype", self.act_fake_quant, self.act_observer ) def get_weight_qparams(self): diff --git a/imperative/python/megengine/quantization/__init__.py b/imperative/python/megengine/quantization/__init__.py index b21b3a636e0bae42ea50000e4fd9a1d108980e6f..0407cd6a58c81f1293301152d33899d8e2730d49 100644 --- a/imperative/python/megengine/quantization/__init__.py +++ b/imperative/python/megengine/quantization/__init__.py @@ -7,8 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from .fake_quant import FakeQuantize -from .internal_fake_quant import * -from .observer import HistogramObserver, Observer +from .observer import Observer from .qconfig import ( QConfig, calibration_qconfig, @@ -20,4 +19,15 @@ from .qconfig import ( sync_ema_fakequant_qconfig, tqt_qconfig, ) -from .utils import QuantMode +from .quantize import ( + apply_easy_quant, + disable_fake_quant, + disable_observer, + enable_fake_quant, + enable_observer, + propagate_qconfig, + quantize, + quantize_qat, + reset_qconfig, +) +from .utils import QParams, QuantMode, create_qparams diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index 15d584db29d080b64a00da5a8b25623ec20ae415..d115c64615e7db3b7c11025a72ac31834db1dcde 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -6,40 +6,48 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import math +from typing import Union from .. import functional as F -from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype +from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes +from ..logger import get_logger from ..module import Module -from ..tensor import Parameter, Tensor -from .utils import QuantMode, fake_quant_tensor, get_qparam_dict, tqt_forward +from ..tensor import Parameter +from .utils import ( + QParams, + QParamsModuleMixin, + QuantMode, + create_qparams, + fake_quant_tensor, + tqt_forward, +) +logger = get_logger(__name__) -class _FakeQuantize(Module): - r""" - A Basic Fake Quant module. - - :param dtype: a string indicating the target quantization type of input. - :param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``, - instead of 1 greater. Usually True for weight and False for activation. - :param enable: whether do ``normal_forward`` or ``fake_quant_forward``. - """ +class _FakeQuantize(Module): def __init__( - self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs + self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs ): super().__init__() - if not dtype in _metadata_dict.keys(): - raise ValueError( - "unknown dtype: {}, only support {}".format( - dtype, _metadata_dict.keys() + if isinstance(dtype, str): + if not dtype in _builtin_quant_dtypes: + raise ValueError( + "unknown dtype: {}, only support {}".format( + dtype, _builtin_quant_dtypes.keys() + ) ) + dtype = _builtin_quant_dtypes[dtype] + if "narrow_range" in kwargs: + del kwargs["narrow_range"] + logger.warning( + "FakeQuantize currently has no narrow_range param " + "so it is ignored here", + exc_info=DeprecationWarning, ) self.dtype = dtype - self.narrow_range = narrow_range - self.qmin = ( - -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin - ) - self.qmax = _metadata_dict[dtype].qmax + self.qmin = dtype.qmin + self.qmax = dtype.qmax self.enabled = enable def enable(self): @@ -48,61 +56,64 @@ class _FakeQuantize(Module): def disable(self): self.enabled = False - def fake_quant_forward(self, inp, q_dict=None): - return inp + def fake_quant_forward(self, inp, qparams: QParams = None): + raise NotImplementedError - def normal_foward(self, inp, q_dict=None): + def normal_foward(self, inp, qparams: QParams = None): return inp - def forward(self, inp, q_dict=None): + def forward(self, inp, qparams: QParams = None): if self.enabled: - return self.fake_quant_forward(inp, q_dict=q_dict) + return self.fake_quant_forward(inp, qparams=qparams) else: - return self.normal_foward(inp, q_dict=q_dict) + return self.normal_foward(inp, qparams=qparams) -class TQT(_FakeQuantize): +class TQT(_FakeQuantize, QParamsModuleMixin): r""" TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. + + :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target + quantization dtype of input. + :param enable: whether do ``normal_forward`` or ``fake_quant_forward``. """ def __init__( - self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs + self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs ): - super().__init__(dtype, narrow_range, enable, **kwargs) + super().__init__(dtype, enable, **kwargs) self.scale = Parameter(0.0, dtype="float32") - def fake_quant_forward(self, inp, q_dict=None): + def fake_quant_forward(self, inp, qparams: QParams = None): # when enable, TQT will do fakequant forward, finetune the scale return tqt_forward(self.qmin, self.qmax, inp, self.scale) - def get_qparams(self): - q_dict = get_qparam_dict(QuantMode.SYMMERTIC) - q_dict["scale"] = 2 ** self.scale.detach() - return q_dict - - def set_qparams(self, q_dict): + def set_qparams(self, qparams: QParams): assert ( - q_dict["mode"] == QuantMode.SYMMERTIC + qparams.mode == QuantMode.SYMMERTIC ), "only symmetric quantization is supported by TQT" - if "scale" not in q_dict or q_dict["scale"] is None: + if qparams.scale is None: raise AssertionError("Can not get an initialized scale") - self.scale._reset(F.log(q_dict["scale"]) / math.log(2)) + self.scale[...] = F.log(qparams.scale) / math.log(2) - def get_dtype(self): - q_dict = self.get_qparams() - scale = None if "scale" not in q_dict else q_dict["scale"].numpy() - zero_point = ( - None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() - ) - return get_quantized_dtype(self.dtype, scale, zero_point) + def get_qparams(self): + return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale) class FakeQuantize(_FakeQuantize): r""" A module to do quant and dequant according to observer's scale and zero_point. + + :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target + quantization dtype of input. + :param enable: whether do ``normal_forward`` or ``fake_quant_forward``. """ - def fake_quant_forward(self, inp, q_dict=None): - return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict) + def fake_quant_forward(self, inp, qparams: QParams = None): + assert ( + qparams.dtype_meta is self.dtype + ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( + qparams.dtype_meta, self.dtype + ) + return fake_quant_tensor(inp, qparams) diff --git a/imperative/python/megengine/quantization/internal_fake_quant.py b/imperative/python/megengine/quantization/internal_fake_quant.py index 013ddde73e87bffe502d5e4615a59723d237ad5c..a46b40d86f0e42d4cd2f96d2094a094cb6f74221 100644 --- a/imperative/python/megengine/quantization/internal_fake_quant.py +++ b/imperative/python/megengine/quantization/internal_fake_quant.py @@ -16,4 +16,6 @@ from ..autodiff import Function from .fake_quant import _FakeQuantize from .observer import MinMaxObserver from .qconfig import QConfig +from .utils import QParams + diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index bc4f0cf5f79236f85a45f60ba1c056e1bcf42c1b..dbeaf821b0b744cfbae347ae4f61a912c0a3fa71 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -8,51 +8,51 @@ import math from abc import abstractmethod from copy import deepcopy +from typing import Union import numpy as np from .. import functional as F -from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype +from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes from ..distributed import WORLD, get_rank, is_distributed from ..functional.distributed import all_reduce_max, all_reduce_min +from ..logger import get_logger from ..module import Module from ..tensor import Tensor -from .utils import QuantMode, get_qparam_dict +from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams +logger = get_logger(__name__) -class Observer(Module): + +class Observer(Module, QParamsModuleMixin): r""" A base class for Observer Module. :param dtype: a string indicating to collect scale and zero_point of which dtype. - :param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``, - instead of 1 greater. Usually True for weight and False for activation. """ - def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): + def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): super().__init__() - if dtype not in _metadata_dict.keys(): - raise ValueError( - "unknown dtype: {}, only support {}".format( - dtype, _metadata_dict.keys() + if isinstance(dtype, str): + if not dtype in _builtin_quant_dtypes: + raise ValueError( + "unknown dtype: {}, only support {}".format( + dtype, _builtin_quant_dtypes.keys() + ) ) + dtype = _builtin_quant_dtypes[dtype] + if "narrow_range" in kwargs: + del kwargs["narrow_range"] + logger.warning( + "FakeQuantize currently has no narrow_range param " + "so it is ignored here", + exc_info=DeprecationWarning, ) self.dtype = dtype - self.narrow_range = narrow_range - self.qmin = ( - -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin - ) - self.qmax = _metadata_dict[dtype].qmax + self.qmin = dtype.qmin + self.qmax = dtype.qmax self.enabled = True - def get_dtype(self): - q_dict = self.get_qparams() - numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy() - numpy_zero_point = ( - None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() - ) - return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) - def enable(self): self.enabled = True @@ -70,21 +70,16 @@ class Observer(Module): def forward(self, x): pass - @abstractmethod - def get_qparams(self, **kwargs): - pass - class MinMaxObserver(Observer): def __init__( self, - mode=QuantMode.SYMMERTIC, - eps=0.00001, - dtype="qint8", - narrow_range: bool = False, + mode: QuantMode = QuantMode.SYMMERTIC, + eps: float = 0.00001, + dtype: Union[str, QuantDtypeMeta] = "qint8", **kwargs ): - super().__init__(dtype, narrow_range, **kwargs) + super().__init__(dtype, **kwargs) self.mode = mode self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) @@ -93,26 +88,22 @@ class MinMaxObserver(Observer): def _calculate_qparams(self, inp_min_val, inp_max_val): min_val = F.minimum(0.0, inp_min_val) max_val = F.maximum(0.0, inp_max_val) - q_dict = get_qparam_dict(self.mode) - q_dict["min_val"] = inp_min_val - q_dict["max_val"] = inp_max_val - q_dict["enable_observer"] = self.enable if self.mode == QuantMode.SYMMERTIC: symmetric_max_vals = F.maximum(-min_val, max_val) # use maximun to avoid scale too small at the begin - q_dict["scale"] = F.maximum( + scale = F.maximum( symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit ) - # zero_point = self.zero_point + zero_point = None else: # use maximun to avoid scale too small at the begin - q_dict["scale"] = F.maximum( + scale = F.maximum( (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit ) # caculate zero_point - q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"]) + zero_point = self.qmin - F.round((min_val / scale)) - return q_dict + return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point) def get_qparams(self): return self._calculate_qparams(self.min_val, self.max_val) @@ -122,8 +113,8 @@ class MinMaxObserver(Observer): # stop gradient x = x_orig.detach() # find max and min - self.min_val._reset(F.minimum(self.min_val, x.min())) - self.max_val._reset(F.maximum(self.max_val, x.max())) + self.min_val[...] = F.minimum(self.min_val, x.min()) + self.max_val[...] = F.maximum(self.max_val, x.max()) return x_orig @@ -137,42 +128,43 @@ class SyncMinMaxObserver(MinMaxObserver): else: min_x = x.min() max_x = x.max() - self.min_val._reset(F.minimum(self.min_val, min_x)) - self.max_val._reset(F.maximum(self.max_val, max_x)) + self.min_val[...] = F.minimum(self.min_val, min_x) + self.max_val[...] = F.maximum(self.max_val, max_x) return x_orig class ExponentialMovingAverageObserver(MinMaxObserver): def __init__( self, - momentum=0.9, - mode=QuantMode.SYMMERTIC, - eps=0.00001, - dtype="qint8", - narrow_range: bool = False, + momentum: float = 0.9, + mode: QuantMode = QuantMode.SYMMERTIC, + eps: float = 0.00001, + dtype: Union[str, QuantDtypeMeta] = "qint8", **kwargs ): - super().__init__(mode, eps, dtype, narrow_range, **kwargs) + super().__init__(mode, eps, dtype, **kwargs) self.momentum = Tensor(momentum, dtype="float32") + # used to avoid if-clauses in the first forward which is not supported + # in trace mode. self.runtime_momentum = Tensor(0.0) def set_momentum(self, momentum): - self.momentum = Tenosr(momentum, dtype="float32") + self.momentum = Tensor(momentum, dtype="float32") def forward(self, x_orig): if self.enabled: # stop gradient x = x_orig.detach() # Exponential Moving Average - self.min_val._reset( + self.min_val[...] = ( self.min_val * self.runtime_momentum + (1 - self.runtime_momentum) * x.min() ) - self.max_val._reset( + self.max_val[...] = ( self.max_val * self.runtime_momentum + (1 - self.runtime_momentum) * x.max() ) - self.runtime_momentum = self.momentum + self.runtime_momentum[...] = self.momentum return x_orig @@ -187,33 +179,34 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): else: min_x = x.min() max_x = x.max() - self.min_val._reset( + self.min_val[...] = ( self.min_val * self.runtime_momentum + (1 - self.runtime_momentum) * min_x ) - self.max_val._reset( + self.max_val[...] = ( self.max_val * self.runtime_momentum + (1 - self.runtime_momentum) * max_x ) - self.runtime_momentum = self.momentum + self.runtime_momentum[...] = self.momentum return x_orig class HistogramObserver(MinMaxObserver): def __init__( self, - bins=2048, - upsample_rate=128, - mode=QuantMode.SYMMERTIC, - eps=0.00001, - dtype="qint8", - narrow_range: bool = False, + bins: int = 2048, + upsample_rate: int = 128, + mode: QuantMode = QuantMode.SYMMERTIC, + eps: float = 0.00001, + dtype: Union[str, QuantDtypeMeta] = "qint8", **kwargs ): - super().__init__(mode, eps, dtype, narrow_range, **kwargs) + super().__init__(mode, eps, dtype, **kwargs) self.bins = bins self.upsample_rate = upsample_rate - self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 + self.dst_nbins = ( + _builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1 + ) self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32") def _non_linear_param_search(self): @@ -450,34 +443,45 @@ class HistogramObserver(MinMaxObserver): class PassiveObserver(Observer): r""" - This class can be set :attr:`scale` derectly. + An Observer that supports setting :attr:`scale` directly. """ - def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): - super().__init__(dtype, narrow_range, **kwargs) - self.q_dict = None + def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): + super().__init__(dtype, **kwargs) + self.qparams = None self.orig_scale = None @property def scale(self): - return self.q_dict["scale"] + return self.qparams.scale @scale.setter - def scale(self, value): - assert value > 0 - self.q_dict["scale"][...] = Tensor(value) + def scale(self, value: np.ndarray): + assert np.all(value > 0) + self.qparams.scale[...] = Tensor(value) def get_qparams(self): - return self.q_dict + return self.qparams - def set_qparams(self, q_dict): - self.q_dict = deepcopy(q_dict) - if "scale" not in q_dict or q_dict["scale"] is None: + def set_qparams(self, qparams: QParams): + """ + :param qparams: used to set initial scale. + """ + self.qparams = deepcopy(qparams) + if qparams.scale is None: raise AssertionError("Can not get an initialized scale") - self.orig_scale = q_dict["scale"].numpy() + if qparams.dtype_meta is None: + qparams.dtype_meta = self.dtype + else: + assert ( + qparams.dtype_meta is self.dtype + ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( + qparams.dtype_meta, self.dtype + ) + self.orig_scale = qparams.scale.numpy() def forward(self, x): r""" - Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. + Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`. """ return x diff --git a/imperative/python/megengine/quantization/qconfig.py b/imperative/python/megengine/quantization/qconfig.py index a699b4b1d85043ec77d75d7d632a77e175b31882..5050045c2f62f852fe48936e02c21aba28b09f4a 100644 --- a/imperative/python/megengine/quantization/qconfig.py +++ b/imperative/python/megengine/quantization/qconfig.py @@ -5,6 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from collections import namedtuple from functools import partial from ..module import Module @@ -19,7 +20,13 @@ from .observer import ( ) -class QConfig: +# use namedtuple to make class immutable, comparable and easy to print +class QConfig( + namedtuple( + "QConfig", + ["weight_observer", "act_observer", "weight_fake_quant", "act_fake_quant"], + ) +): r""" A config class indicating how to do quantize toward :class:`~.QATModule`'s ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. @@ -37,90 +44,66 @@ class QConfig: # Default EMA QConfig for QAT. ema_fakequant_qconfig = QConfig( - weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), - act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False), - weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), - act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), + weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), + act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"), + weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), + act_fake_quant=partial(FakeQuantize, dtype="qint8"), ) Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` to add initialization parameters of the ``class``, so that don't need to provide parameters in :meth:`~.QATModule.set_qconfig`. - Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related - parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if - four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow. - Weights are commonly calculated in this way, so needed to narrow the range. + Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related + paramters and normal version for activation related ones. For the result of + multiplication and addition as ``a * b + c * d``, if four variables are all -128 of + dtype ``qint8``, then the result will be ``2^15`` and cause overflow. + Weights are commonly calculated in this way, so need to narrow qmin to -127. """ - def __init__( - self, weight_observer, act_observer, weight_fake_quant, act_fake_quant - ): + def __new__(cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant): if isinstance(act_observer, Module) or isinstance(weight_observer, Module): raise ValueError( "QConfig must not receive observer instance, please pass observer" " class generator using `partial(Observer, ...)` instead. Use" " partial(MyObserver, x=1) to override arguments to constructor if needed" ) - self.weight_observer = weight_observer - self.act_observer = act_observer - self.weight_fake_quant = weight_fake_quant - self.act_fake_quant = act_fake_quant - - def __eq__(self, other): - def eq(a, b): - if isinstance(a, partial) and isinstance(b, partial): - return all( - [a.func == b.func, a.args == b.args, a.keywords == b.keywords] - ) - else: - return a == b - - return ( - eq(self.weight_observer, other.weight_observer) - and eq(self.act_observer, other.act_observer) - and eq(self.weight_fake_quant, other.weight_fake_quant) - and eq(self.act_fake_quant, other.act_fake_quant) + return super().__new__( + cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant ) min_max_fakequant_qconfig = QConfig( - weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), - act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), - weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), - act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), + weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), + act_observer=partial(MinMaxObserver, dtype="qint8"), + weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), + act_fake_quant=partial(FakeQuantize, dtype="qint8"), ) ema_fakequant_qconfig = QConfig( - weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), - act_observer=partial( - ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False - ), - weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), - act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), + weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), + act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"), + weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), + act_fake_quant=partial(FakeQuantize, dtype="qint8"), ) sync_ema_fakequant_qconfig = QConfig( - weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True), - act_observer=partial( - SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False - ), - weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), - act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), + weight_observer=partial(SyncMinMaxObserver, dtype="qint8_narrow"), + act_observer=partial(SyncExponentialMovingAverageObserver, dtype="qint8"), + weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), + act_fake_quant=partial(FakeQuantize, dtype="qint8"), ) ema_lowbit_fakequant_qconfig = QConfig( - weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), - act_observer=partial( - ExponentialMovingAverageObserver, dtype="qint4", narrow_range=False - ), - weight_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), - act_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), + weight_observer=partial(MinMaxObserver, dtype="qint4"), + act_observer=partial(ExponentialMovingAverageObserver, dtype="qint4"), + weight_fake_quant=partial(FakeQuantize, dtype="qint4"), + act_fake_quant=partial(FakeQuantize, dtype="qint4"), ) calibration_qconfig = QConfig( - weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), - act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False), + weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), + act_observer=partial(HistogramObserver, dtype="qint8"), weight_fake_quant=None, act_fake_quant=None, ) @@ -128,15 +111,15 @@ calibration_qconfig = QConfig( tqt_qconfig = QConfig( weight_observer=None, act_observer=None, - weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), - act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), + weight_fake_quant=partial(TQT, dtype="qint8_narrow"), + act_fake_quant=partial(TQT, dtype="qint8"), ) passive_qconfig = QConfig( - weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True), - act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False), - weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), - act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), + weight_observer=partial(PassiveObserver, dtype="qint8_narrow"), + act_observer=partial(PassiveObserver, dtype="qint8"), + weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), + act_fake_quant=partial(FakeQuantize, dtype="qint8"), ) easyquant_qconfig = passive_qconfig diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index 730947c9027ebfbc4c81ba7372706acec006f434..09fa9012abc3ea298f8902b6ddb5d90cc344fa31 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -18,6 +18,7 @@ from ..module import qat as QAT from ..module import quantized as Quantized from ..module.qat import QATModule from ..module.quantized import QuantizedModule +from ..tensor import Tensor from .qconfig import QConfig, ema_fakequant_qconfig @@ -147,10 +148,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): if not inplace: module = deepcopy(module) - def safe_call(func, q_dict): + def safe_call(func, qparams): inst = func() if func is not None else None if inst is not None and getattr(inst, "set_qparams", None) is not None: - inst.set_qparams(q_dict) + inst.set_qparams(qparams) return inst def is_qat(mod: Module): @@ -158,13 +159,13 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): for m in list(module._flatten(predicate=is_qat)): if m.with_weight: - weight_q_dict = m.get_weight_qparams() - m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict) - m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict) + weight_params = m.get_weight_qparams() + m.weight_observer = safe_call(qconfig.weight_observer, weight_params) + m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_params) if m.with_act: - act_q_dict = m.get_activation_qparams() - m.act_observer = safe_call(qconfig.act_observer, act_q_dict) - m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict) + act_params = m.get_activation_qparams() + m.act_observer = safe_call(qconfig.act_observer, act_params) + m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_params) return module @@ -202,7 +203,9 @@ def hook_qat_module(module: Module, func: Callable): return hooks -def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): +def apply_easy_quant( + module: Module, data: Tensor, start: float = 0.8, stop: float = 1.2, num: int = 40 +): r""" Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669. Search for optimal scales. diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index de1bb0fc3ba42a3bc5dbf41089e3835702559bd9..127a134c2bcf368603834b93b0ce2986e2dd3c9d 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -5,9 +5,10 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import abc from enum import Enum from functools import partial, update_wrapper, wraps -from typing import Dict +from typing import Union import numpy as np @@ -15,7 +16,11 @@ from .. import functional as F from ..autodiff import Function from ..core._imperative_rt.core2 import apply from ..core.ops import builtin -from ..core.tensor.dtype import _metadata_dict +from ..core.tensor.dtype import ( + QuantDtypeMeta, + _builtin_quant_dtypes, + create_quantized_dtype, +) from ..tensor import Tensor @@ -61,37 +66,100 @@ class QuantMode(Enum): ASYMMERTIC = 2 -qparam_dict = { - QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None}, - QuantMode.ASYMMERTIC: { - "mode": QuantMode.ASYMMERTIC, - "scale": None, - "zero_point": None, - }, +class QParams: + """ + To standardize FakeQuant, Observer and Tensor's qparams format. If custom + qparams is needed, inherit this class and add custom ``__slots__``. + """ + + __slots__ = "mode", "dtype_meta", "scale", "zero_point" + + def __init__( + self, + mode: QuantMode, + dtype_meta: QuantDtypeMeta, + scale: Tensor, + zero_point: Tensor, + ): + self.mode = mode + self.dtype_meta = dtype_meta + self.scale = scale + self.zero_point = zero_point + + def update(self, qparams: "QParams"): + for key in self.__slots__: + setattr(self, key, getattr(qparams, key)) + + def __eq__(self, other): + if len(self.__slots__) != len(other.__slots__): + return False + for key in self.__slots__: + if not hasattr(other, key) or getattr(self, key) != getattr(other, key): + return False + return True + + def __repr__(self): + content = ", ".join( + ["{}={}".format(key, getattr(self, key)) for key in self.__slots__] + ) + return "QParams({})".format(content) + + +class QParamsModuleMixin(abc.ABC): + def get_quantized_dtype(self): + qparams = self.get_qparams() + dtype = qparams.dtype_meta + scale = float(qparams.scale.numpy()) if qparams.scale is not None else None + zero_point = ( + int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None + ) + return create_quantized_dtype(dtype, scale, zero_point) + + @abc.abstractmethod + def get_qparams(self) -> QParams: + pass + + +_builtin_qparams = { + QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC), + QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC), } -def get_qparam_dict(mode: QuantMode): +def create_qparams( + mode: QuantMode = QuantMode.SYMMERTIC, + dtype_meta: Union[str, QuantDtypeMeta] = None, + scale: Tensor = None, + zero_point: Tensor = None, +): """ - Return the quantization parameters dictionary according to the mode. + Return :class:`~.QParams` according to the mode. """ - return qparam_dict.get(mode, None) + if isinstance(dtype_meta, str): + dtype_meta = _builtin_quant_dtypes[dtype_meta] + if mode is None: + return QParams(mode, dtype_meta, scale, zero_point) + assert isinstance(mode, QuantMode) + return _builtin_qparams[mode]( + dtype_meta=dtype_meta, scale=scale, zero_point=zero_point + ) -def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor: +def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor: """ Apply fake quantization to the inp tensor. :param inp: the input tensor which need to be faked. - :param qmin: the minimum value which the integer limit to. - :param qmax: the maximum value which the integer limit to. - :param q_dict: the quantization parameter dict. + :param qparams: to get mode, qmin, qmax, scale and zero_point from. """ - scale = q_dict["scale"] - zero_point = Tensor([0.0], dtype=np.float32) - if q_dict["mode"] == QuantMode.ASYMMERTIC: - zero_point = q_dict["zero_point"] + scale = qparams.scale + if qparams.mode == QuantMode.ASYMMERTIC: + zero_point = qparams.zero_point + else: + zero_point = Tensor([0.0], dtype=np.float32) + qmin = qparams.dtype_meta.qmin + qmax = qparams.dtype_meta.qmax op = builtin.FakeQuant(qmin=qmin, qmax=qmax) return apply(op, inp, scale, zero_point)[0] @@ -104,22 +172,34 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: :param bias: the bias tensor which need to be faked. :param inp: the input tensor which contain the quantization parameters. - :param qmax: the weight tensor which contain the quantization parameters. + :param w_qat: the weight tensor which contain the quantization parameters. .. warning:: Only work for symmetric quantization method now. """ b_qat = bias - if hasattr(inp, "q_dict") and b_qat is not None: - if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None: - # use the same mode with weight. - b_dict = get_qparam_dict(w_qat.q_dict["mode"]) - b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"] - # TODO: add zero_point for ASYMMERTIC mode. - qmax = _metadata_dict["qint32"].qmax - qmin = _metadata_dict["qint32"].qmin - b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) - b_qat.q_dict.update(b_dict) + if ( + getattr(inp, "qparams", None) is not None + and getattr(w_qat, "qparams", None) is not None + and bias is not None + ): + inp_params = inp.qparams + w_params = w_qat.qparams + if inp_params.scale is not None and w_params.scale is not None: + assert inp_params.mode == w_params.mode, "incompatible QuantMode" + # TODO: support quint8 dtype. + assert ( + inp_params.dtype_meta.np_dtype_str == "int8" + and w_params.dtype_meta.np_dtype_str == "int8" + ), "fake_quant_bias only support int8 like dtype now" + # use the same mode with weight. + # TODO: avoid hardcode + b_dtype = _builtin_quant_dtypes["qint32"] + b_param = create_qparams( + w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale + ) + b_qat = fake_quant_tensor(bias, b_param) + b_qat.qparams.update(b_param) return b_qat diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 5c94189371c35ab8f8bc90c3102482d40b5df023..8f6072271d42f93b607de5c8863f4e25baaca8ef 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -22,6 +22,8 @@ from .logger import get_logger from .utils.deprecation import deprecated from .utils.naming import auto_naming +logger = get_logger(__name__) + class Tensor(_Tensor, ArrayMethodMixin): r""" @@ -30,7 +32,7 @@ class Tensor(_Tensor, ArrayMethodMixin): grad = None dmap_callback = None - _q_dict = None + _qparams = None def __new__( cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None @@ -50,7 +52,7 @@ class Tensor(_Tensor, ArrayMethodMixin): if isinstance(data, _Tensor): if dtype is not None: - get_logger().warning( + logger.warning( "dtype does not work when creating a new Tensor with another Tensor" ) obj = _Tensor.__new__(cls, data) @@ -101,10 +103,12 @@ class Tensor(_Tensor, ArrayMethodMixin): return super().dtype @property - def q_dict(self): - if self._q_dict is None: - self._q_dict = {"mode": None, "scale": None, "zero_point": None} - return self._q_dict + def qparams(self): + from .quantization.utils import create_qparams # pylint: disable=all + + if self._qparams is None: + self._qparams = create_qparams() + return self._qparams def numpy(self) -> np.ndarray: r""" @@ -185,14 +189,29 @@ class Tensor(_Tensor, ArrayMethodMixin): def __getstate__(self): r""" __getstate__ will be called for pickle serialization or deep copy """ - state = { - "qdict": self.q_dict, + "numpy": self.numpy(), + "dtype": self.dtype, + "device": self.device.logical_name, } + if self._qparams is not None: + state["qparams"] = self._qparams return state def __setstate__(self, state): - self._q_dict = state.pop("qdict") + from .quantization.utils import create_qparams # pylint: disable=all + + if "qdict" in state: + qparams = state.pop("qdict") + logger.warning( + "Tensor's 'qdict' state is depreciated. Use 'qparams' instead" + ) + elif "qparams" in state: + qparams = state.pop("qparams") + else: + qparams = None + self._reset(Tensor(state.pop("numpy"), state.pop("dtype"), state.pop("device"))) + self._qparams = qparams tensor = Tensor diff --git a/imperative/python/test/unit/core/test_dtype_quant.py b/imperative/python/test/unit/core/test_dtype_quant.py index 20e53dd7bf7b98fee67406592f162fca3e308017..594d86e5abd3b3c3820785d391886d503dce320f 100644 --- a/imperative/python/test/unit/core/test_dtype_quant.py +++ b/imperative/python/test/unit/core/test_dtype_quant.py @@ -14,7 +14,7 @@ import pytest import megengine.core.tensor.megbrain_graph as G from megengine.core.ops import builtin as ops from megengine.core.tensor.dtype import ( - _metadata_dict, + _builtin_quant_dtypes, convert_from_qint4, convert_from_qint8, convert_from_quint4, @@ -76,10 +76,10 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None): def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True): - metadata = _metadata_dict[dtype_str] + metadata = _builtin_quant_dtypes[dtype_str] assert "mgb_dtype" in oup.dtype.metadata assert is_quantize(oup.dtype) - np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.name) + np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.cname) np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype)) if is_unsigned: np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype)) diff --git a/imperative/python/test/unit/core/test_serialization.py b/imperative/python/test/unit/core/test_serialization.py index 72e79c5c81a903012bce177292ecf2c4a6eacd34..edd5e6f7172401e239d0b18e5c4d9c51fa0e3bb4 100644 --- a/imperative/python/test/unit/core/test_serialization.py +++ b/imperative/python/test/unit/core/test_serialization.py @@ -65,9 +65,9 @@ def test_tensor_serialization(): with TemporaryFile() as f: a = Tensor(0) - a.q_dict["scale"] = Tensor(1.0) + a.qparams.scale = Tensor(1.0) pickle.dump(a, f) f.seek(0) b = pickle.load(f) - assert isinstance(b.q_dict["scale"], Tensor) - np.testing.assert_equal(b.q_dict["scale"].numpy(), 1.0) + assert isinstance(b.qparams.scale, Tensor) + np.testing.assert_equal(b.qparams.scale.numpy(), 1.0) diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index 099dcccae6b1e1d47495b2ae478e13b69d131c1b..52fc08f5dab71fdbad9176288061161a5df994ea 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -6,6 +6,8 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy + import numpy as np from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 @@ -86,3 +88,23 @@ def test_as_type(): b = a.astype(quint8(0.3, 128)) np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) np.testing.assert_equal(get_zero_point(b.dtype), 128) + + +def test_qparams(): + x = Tensor(1) + assert x.qparams.scale is None + x.qparams.scale = Tensor(1.0) + assert x.qparams.scale.numpy() == 1.0 + x2 = copy.copy(x) + assert x.qparams is x2.qparams and x2.qparams.scale.numpy() == 1.0 + x3 = copy.deepcopy(x) + assert x.qparams is not x3.qparams and x3.qparams.scale.numpy() == 1.0 + + +def test_name(): + x = Tensor(0) + assert x.name == "" + x.name = "x" + assert x.name == "x" + x = Tensor(0, name="x") + assert x.name == "x" diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index b2898c0fab029dee3980e7f0d33a66ba72b381f5..da6825ff30f2749070079cdc128b05dee70f3439 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -406,28 +406,3 @@ def test_copy_d2h(): def test_copy_d2d(): copy_test("gpu0", "gpu1") copy_test("gpu0:0", "gpu0:1") - - -def test_name(): - x = tensor(0) - assert x.name == "" - x.name = "x" - assert x.name == "x" - x = tensor(0, name="x") - assert x.name == "x" - - -def test_q_dict(): - x = tensor(1) - assert x.q_dict["scale"] is None - x.q_dict["scale"] = tensor(1.0) - - y = tensor(1) - assert y.q_dict["scale"] is None - y.q_dict["scale"] = tensor(2.0) - - assert x.q_dict["scale"].numpy() == 1.0 - assert y.q_dict["scale"].numpy() == 2.0 - - z = x + y - assert z.q_dict["scale"] is None diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index a11edff38eb1d5748bd2fdbbb1dcb2b76b005eda..ccb2e570438e021a2c57ba4420ca15481ddcc2c7 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -12,9 +12,15 @@ import pytest import megengine as mge from megengine import tensor from megengine.core.autodiff.grad import Function, Grad +from megengine.core.tensor.dtype import QuantDtypeMeta from megengine.core.tensor.utils import make_shape_tuple from megengine.quantization.internal_fake_quant import * -from megengine.quantization.utils import QuantMode, fake_quant_tensor, tqt_forward +from megengine.quantization.utils import ( + QuantMode, + create_qparams, + fake_quant_tensor, + tqt_forward, +) class TQT_numpy: @@ -111,16 +117,14 @@ def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax): def test_fakequant(): qmin = -126 qmax = 129 + test_dtype = QuantDtypeMeta("test_qint8", None, "int8", qmin, qmax) def run(zero_point, scale): - q_dict = {} - q_dict["mode"] = QuantMode.ASYMMERTIC - q_dict["scale"] = scale - q_dict["zero_point"] = zero_point + qparams = create_qparams(QuantMode.ASYMMERTIC, test_dtype, scale, zero_point) inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32)) inp = tensor(inp_data, dtype=np.float32) # test forward - oup = fake_quant_tensor(inp, qmin, qmax, q_dict).numpy() + oup = fake_quant_tensor(inp, qparams).numpy() oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy() assert np.allclose(oup, oup_gt) assert oup.shape == oup_gt.shape @@ -128,7 +132,7 @@ def test_fakequant(): # test backward x = tensor(inp_data, dtype=np.float32) grad = Grad().wrt(x, callback=_save_to(x)) - y = fake_quant_tensor(x, qmin, qmax, q_dict) + y = fake_quant_tensor(x, qparams) grad(y, tensor(F.ones_like(x))) x1 = tensor(inp_data, dtype=np.float32) diff --git a/imperative/python/test/unit/quantization/test_module.py b/imperative/python/test/unit/quantization/test_module.py index 12e0d27c4d758376c2fa191e928bb8005b7467b0..cfdf03448a6777f4ba627abbccda487694118d4d 100644 --- a/imperative/python/test/unit/quantization/test_module.py +++ b/imperative/python/test/unit/quantization/test_module.py @@ -10,7 +10,13 @@ import megengine.module.qat as QAT import megengine.module.quantized as Q from megengine import Parameter, Tensor from megengine.core.tensor import dtype -from megengine.quantization import FakeQuantize, MinMaxObserver, QConfig +from megengine.quantization import ( + FakeQuantize, + MinMaxObserver, + QConfig, + QuantMode, + create_qparams, +) from megengine.quantization.quantize import ( disable_fake_quant, disable_observer, @@ -18,10 +24,10 @@ from megengine.quantization.quantize import ( ) min_max_fakequant_qconfig = QConfig( - weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), - act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), - weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), - act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), + weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), + act_observer=partial(MinMaxObserver, dtype="qint8"), + weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), + act_fake_quant=partial(FakeQuantize, dtype="qint8"), ) inp_scale = np.float32(np.random.rand() + 1) @@ -111,7 +117,7 @@ def test_dequant_stub(): x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = fake_quant_act(x, inp_scale) - x.q_dict["scale"] = inp_scale + x.qparams.scale = inp_scale normal = normal_net(x) qat_without_fakequant = qat_from_float(x) @@ -146,12 +152,12 @@ def test_elemwise(kind): x1_scale = np.float32(np.random.rand() + 1) x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x1 = fake_quant_act(x1, x1_scale) - x1.q_dict["scale"] = x1_scale + x1.qparams.scale = x1_scale x2_scale = np.float32(np.random.rand() + 1) x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x2 = fake_quant_act(x2, x2_scale) - x2.q_dict["scale"] = x2_scale + x2.qparams.scale = x2_scale x1_int8 = quant(x1, x1_scale) x2_int8 = quant(x2, x2_scale) @@ -187,7 +193,7 @@ def test_linear(): x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = fake_quant_act(x, inp_scale) - x.q_dict["scale"] = inp_scale + x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) x_int8 = quant(x, inp_scale) @@ -230,7 +236,7 @@ def test_conv(module): x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) x = fake_quant_act(x, inp_scale) - x.q_dict["scale"] = inp_scale + x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) x_int8 = quant(x, inp_scale) diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py index 33e4f964b29c0b17278159c1b74f3ade79ebca74..bd306244cb68d50f7c50df919322c9c31f8d57c5 100644 --- a/imperative/python/test/unit/quantization/test_observer.py +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -6,6 +6,7 @@ import pytest import megengine as mge import megengine.distributed as dist from megengine.distributed.helper import get_device_count_by_fork +from megengine.quantization import QuantMode, create_qparams from megengine.quantization.observer import ( ExponentialMovingAverageObserver, HistogramObserver, @@ -56,14 +57,14 @@ def test_histogram_observer(): def test_passive_observer(): - q_dict = {"scale": mge.tensor(1.0)} + qparams = create_qparams(QuantMode.SYMMERTIC, "qint8", mge.tensor(1.0)) m = PassiveObserver("qint8") - m.set_qparams(q_dict) + m.set_qparams(qparams) assert m.orig_scale == 1.0 - assert m.scale == 1.0 - m.scale = 2.0 - assert m.scale == 2.0 - assert m.get_qparams() == {"scale": mge.tensor(2.0)} + assert m.scale.numpy() == 1.0 + assert m.get_qparams().dtype_meta == qparams.dtype_meta + assert m.get_qparams().scale == qparams.scale + assert m.get_qparams() == qparams @pytest.mark.require_ngpu(2) diff --git a/imperative/python/test/unit/quantization/test_op.py b/imperative/python/test/unit/quantization/test_op.py index 20e5b2fc608db5713f44a3b72e56ffefdad31f67..2095565622794ef07af84649fd9711ddb4ac4155 100644 --- a/imperative/python/test/unit/quantization/test_op.py +++ b/imperative/python/test/unit/quantization/test_op.py @@ -6,6 +6,7 @@ import megengine.functional as F from megengine.core.tensor import dtype from megengine.distributed.helper import get_device_count_by_fork from megengine.functional.elemwise import _elemwise_multi_type, _elwise +from megengine.quantization import QuantMode, create_qparams def quant(x, scale): @@ -26,13 +27,13 @@ def test_elemwise(kind): x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x1_scale = np.float32(np.random.rand() + 1) x1 = fake_quant(x1, x1_scale) - x1.q_dict["scale"] = x1_scale + x1.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x1_scale)) x1_int8 = quant(x1, x1_scale) x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x2_scale = np.float32(np.random.rand() + 1) x2 = fake_quant(x2, x2_scale) - x2.q_dict["scale"] = x2_scale + x2.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x2_scale)) x2_int8 = quant(x2, x2_scale) output_scale = np.float32(np.random.rand() + 1) diff --git a/imperative/python/test/unit/quantization/test_qconfig.py b/imperative/python/test/unit/quantization/test_qconfig.py deleted file mode 100644 index 92b0150f5014d3daef5c6c37b2ee1f728f3795c2..0000000000000000000000000000000000000000 --- a/imperative/python/test/unit/quantization/test_qconfig.py +++ /dev/null @@ -1,14 +0,0 @@ -from functools import partial - -from megengine.quantization import QConfig, tqt_qconfig -from megengine.quantization.fake_quant import TQT - - -def test_equal(): - qconfig = QConfig( - weight_observer=None, - act_observer=None, - weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), - act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), - ) - assert qconfig == tqt_qconfig diff --git a/imperative/python/test/unit/quantization/test_quantize.py b/imperative/python/test/unit/quantization/test_quantize.py index ed561bc19760cb6ca916806c8d941e80bef2173f..c77ba7823cddb86e46fb4acae29d4b9b97ab81dd 100644 --- a/imperative/python/test/unit/quantization/test_quantize.py +++ b/imperative/python/test/unit/quantization/test_quantize.py @@ -33,7 +33,7 @@ from megengine.quantization.quantize import ( ) -class Net(Float.Module): +class FloatNet(Float.Module): def __init__(self): super().__init__() self.quant = Float.QuantStub() @@ -113,25 +113,25 @@ def test_reset_qconfig(): def test_enable_and_disable_observer(): net = init_qat_net() enable_observer(net) - assert net.quant.act_observer.enabled == True - assert net.linear.weight_observer.enabled == True - assert net.linear.act_observer.enabled == True + assert net.quant.act_observer.enabled is True + assert net.linear.weight_observer.enabled is True + assert net.linear.act_observer.enabled is True disable_observer(net) - assert net.quant.act_observer.enabled == False - assert net.linear.weight_observer.enabled == False - assert net.linear.act_observer.enabled == False + assert net.quant.act_observer.enabled is False + assert net.linear.weight_observer.enabled is False + assert net.linear.act_observer.enabled is False def test_enable_and_disable_fake_quant(): net = init_qat_net() disable_fake_quant(net) - assert net.quant.act_fake_quant.enabled == False - assert net.linear.weight_fake_quant.enabled == False - assert net.linear.act_fake_quant.enabled == False + assert net.quant.act_fake_quant.enabled is False + assert net.linear.weight_fake_quant.enabled is False + assert net.linear.act_fake_quant.enabled is False enable_fake_quant(net) - assert net.quant.act_fake_quant.enabled == True - assert net.linear.weight_fake_quant.enabled == True - assert net.linear.act_fake_quant.enabled == True + assert net.quant.act_fake_quant.enabled is True + assert net.linear.weight_fake_quant.enabled is True + assert net.linear.act_fake_quant.enabled is True def init_observer(module, data): @@ -144,7 +144,7 @@ def init_observer(module, data): def test_enable_and_disable_all(): x = Tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) - net = Net() + net = FloatNet() y1 = net(x).numpy() net = quantize_qat(net, min_max_fakequant_qconfig) @@ -162,7 +162,7 @@ def test_enable_and_disable_all(): def test_quantize_qat(): - net = Net() + net = FloatNet() qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) assert isinstance(qat_net.quant, QAT.QuantStub) assert isinstance(qat_net.linear, QAT.Linear)