diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index d5562f6a94cfe32c21453db102b8e5afd2361e0a..e56cb54bfd2ccd6b97c1ea0e2a8d7d9e0654b50c 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -92,6 +92,25 @@ class QATModule(Module): else: return self.act_observer.get_dtype() + def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer): + if hasattr(fake_quant, "get_qparams"): + return fake_quant.get_qparams() + elif observer is not None: + return observer.get_qparams() + return None + + def get_weight_qparams(self): + r""" + Get weight's quantization parameters. + """ + return self._get_qparams(self.weight_fake_quant, self.weight_observer) + + def get_activation_qparams(self): + r""" + Get activation's quantization parameters. + """ + return self._get_qparams(self.act_fake_quant, self.act_observer) + @classmethod @abstractmethod def from_float_module(cls, float_module: Module): diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index b93e64530074117c66427233f04d4497c33152ec..82feced1cc3738598f6664f4079e23a6d4854f75 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -8,7 +8,7 @@ from .fake_quant import FakeQuantize from .internal_fake_quant import * -from .observer import HistogramObserver, Observer, ObserverMode +from .observer import HistogramObserver, Observer from .qconfig import ( QConfig, calibration_qconfig, @@ -16,3 +16,4 @@ from .qconfig import ( min_max_fakequant_qconfig, tqt_quant_qconfig, ) +from .utils import QuantMode diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py index 78577d9b7da11f28742bb28d6eadc5e73950b328..6e5a26f0ee70750bebb3e27a2a80c510db7f2aa8 100644 --- a/python_module/megengine/quantization/fake_quant.py +++ b/python_module/megengine/quantization/fake_quant.py @@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype from ..core import Buffer, Function, Parameter from ..jit import sideeffect from ..module import Module -from .observer import ObserverMode, Round +from .observer import Round +from .utils import QuantMode, get_qparam_dict class _FakeQuantize(Module): @@ -121,8 +122,18 @@ class TQT(_FakeQuantize): F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) return inp + def get_qparams(self): + qdict = get_qparam_dict(QuantMode.TQT) + qdict["scale"] = 2 ** self.scale + return qdict + def get_dtype(self): - return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) + q_dict = self.get_qparams() + scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] + zero_point = ( + None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] + ) + return get_quantized_dtype(self.dtype, scale, zero_point) class FakeQuantize(_FakeQuantize): @@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize): """ def fake_quant_forward(self, inp, q_dict): - if q_dict["mode"] == ObserverMode.SYMMERTIC: + if q_dict["mode"] == QuantMode.SYMMERTIC: scale = q_dict["scale"] # Quant oup = Round()(inp / scale) diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 476b73024b55ae560f9c49683a00547a986f87e9..0a1ddc11bb5b5bbffe80c4ab213adc3c59bd1812 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype from ..core import Buffer, Function, tensor from ..jit import sideeffect from ..module import Module +from .utils import QuantMode, get_qparam_dict class Round(Function): @@ -81,29 +82,10 @@ class Observer(Module): pass -class ObserverMode(Enum): - SYMMERTIC = 1 - ASYMMERTIC = 2 - - -def create_observer_dict(mode): - if mode == ObserverMode.SYMMERTIC: - return { - "mode": ObserverMode.SYMMERTIC, - "scale": None, - } - else: - return { - "mode": ObserverMode.ASYMMERTIC, - "scale": None, - "zero_point": None, - } - - class MinMaxObserver(Observer): def __init__( self, - mode=ObserverMode.SYMMERTIC, + mode=QuantMode.SYMMERTIC, eps=0.00001, dtype="qint8", narrow_range: bool = False, @@ -117,10 +99,10 @@ class MinMaxObserver(Observer): def _calculate_qparams(self, inp_min_val, inp_max_val): min_val = F.minimum(0.0, inp_min_val) max_val = F.maximum(0.0, inp_max_val) - q_dict = create_observer_dict(self.mode) + q_dict = get_qparam_dict(self.mode) q_dict["min_val"] = inp_min_val q_dict["max_val"] = inp_max_val - if self.mode == ObserverMode.SYMMERTIC: + if self.mode == QuantMode.SYMMERTIC: symmetric_max_vals = F.maximum(-min_val, max_val) # use maximun to avoid scale too small at the begin q_dict["scale"] = F.maximum( @@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): def __init__( self, momentum=0.9, - mode=ObserverMode.SYMMERTIC, + mode=QuantMode.SYMMERTIC, eps=0.00001, dtype="qint8", narrow_range: bool = False, @@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver): self, bins=2048, upsample_rate=128, - mode=ObserverMode.SYMMERTIC, + mode=QuantMode.SYMMERTIC, eps=0.00001, dtype="qint8", narrow_range: bool = False, diff --git a/python_module/megengine/quantization/utils.py b/python_module/megengine/quantization/utils.py index dff2ddf973c632ea773dd82b1cb15cfceb584d44..470e39ae622084d74cea6cd7f0647350c770c931 100644 --- a/python_module/megengine/quantization/utils.py +++ b/python_module/megengine/quantization/utils.py @@ -6,6 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from enum import Enum from functools import partial, update_wrapper, wraps @@ -21,3 +22,24 @@ def register_method_to_class(cls): return func return decorator + + +class QuantMode(Enum): + SYMMERTIC = 1 + ASYMMERTIC = 2 + TQT = 3 + + +qparam_dict = { + QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, + QuantMode.ASYMMERTIC: { + "mode": QuantMode.ASYMMERTIC, + "scale": None, + "zero_point": None, + }, + QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,}, +} + + +def get_qparam_dict(mode): + return qparam_dict.get(mode, None)