提交 de75dae8 编写于 作者: M Megvii Engine Team

fix(mge/quantization): fix get scale issue

GitOrigin-RevId: 99068d74220ce4030bfe4ca050df46d8dd1fd590
上级 2f3d185d
......@@ -500,7 +500,8 @@ class QATModule(Module):
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
return fq(oup, obs.scale, obs.zero_point)
scale, zero_point = obs.get_qparams()
return fq(oup, scale, zero_point)
def set_qat_mode(self, mode: QATMode):
r"""
......
......@@ -41,7 +41,6 @@ class Observer(Module):
self.dtype = dtype
self.qmin = _metadata_dict[dtype].qmin
self.qmax = _metadata_dict[dtype].qmax
self.zero_point, self.scale = None, None
self.enabled = True
def get_dtype(self):
......@@ -72,23 +71,6 @@ class Observer(Module):
pass
class IdentityObserver(Observer):
r"""
An test Observer that always return scale:1 and zero_point:0.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.zero_point = ones((1), dtype="float32")
self.scale = zeros((1), dtype="float32")
def forward(self, x):
return x
def get_qparams(self):
return self.scale, self.zero_point
class MinMaxObserver(Observer):
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
super().__init__(*args, **kwargs)
......@@ -108,47 +90,28 @@ class MinMaxObserver(Observer):
# FIXME: cond_take will destory shape, use reshape to reset shape
tmp_min = tmp_min.reshape(1)
tmp_max = tmp_max.reshape(1)
if self.training:
F.zero_grad(
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
)
F.zero_grad(
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
)
F.zero_grad(
F.add_update(
self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0
)
)
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0)
# FIXME: add_update is applied after the whole trace procedure in `symbolic=True`
# mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further
# calculation in FakeQuant.
self.set_scale_zero_point(tmp_min, tmp_max)
def set_scale_zero_point(self, tmp_min, tmp_max):
def get_qparams(self):
if self.symmetric:
symmetric_max_vals = F.maximum(-tmp_min, tmp_max)
symmetric_max_vals = F.maximum(-self.min_val, self.max_val)
# use maximun to avoid scale too small at the begin
self.scale = F.maximum(
scale = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
)
# zero_point = self.zero_point
zero_point = self.zero_point
else:
# use maximun to avoid scale too small at the begin
self.scale = F.maximum(
(tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit
scale = F.maximum(
(self.max_val - self.min_val) / (self.qmax - self.qmin),
self.scale_limit,
)
# caculate zero_point
self.zero_point = self.qmin - Round()((tmp_min / self.scale))
def get_qparams(self):
# scale and zero_point is runtime tensor rather than Buffer,
# so need to re-calc if min_val and max_val are loaded.
if self.scale is None:
self.set_scale_zero_point(self.min_val, self.max_val)
zero_point = self.qmin - Round()((self.min_val / scale))
return self.scale, self.zero_point
return scale, zero_point
def forward(self, x_orig):
if self.enabled:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册