提交 38a95744 编写于 作者: M Megvii Engine Team

refactor(mge/quantization): refactor qconfig, remove inp_observer and bias_fakequant

GitOrigin-RevId: e57f9edd1230ee4af4ee6ce3034102e7d05b75b7
上级 de75dae8
......@@ -476,21 +476,17 @@ class QATModule(Module):
self.quantizing = self.QATMode.DISABLED
self.scale = None
self.inp_observer = None # type: Observer
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
self.weight_fake_quant = None # type: FakeQuantize
self.bias_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
def set_qconfig(self, qconfig: "QConfig"):
self.inp_observer = qconfig.inp_observer()
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.bias_fake_quant = qconfig.bias_fake_quant()
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)
def apply_observer(self, target: Tensor, obs: "Observer"):
......
......@@ -8,4 +8,11 @@
from .fake_quant import FakeQuantize
from .observer import Observer
from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig
from .quantize import quantize, quantize_qat
from .quantize import (
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
quantize,
quantize_qat,
)
......@@ -15,21 +15,18 @@ from .observer import ExponentialMovingAverageObserver, MinMaxObserver
class QConfig:
"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation``, ``weight`` and ``bias``.
``activation`` and ``weight``.
And ``fake_quant`` parameter to indicate
See :meth:`~.QATModule.set_qconfig` for detail usage.
:param inp_observer: interface to instantiate an :class:`~.Observer` indicating
how to collect scales and zero_point of input.
:param weight_observer: similar to ``inp_observer`` but toward weight.
:param act_observer: similar to ``inp_observer`` but toward activation.
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating
- how to collect scales and zero_point of wegiht.
:param act_observer: similar to ``weight_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different
instance for each target tensor, for better control on enable and disable.
:param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype``
in advance, for bias's dtype is unable to be inferred from observer.
Examples:
......@@ -37,21 +34,16 @@ class QConfig:
# Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=ExponentialMovingAverageObserver,
weight_observer=MinMaxObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
)
"""
def __init__(
self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant,
):
if (
isinstance(act_observer, Module)
or isinstance(weight_observer, Module)
or isinstance(inp_observer, Module)
self, act_observer, weight_observer, fake_quant,
):
if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
raise ValueError(
"QConfig must not receive observer instance, please pass observer"
" class generator using `partial(Observer, ...)` instead. Use"
......@@ -59,24 +51,18 @@ class QConfig:
)
self.act_observer = act_observer
self.weight_observer = weight_observer
self.inp_observer = inp_observer
self.fake_quant = fake_quant
self.bias_fake_quant = bias_fake_quant
# Default QAT QConfigs
min_max_fakequant_qconfig = QConfig(
inp_observer=MinMaxObserver,
weight_observer=MinMaxObserver,
act_observer=MinMaxObserver,
fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
)
ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=MinMaxObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
)
......@@ -64,7 +64,6 @@ def disable_fake_quant(module: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.disable()
mod.weight_fake_quant.disable()
mod.inp_fake_quant.disable()
module.apply(fn)
......@@ -79,6 +78,7 @@ def disable_observer(module: Module):
def fn(mod):
if isinstance(mod, QATModule):
mod.act_observer.disable()
mod.weight_observer.disable()
module.apply(fn)
......@@ -94,7 +94,6 @@ def enable_fake_quant(module: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.enable()
mod.weight_fake_quant.enable()
mod.inp_fake_quant.enable()
module.apply(fn)
......@@ -109,5 +108,6 @@ def enable_observer(module: Module):
def fn(mod):
if isinstance(mod, QATModule):
mod.act_observer.enable()
mod.weight_observer.enable()
module.apply(fn)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册