提交 4755400e 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

refactor(mge/quantization): add `narrow_range` to control quant dtype's lower bound

GitOrigin-RevId: 92389341be7b1b79fbb818b975d8db2d7de8607a
上级 c8a9094b
...@@ -37,15 +37,14 @@ class QATModule(Module): ...@@ -37,15 +37,14 @@ class QATModule(Module):
Set quantization related configs with ``qconfig``, including Set quantization related configs with ``qconfig``, including
observer and fake_quant for weight and activation. observer and fake_quant for weight and activation.
""" """
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()
if qconfig.fake_quant is None: def safe_call(func):
self.weight_fake_quant = None return func() if func is not None else None
self.act_fake_quant = None
else: self.weight_observer = safe_call(qconfig.weight_observer)
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) self.act_observer = safe_call(qconfig.act_observer)
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
self.act_fake_quant = safe_call(qconfig.act_fake_quant)
def _apply_fakequant_with_observer( def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
......
...@@ -19,7 +19,7 @@ from .observer import ObserverMode, Round ...@@ -19,7 +19,7 @@ from .observer import ObserverMode, Round
class _FakeQuantize(Module): class _FakeQuantize(Module):
def __init__(self, dtype: str, enable: bool = True): def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__() super().__init__()
if not dtype in _metadata_dict.keys(): if not dtype in _metadata_dict.keys():
raise ValueError( raise ValueError(
...@@ -28,7 +28,10 @@ class _FakeQuantize(Module): ...@@ -28,7 +28,10 @@ class _FakeQuantize(Module):
) )
) )
self.dtype = dtype self.dtype = dtype
self.qmin = _metadata_dict[dtype].qmin self.narrow_range = narrow_range
self.qmin = (
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
)
self.qmax = _metadata_dict[dtype].qmax self.qmax = _metadata_dict[dtype].qmax
self.enabled = enable self.enabled = enable
...@@ -90,12 +93,12 @@ class TQT_Function(Function): ...@@ -90,12 +93,12 @@ class TQT_Function(Function):
class TQT(_FakeQuantize): class TQT(_FakeQuantize):
""" """
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks
""" """
def __init__(self, dtype: str, enable: bool = True): def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__(dtype, enable) super().__init__(dtype, narrow_range, enable)
self.scale = Parameter(0.0, dtype=np.float32) self.scale = Parameter(0.0, dtype=np.float32)
def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict):
...@@ -116,6 +119,11 @@ class TQT(_FakeQuantize): ...@@ -116,6 +119,11 @@ class TQT(_FakeQuantize):
class FakeQuantize(_FakeQuantize): class FakeQuantize(_FakeQuantize):
r""" r"""
A module to do quant and dequant according to observer's scale and zero_point. A module to do quant and dequant according to observer's scale and zero_point.
:param dtype: A string indicating the target quantization type of input.
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``.
""" """
def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict):
......
...@@ -31,9 +31,11 @@ class Observer(Module): ...@@ -31,9 +31,11 @@ class Observer(Module):
A base class for Observer Module. A base class for Observer Module.
:param dtype: a string indicating to collect scale and zero_point of which dtype :param dtype: a string indicating to collect scale and zero_point of which dtype
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
""" """
def __init__(self, dtype="qint8"): def __init__(self, dtype: str, narrow_range: bool = False):
super().__init__() super().__init__()
if dtype not in _metadata_dict.keys(): if dtype not in _metadata_dict.keys():
raise ValueError( raise ValueError(
...@@ -42,7 +44,10 @@ class Observer(Module): ...@@ -42,7 +44,10 @@ class Observer(Module):
) )
) )
self.dtype = dtype self.dtype = dtype
self.qmin = _metadata_dict[dtype].qmin self.narrow_range = narrow_range
self.qmin = (
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
)
self.qmax = _metadata_dict[dtype].qmax self.qmax = _metadata_dict[dtype].qmax
self.enabled = True self.enabled = True
...@@ -96,8 +101,14 @@ def create_observer_dict(mode): ...@@ -96,8 +101,14 @@ def create_observer_dict(mode):
class MinMaxObserver(Observer): class MinMaxObserver(Observer):
def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): def __init__(
super().__init__(dtype) self,
mode=ObserverMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
):
super().__init__(dtype, narrow_range)
self.mode = mode self.mode = mode
self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32)
...@@ -153,9 +164,14 @@ class MinMaxObserver(Observer): ...@@ -153,9 +164,14 @@ class MinMaxObserver(Observer):
class ExponentialMovingAverageObserver(MinMaxObserver): class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__( def __init__(
self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8" self,
momentum=0.9,
mode=ObserverMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
): ):
super().__init__(mode, eps, dtype) super().__init__(mode, eps, dtype, narrow_range)
self.momentum = Buffer(momentum) self.momentum = Buffer(momentum)
self.runtime_momentum = Buffer(0.0) self.runtime_momentum = Buffer(0.0)
...@@ -188,11 +204,12 @@ class HistogramObserver(MinMaxObserver): ...@@ -188,11 +204,12 @@ class HistogramObserver(MinMaxObserver):
self, self,
bins=2048, bins=2048,
upsample_rate=128, upsample_rate=128,
dtype="qint8",
mode=ObserverMode.SYMMERTIC, mode=ObserverMode.SYMMERTIC,
eps=0.00001, eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
): ):
super().__init__(mode, eps, dtype) super().__init__(mode, eps, dtype, narrow_range)
self.bins = bins self.bins = bins
self.upsample_rate = upsample_rate self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
from ..module import Module from ..module import Module
from .fake_quant import TQT, FakeQuantize from .fake_quant import TQT, FakeQuantize
from .observer import ( from .observer import (
...@@ -22,9 +24,9 @@ class QConfig: ...@@ -22,9 +24,9 @@ class QConfig:
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating :param weight_observer: interface to instantiate an :class:`~.Observer` indicating
how to collect scales and zero_point of wegiht. how to collect scales and zero_point of wegiht.
:param act_observer: similar to ``weight_observer`` but toward activation. :param act_observer: similar to ``weight_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating :param weight_fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different how to do fake_quant calculation.
instance for each target tensor, for better control on enable and disable. :param act_observer: similar to ``weight_fake_quant`` but toward activation.
Examples: Examples:
...@@ -32,14 +34,24 @@ class QConfig: ...@@ -32,14 +34,24 @@ class QConfig:
# Default EMA QConfig for QAT. # Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig( ema_fakequant_qconfig = QConfig(
weight_observer=MinMaxObserver, weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=ExponentialMovingAverageObserver, act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False),
fake_quant=FakeQuantize, weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
) )
Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial``
to add initialization parameters of the ``class``, so that don't need to provide parameters in
:meth:`~.QATModule.set_qconfig`.
Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related
parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if
four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
Weights are commonly calculated in this way, so needed to narrow the range.
""" """
def __init__( def __init__(
self, act_observer, weight_observer, fake_quant, self, weight_observer, act_observer, weight_fake_quant, act_fake_quant
): ):
if isinstance(act_observer, Module) or isinstance(weight_observer, Module): if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
raise ValueError( raise ValueError(
...@@ -47,30 +59,42 @@ class QConfig: ...@@ -47,30 +59,42 @@ class QConfig:
" class generator using `partial(Observer, ...)` instead. Use" " class generator using `partial(Observer, ...)` instead. Use"
" partial(MyObserver, x=1) to override arguments to constructor if needed" " partial(MyObserver, x=1) to override arguments to constructor if needed"
) )
self.act_observer = act_observer
self.weight_observer = weight_observer self.weight_observer = weight_observer
self.fake_quant = fake_quant self.act_observer = act_observer
self.weight_fake_quant = weight_fake_quant
self.act_fake_quant = act_fake_quant
tqt_quant_qconfig = QConfig( tqt_quant_qconfig = QConfig(
weight_observer=ExponentialMovingAverageObserver, weight_observer=partial(
act_observer=ExponentialMovingAverageObserver, ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True
fake_quant=TQT, ),
act_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
) )
# Default QAT QConfigs
min_max_fakequant_qconfig = QConfig( min_max_fakequant_qconfig = QConfig(
weight_observer=MinMaxObserver, weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=MinMaxObserver, act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False),
fake_quant=FakeQuantize, weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
) )
ema_fakequant_qconfig = QConfig( ema_fakequant_qconfig = QConfig(
weight_observer=MinMaxObserver, weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=ExponentialMovingAverageObserver, act_observer=partial(
fake_quant=FakeQuantize, ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
) )
calibration_qconfig = QConfig( calibration_qconfig = QConfig(
weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None, weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=None,
act_fake_quant=None,
) )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册