提交 ff341cb1 编写于 作者: M Megvii Engine Team

fix(mge/quantization): modify observer api

GitOrigin-RevId: 7b9c22be96c4cab9d8f14f659f3ca6ec4c9b3cb2
上级 27ef788f
...@@ -505,8 +505,8 @@ class QATModule(Module): ...@@ -505,8 +505,8 @@ class QATModule(Module):
): ):
oup = self.apply_observer(target, obs) oup = self.apply_observer(target, obs)
if fq is not None: if fq is not None:
scale, zero_point = obs.get_qparams() q_dict = obs.get_qparams()
oup = fq(oup, scale, zero_point) oup = fq(oup, q_dict)
return oup return oup
def set_qat_mode(self, mode: QATMode): def set_qat_mode(self, mode: QATMode):
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
from .. import functional as F from .. import functional as F
from .._internal.dtype import _metadata_dict from .._internal.dtype import _metadata_dict
from ..module import Module from ..module import Module
from .observer import Round from .observer import ObserverMode, Round
class FakeQuantize(Module): class FakeQuantize(Module):
...@@ -35,14 +35,25 @@ class FakeQuantize(Module): ...@@ -35,14 +35,25 @@ class FakeQuantize(Module):
def disable(self): def disable(self):
self.enabled = False self.enabled = False
def forward(self, inp, scale, zero_point): def forward(self, inp, q_dict):
if self.enabled: if self.enabled:
# Quant if q_dict["mode"] == ObserverMode.SYMMERTIC:
oup = Round()(inp / scale) + zero_point scale = q_dict["scale"]
# clip # Quant
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) oup = Round()(inp / scale)
# DeQuant # clip
oup = (oup - zero_point) * scale oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
return oup # DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
return inp return inp
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math import math
from abc import abstractmethod from abc import abstractmethod
from enum import Enum
import numpy as np import numpy as np
...@@ -46,9 +47,11 @@ class Observer(Module): ...@@ -46,9 +47,11 @@ class Observer(Module):
self.enabled = True self.enabled = True
def get_dtype(self): def get_dtype(self):
scale, zero_point = self.get_qparams() q_dict = self.get_qparams()
numpy_scale = None if scale is None else scale.numpy()[0] numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
numpy_zero_point = None if zero_point is None else zero_point.numpy()[0] numpy_zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
)
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
def enable(self): def enable(self):
...@@ -73,13 +76,29 @@ class Observer(Module): ...@@ -73,13 +76,29 @@ class Observer(Module):
pass pass
class ObserverMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2
def create_observer_dict(mode):
if mode == ObserverMode.SYMMERTIC:
return {
"mode": ObserverMode.SYMMERTIC,
"scale": None,
}
else:
return {
"mode": ObserverMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
}
class MinMaxObserver(Observer): class MinMaxObserver(Observer):
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"):
super().__init__(*args, **kwargs) super().__init__(dtype)
self.symmetric = symmetric self.mode = mode
if self.symmetric:
# assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
self.min_val = Buffer(0.0, dtype=np.float32) self.min_val = Buffer(0.0, dtype=np.float32)
self.max_val = Buffer(0.0, dtype=np.float32) self.max_val = Buffer(0.0, dtype=np.float32)
...@@ -99,22 +118,23 @@ class MinMaxObserver(Observer): ...@@ -99,22 +118,23 @@ class MinMaxObserver(Observer):
def _calculate_qparams(self, inp_min_val, inp_max_val): def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val) min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val) max_val = F.maximum(0.0, inp_max_val)
if self.symmetric: q_dict = create_observer_dict(self.mode)
if self.mode == ObserverMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
scale = F.maximum( q_dict["scale"] = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
) )
zero_point = self.zero_point # zero_point = self.zero_point
else: else:
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
scale = F.maximum( q_dict["scale"] = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
) )
# caculate zero_point # caculate zero_point
zero_point = self.qmin - Round()((min_val / scale)) q_dict["zero_point"] = self.qmin - Round()((min_val / scale))
return scale, zero_point return q_dict
def get_qparams(self): def get_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val) return self._calculate_qparams(self.min_val, self.max_val)
...@@ -135,8 +155,10 @@ class MinMaxObserver(Observer): ...@@ -135,8 +155,10 @@ class MinMaxObserver(Observer):
class ExponentialMovingAverageObserver(MinMaxObserver): class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(self, momentum=0.9, *args, **kwargs): def __init__(
super().__init__(*args, **kwargs) self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"
):
super().__init__(mode, eps, dtype)
self.momentum = Buffer(momentum) self.momentum = Buffer(momentum)
def set_momentum(self, momentum): def set_momentum(self, momentum):
...@@ -170,8 +192,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ...@@ -170,8 +192,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
class HistogramObserver(MinMaxObserver): class HistogramObserver(MinMaxObserver):
def __init__(self, bins=2048, upsample_rate=128, dtype="qint8", *args, **kwargs): def __init__(
super().__init__(dtype=dtype, *args, **kwargs) self,
bins=2048,
upsample_rate=128,
dtype="qint8",
mode=ObserverMode.SYMMERTIC,
eps=0.00001,
):
super().__init__(mode, eps, dtype)
self.bins = bins self.bins = bins
self.upsample_rate = upsample_rate self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册