diff --git a/dnn/src/cuda/tqt/kern.cuh b/dnn/src/cuda/tqt/kern.cuh index 3be3b763c9d2105fb38bb43e150adf2929c887d6..cd11b733ebce9b5866acabada6910b82b76c2194 100644 --- a/dnn/src/cuda/tqt/kern.cuh +++ b/dnn/src/cuda/tqt/kern.cuh @@ -58,7 +58,7 @@ struct TQTBwdKernOp { ctype scaled = input[idx] / t; ctype rounded = round(scaled); rounded = fmaxf(fminf(rounded, qmax), qmin); - bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; + bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax); bool mask_quant = !mask_clip; grad_x[idx] = diff[idx] * mask_quant; diff --git a/dnn/src/naive/tqt/opr_impl.cpp b/dnn/src/naive/tqt/opr_impl.cpp index c2746f67397c37f7598c63bda71b18185a0b9b1a..c5d705cc87298f611dc9a177a5b3547db68ca7a5 100644 --- a/dnn/src/naive/tqt/opr_impl.cpp +++ b/dnn/src/naive/tqt/opr_impl.cpp @@ -53,7 +53,7 @@ void backward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) { T rounded = round(scaled); rounded = rounded <= qmin ? qmin : rounded; rounded = rounded >= qmax ? qmax : rounded; - bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; + bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax); bool mask_quant = !mask_clip; *grad_x = *diff * mask_quant; diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index ccb2e570438e021a2c57ba4420ca15481ddcc2c7..a72aa6a48dcfe53980622d1e46aa440199bd61ad 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -69,8 +69,8 @@ def test_tqt(): def cb(grad): g.append(grad) - x = np.random.normal(size=(1, 2, 3, 4)) - s = np.random.rand(1) + 1 + x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") + s = np.random.rand(1) - 1 g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") n = TQT_numpy(-127, 127) @@ -85,9 +85,9 @@ def test_tqt(): grad(y, g_y) g_x, g_s = g - np.testing.assert_allclose(y.numpy(), y_np, atol=1e-6) - np.testing.assert_allclose(g_x.numpy(), g_x_np, atol=1e-6) - np.testing.assert_allclose(g_s.numpy(), g_s_np, atol=1e-6) + np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-5, atol=5e-5)