From d7b6bfd56c642ff758732ddeb74b6a92278248f1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 14 Jul 2021 11:23:16 +0800 Subject: [PATCH] test(mge/fakequant): use fixed input for lsq test to temperarily avoid precision error GitOrigin-RevId: e91c71874eb16e8535497ec7c5840d08cf1a3d2b --- .../test/unit/quantization/test_fake_quant.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index 9f93a1759..612e03ab7 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -191,19 +191,34 @@ class LSQ_numpy: def test_lsq(): - def preprocess(scale, eps): - scale = np.array([0]) if scale < eps else scale - eps - return np.abs(scale) + eps - g = [] def cb(grad): g.append(grad) - x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") - s = np.random.rand(1) + # FIXME: use random number when LSQ is fixed + # x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") + # s = np.random.rand(1) + x = np.array( + [ + [ + [ + [4.0, 38.0, -121.0, 38.0], + [15.0, -115.0, -112.0, 24.0], + [23.0, -65.0, 109.0, -115.0], + ], + [ + [-66.0, -90.0, -45.0, -101.0], + [68.0, -98.0, 108.0, -79.0], + [54.0, 63.0, -10.0, -50.0], + ], + ] + ], + dtype="float32", + ) + s = np.array([0.02918224], dtype="float32") eps = np.array([1e-5], dtype="float32") - s = preprocess(s, eps) + s = np.abs(s) if np.abs(s) > eps else eps zero_point = np.array([1.0], dtype="float32") grad_s = np.array([2.0], dtype="float32") -- GitLab