diff --git a/python_module/megengine/module/qat/elemwise.py b/python_module/megengine/module/qat/elemwise.py index 3385e774faa488b53b0de592ed3e8383299d483c..f99583bdeaf8d9a4739088920a99bb8ab7973e29 100644 --- a/python_module/megengine/module/qat/elemwise.py +++ b/python_module/megengine/module/qat/elemwise.py @@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule): :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. """ - def __init__(self, method): - super().__init__(method) - self.with_weight = False + with_weight = False def forward(self, *inps): return self.apply_quant_activation(super().forward(*inps)) diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index 931747c41feb5df78f629da36e830d16c5a5f6b0..7eec68658751c34a735a1047ffe9705d9baf9c3a 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -23,6 +23,9 @@ class QATModule(Module): :func:`~.quantize.quantize` further. """ + with_weight = True + with_act = True + def __init__(self): super().__init__() @@ -32,9 +35,6 @@ class QATModule(Module): self.weight_fake_quant = None # type: FakeQuantize self.act_fake_quant = None # type: FakeQuantize - self.with_weight = True - self.with_act = True - def set_qconfig(self, qconfig: QConfig): r""" Set quantization related configs with ``qconfig``, including @@ -51,29 +51,21 @@ class QATModule(Module): self.weight_observer = safe_call(qconfig.weight_observer) self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) + def _enable_exec(self, with_module, func, enable): + if not with_module: + return + if enable: + func.enable() + else: + func.disable() + def set_fake_quant(self, enable): - if self.with_act: - if enable: - self.act_fake_quant.enable() - else: - self.act_fake_quant.disable() - if self.with_weight: - if enable: - self.weight_fake_quant.enable() - else: - self.weight_fake_quant.disable() + self._enable_exec(self.with_act, self.act_fake_quant, enable) + self._enable_exec(self.with_weight, self.weight_fake_quant, enable) def set_observer(self, enable): - if self.with_act: - if enable: - self.act_observer.enable() - else: - self.act_observer.disable() - if self.with_weight: - if enable: - self.weight_observer.enable() - else: - self.weight_observer.disable() + self._enable_exec(self.with_act, self.act_observer, enable) + self._enable_exec(self.with_weight, self.weight_observer, enable) def _apply_fakequant_with_observer( self, target: Tensor, fake_quant: FakeQuantize, observer: Observer diff --git a/python_module/megengine/module/qat/quant_dequant.py b/python_module/megengine/module/qat/quant_dequant.py index fb4018ffcc5b55b4a116ed59db90413d29d1e2e7..0baa3e1c7822085e520c01017d8219104905d6ec 100644 --- a/python_module/megengine/module/qat/quant_dequant.py +++ b/python_module/megengine/module/qat/quant_dequant.py @@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule): input after converted to :class:`~.QuantizedModule`. """ - def __init__(self): - super().__init__() - self.with_weight = False + with_weight = False def forward(self, inp): return self.apply_quant_activation(inp) @@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule): input after converted to :class:`~.QuantizedModule`. """ - def __init__(self): - super().__init__() - self.with_weight = False - self.with_act = False + with_weight = False + with_act = False def forward(self, inp): return inp diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py index 6e5a26f0ee70750bebb3e27a2a80c510db7f2aa8..b2e9d93939c0b2d50779ad4c1cadaa626ff11995 100644 --- a/python_module/megengine/quantization/fake_quant.py +++ b/python_module/megengine/quantization/fake_quant.py @@ -116,10 +116,11 @@ class TQT(_FakeQuantize): return TQT_Function(self.qmin, self.qmax)(inp, self.scale) def normal_foward(self, inp, q_dict): - # when disable, TQT will do normal forward, initialize scale weight - tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) - tmp_scale = F.log(tmp_scale / 127) / F.log(2) - F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) + if q_dict["enable_observer"]: + # when disable, TQT will do normal forward, initialize scale weight + tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) + tmp_scale = F.log(tmp_scale / 127) / F.log(2) + F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) return inp def get_qparams(self): diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 0a1ddc11bb5b5bbffe80c4ab213adc3c59bd1812..33715574edb0cab8f45ffc4b076a86602b275101 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -102,6 +102,7 @@ class MinMaxObserver(Observer): q_dict = get_qparam_dict(self.mode) q_dict["min_val"] = inp_min_val q_dict["max_val"] = inp_max_val + q_dict["enable_observer"] = self.enable if self.mode == QuantMode.SYMMERTIC: symmetric_max_vals = F.maximum(-min_val, max_val) # use maximun to avoid scale too small at the begin diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index e36f53441a449213ecac16f879ee467d9a7a123c..bb8bb36ba26f77218f6871296c27cc384e906c20 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -14,6 +14,7 @@ from ..module import qat as QAT from ..module import quantized as Quantized from ..module.qat import QATModule from ..module.quantized import QuantizedModule +from .fake_quant import TQT from .qconfig import QConfig, ema_fakequant_qconfig @@ -119,6 +120,14 @@ def quantize_qat( return module +def _propagate(module: Module, func_str: str, *args, **kargs): + def fn(mod: Module): + if isinstance(mod, QATModule): + getattr(mod, func_str)(*args, **kargs) + + module.apply(fn) + + def propagate_qconfig(module: QATModule, qconfig: QConfig): r""" Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. @@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): :param module: root module to traverse recursively. :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. """ - - def fn(mod: Module): - if isinstance(mod, QATModule): - mod.set_qconfig(qconfig) - - module.apply(fn) + _propagate(module, "set_qconfig", qconfig) def disable_fake_quant(module: Module): @@ -141,11 +145,7 @@ def disable_fake_quant(module: Module): :param module: root module to do disable fake quantization recursively. """ - def fn(mod: Module): - if isinstance(mod, QATModule): - mod.set_fake_quant(False) - - module.apply(fn) + _propagate(module, "set_fake_quant", False) def disable_observer(module: Module): @@ -155,11 +155,7 @@ def disable_observer(module: Module): :param module: root module to do disable observer recursively. """ - def fn(mod: Module): - if isinstance(mod, QATModule): - self.set_observer(False) - - module.apply(fn) + _propagate(module, "set_observer", False) def enable_fake_quant(module: Module): @@ -169,11 +165,7 @@ def enable_fake_quant(module: Module): :param module: root module to do enable fake quantization recursively. """ - def fn(mod: Module): - if isinstance(mod, QATModule): - mod.set_fake_quant(True) - - module.apply(fn) + _propagate(module, "set_fake_quant", True) def enable_observer(module: Module): @@ -183,8 +175,4 @@ def enable_observer(module: Module): :param module: root module to do enable observer recursively. """ - def fn(mod: Module): - if isinstance(mod, QATModule): - mod.set_observer(True) - - module.apply(fn) + _propagate(module, "set_observer", False)