From de75dae84cd9ad4edde611546662e02b1dba8645 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 12 May 2020 16:01:38 +0800 Subject: [PATCH] fix(mge/quantization): fix get scale issue GitOrigin-RevId: 99068d74220ce4030bfe4ca050df46d8dd1fd590 --- python_module/megengine/module/module.py | 3 +- .../megengine/quantization/observer.py | 61 ++++--------------- 2 files changed, 14 insertions(+), 50 deletions(-) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 8c7bf8cdb..6ab0e6e20 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -500,7 +500,8 @@ class QATModule(Module): self, target: Tensor, fq: "FakeQuantize", obs: "Observer" ): oup = self.apply_observer(target, obs) - return fq(oup, obs.scale, obs.zero_point) + scale, zero_point = obs.get_qparams() + return fq(oup, scale, zero_point) def set_qat_mode(self, mode: QATMode): r""" diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index a1d9b84bc..3c4484e61 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -41,7 +41,6 @@ class Observer(Module): self.dtype = dtype self.qmin = _metadata_dict[dtype].qmin self.qmax = _metadata_dict[dtype].qmax - self.zero_point, self.scale = None, None self.enabled = True def get_dtype(self): @@ -72,23 +71,6 @@ class Observer(Module): pass -class IdentityObserver(Observer): - r""" - An test Observer that always return scale:1 and zero_point:0. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.zero_point = ones((1), dtype="float32") - self.scale = zeros((1), dtype="float32") - - def forward(self, x): - return x - - def get_qparams(self): - return self.scale, self.zero_point - - class MinMaxObserver(Observer): def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): super().__init__(*args, **kwargs) @@ -108,47 +90,28 @@ class MinMaxObserver(Observer): # FIXME: cond_take will destory shape, use reshape to reset shape tmp_min = tmp_min.reshape(1) tmp_max = tmp_max.reshape(1) - if self.training: - F.zero_grad( - F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) - ) - F.zero_grad( - F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) - ) - F.zero_grad( - F.add_update( - self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0 - ) - ) + F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) + F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) + F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0) - # FIXME: add_update is applied after the whole trace procedure in `symbolic=True` - # mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further - # calculation in FakeQuant. - self.set_scale_zero_point(tmp_min, tmp_max) - - def set_scale_zero_point(self, tmp_min, tmp_max): + def get_qparams(self): if self.symmetric: - symmetric_max_vals = F.maximum(-tmp_min, tmp_max) + symmetric_max_vals = F.maximum(-self.min_val, self.max_val) # use maximun to avoid scale too small at the begin - self.scale = F.maximum( + scale = F.maximum( symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit ) - # zero_point = self.zero_point + zero_point = self.zero_point else: # use maximun to avoid scale too small at the begin - self.scale = F.maximum( - (tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit + scale = F.maximum( + (self.max_val - self.min_val) / (self.qmax - self.qmin), + self.scale_limit, ) # caculate zero_point - self.zero_point = self.qmin - Round()((tmp_min / self.scale)) - - def get_qparams(self): - # scale and zero_point is runtime tensor rather than Buffer, - # so need to re-calc if min_val and max_val are loaded. - if self.scale is None: - self.set_scale_zero_point(self.min_val, self.max_val) + zero_point = self.qmin - Round()((self.min_val / scale)) - return self.scale, self.zero_point + return scale, zero_point def forward(self, x_orig): if self.enabled: -- GitLab