提交 ff341cb1 编写于 作者: M Megvii Engine Team

fix(mge/quantization): modify observer api

GitOrigin-RevId: 7b9c22be96c4cab9d8f14f659f3ca6ec4c9b3cb2
上级 27ef788f
......@@ -505,8 +505,8 @@ class QATModule(Module):
):
oup = self.apply_observer(target, obs)
if fq is not None:
scale, zero_point = obs.get_qparams()
oup = fq(oup, scale, zero_point)
q_dict = obs.get_qparams()
oup = fq(oup, q_dict)
return oup
def set_qat_mode(self, mode: QATMode):
......
......@@ -8,7 +8,7 @@
from .. import functional as F
from .._internal.dtype import _metadata_dict
from ..module import Module
from .observer import Round
from .observer import ObserverMode, Round
class FakeQuantize(Module):
......@@ -35,14 +35,25 @@ class FakeQuantize(Module):
def disable(self):
self.enabled = False
def forward(self, inp, scale, zero_point):
def forward(self, inp, q_dict):
if self.enabled:
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
if q_dict["mode"] == ObserverMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
return inp
......@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from abc import abstractmethod
from enum import Enum
import numpy as np
......@@ -46,9 +47,11 @@ class Observer(Module):
self.enabled = True
def get_dtype(self):
scale, zero_point = self.get_qparams()
numpy_scale = None if scale is None else scale.numpy()[0]
numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
q_dict = self.get_qparams()
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
numpy_zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
)
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
def enable(self):
......@@ -73,13 +76,29 @@ class Observer(Module):
pass
class ObserverMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2
def create_observer_dict(mode):
if mode == ObserverMode.SYMMERTIC:
return {
"mode": ObserverMode.SYMMERTIC,
"scale": None,
}
else:
return {
"mode": ObserverMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
}
class MinMaxObserver(Observer):
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
super().__init__(*args, **kwargs)
self.symmetric = symmetric
if self.symmetric:
# assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"):
super().__init__(dtype)
self.mode = mode
self.min_val = Buffer(0.0, dtype=np.float32)
self.max_val = Buffer(0.0, dtype=np.float32)
......@@ -99,22 +118,23 @@ class MinMaxObserver(Observer):
def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val)
if self.symmetric:
q_dict = create_observer_dict(self.mode)
if self.mode == ObserverMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin
scale = F.maximum(
q_dict["scale"] = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
)
zero_point = self.zero_point
# zero_point = self.zero_point
else:
# use maximun to avoid scale too small at the begin
scale = F.maximum(
q_dict["scale"] = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
)
# caculate zero_point
zero_point = self.qmin - Round()((min_val / scale))
q_dict["zero_point"] = self.qmin - Round()((min_val / scale))
return scale, zero_point
return q_dict
def get_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val)
......@@ -135,8 +155,10 @@ class MinMaxObserver(Observer):
class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(self, momentum=0.9, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"
):
super().__init__(mode, eps, dtype)
self.momentum = Buffer(momentum)
def set_momentum(self, momentum):
......@@ -170,8 +192,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
class HistogramObserver(MinMaxObserver):
def __init__(self, bins=2048, upsample_rate=128, dtype="qint8", *args, **kwargs):
super().__init__(dtype=dtype, *args, **kwargs)
def __init__(
self,
bins=2048,
upsample_rate=128,
dtype="qint8",
mode=ObserverMode.SYMMERTIC,
eps=0.00001,
):
super().__init__(mode, eps, dtype)
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册