Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
884a07ff
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
884a07ff
编写于
9月 18, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(test/random): set a random seed for random unit test
GitOrigin-RevId: ad4b01eac7238d7a71f8b71075c37bd4b3e58235
上级
d7cc4628
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
13 addition
and
4 deletion
+13
-4
imperative/python/test/unit/random/test_rng.py
imperative/python/test/unit/random/test_rng.py
+13
-4
未找到文件。
imperative/python/test/unit/random/test_rng.py
浏览文件 @
884a07ff
...
...
@@ -27,13 +27,16 @@ from megengine.core.ops.builtin import (
UniformRNG
,
)
from
megengine.device
import
get_device_count
from
megengine.random
import
RNG
,
seed
,
uniform
from
megengine.random
import
RNG
from
megengine.random
import
seed
as
set_global_seed
from
megengine.random
import
uniform
@
pytest
.
mark
.
skipif
(
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_gaussian_op
():
set_global_seed
(
1024
)
shape
=
(
8
,
9
,
...
...
@@ -64,6 +67,7 @@ def test_gaussian_op():
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_uniform_op
():
set_global_seed
(
1024
)
shape
=
(
8
,
9
,
...
...
@@ -92,6 +96,7 @@ def test_uniform_op():
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_gamma_op
():
set_global_seed
(
1024
)
_shape
,
_scale
=
2
,
0.8
_expected_mean
,
_expected_std
=
_shape
*
_scale
,
np
.
sqrt
(
_shape
)
*
_scale
...
...
@@ -120,6 +125,7 @@ def test_gamma_op():
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_beta_op
():
set_global_seed
(
1024
)
_alpha
,
_beta
=
2
,
0.8
_expected_mean
=
_alpha
/
(
_alpha
+
_beta
)
_expected_std
=
np
.
sqrt
(
...
...
@@ -151,6 +157,7 @@ def test_beta_op():
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_poisson_op
():
set_global_seed
(
1024
)
lam
=
F
.
full
([
8
,
9
,
11
,
12
],
value
=
2
,
dtype
=
"float32"
)
op
=
PoissonRNG
(
seed
=
get_global_rng_seed
())
(
output
,)
=
apply
(
op
,
lam
)
...
...
@@ -174,6 +181,7 @@ def test_poisson_op():
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_permutation_op
():
set_global_seed
(
1024
)
n
=
1000
def
test_permutation_op_dtype
(
dtype
):
...
...
@@ -390,22 +398,23 @@ def test_PermutationRNG():
def
test_seed
():
seed
(
10
)
se
t_global_se
ed
(
10
)
out1
=
uniform
(
size
=
[
10
,
10
])
out2
=
uniform
(
size
=
[
10
,
10
])
assert
not
(
out1
.
numpy
()
==
out2
.
numpy
()).
all
()
seed
(
10
)
se
t_global_se
ed
(
10
)
out3
=
uniform
(
size
=
[
10
,
10
])
np
.
testing
.
assert_equal
(
out1
.
numpy
(),
out3
.
numpy
())
seed
(
11
)
se
t_global_se
ed
(
11
)
out4
=
uniform
(
size
=
[
10
,
10
])
assert
not
(
out1
.
numpy
()
==
out4
.
numpy
()).
all
()
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
,
False
,
True
])
def
test_rng_empty_tensor
(
is_symbolic
):
set_global_seed
(
1024
)
shapes
=
[
(
0
,),
(
0
,
0
,
0
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录