提交 13272eaa 编写于 作者: M Megvii Engine Team

fix(mge/trace): fix random op in symbolic trace

GitOrigin-RevId: 9a851cd177119ee43155b831e9622ce342423090
上级 1fed5929
......@@ -52,7 +52,8 @@ def normal(
size = (1,)
seed = _random_seed_generator().__next__()
op = GaussianRNG(seed=seed, mean=mean, std=std)
size = Tensor(size, dtype="int32")
_ref = Tensor([], dtype="int32")
size = utils.astensor1d(size, _ref, dtype="int32")
(output,) = apply(op, size)
return output
......@@ -93,7 +94,8 @@ def uniform(
size = (1,)
seed = _random_seed_generator().__next__()
op = UniformRNG(seed=seed)
size = Tensor(size, dtype="int32")
_ref = Tensor([], dtype="int32")
size = utils.astensor1d(size, _ref, dtype="int32")
(output,) = apply(op, size)
return low + (high - low) * output
......@@ -23,6 +23,7 @@ from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
from megengine.random import normal, uniform
def test_trace():
......@@ -431,3 +432,23 @@ def test_slice():
y = f(x)
np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2])
y + y
def test_random():
def run_test(op):
for symbolic_shape in [True, False]:
@trace(symbolic=True, symbolic_shape=symbolic_shape)
def f():
out = op(size=[10, 10])
out_shape = out.shape
assert out_shape is not None
if not isinstance(out_shape, tuple):
assert out.shape.numpy() is not None
return out
for _ in range(3):
f()
run_test(uniform)
run_test(normal)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册