Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
afddefb6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
afddefb6
编写于
8月 31, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/imperative): add more trace test
GitOrigin-RevId: b02e420a8a4ef7290fa103aa45487f89ed83db0e
上级
a085b71c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
40 addition
and
4 deletion
+40
-4
imperative/python/test/integration/test_correctness.py
imperative/python/test/integration/test_correctness.py
+38
-2
imperative/src/test/opr_utility.cpp
imperative/src/test/opr_utility.cpp
+2
-2
未找到文件。
imperative/python/test/integration/test_correctness.py
浏览文件 @
afddefb6
...
...
@@ -16,6 +16,8 @@ import pytest
import
megengine
as
mge
import
megengine.functional
as
F
from
megengine
import
jit
from
megengine.core._trace_option
import
set_tensor_shape
from
megengine.functional.debug_param
import
set_conv_execution_strategy
from
megengine.module
import
AvgPool2d
,
BatchNorm2d
,
Conv2d
,
Linear
,
Module
from
megengine.optimizer
import
SGD
...
...
@@ -129,7 +131,7 @@ def update_model(model_path):
mge
.
save
(
checkpoint
,
model_path
)
def
run_t
est
(
def
run_t
rain
(
model_path
,
use_jit
,
use_symbolic
,
sublinear_memory_config
=
None
,
max_err
=
None
,
):
...
...
@@ -175,6 +177,37 @@ def run_test(
assertTensorClose
(
param
[
1
],
param_ref
[
1
],
max_err
=
max_err
)
def
run_eval
(
model_path
,
use_symbolic
,
sublinear_memory_config
=
None
,
max_err
=
None
,
):
"""
Load the model with test cases and run the training for one iter.
The loss and updated weights are compared with reference value to verify the correctness.
Dump a new file with updated result by calling update_model
if you think the test fails due to numerical rounding errors instead of bugs.
Please think twice before you do so.
"""
net
=
MnistNet
(
has_bn
=
True
)
checkpoint
=
mge
.
load
(
model_path
)
net
.
load_state_dict
(
checkpoint
[
"net_init"
])
data
=
Tensor
(
checkpoint
[
"data"
],
dtype
=
np
.
float32
)
def
eval_fun
(
data
,
*
,
net
=
None
):
pred
=
net
(
data
)
return
pred
refer_value
=
eval_fun
(
data
,
net
=
net
)
eval_fun
=
jit
.
trace
(
eval_fun
,
symbolic
=
use_symbolic
)
for
_
in
range
(
3
):
new_value
=
eval_fun
(
data
,
net
=
net
)
assertTensorClose
(
new_value
.
numpy
(),
refer_value
.
numpy
(),
max_err
=
max_err
)
def
test_correctness
():
if
mge
.
is_cuda_available
():
model_name
=
"mnist_model_with_test.mge"
...
...
@@ -183,7 +216,7 @@ def test_correctness():
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
model_name
)
set_conv_execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
run_t
est
(
model_path
,
False
,
False
,
max_err
=
1e-5
)
run_t
rain
(
model_path
,
False
,
False
,
max_err
=
1e-5
)
# run_test(model_path, True, False)
# run_test(model_path, True, True)
...
...
@@ -192,3 +225,6 @@ def test_correctness():
# run_test(
# model_path, True, True, sublinear_memory_config=config, max_err=1e-5,
# )
run_eval
(
model_path
,
False
,
max_err
=
1e-7
)
# run_eval(model_path, True, max_err=1e-7) # XXX: fix me
imperative/src/test/opr_utility.cpp
浏览文件 @
afddefb6
...
...
@@ -25,7 +25,7 @@ TEST(TestOprUtility, InputCallback) {
dv
.
copy_from
(
*
hv
).
sync
();
auto
graph
=
ComputingGraph
::
make
();
auto
callback
=
[
dv
]()
{
return
dv
;};
auto
outputs
=
opr
::
InputCallback
::
make
(
*
graph
,
callback
,
dv
.
comp_node
(),
dv
.
dtype
());
auto
outputs
=
opr
::
InputCallback
::
make
(
*
graph
,
callback
,
dv
.
comp_node
(),
dv
.
dtype
()
,
{
2
,
3
}
);
HostTensorND
hout
;
ComputingGraph
::
OutputSpec
outspec
{
make_callback_copy
(
outputs
[
0
],
hout
)};
...
...
@@ -99,7 +99,7 @@ TEST(TestOprUtility, CallbackChain) {
dev_x
.
storage
({});
return
ret
;
};
auto
out
=
opr
::
InputCallback
::
make
(
*
graph
,
callback
,
cn
,
dev_x
.
dtype
());
auto
out
=
opr
::
InputCallback
::
make
(
*
graph
,
callback
,
cn
,
dev_x
.
dtype
()
,
{
2
,
3
}
);
x
=
out
[
0
];
dummy
=
out
[
1
];
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录