diff --git a/imperative/python/megengine/random/rng.py b/imperative/python/megengine/random/rng.py index 45c19090ce8db7e9469e53400d516a71f9cc2433..e8878a06df66564e3e2f40b0bbd2e8b6ba72a10a 100644 --- a/imperative/python/megengine/random/rng.py +++ b/imperative/python/megengine/random/rng.py @@ -225,7 +225,7 @@ def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor: assert inp.size > 0, "size needs to be greater than 0" op = ShuffleRNG(seed=seed, handle=handle) output, _ = apply(op, inp) - inp._reset(output) + return output class RNG: @@ -554,12 +554,15 @@ class RNG: _seed = self._seed() if callable(self._seed) else self._seed return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle) - def permutation(self, n: int, *, dtype: str = "int32"): - r"""Generates a random permutation of integers from :math:`0` to :math:`n - 1`. + def permutation(self, n: Union[int, Tensor], *, dtype: str = "int32"): + r"""Randomly permute a sequence, or return a permuted range. + If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index. Args: - n: the upper bound. Must be larger than 0. - dtype: the output data type. int32, int16 and float32 are supported. Default: int32 + n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`. + If ``n`` is an tensor, make a copy and shuffle the elements randomly. + dtype: the output data type when ``n`` is an integer. + int32, int16 and float32 are supported. Default: int32 Returns: the output tensor. @@ -568,13 +571,18 @@ class RNG: .. testcode:: + import numpy as np import megengine as mge import megengine.random as rand - x = rand.permutation(n=10, dtype="int32") + x = rand.permutation(10, dtype="int32") + print(x.numpy()) + + x = rand.permutation(10, dtype="float32") print(x.numpy()) - x = rand.permutation(n=10, dtype="float32") + x = mge.tensor(np.arange(18)).reshape(6,3) + x = rand.permutation(x) print(x.numpy()) Outputs: @@ -584,11 +592,20 @@ class RNG: [4 5 0 7 3 8 6 1 9 2] [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.] + [[12 13 14] + [ 3 4 5] + [15 16 17] + [ 0 1 2] + [ 9 10 11] + [ 6 7 8]] """ _seed = self._seed() if callable(self._seed) else self._seed - return _permutation( - n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype - ) + if isinstance(n, int): + return _permutation( + n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype + ) + assert isinstance(n, Tensor) + return _shuffle(inp=n, seed=_seed, handle=self._handle) def shuffle(self, inp: Tensor): r"""Modify a sequence in-place by shuffling its contents. @@ -627,7 +644,7 @@ class RNG: [ 6. 7. 8.]] """ _seed = self._seed() if callable(self._seed) else self._seed - _shuffle(inp=inp, seed=_seed, handle=self._handle) + inp._reset(_shuffle(inp=inp, seed=_seed, handle=self._handle)) def __del__(self): if self._handle != 0: diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index 6df000bbed5d68e5e09ff945b4f10744b41a05b3..a33a58405dba55916b22275dd4fe5bcc475092d5 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -28,6 +28,7 @@ from megengine.core.ops.builtin import ( UniformRNG, ) from megengine.device import get_device_count +from megengine.jit import trace from megengine.random import RNG from megengine.random import seed as set_global_seed from megengine.random import uniform @@ -370,21 +371,22 @@ def test_PoissonRNG(): @pytest.mark.skipif( get_device_count("xpu") <= 1, reason="xpu counts need > 1", ) -def test_PermutationRNG(): +@pytest.mark.parametrize("symbolic", [True, False]) +def test_PermutationRNG(symbolic): m1 = RNG(seed=111, device="xpu0") m2 = RNG(seed=111, device="xpu1") m3 = RNG(seed=222, device="xpu0") - out1 = m1.permutation(n=1000) + out1 = m1.permutation(1000) out1_ = m1.uniform(size=(1000,)) - out2 = m2.permutation(n=1000) - out3 = m3.permutation(n=1000) + out2 = m2.permutation(1000) + out3 = m3.permutation(1000) np.testing.assert_equal(out1.numpy(), out2.numpy()) assert out1.device == "xpu0" and out2.device == "xpu1" assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all() - out = m1.permutation(n=1000) + out = m1.permutation(1000) out_shp = out.shape if isinstance(out_shp, tuple): assert out_shp == (1000,) @@ -397,6 +399,24 @@ def test_PermutationRNG(): assert sum_result(out, lambda x: x) < 500 assert sum_result(out, np.sort) == 1000 + def func(): + out = m1.permutation(Tensor(7)) + out_shp = out.shape + if isinstance(out_shp, tuple): + assert out_shp == (1,) + else: + assert all(out.shape.numpy() == np.array([1])) + n, m = 6, 3 + out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m)) + out_shp = out.shape + if isinstance(out_shp, tuple): + assert out_shp == (n, m) + else: + assert all(out.shape.numpy() == np.array([n, m])) + + func = trace(symbolic=symbolic)(func) + func() + @pytest.mark.skipif( get_device_count("xpu") <= 1, reason="xpu counts need > 1", diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index 02e91c699d14712c5c6fa9f0845ffbcd8afee287..928c51762e1d57c11ab0decde353d886f30ed980 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -214,8 +214,12 @@ ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param, const OperatorNodeConfig& config) : Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { add_input({data}); - add_output(None)->dtype(data->dtype()); - add_output(None)->dtype(dtype::Int32{}); + add_output(None) + ->dtype(data->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None) + ->dtype(dtype::Int32{}) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); cg::add_workspace_output(this); add_equivalence_component>(this); } @@ -266,12 +270,27 @@ void ShuffleRNGForward::add_input_layout_constraint() { }; void ShuffleRNGForward::scn_do_execute() { + auto&& ret = output(0); + if (ret->layout().is_empty()) { + mgb_assert(ret->dev_tensor().empty()); + return; + } m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(), get_megdnn_workspace_from_var(output(2))); } +cg::OperatorNodeBase::NodeProp* ShuffleRNGForward::do_make_node_prop() const { + auto prop = Super::do_make_node_prop(); + prop->add_flag(NodeProp::Flag::IMPURE_FUNC); + for (auto i : input()) { + prop->add_dep_type_existing_var(i, + NodeProp::DepType::VALUE_ALLOW_EMPTY); + } + return prop; +} + #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);