diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index 9381ef4f6be6e2dfbee1de953ccaac3a1fbeec8c..37315ee6640122ef11e3d797385cda5c990d8d67 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -37,15 +37,14 @@ class QATModule(Module): Set quantization related configs with ``qconfig``, including observer and fake_quant for weight and activation. """ - self.weight_observer = qconfig.weight_observer() - self.act_observer = qconfig.act_observer() - if qconfig.fake_quant is None: - self.weight_fake_quant = None - self.act_fake_quant = None - else: - self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) - self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) + def safe_call(func): + return func() if func is not None else None + + self.weight_observer = safe_call(qconfig.weight_observer) + self.act_observer = safe_call(qconfig.act_observer) + self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) + self.act_fake_quant = safe_call(qconfig.act_fake_quant) def _apply_fakequant_with_observer( self, target: Tensor, fake_quant: FakeQuantize, observer: Observer diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py index 7ac8889da363e4b7ebfe57af83fc3a9da9b28c51..9349f0b36f872d528c63d9682e47f3e53c05110d 100644 --- a/python_module/megengine/quantization/fake_quant.py +++ b/python_module/megengine/quantization/fake_quant.py @@ -19,7 +19,7 @@ from .observer import ObserverMode, Round class _FakeQuantize(Module): - def __init__(self, dtype: str, enable: bool = True): + def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): super().__init__() if not dtype in _metadata_dict.keys(): raise ValueError( @@ -28,7 +28,10 @@ class _FakeQuantize(Module): ) ) self.dtype = dtype - self.qmin = _metadata_dict[dtype].qmin + self.narrow_range = narrow_range + self.qmin = ( + -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin + ) self.qmax = _metadata_dict[dtype].qmax self.enabled = enable @@ -90,12 +93,12 @@ class TQT_Function(Function): class TQT(_FakeQuantize): """ - TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds + TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks """ - def __init__(self, dtype: str, enable: bool = True): - super().__init__(dtype, enable) + def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): + super().__init__(dtype, narrow_range, enable) self.scale = Parameter(0.0, dtype=np.float32) def fake_quant_forward(self, inp, q_dict): @@ -116,6 +119,11 @@ class TQT(_FakeQuantize): class FakeQuantize(_FakeQuantize): r""" A module to do quant and dequant according to observer's scale and zero_point. + + :param dtype: A string indicating the target quantization type of input. + :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, + instead of 1 greater. Usually True for weight and False for activation. + :param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. """ def fake_quant_forward(self, inp, q_dict): diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 8f89b010316f06a0cfe45c227016f952f060e5d8..476b73024b55ae560f9c49683a00547a986f87e9 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -31,9 +31,11 @@ class Observer(Module): A base class for Observer Module. :param dtype: a string indicating to collect scale and zero_point of which dtype + :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, + instead of 1 greater. Usually True for weight and False for activation. """ - def __init__(self, dtype="qint8"): + def __init__(self, dtype: str, narrow_range: bool = False): super().__init__() if dtype not in _metadata_dict.keys(): raise ValueError( @@ -42,7 +44,10 @@ class Observer(Module): ) ) self.dtype = dtype - self.qmin = _metadata_dict[dtype].qmin + self.narrow_range = narrow_range + self.qmin = ( + -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin + ) self.qmax = _metadata_dict[dtype].qmax self.enabled = True @@ -96,8 +101,14 @@ def create_observer_dict(mode): class MinMaxObserver(Observer): - def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): - super().__init__(dtype) + def __init__( + self, + mode=ObserverMode.SYMMERTIC, + eps=0.00001, + dtype="qint8", + narrow_range: bool = False, + ): + super().__init__(dtype, narrow_range) self.mode = mode self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) @@ -153,9 +164,14 @@ class MinMaxObserver(Observer): class ExponentialMovingAverageObserver(MinMaxObserver): def __init__( - self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8" + self, + momentum=0.9, + mode=ObserverMode.SYMMERTIC, + eps=0.00001, + dtype="qint8", + narrow_range: bool = False, ): - super().__init__(mode, eps, dtype) + super().__init__(mode, eps, dtype, narrow_range) self.momentum = Buffer(momentum) self.runtime_momentum = Buffer(0.0) @@ -188,11 +204,12 @@ class HistogramObserver(MinMaxObserver): self, bins=2048, upsample_rate=128, - dtype="qint8", mode=ObserverMode.SYMMERTIC, eps=0.00001, + dtype="qint8", + narrow_range: bool = False, ): - super().__init__(mode, eps, dtype) + super().__init__(mode, eps, dtype, narrow_range) self.bins = bins self.upsample_rate = upsample_rate self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 diff --git a/python_module/megengine/quantization/qconfig.py b/python_module/megengine/quantization/qconfig.py index 00d82429e7ca56f8e04506ce20ea7c7532616669..4a7b75ecb12dbdb746823609de12f9977219a86f 100644 --- a/python_module/megengine/quantization/qconfig.py +++ b/python_module/megengine/quantization/qconfig.py @@ -5,6 +5,8 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from functools import partial + from ..module import Module from .fake_quant import TQT, FakeQuantize from .observer import ( @@ -22,9 +24,9 @@ class QConfig: :param weight_observer: interface to instantiate an :class:`~.Observer` indicating how to collect scales and zero_point of wegiht. :param act_observer: similar to ``weight_observer`` but toward activation. - :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating - how to do fake_quant calculation. can be invoked multi times to get different - instance for each target tensor, for better control on enable and disable. + :param weight_fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating + how to do fake_quant calculation. + :param act_observer: similar to ``weight_fake_quant`` but toward activation. Examples: @@ -32,14 +34,24 @@ class QConfig: # Default EMA QConfig for QAT. ema_fakequant_qconfig = QConfig( - weight_observer=MinMaxObserver, - act_observer=ExponentialMovingAverageObserver, - fake_quant=FakeQuantize, + weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), + act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False), + weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), + act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), ) + + Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` + to add initialization parameters of the ``class``, so that don't need to provide parameters in + :meth:`~.QATModule.set_qconfig`. + + Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related + parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if + four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow. + Weights are commonly calculated in this way, so needed to narrow the range. """ def __init__( - self, act_observer, weight_observer, fake_quant, + self, weight_observer, act_observer, weight_fake_quant, act_fake_quant ): if isinstance(act_observer, Module) or isinstance(weight_observer, Module): raise ValueError( @@ -47,30 +59,42 @@ class QConfig: " class generator using `partial(Observer, ...)` instead. Use" " partial(MyObserver, x=1) to override arguments to constructor if needed" ) - self.act_observer = act_observer self.weight_observer = weight_observer - self.fake_quant = fake_quant + self.act_observer = act_observer + self.weight_fake_quant = weight_fake_quant + self.act_fake_quant = act_fake_quant tqt_quant_qconfig = QConfig( - weight_observer=ExponentialMovingAverageObserver, - act_observer=ExponentialMovingAverageObserver, - fake_quant=TQT, + weight_observer=partial( + ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True + ), + act_observer=partial( + ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False + ), + weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), + act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), ) -# Default QAT QConfigs min_max_fakequant_qconfig = QConfig( - weight_observer=MinMaxObserver, - act_observer=MinMaxObserver, - fake_quant=FakeQuantize, + weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), + act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), + weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), + act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), ) ema_fakequant_qconfig = QConfig( - weight_observer=MinMaxObserver, - act_observer=ExponentialMovingAverageObserver, - fake_quant=FakeQuantize, + weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), + act_observer=partial( + ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False + ), + weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), + act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), ) calibration_qconfig = QConfig( - weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None, + weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), + act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False), + weight_fake_quant=None, + act_fake_quant=None, )