Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
579a3d77
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
579a3d77
编写于
5月 17, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/jit): error out if dump bn in training mode
GitOrigin-RevId: edc7ea2962da24c8a680c0c6fb2effcfaf3508c2
上级
68da973e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
18 addition
and
1 deletion
+18
-1
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+5
-1
imperative/python/test/integration/test_trace_dump.py
imperative/python/test/integration/test_trace_dump.py
+13
-0
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
579a3d77
...
...
@@ -36,7 +36,7 @@ from ..core._imperative_rt.ops import (
)
from
..core._trace_option
import
set_symbolic_shape
from
..core._wrap
import
device
as
as_device
from
..core.ops.builtin
import
BackwardGraph
,
OpDef
from
..core.ops.builtin
import
BackwardGraph
,
BatchNorm
,
OpDef
from
..core.ops.special
import
Const
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.utils
import
setscalar
...
...
@@ -833,6 +833,10 @@ class trace:
if
isinstance
(
op
,
BackwardGraph
):
ovars
=
G
.
apply_backward_varnode
(
op
,
*
ivars
)
else
:
if
isinstance
(
op
,
BatchNorm
):
assert
(
op
.
fwd_mode
==
BatchNorm
.
FwdMode
.
INFERENCE
),
"can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
AutoNaming
.
record_opnode
(
ovars
[
0
].
op
)
...
...
imperative/python/test/integration/test_trace_dump.py
浏览文件 @
579a3d77
...
...
@@ -11,6 +11,7 @@ import os
import
tempfile
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.functional
as
F
...
...
@@ -140,3 +141,15 @@ def test_xornet_trace_dump():
with
mkstemp
()
as
out
:
pred_fun
.
dump
(
out
,
arg_names
=
[
"data"
],
output_names
=
[
"label"
])
def
test_dump_bn_train_mode
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
bn_train
(
data
):
pred
=
M
.
BatchNorm2d
(
10
)(
data
).
sum
()
return
pred
data
=
mge
.
tensor
(
np
.
random
.
random
((
10
,
10
,
10
,
10
)))
bn_train
(
data
)
with
pytest
.
raises
(
AssertionError
):
bn_train
.
dump
(
"test.mge"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录