提交 884a07ff 编写于 作者: M Megvii Engine Team

fix(test/random): set a random seed for random unit test

GitOrigin-RevId: ad4b01eac7238d7a71f8b71075c37bd4b3e58235
上级 d7cc4628
......@@ -27,13 +27,16 @@ from megengine.core.ops.builtin import (
UniformRNG,
)
from megengine.device import get_device_count
from megengine.random import RNG, seed, uniform
from megengine.random import RNG
from megengine.random import seed as set_global_seed
from megengine.random import uniform
@pytest.mark.skipif(
get_device_count("xpu") <= 2, reason="xpu counts need > 2",
)
def test_gaussian_op():
set_global_seed(1024)
shape = (
8,
9,
......@@ -64,6 +67,7 @@ def test_gaussian_op():
get_device_count("xpu") <= 2, reason="xpu counts need > 2",
)
def test_uniform_op():
set_global_seed(1024)
shape = (
8,
9,
......@@ -92,6 +96,7 @@ def test_uniform_op():
get_device_count("xpu") <= 2, reason="xpu counts need > 2",
)
def test_gamma_op():
set_global_seed(1024)
_shape, _scale = 2, 0.8
_expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale
......@@ -120,6 +125,7 @@ def test_gamma_op():
get_device_count("xpu") <= 2, reason="xpu counts need > 2",
)
def test_beta_op():
set_global_seed(1024)
_alpha, _beta = 2, 0.8
_expected_mean = _alpha / (_alpha + _beta)
_expected_std = np.sqrt(
......@@ -151,6 +157,7 @@ def test_beta_op():
get_device_count("xpu") <= 2, reason="xpu counts need > 2",
)
def test_poisson_op():
set_global_seed(1024)
lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
op = PoissonRNG(seed=get_global_rng_seed())
(output,) = apply(op, lam)
......@@ -174,6 +181,7 @@ def test_poisson_op():
get_device_count("xpu") <= 2, reason="xpu counts need > 2",
)
def test_permutation_op():
set_global_seed(1024)
n = 1000
def test_permutation_op_dtype(dtype):
......@@ -390,22 +398,23 @@ def test_PermutationRNG():
def test_seed():
seed(10)
set_global_seed(10)
out1 = uniform(size=[10, 10])
out2 = uniform(size=[10, 10])
assert not (out1.numpy() == out2.numpy()).all()
seed(10)
set_global_seed(10)
out3 = uniform(size=[10, 10])
np.testing.assert_equal(out1.numpy(), out3.numpy())
seed(11)
set_global_seed(11)
out4 = uniform(size=[10, 10])
assert not (out1.numpy() == out4.numpy()).all()
@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_rng_empty_tensor(is_symbolic):
set_global_seed(1024)
shapes = [
(0,),
(0, 0, 0),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册