Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
27346b0b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
27346b0b
编写于
9月 27, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(opr): add scalar check for opr_test
GitOrigin-RevId: dcfd7ad5d6b8a85027df796051a0521c7be48575
上级
22504523
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
56 addition
and
60 deletion
+56
-60
imperative/python/test/helpers/utils.py
imperative/python/test/helpers/utils.py
+13
-9
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+10
-9
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+33
-42
未找到文件。
imperative/python/test/helpers/utils.py
浏览文件 @
27346b0b
...
...
@@ -11,12 +11,12 @@ from megengine.utils.network_node import VarNode
def
_default_compare_fn
(
x
,
y
):
if
isinstance
(
x
,
np
.
ndarray
):
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
1e-6
)
elif
isinstance
(
x
,
tensor
):
np
.
testing
.
assert_allclose
(
x
.
numpy
(),
y
,
rtol
=
1e-6
)
else
:
np
.
testing
.
assert_allclose
(
get_var_value
(
x
)
,
y
,
rtol
=
1e-6
)
if
isinstance
(
x
,
tensor
):
x
=
x
.
numpy
(
)
elif
not
isinstance
(
x
,
np
.
ndarray
):
x
=
get_var_value
(
x
)
assert
isinstance
(
x
,
np
.
ndarray
)
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
1e-6
)
def
make_tensor
(
x
,
network
=
None
,
device
=
None
):
...
...
@@ -69,12 +69,16 @@ def opr_test(
"""
def
check_results
(
results
,
expected
):
def
check_results
(
results
,
expected
,
check_shape
=
True
):
if
not
isinstance
(
results
,
(
tuple
,
list
)):
results
=
(
results
,)
for
r
,
e
in
zip
(
results
,
expected
):
if
not
isinstance
(
r
,
(
tensor
,
VarNode
)):
r
=
tensor
(
r
)
if
check_shape
:
r_shape
=
r
.
numpy
().
shape
e_shape
=
e
.
shape
if
isinstance
(
e
,
np
.
ndarray
)
else
()
assert
r_shape
==
e_shape
compare_fn
(
r
,
e
)
def
get_param
(
cases
,
idx
):
...
...
@@ -127,10 +131,10 @@ def opr_test(
# assume #outputs == 1
loaded_results
=
list
(
infer_cg
.
run
(
inp_dict
=
inp_dict
).
values
())[
0
]
check_results
(
loaded_results
,
outp
)
check_results
(
loaded_results
,
outp
,
check_shape
=
False
)
# scalar info lost
results
=
func
(
*
inp_tensor
,
**
kwargs
)
check_results
(
results
,
outp
)
check_results
(
results
,
outp
,
check_shape
=
(
network
is
None
)
)
if
len
(
cases
)
==
0
:
raise
ValueError
(
"should give one case at least"
)
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
27346b0b
...
...
@@ -39,12 +39,6 @@ def test_where():
xv1
=
np
.
array
([[
1
,
np
.
inf
,
2
],
[
0
,
np
.
nan
,
4
],
[
1
,
5
,
7
]],
dtype
=
np
.
float32
)
yv1
=
np
.
array
([[
5
,
6
,
9
],
[
2
,
7
,
8
],
[
2
,
1
,
9
]],
dtype
=
np
.
float32
)
cases
=
[
{
"input"
:
[
maskv0
,
xv0
,
yv0
]},
{
"input"
:
[
maskv1
,
xv1
,
yv1
]},
]
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
,
test_trace
=
False
)
maskv2
=
np
.
array
([
1
,
1
,
1
],
dtype
=
np
.
bool_
)
xv2
=
np
.
array
([
1
,
3
,
2
],
dtype
=
np
.
float32
)
yv2
=
np
.
array
([
5
,
6
,
9
],
dtype
=
np
.
float32
)
...
...
@@ -53,11 +47,18 @@ def test_where():
xv3
=
np
.
array
([
1
,
3
,
2
],
dtype
=
np
.
float32
)
yv3
=
np
.
array
([
5
,
6
,
9
],
dtype
=
np
.
float32
)
maskv4
=
np
.
array
(
1
,
dtype
=
np
.
bool_
)
xv4
=
np
.
array
(
1
,
dtype
=
np
.
float32
)
yv4
=
np
.
array
(
0
,
dtype
=
np
.
float32
)
cases
=
[
{
"input"
:
[
maskv0
,
xv0
,
yv0
]},
{
"input"
:
[
maskv1
,
xv1
,
yv1
]},
{
"input"
:
[
maskv2
,
xv2
,
yv2
]},
{
"input"
:
[
maskv3
,
xv3
,
yv3
]},
{
"input"
:
[
maskv4
,
xv4
,
yv4
]},
]
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
,
test_trace
=
Fals
e
)
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
,
test_trace
=
Tru
e
)
def
test_dropout
():
...
...
@@ -618,12 +619,12 @@ def test_binary_cross_entropy():
np
.
random
.
seed
(
123
)
data1
=
np
.
random
.
uniform
(
size
=
data1_shape
).
astype
(
np
.
float32
)
label1
=
np
.
random
.
uniform
(
size
=
label1_shape
).
astype
(
np
.
float32
)
expect1
=
np
.
array
(
[
0.6361
]
,
dtype
=
np
.
float32
)
expect1
=
np
.
array
(
0.6361
,
dtype
=
np
.
float32
)
np
.
random
.
seed
(
123
)
data2
=
np
.
random
.
uniform
(
size
=
data2_shape
).
astype
(
np
.
float32
)
label2
=
np
.
random
.
uniform
(
size
=
label2_shape
).
astype
(
np
.
float32
)
expect2
=
np
.
array
(
[
0.6750
]
,
dtype
=
np
.
float32
)
expect2
=
np
.
array
(
0.6750
,
dtype
=
np
.
float32
)
cases
=
[
{
"input"
:
[
data1
,
label1
],
"output"
:
expect1
,},
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
27346b0b
...
...
@@ -335,18 +335,18 @@ def test_reshape_shape_inference(is_varnode):
source
=
output
.
shape
if
isinstance
(
source
,
tensor
):
source
=
source
.
numpy
()
np
.
testing
.
assert_equal
(
source
,
target
)
np
.
testing
.
assert_equal
(
source
,
target
.
shape
)
def
func
(
x
,
target_shape
):
return
x
.
reshape
(
target_shape
)
cases
=
[
{
"input"
:
[
x_shape_known
,
tshp_unknown
],
"output"
:
[
(
2
,
2
),]},
{
"input"
:
[
x_shape_unknown
,
tshp_unknown
],
"output"
:
[
(
2
,
2
),]},
{
"input"
:
[
x_shape_known
,
tshp_known
],
"output"
:
[
(
2
,
2
),]},
{
"input"
:
[
x_shape_known
,
tshp_known_unspec
],
"output"
:
[
(
2
,
2
),]},
{
"input"
:
[
x_shape_unknown
,
tshp_known
],
"output"
:
[
(
2
,
2
),]},
{
"input"
:
[
x_shape_unknown
,
tshp_known_unspec
],
"output"
:
[
(
2
,
2
),]},
{
"input"
:
[
x_shape_known
,
tshp_unknown
],
"output"
:
[
np
.
zeros
((
2
,
2
)
),]},
{
"input"
:
[
x_shape_unknown
,
tshp_unknown
],
"output"
:
[
np
.
zeros
((
2
,
2
)
),]},
{
"input"
:
[
x_shape_known
,
tshp_known
],
"output"
:
[
np
.
zeros
((
2
,
2
)
),]},
{
"input"
:
[
x_shape_known
,
tshp_known_unspec
],
"output"
:
[
np
.
zeros
((
2
,
2
)
),]},
{
"input"
:
[
x_shape_unknown
,
tshp_known
],
"output"
:
[
np
.
zeros
((
2
,
2
)
),]},
{
"input"
:
[
x_shape_unknown
,
tshp_known_unspec
],
"output"
:
[
np
.
zeros
((
2
,
2
)
),]},
]
opr_test
(
cases
,
func
,
compare_fn
=
check_shape
,
test_trace
=
True
,
network
=
network
)
if
is_varnode
:
...
...
@@ -533,46 +533,30 @@ def test_flatten(is_varnode):
data0
=
np
.
random
.
random
(
data0_shape
).
astype
(
np
.
float32
)
data1
=
np
.
random
.
random
(
data1_shape
).
astype
(
np
.
float32
)
def
compare_fn
(
x
,
y
):
assert
x
.
_tuple_shape
[
0
]
==
y
output0
=
(
2
*
3
*
4
*
5
,)
output1
=
(
4
*
5
*
6
*
7
,)
cases
=
[
{
"input"
:
data0
,
"output"
:
output0
},
{
"input"
:
data1
,
"output"
:
output1
},
{
"input"
:
data0
,
"output"
:
data0
.
flatten
()
},
{
"input"
:
data1
,
"output"
:
data1
.
flatten
()
},
]
opr_test
(
cases
,
F
.
flatten
,
compare_fn
=
compare_fn
,
network
=
network
)
opr_test
(
cases
,
F
.
flatten
,
network
=
network
)
output0
=
(
2
,
3
*
4
*
5
)
output1
=
(
4
,
5
*
6
*
7
)
cases
=
[
{
"input"
:
data0
,
"output"
:
output0
},
{
"input"
:
data1
,
"output"
:
output1
},
{
"input"
:
data0
,
"output"
:
data0
.
reshape
(
2
,
-
1
)
},
{
"input"
:
data1
,
"output"
:
data1
.
reshape
(
4
,
-
1
)
},
]
opr_test
(
cases
,
F
.
flatten
,
compare_fn
=
compare_fn
,
start_axis
=
1
,
network
=
network
)
opr_test
(
cases
,
F
.
flatten
,
start_axis
=
1
,
network
=
network
)
output0
=
(
2
,
3
,
4
*
5
)
output1
=
(
4
,
5
,
6
*
7
)
cases
=
[
{
"input"
:
data0
,
"output"
:
output0
},
{
"input"
:
data1
,
"output"
:
output1
},
{
"input"
:
data0
,
"output"
:
data0
.
reshape
(
2
,
3
,
-
1
)
},
{
"input"
:
data1
,
"output"
:
data1
.
reshape
(
4
,
5
,
-
1
)
},
]
opr_test
(
cases
,
F
.
flatten
,
compare_fn
=
compare_fn
,
start_axis
=
2
,
network
=
network
)
opr_test
(
cases
,
F
.
flatten
,
start_axis
=
2
,
network
=
network
)
output0
=
(
2
,
3
*
4
,
5
)
output1
=
(
4
,
5
*
6
,
7
)
cases
=
[
{
"input"
:
data0
,
"output"
:
output0
},
{
"input"
:
data1
,
"output"
:
output1
},
{
"input"
:
data0
,
"output"
:
data0
.
reshape
(
2
,
-
1
,
5
)
},
{
"input"
:
data1
,
"output"
:
data1
.
reshape
(
4
,
-
1
,
7
)
},
]
opr_test
(
cases
,
F
.
flatten
,
compare_fn
=
compare_fn
,
start_axis
=
1
,
end_axis
=
2
,
network
=
network
,
cases
,
F
.
flatten
,
start_axis
=
1
,
end_axis
=
2
,
network
=
network
,
)
...
...
@@ -595,15 +579,22 @@ def test_broadcast(is_varnode):
output3_shape
=
(
10
,
10
)
data3
=
np
.
random
.
random
(
input3_shape
).
astype
(
np
.
float32
)
def
compare_fn
(
x
,
y
):
assert
x
.
_tuple_shape
[
0
]
==
y
cases
=
[
{
"input"
:
[
data1
,
output1_shape
],
"output"
:
output1_shape
},
{
"input"
:
[
data2
,
output2_shape
],
"output"
:
output2_shape
},
{
"input"
:
[
data3
,
output3_shape
],
"output"
:
output3_shape
},
{
"input"
:
[
data1
,
output1_shape
],
"output"
:
np
.
broadcast_to
(
data1
,
output1_shape
),
},
{
"input"
:
[
data2
,
output2_shape
],
"output"
:
np
.
broadcast_to
(
data2
,
output2_shape
),
},
{
"input"
:
[
data3
,
output3_shape
],
"output"
:
np
.
broadcast_to
(
data3
,
output3_shape
),
},
]
opr_test
(
cases
,
F
.
broadcast_to
,
compare_fn
=
compare_fn
,
network
=
network
)
opr_test
(
cases
,
F
.
broadcast_to
,
network
=
network
)
x
=
F
.
ones
((
2
,
1
,
3
))
with
pytest
.
raises
(
RuntimeError
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录