diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index a5693386d92427941fb1e45c9eb0300098d6e68b..847cc814106316ec7d6178ab198d92f1e9f7f87a 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -99,21 +99,9 @@ class MinMaxObserver(Observer): def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): super().__init__(dtype) self.mode = mode - - self.min_val = Buffer(0.0, dtype=np.float32) - self.max_val = Buffer(0.0, dtype=np.float32) + self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) + self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) self.scale_limit = eps - # flag is used by cond_take, first time will be first flag, and after will be set as not_flag - self.first_flag = Buffer(np.array([1, 0], dtype=np.int32)) - self.not_flag = Buffer(np.array([0, 1], dtype=np.int32)) - - def set_min_max(self, tmp_min, tmp_max): - # FIXME: cond_take will destory shape, use reshape to reset shape - tmp_min = tmp_min.reshape(1) - tmp_max = tmp_max.reshape(1) - F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) - F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) - F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0) def _calculate_qparams(self, inp_min_val, inp_max_val): min_val = F.minimum(0.0, inp_min_val) @@ -144,13 +132,20 @@ class MinMaxObserver(Observer): # stop gradient x = F.zero_grad(x_orig) # find max and min - tmp_min, _ = F.cond_take( - self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())]) + F.add_update( + self.min_val, + F.minimum(self.min_val, x.min()), + alpha=0.0, + beta=1.0, + bias=0.0, ) - tmp_max, _ = F.cond_take( - self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())]) + F.add_update( + self.max_val, + F.maximum(self.max_val, x.max()), + alpha=0.0, + beta=1.0, + bias=0.0, ) - self.set_min_max(tmp_min, tmp_max) return x_orig @@ -160,6 +155,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ): super().__init__(mode, eps, dtype) self.momentum = Buffer(momentum) + self.runtime_momentum = Buffer(0.0) def set_momentum(self, momentum): self.momentum.set_value(momentum) @@ -169,25 +165,19 @@ class ExponentialMovingAverageObserver(MinMaxObserver): # stop gradient x = F.zero_grad(x_orig) # Exponential Moving Average - tmp_min, _ = F.cond_take( - self.first_flag, - F.concat( - [ - x.min(), - self.momentum * self.min_val + (1 - self.momentum) * x.min(), - ] - ), + tmp_min = ( + self.min_val * self.runtime_momentum + + (1 - self.runtime_momentum) * x.min() + ) + tmp_max = ( + self.max_val * self.runtime_momentum + + (1 - self.runtime_momentum) * x.max() ) - tmp_max, _ = F.cond_take( - self.first_flag, - F.concat( - [ - x.max(), - self.momentum * self.max_val + (1 - self.momentum) * x.max(), - ] - ), + F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) + F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) + F.add_update( + self.runtime_momentum, self.momentum, alpha=0.0, beta=1.0, bias=0.0 ) - self.set_min_max(tmp_min, tmp_max) return x_orig