Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b982be56
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b982be56
编写于
9月 16, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/imperative): add permutation support for the tensor
GitOrigin-RevId: 7ed0447bfe18d7744fa7771191313d6b45ec8522
上级
3977b7aa
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
74 addition
and
18 deletion
+74
-18
imperative/python/megengine/random/rng.py
imperative/python/megengine/random/rng.py
+28
-11
imperative/python/test/unit/random/test_rng.py
imperative/python/test/unit/random/test_rng.py
+25
-5
src/opr/impl/rand.cpp
src/opr/impl/rand.cpp
+21
-2
未找到文件。
imperative/python/megengine/random/rng.py
浏览文件 @
b982be56
...
...
@@ -225,7 +225,7 @@ def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor:
assert
inp
.
size
>
0
,
"size needs to be greater than 0"
op
=
ShuffleRNG
(
seed
=
seed
,
handle
=
handle
)
output
,
_
=
apply
(
op
,
inp
)
inp
.
_reset
(
output
)
return
output
class
RNG
:
...
...
@@ -554,12 +554,15 @@ class RNG:
_seed
=
self
.
_seed
()
if
callable
(
self
.
_seed
)
else
self
.
_seed
return
_poisson
(
lam
=
lam
,
size
=
size
,
seed
=
_seed
,
handle
=
self
.
_handle
)
def
permutation
(
self
,
n
:
int
,
*
,
dtype
:
str
=
"int32"
):
r
"""Generates a random permutation of integers from :math:`0` to :math:`n - 1`.
def
permutation
(
self
,
n
:
Union
[
int
,
Tensor
],
*
,
dtype
:
str
=
"int32"
):
r
"""Randomly permute a sequence, or return a permuted range.
If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index.
Args:
n: the upper bound. Must be larger than 0.
dtype: the output data type. int32, int16 and float32 are supported. Default: int32
n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`.
If ``n`` is an tensor, make a copy and shuffle the elements randomly.
dtype: the output data type when ``n`` is an integer.
int32, int16 and float32 are supported. Default: int32
Returns:
the output tensor.
...
...
@@ -568,13 +571,18 @@ class RNG:
.. testcode::
import numpy as np
import megengine as mge
import megengine.random as rand
x = rand.permutation(n=10, dtype="int32")
x = rand.permutation(10, dtype="int32")
print(x.numpy())
x = rand.permutation(10, dtype="float32")
print(x.numpy())
x = rand.permutation(n=10, dtype="float32")
x = mge.tensor(np.arange(18)).reshape(6,3)
x = rand.permutation(x)
print(x.numpy())
Outputs:
...
...
@@ -584,11 +592,20 @@ class RNG:
[4 5 0 7 3 8 6 1 9 2]
[3. 4. 9. 0. 6. 8. 7. 1. 5. 2.]
[[12 13 14]
[ 3 4 5]
[15 16 17]
[ 0 1 2]
[ 9 10 11]
[ 6 7 8]]
"""
_seed
=
self
.
_seed
()
if
callable
(
self
.
_seed
)
else
self
.
_seed
if
isinstance
(
n
,
int
):
return
_permutation
(
n
=
n
,
seed
=
_seed
,
device
=
self
.
_device
,
handle
=
self
.
_handle
,
dtype
=
dtype
)
assert
isinstance
(
n
,
Tensor
)
return
_shuffle
(
inp
=
n
,
seed
=
_seed
,
handle
=
self
.
_handle
)
def
shuffle
(
self
,
inp
:
Tensor
):
r
"""Modify a sequence in-place by shuffling its contents.
...
...
@@ -627,7 +644,7 @@ class RNG:
[ 6. 7. 8.]]
"""
_seed
=
self
.
_seed
()
if
callable
(
self
.
_seed
)
else
self
.
_seed
_shuffle
(
inp
=
inp
,
seed
=
_seed
,
handle
=
self
.
_handle
)
inp
.
_reset
(
_shuffle
(
inp
=
inp
,
seed
=
_seed
,
handle
=
self
.
_handle
)
)
def
__del__
(
self
):
if
self
.
_handle
!=
0
:
...
...
imperative/python/test/unit/random/test_rng.py
浏览文件 @
b982be56
...
...
@@ -28,6 +28,7 @@ from megengine.core.ops.builtin import (
UniformRNG
,
)
from
megengine.device
import
get_device_count
from
megengine.jit
import
trace
from
megengine.random
import
RNG
from
megengine.random
import
seed
as
set_global_seed
from
megengine.random
import
uniform
...
...
@@ -370,21 +371,22 @@ def test_PoissonRNG():
@
pytest
.
mark
.
skipif
(
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
)
def
test_PermutationRNG
():
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
True
,
False
])
def
test_PermutationRNG
(
symbolic
):
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
m2
=
RNG
(
seed
=
111
,
device
=
"xpu1"
)
m3
=
RNG
(
seed
=
222
,
device
=
"xpu0"
)
out1
=
m1
.
permutation
(
n
=
1000
)
out1
=
m1
.
permutation
(
1000
)
out1_
=
m1
.
uniform
(
size
=
(
1000
,))
out2
=
m2
.
permutation
(
n
=
1000
)
out3
=
m3
.
permutation
(
n
=
1000
)
out2
=
m2
.
permutation
(
1000
)
out3
=
m3
.
permutation
(
1000
)
np
.
testing
.
assert_equal
(
out1
.
numpy
(),
out2
.
numpy
())
assert
out1
.
device
==
"xpu0"
and
out2
.
device
==
"xpu1"
assert
not
(
out1
.
numpy
()
==
out3
.
numpy
()).
all
()
assert
not
(
out1
.
numpy
()
==
out1_
.
numpy
()).
all
()
out
=
m1
.
permutation
(
n
=
1000
)
out
=
m1
.
permutation
(
1000
)
out_shp
=
out
.
shape
if
isinstance
(
out_shp
,
tuple
):
assert
out_shp
==
(
1000
,)
...
...
@@ -397,6 +399,24 @@ def test_PermutationRNG():
assert
sum_result
(
out
,
lambda
x
:
x
)
<
500
assert
sum_result
(
out
,
np
.
sort
)
==
1000
def
func
():
out
=
m1
.
permutation
(
Tensor
(
7
))
out_shp
=
out
.
shape
if
isinstance
(
out_shp
,
tuple
):
assert
out_shp
==
(
1
,)
else
:
assert
all
(
out
.
shape
.
numpy
()
==
np
.
array
([
1
]))
n
,
m
=
6
,
3
out
=
m1
.
permutation
(
Tensor
(
np
.
arange
(
n
*
m
),
dtype
=
"float32"
).
reshape
(
n
,
m
))
out_shp
=
out
.
shape
if
isinstance
(
out_shp
,
tuple
):
assert
out_shp
==
(
n
,
m
)
else
:
assert
all
(
out
.
shape
.
numpy
()
==
np
.
array
([
n
,
m
]))
func
=
trace
(
symbolic
=
symbolic
)(
func
)
func
()
@
pytest
.
mark
.
skipif
(
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
...
...
src/opr/impl/rand.cpp
浏览文件 @
b982be56
...
...
@@ -214,8 +214,12 @@ ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param,
const
OperatorNodeConfig
&
config
)
:
Super
({
data
->
owner_graph
(),
config
,
"shuffle_rng"
,
{
data
}},
param
)
{
add_input
({
data
});
add_output
(
None
)
->
dtype
(
data
->
dtype
());
add_output
(
None
)
->
dtype
(
dtype
::
Int32
{});
add_output
(
None
)
->
dtype
(
data
->
dtype
())
.
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
add_output
(
None
)
->
dtype
(
dtype
::
Int32
{})
.
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
cg
::
add_workspace_output
(
this
);
add_equivalence_component
<
ScalarHash
<
void
*>>
(
this
);
}
...
...
@@ -266,12 +270,27 @@ void ShuffleRNGForward::add_input_layout_constraint() {
};
void
ShuffleRNGForward
::
scn_do_execute
()
{
auto
&&
ret
=
output
(
0
);
if
(
ret
->
layout
().
is_empty
())
{
mgb_assert
(
ret
->
dev_tensor
().
empty
());
return
;
}
m_dnn_opr
->
exec
(
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
output
(
0
)
->
dev_tensor
().
as_megdnn
(),
output
(
1
)
->
dev_tensor
().
as_megdnn
(),
get_megdnn_workspace_from_var
(
output
(
2
)));
}
cg
::
OperatorNodeBase
::
NodeProp
*
ShuffleRNGForward
::
do_make_node_prop
()
const
{
auto
prop
=
Super
::
do_make_node_prop
();
prop
->
add_flag
(
NodeProp
::
Flag
::
IMPURE_FUNC
);
for
(
auto
i
:
input
())
{
prop
->
add_dep_type_existing_var
(
i
,
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
}
return
prop
;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ShuffleRNGForward
)
{
mgb_assert
(
out_grad
.
size
()
==
3
&&
wrt_idx
==
0
&&
!
out_grad
[
2
]);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录