提交 652ec9f2 编写于 作者: M Megvii Engine Team

fix(mgb/dnn): fix backward computation of tqt

GitOrigin-RevId: 850d11a5ce03026e0100b38da65ba911348a809e
上级 27638461
......@@ -58,7 +58,7 @@ struct TQTBwdKernOp {
ctype scaled = input[idx] / t;
ctype rounded = round(scaled);
rounded = fmaxf(fminf(rounded, qmax), qmin);
bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax;
bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax);
bool mask_quant = !mask_clip;
grad_x[idx] = diff[idx] * mask_quant;
......
......@@ -53,7 +53,7 @@ void backward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) {
T rounded = round(scaled);
rounded = rounded <= qmin ? qmin : rounded;
rounded = rounded >= qmax ? qmax : rounded;
bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax;
bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax);
bool mask_quant = !mask_clip;
*grad_x = *diff * mask_quant;
......
......@@ -69,8 +69,8 @@ def test_tqt():
def cb(grad):
g.append(grad)
x = np.random.normal(size=(1, 2, 3, 4))
s = np.random.rand(1) + 1
x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32")
s = np.random.rand(1) - 1
g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32")
n = TQT_numpy(-127, 127)
......@@ -85,9 +85,9 @@ def test_tqt():
grad(y, g_y)
g_x, g_s = g
np.testing.assert_allclose(y.numpy(), y_np, atol=1e-6)
np.testing.assert_allclose(g_x.numpy(), g_x_np, atol=1e-6)
np.testing.assert_allclose(g_s.numpy(), g_s_np, atol=1e-6)
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-5, atol=5e-5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册