提交 b982be56 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add permutation support for the tensor

GitOrigin-RevId: 7ed0447bfe18d7744fa7771191313d6b45ec8522
上级 3977b7aa
...@@ -225,7 +225,7 @@ def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor: ...@@ -225,7 +225,7 @@ def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor:
assert inp.size > 0, "size needs to be greater than 0" assert inp.size > 0, "size needs to be greater than 0"
op = ShuffleRNG(seed=seed, handle=handle) op = ShuffleRNG(seed=seed, handle=handle)
output, _ = apply(op, inp) output, _ = apply(op, inp)
inp._reset(output) return output
class RNG: class RNG:
...@@ -554,12 +554,15 @@ class RNG: ...@@ -554,12 +554,15 @@ class RNG:
_seed = self._seed() if callable(self._seed) else self._seed _seed = self._seed() if callable(self._seed) else self._seed
return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle) return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle)
def permutation(self, n: int, *, dtype: str = "int32"): def permutation(self, n: Union[int, Tensor], *, dtype: str = "int32"):
r"""Generates a random permutation of integers from :math:`0` to :math:`n - 1`. r"""Randomly permute a sequence, or return a permuted range.
If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index.
Args: Args:
n: the upper bound. Must be larger than 0. n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`.
dtype: the output data type. int32, int16 and float32 are supported. Default: int32 If ``n`` is an tensor, make a copy and shuffle the elements randomly.
dtype: the output data type when ``n`` is an integer.
int32, int16 and float32 are supported. Default: int32
Returns: Returns:
the output tensor. the output tensor.
...@@ -568,13 +571,18 @@ class RNG: ...@@ -568,13 +571,18 @@ class RNG:
.. testcode:: .. testcode::
import numpy as np
import megengine as mge import megengine as mge
import megengine.random as rand import megengine.random as rand
x = rand.permutation(n=10, dtype="int32") x = rand.permutation(10, dtype="int32")
print(x.numpy())
x = rand.permutation(10, dtype="float32")
print(x.numpy()) print(x.numpy())
x = rand.permutation(n=10, dtype="float32") x = mge.tensor(np.arange(18)).reshape(6,3)
x = rand.permutation(x)
print(x.numpy()) print(x.numpy())
Outputs: Outputs:
...@@ -584,11 +592,20 @@ class RNG: ...@@ -584,11 +592,20 @@ class RNG:
[4 5 0 7 3 8 6 1 9 2] [4 5 0 7 3 8 6 1 9 2]
[3. 4. 9. 0. 6. 8. 7. 1. 5. 2.] [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.]
[[12 13 14]
[ 3 4 5]
[15 16 17]
[ 0 1 2]
[ 9 10 11]
[ 6 7 8]]
""" """
_seed = self._seed() if callable(self._seed) else self._seed _seed = self._seed() if callable(self._seed) else self._seed
return _permutation( if isinstance(n, int):
n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype return _permutation(
) n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype
)
assert isinstance(n, Tensor)
return _shuffle(inp=n, seed=_seed, handle=self._handle)
def shuffle(self, inp: Tensor): def shuffle(self, inp: Tensor):
r"""Modify a sequence in-place by shuffling its contents. r"""Modify a sequence in-place by shuffling its contents.
...@@ -627,7 +644,7 @@ class RNG: ...@@ -627,7 +644,7 @@ class RNG:
[ 6. 7. 8.]] [ 6. 7. 8.]]
""" """
_seed = self._seed() if callable(self._seed) else self._seed _seed = self._seed() if callable(self._seed) else self._seed
_shuffle(inp=inp, seed=_seed, handle=self._handle) inp._reset(_shuffle(inp=inp, seed=_seed, handle=self._handle))
def __del__(self): def __del__(self):
if self._handle != 0: if self._handle != 0:
......
...@@ -28,6 +28,7 @@ from megengine.core.ops.builtin import ( ...@@ -28,6 +28,7 @@ from megengine.core.ops.builtin import (
UniformRNG, UniformRNG,
) )
from megengine.device import get_device_count from megengine.device import get_device_count
from megengine.jit import trace
from megengine.random import RNG from megengine.random import RNG
from megengine.random import seed as set_global_seed from megengine.random import seed as set_global_seed
from megengine.random import uniform from megengine.random import uniform
...@@ -370,21 +371,22 @@ def test_PoissonRNG(): ...@@ -370,21 +371,22 @@ def test_PoissonRNG():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count("xpu") <= 1, reason="xpu counts need > 1", get_device_count("xpu") <= 1, reason="xpu counts need > 1",
) )
def test_PermutationRNG(): @pytest.mark.parametrize("symbolic", [True, False])
def test_PermutationRNG(symbolic):
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1") m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0") m3 = RNG(seed=222, device="xpu0")
out1 = m1.permutation(n=1000) out1 = m1.permutation(1000)
out1_ = m1.uniform(size=(1000,)) out1_ = m1.uniform(size=(1000,))
out2 = m2.permutation(n=1000) out2 = m2.permutation(1000)
out3 = m3.permutation(n=1000) out3 = m3.permutation(1000)
np.testing.assert_equal(out1.numpy(), out2.numpy()) np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1" assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all()
out = m1.permutation(n=1000) out = m1.permutation(1000)
out_shp = out.shape out_shp = out.shape
if isinstance(out_shp, tuple): if isinstance(out_shp, tuple):
assert out_shp == (1000,) assert out_shp == (1000,)
...@@ -397,6 +399,24 @@ def test_PermutationRNG(): ...@@ -397,6 +399,24 @@ def test_PermutationRNG():
assert sum_result(out, lambda x: x) < 500 assert sum_result(out, lambda x: x) < 500
assert sum_result(out, np.sort) == 1000 assert sum_result(out, np.sort) == 1000
def func():
out = m1.permutation(Tensor(7))
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (1,)
else:
assert all(out.shape.numpy() == np.array([1]))
n, m = 6, 3
out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m))
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (n, m)
else:
assert all(out.shape.numpy() == np.array([n, m]))
func = trace(symbolic=symbolic)(func)
func()
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count("xpu") <= 1, reason="xpu counts need > 1", get_device_count("xpu") <= 1, reason="xpu counts need > 1",
......
...@@ -214,8 +214,12 @@ ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param, ...@@ -214,8 +214,12 @@ ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param,
const OperatorNodeConfig& config) const OperatorNodeConfig& config)
: Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { : Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) {
add_input({data}); add_input({data});
add_output(None)->dtype(data->dtype()); add_output(None)
add_output(None)->dtype(dtype::Int32{}); ->dtype(data->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
add_output(None)
->dtype(dtype::Int32{})
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
cg::add_workspace_output(this); cg::add_workspace_output(this);
add_equivalence_component<ScalarHash<void*>>(this); add_equivalence_component<ScalarHash<void*>>(this);
} }
...@@ -266,12 +270,27 @@ void ShuffleRNGForward::add_input_layout_constraint() { ...@@ -266,12 +270,27 @@ void ShuffleRNGForward::add_input_layout_constraint() {
}; };
void ShuffleRNGForward::scn_do_execute() { void ShuffleRNGForward::scn_do_execute() {
auto&& ret = output(0);
if (ret->layout().is_empty()) {
mgb_assert(ret->dev_tensor().empty());
return;
}
m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(),
get_megdnn_workspace_from_var(output(2))); get_megdnn_workspace_from_var(output(2)));
} }
cg::OperatorNodeBase::NodeProp* ShuffleRNGForward::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
prop->add_flag(NodeProp::Flag::IMPURE_FUNC);
for (auto i : input()) {
prop->add_dep_type_existing_var(i,
NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return prop;
}
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { MGB_IMPL_OPR_GRAD(ShuffleRNGForward) {
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册