Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f4b16932
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f4b16932
编写于
10月 10, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(mgb/imperative): add adaptive pooling pytest
GitOrigin-RevId: c4dfed1f8047b3c6c6cca428e9790b5a2cff0b4a
上级
37e56f4b
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
54 addition
and
7 deletion
+54
-7
imperative/python/test/integration/test_correctness.py
imperative/python/test/integration/test_correctness.py
+54
-7
未找到文件。
imperative/python/test/integration/test_correctness.py
浏览文件 @
f4b16932
...
...
@@ -19,9 +19,17 @@ import megengine.autodiff as ad
import
megengine.functional
as
F
from
megengine
import
jit
from
megengine.core._trace_option
import
set_tensor_shape
from
megengine.core.tensor.utils
import
make_shape_tuple
from
megengine.functional.debug_param
import
set_conv_execution_strategy
from
megengine.jit
import
SublinearMemoryConfig
from
megengine.module
import
AvgPool2d
,
BatchNorm2d
,
Conv2d
,
Linear
,
Module
from
megengine.module
import
(
AdaptiveAvgPool2d
,
AvgPool2d
,
BatchNorm2d
,
Conv2d
,
Linear
,
Module
,
)
from
megengine.optimizer
import
SGD
from
megengine.tensor
import
Tensor
...
...
@@ -57,9 +65,12 @@ def get_xpu_name():
class
MnistNet
(
Module
):
def
__init__
(
self
,
has_bn
=
False
):
def
__init__
(
self
,
has_bn
=
False
,
use_adaptive_pooling
=
False
):
super
().
__init__
()
self
.
conv0
=
Conv2d
(
1
,
20
,
kernel_size
=
5
,
bias
=
True
)
if
use_adaptive_pooling
:
self
.
pool0
=
AdaptiveAvgPool2d
(
12
)
else
:
self
.
pool0
=
AvgPool2d
(
2
)
self
.
conv1
=
Conv2d
(
20
,
20
,
kernel_size
=
5
,
bias
=
True
)
self
.
pool1
=
AvgPool2d
(
2
)
...
...
@@ -134,7 +145,12 @@ def update_model(model_path):
def
run_train
(
model_path
,
use_jit
,
use_symbolic
,
sublinear_memory_config
=
None
,
max_err
=
None
,
model_path
,
use_jit
,
use_symbolic
,
sublinear_memory_config
=
None
,
max_err
=
None
,
use_adaptive_pooling
=
False
,
):
"""
...
...
@@ -146,7 +162,7 @@ def run_train(
Please think twice before you do so.
"""
net
=
MnistNet
(
has_bn
=
True
)
net
=
MnistNet
(
has_bn
=
True
,
use_adaptive_pooling
=
use_adaptive_pooling
)
checkpoint
=
mge
.
load
(
model_path
)
net
.
load_state_dict
(
checkpoint
[
"net_init"
])
lr
=
checkpoint
[
"sgd_lr"
]
...
...
@@ -181,7 +197,11 @@ def run_train(
def
run_eval
(
model_path
,
use_symbolic
,
sublinear_memory_config
=
None
,
max_err
=
None
,
model_path
,
use_symbolic
,
sublinear_memory_config
=
None
,
max_err
=
None
,
use_adaptive_pooling
=
False
,
):
"""
...
...
@@ -193,7 +213,7 @@ def run_eval(
Please think twice before you do so.
"""
net
=
MnistNet
(
has_bn
=
True
)
net
=
MnistNet
(
has_bn
=
True
,
use_adaptive_pooling
=
use_adaptive_pooling
)
checkpoint
=
mge
.
load
(
model_path
)
net
.
load_state_dict
(
checkpoint
[
"net_init"
])
...
...
@@ -231,3 +251,30 @@ def test_correctness():
run_eval
(
model_path
,
False
,
max_err
=
1e-7
)
run_eval
(
model_path
,
True
,
max_err
=
1e-7
)
def
test_correctness_use_adaptive_pooling
():
if
mge
.
is_cuda_available
():
model_name
=
"mnist_model_with_test.mge"
else
:
model_name
=
"mnist_model_with_test_cpu.mge"
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
model_name
)
set_conv_execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
run_train
(
model_path
,
False
,
False
,
max_err
=
1e-5
,
use_adaptive_pooling
=
True
)
run_train
(
model_path
,
True
,
False
,
max_err
=
1e-5
,
use_adaptive_pooling
=
True
)
run_train
(
model_path
,
True
,
True
,
max_err
=
1e-5
,
use_adaptive_pooling
=
True
)
# sublinear
config
=
SublinearMemoryConfig
(
genetic_nr_iter
=
10
)
run_train
(
model_path
,
True
,
True
,
sublinear_memory_config
=
config
,
max_err
=
1e-5
,
use_adaptive_pooling
=
True
,
)
run_eval
(
model_path
,
False
,
max_err
=
1e-7
,
use_adaptive_pooling
=
True
)
run_eval
(
model_path
,
True
,
max_err
=
1e-7
,
use_adaptive_pooling
=
True
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录