提交 2174fd12 编写于 作者: M Megvii Engine Team

refactor(mge/test): make test_converge more stable

GitOrigin-RevId: 710e3ede406119d9a497c4c92f1e6ca768e05128
上级 dd41e2e4
......@@ -103,6 +103,12 @@ def test_training_converge():
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
data, _ = next(train_dataset)
ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"
assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册