From ff341cb19b3c2a003eba1cec26ac17cfb4d55a1b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 25 May 2020 18:02:33 +0800 Subject: [PATCH] fix(mge/quantization): modify observer api GitOrigin-RevId: 7b9c22be96c4cab9d8f14f659f3ca6ec4c9b3cb2 --- python_module/megengine/module/module.py | 4 +- .../megengine/quantization/fake_quant.py | 31 ++++++--- .../megengine/quantization/observer.py | 67 +++++++++++++------ 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 1ee440aba..fdea96b51 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -505,8 +505,8 @@ class QATModule(Module): ): oup = self.apply_observer(target, obs) if fq is not None: - scale, zero_point = obs.get_qparams() - oup = fq(oup, scale, zero_point) + q_dict = obs.get_qparams() + oup = fq(oup, q_dict) return oup def set_qat_mode(self, mode: QATMode): diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py index 21652d68f..676633090 100644 --- a/python_module/megengine/quantization/fake_quant.py +++ b/python_module/megengine/quantization/fake_quant.py @@ -8,7 +8,7 @@ from .. import functional as F from .._internal.dtype import _metadata_dict from ..module import Module -from .observer import Round +from .observer import ObserverMode, Round class FakeQuantize(Module): @@ -35,14 +35,25 @@ class FakeQuantize(Module): def disable(self): self.enabled = False - def forward(self, inp, scale, zero_point): + def forward(self, inp, q_dict): if self.enabled: - # Quant - oup = Round()(inp / scale) + zero_point - # clip - oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) - # DeQuant - oup = (oup - zero_point) * scale - return oup - + if q_dict["mode"] == ObserverMode.SYMMERTIC: + scale = q_dict["scale"] + # Quant + oup = Round()(inp / scale) + # clip + oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) + # DeQuant + oup = (oup) * scale + return oup + else: + scale = q_dict["scale"] + zero_point = q_dict["zero_point"] + # Quant + oup = Round()(inp / scale) + zero_point + # clip + oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) + # DeQuant + oup = (oup - zero_point) * scale + return oup return inp diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index bd85a9a97..9a338a870 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -7,6 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import math from abc import abstractmethod +from enum import Enum import numpy as np @@ -46,9 +47,11 @@ class Observer(Module): self.enabled = True def get_dtype(self): - scale, zero_point = self.get_qparams() - numpy_scale = None if scale is None else scale.numpy()[0] - numpy_zero_point = None if zero_point is None else zero_point.numpy()[0] + q_dict = self.get_qparams() + numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] + numpy_zero_point = ( + None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] + ) return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) def enable(self): @@ -73,13 +76,29 @@ class Observer(Module): pass +class ObserverMode(Enum): + SYMMERTIC = 1 + ASYMMERTIC = 2 + + +def create_observer_dict(mode): + if mode == ObserverMode.SYMMERTIC: + return { + "mode": ObserverMode.SYMMERTIC, + "scale": None, + } + else: + return { + "mode": ObserverMode.ASYMMERTIC, + "scale": None, + "zero_point": None, + } + + class MinMaxObserver(Observer): - def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): - super().__init__(*args, **kwargs) - self.symmetric = symmetric - if self.symmetric: - # assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1' - self.zero_point = tensor((self.qmin + self.qmax + 1) // 2) + def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): + super().__init__(dtype) + self.mode = mode self.min_val = Buffer(0.0, dtype=np.float32) self.max_val = Buffer(0.0, dtype=np.float32) @@ -99,22 +118,23 @@ class MinMaxObserver(Observer): def _calculate_qparams(self, inp_min_val, inp_max_val): min_val = F.minimum(0.0, inp_min_val) max_val = F.maximum(0.0, inp_max_val) - if self.symmetric: + q_dict = create_observer_dict(self.mode) + if self.mode == ObserverMode.SYMMERTIC: symmetric_max_vals = F.maximum(-min_val, max_val) # use maximun to avoid scale too small at the begin - scale = F.maximum( + q_dict["scale"] = F.maximum( symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit ) - zero_point = self.zero_point + # zero_point = self.zero_point else: # use maximun to avoid scale too small at the begin - scale = F.maximum( + q_dict["scale"] = F.maximum( (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, ) # caculate zero_point - zero_point = self.qmin - Round()((min_val / scale)) + q_dict["zero_point"] = self.qmin - Round()((min_val / scale)) - return scale, zero_point + return q_dict def get_qparams(self): return self._calculate_qparams(self.min_val, self.max_val) @@ -135,8 +155,10 @@ class MinMaxObserver(Observer): class ExponentialMovingAverageObserver(MinMaxObserver): - def __init__(self, momentum=0.9, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8" + ): + super().__init__(mode, eps, dtype) self.momentum = Buffer(momentum) def set_momentum(self, momentum): @@ -170,8 +192,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver): class HistogramObserver(MinMaxObserver): - def __init__(self, bins=2048, upsample_rate=128, dtype="qint8", *args, **kwargs): - super().__init__(dtype=dtype, *args, **kwargs) + def __init__( + self, + bins=2048, + upsample_rate=128, + dtype="qint8", + mode=ObserverMode.SYMMERTIC, + eps=0.00001, + ): + super().__init__(mode, eps, dtype) self.bins = bins self.upsample_rate = upsample_rate self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 -- GitLab