提交 fc4dedc0 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/random): add lower bound and higher bound for uniform sampling

GitOrigin-RevId: 2a2c56fd17ea53b805a4a13906a9f9dd8ab0e133
上级 712b87c8
......@@ -62,12 +62,16 @@ def gaussian(
@wrap_io_tensor
def uniform(
shape: Iterable[int],
low: float = 0,
high: float = 1,
comp_node: Optional[CompNode] = None,
comp_graph: Optional[CompGraph] = None,
) -> Tensor:
r"""Random variable with uniform distribution $U(0, 1)$
:param shape: Output tensor shape
:param low: Lower range
:param high: Upper range
:param comp_node: The comp node output on, default to None
:param comp_graph: The graph in which output is, default to None
:return: The output tensor
......@@ -91,6 +95,6 @@ def uniform(
"""
comp_node, comp_graph = _use_default_if_none(comp_node, comp_graph)
seed = _random_seed_generator().__next__()
return mgb.opr.uniform_rng(
return low + (high - low) * mgb.opr.uniform_rng(
shape, seed=seed, comp_node=comp_node, comp_graph=comp_graph
)
......@@ -59,6 +59,50 @@ def test_random_dynamic_same_result():
assert np.all(a.numpy() == b.numpy())
def test_range_uniform_static_diff_result():
@jit.trace(symbolic=True)
def graph_a():
return R.uniform(5, low=-2, high=2)
@jit.trace(symbolic=True)
def graph_b():
return R.uniform(5, low=-2, high=2)
a = graph_a()
b = graph_b()
assert np.any(a.numpy() != b.numpy())
def test_range_uniform_static_same_result():
@jit.trace(symbolic=True)
def graph_a():
R.manual_seed(731)
return R.uniform(5, low=-2, high=2)
@jit.trace(symbolic=True)
def graph_b():
R.manual_seed(731)
return R.uniform(5, low=-2, high=2)
a = graph_a()
b = graph_b()
assert np.all(a.numpy() == b.numpy())
def test_range_uniform_dynamic_diff_result():
a = R.uniform(5, low=-2, high=2)
b = R.uniform(5, low=-2, high=2)
assert np.any(a.numpy() != b.numpy())
def test_range_uniform_dynamic_same_result():
R.manual_seed(0)
a = R.uniform(5, low=-2, high=2)
R.manual_seed(0)
b = R.uniform(5, low=-2, high=2)
assert np.all(a.numpy() == b.numpy())
def test_dropout_dynamic_diff_result():
x = mge.ones(10)
a = F.dropout(x, 0.5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册