From 652ec9f2516a4435a32c83b3f5188c344c5e21a3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 8 Mar 2021 14:07:08 +0800 Subject: [PATCH] fix(mgb/dnn): fix backward computation of tqt GitOrigin-RevId: 850d11a5ce03026e0100b38da65ba911348a809e --- dnn/src/cuda/tqt/kern.cuh | 2 +- dnn/src/naive/tqt/opr_impl.cpp | 2 +- .../python/test/unit/quantization/test_fake_quant.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dnn/src/cuda/tqt/kern.cuh b/dnn/src/cuda/tqt/kern.cuh index 3be3b763..cd11b733 100644 --- a/dnn/src/cuda/tqt/kern.cuh +++ b/dnn/src/cuda/tqt/kern.cuh @@ -58,7 +58,7 @@ struct TQTBwdKernOp { ctype scaled = input[idx] / t; ctype rounded = round(scaled); rounded = fmaxf(fminf(rounded, qmax), qmin); - bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; + bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax); bool mask_quant = !mask_clip; grad_x[idx] = diff[idx] * mask_quant; diff --git a/dnn/src/naive/tqt/opr_impl.cpp b/dnn/src/naive/tqt/opr_impl.cpp index c2746f67..c5d705cc 100644 --- a/dnn/src/naive/tqt/opr_impl.cpp +++ b/dnn/src/naive/tqt/opr_impl.cpp @@ -53,7 +53,7 @@ void backward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) { T rounded = round(scaled); rounded = rounded <= qmin ? qmin : rounded; rounded = rounded >= qmax ? qmax : rounded; - bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; + bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax); bool mask_quant = !mask_clip; *grad_x = *diff * mask_quant; diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index ccb2e570..a72aa6a4 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -69,8 +69,8 @@ def test_tqt(): def cb(grad): g.append(grad) - x = np.random.normal(size=(1, 2, 3, 4)) - s = np.random.rand(1) + 1 + x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") + s = np.random.rand(1) - 1 g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") n = TQT_numpy(-127, 127) @@ -85,9 +85,9 @@ def test_tqt(): grad(y, g_y) g_x, g_s = g - np.testing.assert_allclose(y.numpy(), y_np, atol=1e-6) - np.testing.assert_allclose(g_x.numpy(), g_x_np, atol=1e-6) - np.testing.assert_allclose(g_s.numpy(), g_s_np, atol=1e-6) + np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-5, atol=5e-5) -- GitLab