Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c70a49ed
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
395
Star
4704
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c70a49ed
编写于
1月 08, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge): correct trace outputs when grad does copy
GitOrigin-RevId: 65c8956a7df80bea78aa1ff9baa26c31490e2492
上级
d4ada69d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
35 addition
and
11 deletion
+35
-11
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+26
-11
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+6
-0
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+1
-0
imperative/python/src/trace_info.h
imperative/python/src/trace_info.h
+2
-0
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
c70a49ed
...
...
@@ -163,9 +163,9 @@ class trace:
self
.
_graph
=
None
self
.
_need_reset_nodes
=
None
self
.
_lazy_eval_graph
=
None
self
.
_lazy_eval_tensors
=
set
()
self
.
_lazy_eval_tensors
=
{}
self
.
_lazy_eval_links
=
None
self
.
_active_tensors
=
set
()
self
.
_active_tensors
=
{}
self
.
_tensor_remaps
=
None
self
.
_inputs_to_restore
=
None
self
.
_arg_bindings
=
None
...
...
@@ -249,8 +249,8 @@ class trace:
y
.
_compiled_info
=
CompiledTensorProxy
(
h
)
y
.
mixin_handle
=
h
outputs
+=
[
y
]
self
.
_active_tensors
[
h
]
=
TensorWeakRef
(
y
)
self
.
_output_handles
.
update
(
ohandles
)
self
.
_active_tensors
.
update
([
TensorWeakRef
(
o
)
for
o
in
outputs
])
return
outputs
def
_apply_const
(
self
,
value
,
dtype
,
device
):
...
...
@@ -303,9 +303,11 @@ class trace:
x
.
mixin_handle
=
h
x
.
recording
=
True
x
.
_trace_mixin_info
=
info
self
.
_active_tensors
[
h
]
=
TensorWeakRef
(
x
)
if
self
.
_symbolic
:
self
.
_lazy_eval_tensors
[
h
]
=
TensorWeakRef
(
x
)
self
.
_seq
.
append
((
op
,
tuple
(
ihandles
),
tuple
(
ohandles
)))
self
.
_active_tensors
.
update
([
TensorWeakRef
(
o
)
for
o
in
outputs
])
def
_record_const
(
self
,
outputs
):
if
skip_tracing
:
...
...
@@ -327,6 +329,8 @@ class trace:
x
.
mixin_handle
=
h
x
.
recording
=
True
x
.
_trace_mixin_info
=
info
if
self
.
_symbolic
:
self
.
_lazy_eval_tensors
[
h
]
=
TensorWeakRef
(
x
)
self
.
_seq
.
append
((
"Const"
,
tuple
(),
tuple
(
ohandles
)))
def
_set_active
(
self
,
active
:
bool
):
...
...
@@ -346,12 +350,12 @@ class trace:
self
.
_lazy_eval_links
=
()
def
_take_escaped_tensors
(
self
):
escaped_tensors
=
tuple
(
filter
(
lambda
x
:
x
()
is
not
None
,
self
.
_active_tensors
))
escaped_tensors
=
tuple
(
filter
(
lambda
x
:
x
()
is
not
None
,
self
.
_active_tensors
.
values
()
))
self
.
_active_tensors
.
clear
()
return
escaped_tensors
def
_lazy_eval
(
self
,
lazy_eval_graph
,
lazy_eval_tensors
,
lazy_eval_links
):
lazy_eval_tensors
=
list
(
filter
(
lambda
x
:
x
()
is
not
None
,
lazy_eval_tensors
))
lazy_eval_tensors
=
list
(
filter
(
lambda
x
:
x
()
is
not
None
,
lazy_eval_tensors
.
values
()
))
readers
=
[
G
.
OutputNode
(
x
().
_varnode
).
outputs
[
0
]
for
x
in
lazy_eval_tensors
]
self
.
_apply_graph_options
(
lazy_eval_graph
)
# FIXME
...
...
@@ -401,7 +405,7 @@ class trace:
# eval lazy eval tensors
self
.
_lazy_eval
(
self
.
_lazy_eval_graph
,
tuple
(
self
.
_lazy_eval_tensors
)
,
self
.
_lazy_eval_tensors
,
self
.
_lazy_eval_links
,
)
self
.
_lazy_eval_graph
=
None
...
...
@@ -433,9 +437,10 @@ class trace:
if
not
self
.
_untraced
and
self
.
_pc
!=
len
(
self
.
_seq
):
raise
TraceMismatchError
(
"premature end"
)
if
not
self
.
_symbolic
or
not
self
.
_untraced
:
for
x
in
self
.
_active_tensors
:
for
x
in
self
.
_active_tensors
.
values
()
:
if
x
()
is
not
None
:
x
().
_dev_tensor
()
x
().
_reset_varnode
()
x
().
mixin_handle
=
-
1
x
().
recording
=
False
...
...
@@ -459,7 +464,7 @@ class trace:
if
self
.
_untraced
:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
for
x
in
self
.
_active_tensors
:
for
x
in
self
.
_active_tensors
.
values
()
:
info
=
self
.
_tinfo
[
x
().
mixin_handle
]
info
.
exported
=
True
info
.
data_read
=
True
...
...
@@ -626,8 +631,20 @@ class trace:
if
self
.
_capture_as_const
:
self
.
_process_inputs
(
*
args
,
**
kwargs
)
outputs
=
self
.
__wrapped__
(
*
args
,
**
kwargs
)
transform
=
False
if
outputs
is
not
None
:
if
not
isinstance
(
outputs
,
collections
.
abc
.
Sequence
):
transform
=
True
outputs
=
(
outputs
,)
for
o
in
outputs
:
if
o
.
_copied
:
self
.
_active_tensors
[
o
.
mixin_handle
]
=
TensorWeakRef
(
o
)
if
self
.
_untraced
and
self
.
_symbolic
:
self
.
_lazy_eval_tensors
[
o
.
mixin_handle
]
=
TensorWeakRef
(
o
)
if
self
.
_capture_as_const
:
self
.
_process_outputs
(
outputs
)
if
transform
:
outputs
=
outputs
[
0
]
return
outputs
def
dump
(
...
...
@@ -1031,7 +1048,6 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
if
require_links
:
active_trace
.
_lazy_eval_links
=
(
G
.
VarNode
(
outputs
[
0
].
_varnode
),)
active_trace
.
_lazy_eval_tensors
.
update
([
TensorWeakRef
(
o
)
for
o
in
outputs
])
return
outputs
...
...
@@ -1042,7 +1058,6 @@ def apply_const_symbolic_mode(value, dtype, device):
ret
=
RawTensor
(
graph
.
make_const
(
value
,
dtype
=
dtype
,
device
=
device
))
if
np
.
array
(
value
).
ndim
==
0
:
setscalar
(
ret
)
active_trace
.
_lazy_eval_tensors
.
add
(
TensorWeakRef
(
ret
))
return
(
ret
,)
...
...
imperative/python/src/tensor.cpp
浏览文件 @
c70a49ed
...
...
@@ -284,6 +284,11 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC
PyObject
*
TensorWrapper
::
copied
()
{
return
py
::
cast
(
m_tensor
->
m_trace_info
.
copied
).
release
().
ptr
();
}
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
return m_tensor->m_trace_info.member; \
...
...
@@ -740,6 +745,7 @@ void init_tensor(py::module m) {
.
def
<&
TensorWrapper
::
_drop
>
(
"_drop"
)
.
def
<&
TensorWrapper
::
reset_varnode
>
(
"_reset_varnode"
)
.
def_getset
<&
TensorWrapper
::
varnode
>
(
"_varnode"
)
.
def_getset
<&
TensorWrapper
::
copied
>
(
"_copied"
)
.
def_getset
<&
TensorWrapper
::
mixin_handle
,
&
TensorWrapper
::
set_mixin_handle
>
(
"mixin_handle"
)
.
def_getset
<&
TensorWrapper
::
recording
,
&
TensorWrapper
::
set_recording
>
(
"recording"
)
.
def_getset
<&
TensorWrapper
::
handle
,
&
TensorWrapper
::
set_handle
>
(
"_handle"
)
...
...
imperative/python/src/tensor.h
浏览文件 @
c70a49ed
...
...
@@ -161,6 +161,7 @@ struct TensorWrapper {
PyObject
*
mixin_handle
();
PyObject
*
recording
();
PyObject
*
copied
();
void
set_mixin_handle
(
PyObject
*
);
void
set_recording
(
PyObject
*
);
...
...
imperative/python/src/trace_info.h
浏览文件 @
c70a49ed
...
...
@@ -17,6 +17,7 @@ namespace mgb::imperative::python {
struct
TraceInfo
{
int64_t
mixin_handle
=
-
1
;
bool
recording
=
false
;
bool
copied
=
false
;
PyObject
*
compiled_info
=
nullptr
;
PyObject
*
trace_mixin_info
=
nullptr
;
...
...
@@ -32,6 +33,7 @@ struct TraceInfo {
trace_mixin_info
=
that
.
trace_mixin_info
;
Py_XINCREF
(
trace_mixin_info
);
copied
=
true
;
return
*
this
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录