提交 35c71276 编写于 作者: M Megvii Engine Team

fix(mge/quant): fix TQT epoch scale change bug

GitOrigin-RevId: 6e39de9cecbfae2b4f1e0acdf34f821af0e339bb
上级 e6e41242
......@@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule):
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
"""
def __init__(self, method):
super().__init__(method)
self.with_weight = False
with_weight = False
def forward(self, *inps):
return self.apply_quant_activation(super().forward(*inps))
......
......@@ -23,6 +23,9 @@ class QATModule(Module):
:func:`~.quantize.quantize` further.
"""
with_weight = True
with_act = True
def __init__(self):
super().__init__()
......@@ -32,9 +35,6 @@ class QATModule(Module):
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
self.with_weight = True
self.with_act = True
def set_qconfig(self, qconfig: QConfig):
r"""
Set quantization related configs with ``qconfig``, including
......@@ -51,29 +51,21 @@ class QATModule(Module):
self.weight_observer = safe_call(qconfig.weight_observer)
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
def _enable_exec(self, with_module, func, enable):
if not with_module:
return
if enable:
func.enable()
else:
func.disable()
def set_fake_quant(self, enable):
if self.with_act:
if enable:
self.act_fake_quant.enable()
else:
self.act_fake_quant.disable()
if self.with_weight:
if enable:
self.weight_fake_quant.enable()
else:
self.weight_fake_quant.disable()
self._enable_exec(self.with_act, self.act_fake_quant, enable)
self._enable_exec(self.with_weight, self.weight_fake_quant, enable)
def set_observer(self, enable):
if self.with_act:
if enable:
self.act_observer.enable()
else:
self.act_observer.disable()
if self.with_weight:
if enable:
self.weight_observer.enable()
else:
self.weight_observer.disable()
self._enable_exec(self.with_act, self.act_observer, enable)
self._enable_exec(self.with_weight, self.weight_observer, enable)
def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
......
......@@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule):
input after converted to :class:`~.QuantizedModule`.
"""
def __init__(self):
super().__init__()
self.with_weight = False
with_weight = False
def forward(self, inp):
return self.apply_quant_activation(inp)
......@@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule):
input after converted to :class:`~.QuantizedModule`.
"""
def __init__(self):
super().__init__()
self.with_weight = False
self.with_act = False
with_weight = False
with_act = False
def forward(self, inp):
return inp
......
......@@ -116,10 +116,11 @@ class TQT(_FakeQuantize):
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)
def normal_foward(self, inp, q_dict):
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp
def get_qparams(self):
......
......@@ -102,6 +102,7 @@ class MinMaxObserver(Observer):
q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val
q_dict["enable_observer"] = self.enable
if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin
......
......@@ -14,6 +14,7 @@ from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from .fake_quant import TQT
from .qconfig import QConfig, ema_fakequant_qconfig
......@@ -119,6 +120,14 @@ def quantize_qat(
return module
def _propagate(module: Module, func_str: str, *args, **kargs):
def fn(mod: Module):
if isinstance(mod, QATModule):
getattr(mod, func_str)(*args, **kargs)
module.apply(fn)
def propagate_qconfig(module: QATModule, qconfig: QConfig):
r"""
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
......@@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
:param module: root module to traverse recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qconfig(qconfig)
module.apply(fn)
_propagate(module, "set_qconfig", qconfig)
def disable_fake_quant(module: Module):
......@@ -141,11 +145,7 @@ def disable_fake_quant(module: Module):
:param module: root module to do disable fake quantization recursively.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_fake_quant(False)
module.apply(fn)
_propagate(module, "set_fake_quant", False)
def disable_observer(module: Module):
......@@ -155,11 +155,7 @@ def disable_observer(module: Module):
:param module: root module to do disable observer recursively.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
self.set_observer(False)
module.apply(fn)
_propagate(module, "set_observer", False)
def enable_fake_quant(module: Module):
......@@ -169,11 +165,7 @@ def enable_fake_quant(module: Module):
:param module: root module to do enable fake quantization recursively.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_fake_quant(True)
module.apply(fn)
_propagate(module, "set_fake_quant", True)
def enable_observer(module: Module):
......@@ -183,8 +175,4 @@ def enable_observer(module: Module):
:param module: root module to do enable observer recursively.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_observer(True)
module.apply(fn)
_propagate(module, "set_observer", False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册