Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7ac4dbc2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7ac4dbc2
编写于
9月 23, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/trace): finalize when exception raise
GitOrigin-RevId: b8ffd00a7ea29add26a2a3d6275ea9f29d877908
上级
2bd84d67
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
138 addition
and
62 deletion
+138
-62
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+105
-62
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+33
-0
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
7ac4dbc2
...
...
@@ -125,6 +125,9 @@ class trace:
self
.
_graph_opt_level
=
opt_level
self
.
_tensor_shape
=
tensor_shape
self
.
_reset
()
def
_reset
(
self
):
self
.
_untraced
=
True
self
.
_tinfo
=
[]
# handle -> TensorInfo
self
.
_seq
=
[]
...
...
@@ -257,77 +260,117 @@ class trace:
def
_record_const
(
self
,
op
,
outputs
):
pass
@
contextlib
.
contextmanager
def
_setup
(
self
):
def
_set_active
(
self
,
active
:
bool
):
global
active_trace
if
active_trace
:
raise
NotImplementedError
(
"sorry, not implemented: nested trace"
)
active_trace
=
self
if
self
.
_untraced
:
apply
.
enable
(
apply_with_tracing
)
apply
.
enable
(
apply_const_with_tracing
)
if
self
.
_symbolic
:
apply
.
enable
(
apply_symbolic_mode
)
apply
.
enable
(
apply_const_symbolic_mode
)
self
.
_lazy_eval_graph
=
G
.
Graph
()
if
active
:
if
active_trace
:
raise
NotImplementedError
(
"sorry, not implemented: nested trace"
)
active_trace
=
self
else
:
apply
.
enable
(
apply_compiled_mode
)
if
self
.
_graph
is
None
:
self
.
_compile
()
self
.
_graph
.
execute
()
yield
assert
active_trace
is
self
active_trace
=
None
def
_init_trace
(
self
,
symbolic
:
bool
):
apply
.
enable
(
apply_with_tracing
)
apply
.
enable
(
apply_const_with_tracing
)
if
symbolic
:
apply
.
enable
(
apply_symbolic_mode
)
apply
.
enable
(
apply_const_symbolic_mode
)
self
.
_lazy_eval_graph
=
G
.
Graph
()
def
_take_escaped_tensors
(
self
):
escaped_tensors
=
tuple
(
self
.
_active_tensors
)
self
.
_active_tensors
.
clear
()
return
escaped_tensors
if
self
.
_untraced
:
for
x
in
escaped_tensors
:
info
=
self
.
_tinfo
[
x
.
_TraceMixin__handle
]
info
.
data_read
=
True
x
.
_TraceMixin__restore
()
if
self
.
_inputs_to_restore
:
for
x
in
self
.
_inputs_to_restore
:
def
_lazy_eval
(
self
,
lazy_eval_graph
,
lazy_eval_tensors
):
active_lazy_eval_tensors
=
[]
visited
=
set
()
readers
=
[]
for
x
in
lazy_eval_tensors
:
x
=
x
()
if
x
is
None
or
x
in
visited
:
continue
reader
=
G
.
OutputNode
(
x
.
_LazyEvalTensor__varnode
).
outputs
[
0
]
readers
.
append
(
reader
)
active_lazy_eval_tensors
.
append
(
x
)
visited
.
add
(
x
)
self
.
_apply_graph_options
(
lazy_eval_graph
)
lazy_eval_graph
.
compile
(
*
readers
)
lazy_eval_graph
()
for
r
,
x
in
zip
(
readers
,
active_lazy_eval_tensors
):
assign_raw_tensor
(
x
,
as_raw_tensor
(
r
.
op
.
get_value
()))
@
contextlib
.
contextmanager
def
_setup
(
self
):
interrupted
=
False
def
do_enter
():
self
.
_set_active
(
True
)
if
self
.
_untraced
:
self
.
_init_trace
(
self
.
_symbolic
)
else
:
apply
.
enable
(
apply_compiled_mode
)
if
self
.
_graph
is
None
:
self
.
_compile
()
self
.
_graph
.
execute
()
def
do_finalize
():
escaped_tensors
=
self
.
_take_escaped_tensors
()
if
self
.
_untraced
:
for
x
in
escaped_tensors
:
info
=
self
.
_tinfo
[
x
.
_TraceMixin__handle
]
info
.
data_read
=
True
x
.
_TraceMixin__restore
()
if
self
.
_symbolic
:
# eval lazy eval tensors
if
self
.
_lazy_eval_tensors
:
lazy_eval_tensors
=
[]
visited
=
set
()
readers
=
[]
for
x
in
self
.
_lazy_eval_tensors
:
x
=
x
()
if
x
is
None
or
x
in
visited
:
continue
reader
=
G
.
OutputNode
(
x
.
_LazyEvalTensor__varnode
).
outputs
[
0
]
readers
.
append
(
reader
)
lazy_eval_tensors
.
append
(
x
)
visited
.
add
(
x
)
self
.
_apply_graph_options
(
self
.
_lazy_eval_graph
)
self
.
_lazy_eval_graph
.
compile
(
*
readers
)
self
.
_lazy_eval_graph
()
for
r
,
x
in
zip
(
readers
,
lazy_eval_tensors
):
assign_raw_tensor
(
x
,
as_raw_tensor
(
r
.
op
.
get_value
()))
if
self
.
_inputs_to_restore
:
for
x
in
self
.
_inputs_to_restore
:
x
.
_TraceMixin__restore
()
if
self
.
_symbolic
and
self
.
_lazy_eval_tensors
:
# eval lazy eval tensors
self
.
_lazy_eval
(
self
.
_lazy_eval_graph
,
self
.
_lazy_eval_tensors
)
self
.
_lazy_eval_graph
=
None
self
.
_lazy_eval_tensors
=
None
self
.
_untraced
=
False
else
:
if
self
.
_pc
!=
len
(
self
.
_seq
):
raise
TraceMismatchError
(
"premature end"
)
for
x
in
escaped_tensors
:
assign_raw_tensor
(
x
,
as_raw_tensor
(
x
.
_dev_tensor
()))
self
.
_graph
.
wait
()
self
.
_reset_exec_env
()
self
.
_untraced
=
False
else
:
# compiled_tensor leaks
if
self
.
_pc
==
len
(
self
.
_seq
):
for
x
in
escaped_tensors
:
try
:
assign_raw_tensor
(
x
,
as_raw_tensor
(
x
.
_dev_tensor
()))
except
TraceMismatchError
:
# TraceMismatchError thrown in do_exit
pass
self
.
_graph
.
wait
()
self
.
_reset_exec_env
()
# reset status
self
.
_pc
=
0
self
.
_tensor_remaps
=
None
apply
.
disable
(
apply_with_tracing
)
apply
.
disable
(
apply_const_with_tracing
)
apply
.
disable
(
apply_symbolic_mode
)
apply
.
disable
(
apply_const_symbolic_mode
)
apply
.
disable
(
apply_compiled_mode
)
active_trace
=
None
self
.
_tensor_remaps
=
None
apply
.
disable
(
apply_with_tracing
)
apply
.
disable
(
apply_const_with_tracing
)
apply
.
disable
(
apply_symbolic_mode
)
apply
.
disable
(
apply_const_symbolic_mode
)
apply
.
disable
(
apply_compiled_mode
)
self
.
_set_active
(
False
)
def
do_exit
():
if
not
self
.
_untraced
and
self
.
_pc
!=
len
(
self
.
_seq
):
raise
TraceMismatchError
(
"premature end"
)
if
not
self
.
_symbolic
or
not
self
.
_untraced
:
for
x
in
self
.
_active_tensors
:
x
.
_dev_tensor
()
try
:
do_enter
()
yield
do_exit
()
except
:
interrupted
=
True
raise
finally
:
do_finalize
()
if
interrupted
:
self
.
_reset
()
def
_begin_excluded_region
(
self
):
if
self
.
_capture_as_const
:
...
...
imperative/python/test/unit/test_tracing.py
浏览文件 @
7ac4dbc2
...
...
@@ -307,3 +307,36 @@ def test_trace_warp_perspective():
for
i
in
range
(
1
):
f
(
x
,
M
)
def
test_raise_on_trace
():
step_count
=
0
catch_count
=
0
bad_step
=
10
class
CatchMe
(
Exception
):
pass
a
=
tensor
([
1
,
2
,
3
,
4
])
b
=
tensor
([
5
,
6
,
7
,
8
])
c
=
tensor
([
9
,
0
,
1
,
2
])
@
trace
def
add_abc
(
a
,
b
,
c
):
print
(
"Hello"
)
ps
=
a
+
b
result
=
ps
+
c
if
step_count
==
bad_step
:
raise
CatchMe
(
"catch me"
)
return
result
for
i
in
range
(
100
):
try
:
d
=
add_abc
(
a
,
b
,
c
)
except
CatchMe
as
e
:
catch_count
+=
1
else
:
np
.
testing
.
assert_equal
(
d
.
numpy
(),
(
a
+
b
+
c
).
numpy
())
step_count
+=
1
assert
catch_count
==
1
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录