提交 35c71276 编写于 作者: M Megvii Engine Team

fix(mge/quant): fix TQT epoch scale change bug

GitOrigin-RevId: 6e39de9cecbfae2b4f1e0acdf34f821af0e339bb
上级 e6e41242
...@@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule): ...@@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule):
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
""" """
def __init__(self, method): with_weight = False
super().__init__(method)
self.with_weight = False
def forward(self, *inps): def forward(self, *inps):
return self.apply_quant_activation(super().forward(*inps)) return self.apply_quant_activation(super().forward(*inps))
......
...@@ -23,6 +23,9 @@ class QATModule(Module): ...@@ -23,6 +23,9 @@ class QATModule(Module):
:func:`~.quantize.quantize` further. :func:`~.quantize.quantize` further.
""" """
with_weight = True
with_act = True
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -32,9 +35,6 @@ class QATModule(Module): ...@@ -32,9 +35,6 @@ class QATModule(Module):
self.weight_fake_quant = None # type: FakeQuantize self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize self.act_fake_quant = None # type: FakeQuantize
self.with_weight = True
self.with_act = True
def set_qconfig(self, qconfig: QConfig): def set_qconfig(self, qconfig: QConfig):
r""" r"""
Set quantization related configs with ``qconfig``, including Set quantization related configs with ``qconfig``, including
...@@ -51,29 +51,21 @@ class QATModule(Module): ...@@ -51,29 +51,21 @@ class QATModule(Module):
self.weight_observer = safe_call(qconfig.weight_observer) self.weight_observer = safe_call(qconfig.weight_observer)
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
def _enable_exec(self, with_module, func, enable):
if not with_module:
return
if enable:
func.enable()
else:
func.disable()
def set_fake_quant(self, enable): def set_fake_quant(self, enable):
if self.with_act: self._enable_exec(self.with_act, self.act_fake_quant, enable)
if enable: self._enable_exec(self.with_weight, self.weight_fake_quant, enable)
self.act_fake_quant.enable()
else:
self.act_fake_quant.disable()
if self.with_weight:
if enable:
self.weight_fake_quant.enable()
else:
self.weight_fake_quant.disable()
def set_observer(self, enable): def set_observer(self, enable):
if self.with_act: self._enable_exec(self.with_act, self.act_observer, enable)
if enable: self._enable_exec(self.with_weight, self.weight_observer, enable)
self.act_observer.enable()
else:
self.act_observer.disable()
if self.with_weight:
if enable:
self.weight_observer.enable()
else:
self.weight_observer.disable()
def _apply_fakequant_with_observer( def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
......
...@@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule): ...@@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule):
input after converted to :class:`~.QuantizedModule`. input after converted to :class:`~.QuantizedModule`.
""" """
def __init__(self): with_weight = False
super().__init__()
self.with_weight = False
def forward(self, inp): def forward(self, inp):
return self.apply_quant_activation(inp) return self.apply_quant_activation(inp)
...@@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule): ...@@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule):
input after converted to :class:`~.QuantizedModule`. input after converted to :class:`~.QuantizedModule`.
""" """
def __init__(self): with_weight = False
super().__init__() with_act = False
self.with_weight = False
self.with_act = False
def forward(self, inp): def forward(self, inp):
return inp return inp
......
...@@ -116,10 +116,11 @@ class TQT(_FakeQuantize): ...@@ -116,10 +116,11 @@ class TQT(_FakeQuantize):
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) return TQT_Function(self.qmin, self.qmax)(inp, self.scale)
def normal_foward(self, inp, q_dict): def normal_foward(self, inp, q_dict):
# when disable, TQT will do normal forward, initialize scale weight if q_dict["enable_observer"]:
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) # when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.log(tmp_scale / 127) / F.log(2) tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp return inp
def get_qparams(self): def get_qparams(self):
......
...@@ -102,6 +102,7 @@ class MinMaxObserver(Observer): ...@@ -102,6 +102,7 @@ class MinMaxObserver(Observer):
q_dict = get_qparam_dict(self.mode) q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val q_dict["max_val"] = inp_max_val
q_dict["enable_observer"] = self.enable
if self.mode == QuantMode.SYMMERTIC: if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
......
...@@ -14,6 +14,7 @@ from ..module import qat as QAT ...@@ -14,6 +14,7 @@ from ..module import qat as QAT
from ..module import quantized as Quantized from ..module import quantized as Quantized
from ..module.qat import QATModule from ..module.qat import QATModule
from ..module.quantized import QuantizedModule from ..module.quantized import QuantizedModule
from .fake_quant import TQT
from .qconfig import QConfig, ema_fakequant_qconfig from .qconfig import QConfig, ema_fakequant_qconfig
...@@ -119,6 +120,14 @@ def quantize_qat( ...@@ -119,6 +120,14 @@ def quantize_qat(
return module return module
def _propagate(module: Module, func_str: str, *args, **kargs):
def fn(mod: Module):
if isinstance(mod, QATModule):
getattr(mod, func_str)(*args, **kargs)
module.apply(fn)
def propagate_qconfig(module: QATModule, qconfig: QConfig): def propagate_qconfig(module: QATModule, qconfig: QConfig):
r""" r"""
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
...@@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): ...@@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
:param module: root module to traverse recursively. :param module: root module to traverse recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
""" """
_propagate(module, "set_qconfig", qconfig)
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qconfig(qconfig)
module.apply(fn)
def disable_fake_quant(module: Module): def disable_fake_quant(module: Module):
...@@ -141,11 +145,7 @@ def disable_fake_quant(module: Module): ...@@ -141,11 +145,7 @@ def disable_fake_quant(module: Module):
:param module: root module to do disable fake quantization recursively. :param module: root module to do disable fake quantization recursively.
""" """
def fn(mod: Module): _propagate(module, "set_fake_quant", False)
if isinstance(mod, QATModule):
mod.set_fake_quant(False)
module.apply(fn)
def disable_observer(module: Module): def disable_observer(module: Module):
...@@ -155,11 +155,7 @@ def disable_observer(module: Module): ...@@ -155,11 +155,7 @@ def disable_observer(module: Module):
:param module: root module to do disable observer recursively. :param module: root module to do disable observer recursively.
""" """
def fn(mod: Module): _propagate(module, "set_observer", False)
if isinstance(mod, QATModule):
self.set_observer(False)
module.apply(fn)
def enable_fake_quant(module: Module): def enable_fake_quant(module: Module):
...@@ -169,11 +165,7 @@ def enable_fake_quant(module: Module): ...@@ -169,11 +165,7 @@ def enable_fake_quant(module: Module):
:param module: root module to do enable fake quantization recursively. :param module: root module to do enable fake quantization recursively.
""" """
def fn(mod: Module): _propagate(module, "set_fake_quant", True)
if isinstance(mod, QATModule):
mod.set_fake_quant(True)
module.apply(fn)
def enable_observer(module: Module): def enable_observer(module: Module):
...@@ -183,8 +175,4 @@ def enable_observer(module: Module): ...@@ -183,8 +175,4 @@ def enable_observer(module: Module):
:param module: root module to do enable observer recursively. :param module: root module to do enable observer recursively.
""" """
def fn(mod: Module): _propagate(module, "set_observer", False)
if isinstance(mod, QATModule):
mod.set_observer(True)
module.apply(fn)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册