提交 878ce911 编写于 作者: M Megvii Engine Team

fix(mge/test): replace equal with allclose to fix rng test for ci

GitOrigin-RevId: 12758cf5d5e22c883d0de893fca9e8acc6ec0556
上级 c53cad20
...@@ -226,7 +226,7 @@ def test_UniformRNG(): ...@@ -226,7 +226,7 @@ def test_UniformRNG():
out2 = m2.uniform(size=(100,)) out2 = m2.uniform(size=(100,))
out3 = m3.uniform(size=(100,)) out3 = m3.uniform(size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy()) np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
assert out1.device == "xpu0" and out2.device == "xpu1" assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all()
...@@ -254,7 +254,7 @@ def test_NormalRNG(): ...@@ -254,7 +254,7 @@ def test_NormalRNG():
out2 = m2.normal(size=(100,)) out2 = m2.normal(size=(100,))
out3 = m3.normal(size=(100,)) out3 = m3.normal(size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy()) np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
assert out1.device == "xpu0" and out2.device == "xpu1" assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all()
...@@ -283,7 +283,7 @@ def test_GammaRNG(): ...@@ -283,7 +283,7 @@ def test_GammaRNG():
out2 = m2.gamma(2, size=(100,)) out2 = m2.gamma(2, size=(100,))
out3 = m3.gamma(2, size=(100,)) out3 = m3.gamma(2, size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy()) np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
assert out1.device == "xpu0" and out2.device == "xpu1" assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all()
...@@ -316,7 +316,7 @@ def test_BetaRNG(): ...@@ -316,7 +316,7 @@ def test_BetaRNG():
out2 = m2.beta(2, 1, size=(100,)) out2 = m2.beta(2, 1, size=(100,))
out3 = m3.beta(2, 1, size=(100,)) out3 = m3.beta(2, 1, size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy()) np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
assert out1.device == "xpu0" and out2.device == "xpu1" assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all()
...@@ -351,7 +351,7 @@ def test_PoissonRNG(): ...@@ -351,7 +351,7 @@ def test_PoissonRNG():
out2 = m2.poisson(lam.to("xpu1"), size=(100,)) out2 = m2.poisson(lam.to("xpu1"), size=(100,))
out3 = m3.poisson(lam.to("xpu0"), size=(100,)) out3 = m3.poisson(lam.to("xpu0"), size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy()) np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
assert out1.device == "xpu0" and out2.device == "xpu1" assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out3.numpy()).all()
...@@ -381,7 +381,7 @@ def test_PermutationRNG(symbolic): ...@@ -381,7 +381,7 @@ def test_PermutationRNG(symbolic):
out2 = m2.permutation(1000) out2 = m2.permutation(1000)
out3 = m3.permutation(1000) out3 = m3.permutation(1000)
np.testing.assert_equal(out1.numpy(), out2.numpy()) np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
assert out1.device == "xpu0" and out2.device == "xpu1" assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all()
...@@ -443,7 +443,7 @@ def test_ShuffleRNG(): ...@@ -443,7 +443,7 @@ def test_ShuffleRNG():
m2.shuffle(out2) m2.shuffle(out2)
m3.shuffle(out3) m3.shuffle(out3)
np.testing.assert_equal(out1.numpy(), out2.numpy()) np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6)
assert out1.device == "xpu0" and out2.device == "xpu1" assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out3.numpy()).all()
...@@ -465,7 +465,7 @@ def test_seed(): ...@@ -465,7 +465,7 @@ def test_seed():
set_global_seed(10) set_global_seed(10)
out3 = uniform(size=[10, 10]) out3 = uniform(size=[10, 10])
np.testing.assert_equal(out1.numpy(), out3.numpy()) np.testing.assert_allclose(out1.numpy(), out3.numpy(), atol=1e-6)
set_global_seed(11) set_global_seed(11)
out4 = uniform(size=[10, 10]) out4 = uniform(size=[10, 10])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册