提交 888895e9 编写于 作者: M Megvii Engine Team

refactor(mge/test): make test_parampack more stable

GitOrigin-RevId: d82230ea07b49dd08a9e968974de6950e1987a08
上级 da522568
...@@ -105,9 +105,15 @@ def test_static_graph_parampack(): ...@@ -105,9 +105,15 @@ def test_static_graph_parampack():
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
data, _ = next(train_dataset) ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
pred = infer(data).numpy() pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"
@pytest.mark.slow @pytest.mark.slow
...@@ -140,9 +146,15 @@ def test_nopack_parampack(): ...@@ -140,9 +146,15 @@ def test_nopack_parampack():
losses.append(loss.numpy()) losses.append(loss.numpy())
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
data, _ = next(train_dataset) ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
pred = infer(data).numpy() pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"
@pytest.mark.slow @pytest.mark.slow
...@@ -178,9 +190,15 @@ def test_dynamic_graph_parampack(): ...@@ -178,9 +190,15 @@ def test_dynamic_graph_parampack():
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
data, _ = next(train_dataset) ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
pred = infer(data).numpy() pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"
@pytest.mark.slow @pytest.mark.slow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册