Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
334eda87
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
334eda87
编写于
3年前
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge): test trace inside opr_test
GitOrigin-RevId: 2cf1135c1ccbdba234238d29465dd1eda6765a59
上级
2b8150ab
无相关合并请求
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
64 addition
and
8 deletion
+64
-8
imperative/python/megengine/utils/comp_graph_tools.py
imperative/python/megengine/utils/comp_graph_tools.py
+2
-2
imperative/python/test/helpers/utils.py
imperative/python/test/helpers/utils.py
+44
-2
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+6
-3
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+12
-1
未找到文件。
imperative/python/megengine/utils/comp_graph_tools.py
浏览文件 @
334eda87
...
...
@@ -315,9 +315,9 @@ class GraphInference:
inputs
=
get_dep_vars
(
output_nodes
,
"Host2DeviceCopy"
)
self
.
_inp_dict
=
OrderedDict
()
replace_dict
=
{}
for
i
in
inputs
:
for
i
dx
,
i
in
enumerate
(
inputs
)
:
inp_node
=
G
.
InputNode
(
device
=
"xpux"
,
dtype
=
inputs
[
0
].
dtype
,
graph
=
inputs
[
0
].
graph
device
=
"xpux"
,
dtype
=
inputs
[
idx
].
dtype
,
graph
=
inputs
[
0
].
graph
)
self
.
_inp_dict
[
i
.
name
]
=
inp_node
replace_dict
[
i
]
=
inp_node
.
outputs
[
0
]
...
...
This diff is collapsed.
Click to expand it.
imperative/python/test/helpers/utils.py
浏览文件 @
334eda87
import
io
import
numpy
as
np
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine
import
tensor
from
megengine.jit
import
trace
def
_default_compare_fn
(
x
,
y
):
if
isinstance
(
x
,
np
.
ndarray
):
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
1e-6
)
else
:
np
.
testing
.
assert_allclose
(
x
.
numpy
(),
y
,
rtol
=
1e-6
)
def
opr_test
(
cases
,
func
,
compare_fn
=
_default_compare_fn
,
ref_fn
=
None
,
**
kwargs
):
def
opr_test
(
cases
,
func
,
compare_fn
=
_default_compare_fn
,
ref_fn
=
None
,
test_trace
=
True
,
**
kwargs
):
"""
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test.
and the dict should have input,
...
...
@@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs)
if
not
isinstance
(
results
,
(
tuple
,
list
)):
results
=
(
results
,)
for
r
,
e
in
zip
(
results
,
expected
):
if
not
isinstance
(
r
,
tensor
):
r
=
tensor
(
r
)
compare_fn
(
r
,
e
)
def
get_param
(
cases
,
idx
):
...
...
@@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs)
inp
,
outp
=
get_param
(
cases
,
0
)
inp_tensor
=
[
tensor
(
inpi
)
for
inpi
in
inp
]
if
test_trace
:
copied_inp
=
inp_tensor
.
copy
()
for
symbolic
in
[
False
,
True
]:
traced_func
=
trace
(
symbolic
=
symbolic
)(
func
)
for
_
in
range
(
3
):
traced_results
=
traced_func
(
*
copied_inp
,
**
kwargs
)
check_results
(
traced_results
,
outp
)
dumped_func
=
trace
(
symbolic
=
True
,
capture_as_const
=
True
)(
func
)
dumped_results
=
dumped_func
(
*
copied_inp
,
**
kwargs
)
check_results
(
dumped_results
,
outp
)
file
=
io
.
BytesIO
()
dump_info
=
dumped_func
.
dump
(
file
)
file
.
seek
(
0
)
# arg_name has pattern arg_xxx, xxx is int value
def
take_number
(
arg_name
):
return
int
(
arg_name
.
split
(
"_"
)[
-
1
])
input_names
=
dump_info
[
4
]
inps_np
=
[
i
.
numpy
()
for
i
in
copied_inp
]
input_names
.
sort
(
key
=
take_number
)
inp_dict
=
dict
(
zip
(
input_names
,
inps_np
))
infer_cg
=
cgtools
.
GraphInference
(
file
)
# assume #outputs == 1
loaded_results
=
list
(
infer_cg
.
run
(
inp_dict
=
inp_dict
).
values
())[
0
]
check_results
(
loaded_results
,
outp
)
results
=
func
(
*
inp_tensor
,
**
kwargs
)
check_results
(
results
,
outp
)
This diff is collapsed.
Click to expand it.
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
334eda87
...
...
@@ -36,7 +36,7 @@ def test_where():
{
"input"
:
[
maskv0
,
xv0
,
yv0
]},
{
"input"
:
[
maskv1
,
xv1
,
yv1
]},
]
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
)
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
,
test_trace
=
False
)
maskv2
=
np
.
array
([
1
,
1
,
1
],
dtype
=
np
.
bool_
)
xv2
=
np
.
array
([
1
,
3
,
2
],
dtype
=
np
.
float32
)
...
...
@@ -50,7 +50,7 @@ def test_where():
{
"input"
:
[
maskv2
,
xv2
,
yv2
]},
{
"input"
:
[
maskv3
,
xv3
,
yv3
]},
]
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
)
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
,
test_trace
=
False
)
def
test_dropout
():
...
...
@@ -115,14 +115,17 @@ def test_matmul():
{
"input"
:
[
data4
,
data5
]},
]
for
_
in
range
(
0
,
batch_size
):
# FIXME: remove test_trace=False in the future
opr_test
(
cases
,
F
.
matmul
,
ref_fn
=
np
.
matmul
,
cases
,
F
.
matmul
,
test_trace
=
False
,
ref_fn
=
np
.
matmul
,
)
# FIXME: remove test_trace=False in the future
opr_test
(
[{
"input"
:
[
data1
,
data4
]}],
F
.
matmul
,
ref_fn
=
lambda
x
,
y
:
np
.
matmul
(
x
,
y
.
transpose
(
0
,
1
,
3
,
2
)),
test_trace
=
False
,
transpose_b
=
True
,
)
...
...
This diff is collapsed.
Click to expand it.
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
334eda87
...
...
@@ -162,20 +162,24 @@ def test_linspace():
{
"input"
:
[
1
,
9
,
9
]},
{
"input"
:
[
3
,
10
,
8
]},
]
# FIXME: remove test_trace=False in the future
opr_test
(
cases
,
F
.
linspace
,
ref_fn
=
lambda
start
,
end
,
step
:
np
.
linspace
(
start
,
end
,
step
,
dtype
=
np
.
float32
),
test_trace
=
False
,
)
cases
=
[
{
"input"
:
[
9
,
1
,
9
]},
{
"input"
:
[
10
,
3
,
8
]},
]
# FIXME: remove test_trace=False in the future
opr_test
(
cases
,
F
.
linspace
,
ref_fn
=
lambda
start
,
end
,
step
:
np
.
linspace
(
start
,
end
,
step
,
dtype
=
np
.
float32
),
test_trace
=
False
,
)
...
...
@@ -184,30 +188,36 @@ def test_arange():
{
"input"
:
[
1
,
9
,
1
]},
{
"input"
:
[
2
,
10
,
2
]},
]
# FIXME: remove test_trace=False in the future
opr_test
(
cases
,
F
.
arange
,
ref_fn
=
lambda
start
,
end
,
step
:
np
.
arange
(
start
,
end
,
step
,
dtype
=
np
.
float32
),
test_trace
=
False
,
)
cases
=
[
{
"input"
:
[
9
,
1
,
-
1
]},
{
"input"
:
[
10
,
2
,
-
2
]},
]
# FIXME: remove test_trace=False in the future
opr_test
(
cases
,
F
.
arange
,
ref_fn
=
lambda
start
,
end
,
step
:
np
.
arange
(
start
,
end
,
step
,
dtype
=
np
.
float32
),
test_trace
=
False
,
)
cases
=
[
{
"input"
:
[
9.3
,
1.2
,
-
0.5
]},
{
"input"
:
[
10.3
,
2.1
,
-
1.7
]},
]
# FIXME: remove test_trace=False in the future
opr_test
(
cases
,
F
.
arange
,
ref_fn
=
lambda
start
,
end
,
step
:
np
.
arange
(
start
,
end
,
step
,
dtype
=
np
.
float32
),
test_trace
=
False
,
)
...
...
@@ -279,7 +289,8 @@ def test_broadcast():
{
"input"
:
[
data1
,
output1_shape
],
"output"
:
output1_shape
},
{
"input"
:
[
data2
,
output2_shape
],
"output"
:
output2_shape
},
]
opr_test
(
cases
,
F
.
broadcast_to
,
compare_fn
=
compare_fn
)
# FIXME: remove test_trace=False in the future
opr_test
(
cases
,
F
.
broadcast_to
,
compare_fn
=
compare_fn
,
test_trace
=
False
)
x
=
F
.
ones
((
2
,
1
,
3
))
with
pytest
.
raises
(
RuntimeError
):
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部