From 888895e971705000339fc68c3ba5daca4a0b011d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 8 Apr 2020 10:45:24 +0800 Subject: [PATCH] refactor(mge/test): make test_parampack more stable GitOrigin-RevId: d82230ea07b49dd08a9e968974de6950e1987a08 --- .../test/integration/test_parampack.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/python_module/test/integration/test_parampack.py b/python_module/test/integration/test_parampack.py index 8895a3f36..6b73c9f88 100644 --- a/python_module/test/integration/test_parampack.py +++ b/python_module/test/integration/test_parampack.py @@ -105,9 +105,15 @@ def test_static_graph_parampack(): assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" - data, _ = next(train_dataset) + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + pred = infer(data).numpy() - assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough" @pytest.mark.slow @@ -140,9 +146,15 @@ def test_nopack_parampack(): losses.append(loss.numpy()) assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" - data, _ = next(train_dataset) + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + pred = infer(data).numpy() - assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough" @pytest.mark.slow @@ -178,9 +190,15 @@ def test_dynamic_graph_parampack(): assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" - data, _ = next(train_dataset) + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + pred = infer(data).numpy() - assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough" @pytest.mark.slow -- GitLab