提交 6742a58b 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(quant): observer do not use cond_take

GitOrigin-RevId: a954814bcbbd81736da5dd383d9492efc1ef8ed1
上级 9e876203
......@@ -99,21 +99,9 @@ class MinMaxObserver(Observer):
def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"):
super().__init__(dtype)
self.mode = mode
self.min_val = Buffer(0.0, dtype=np.float32)
self.max_val = Buffer(0.0, dtype=np.float32)
self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32)
self.scale_limit = eps
# flag is used by cond_take, first time will be first flag, and after will be set as not_flag
self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))
def set_min_max(self, tmp_min, tmp_max):
# FIXME: cond_take will destory shape, use reshape to reset shape
tmp_min = tmp_min.reshape(1)
tmp_max = tmp_max.reshape(1)
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0)
def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val)
......@@ -144,13 +132,20 @@ class MinMaxObserver(Observer):
# stop gradient
x = F.zero_grad(x_orig)
# find max and min
tmp_min, _ = F.cond_take(
self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
F.add_update(
self.min_val,
F.minimum(self.min_val, x.min()),
alpha=0.0,
beta=1.0,
bias=0.0,
)
tmp_max, _ = F.cond_take(
self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
F.add_update(
self.max_val,
F.maximum(self.max_val, x.max()),
alpha=0.0,
beta=1.0,
bias=0.0,
)
self.set_min_max(tmp_min, tmp_max)
return x_orig
......@@ -160,6 +155,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
):
super().__init__(mode, eps, dtype)
self.momentum = Buffer(momentum)
self.runtime_momentum = Buffer(0.0)
def set_momentum(self, momentum):
self.momentum.set_value(momentum)
......@@ -169,25 +165,19 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
# stop gradient
x = F.zero_grad(x_orig)
# Exponential Moving Average
tmp_min, _ = F.cond_take(
self.first_flag,
F.concat(
[
x.min(),
self.momentum * self.min_val + (1 - self.momentum) * x.min(),
]
),
tmp_min = (
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.min()
)
tmp_max, _ = F.cond_take(
self.first_flag,
F.concat(
[
x.max(),
self.momentum * self.max_val + (1 - self.momentum) * x.max(),
]
),
tmp_max = (
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.max()
)
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(
self.runtime_momentum, self.momentum, alpha=0.0, beta=1.0, bias=0.0
)
self.set_min_max(tmp_min, tmp_max)
return x_orig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册