Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
09af925f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
09af925f
编写于
1月 05, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge): fix cpp trace function release
GitOrigin-RevId: 924f945c211bc17596710410e616ab4b1e2e612e
上级
3975a54a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
25 addition
and
30 deletion
+25
-30
imperative/python/megengine/__init__.py
imperative/python/megengine/__init__.py
+0
-3
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+7
-15
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+2
-2
imperative/python/src/trace.cpp
imperative/python/src/trace.cpp
+16
-10
未找到文件。
imperative/python/megengine/__init__.py
浏览文件 @
09af925f
...
@@ -72,7 +72,6 @@ if sys.platform == "win32":
...
@@ -72,7 +72,6 @@ if sys.platform == "win32":
kernel32
.
SetErrorMode
(
old_error_mode
)
kernel32
.
SetErrorMode
(
old_error_mode
)
from
.core._imperative_rt.core2
import
full_sync
as
_full_sync
from
.core._imperative_rt.core2
import
full_sync
as
_full_sync
from
.core._imperative_rt.core2
import
release_trace_apply_func
from
.core._imperative_rt.core2
import
sync
as
_sync
from
.core._imperative_rt.core2
import
sync
as
_sync
from
.core._imperative_rt.utils
import
_set_fork_exec_path_for_timed_func
from
.core._imperative_rt.utils
import
_set_fork_exec_path_for_timed_func
from
.device
import
*
from
.device
import
*
...
@@ -92,9 +91,7 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
...
@@ -92,9 +91,7 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins
.
reg
()
_persistent_cache_impl_ins
.
reg
()
atexit
.
register
(
_full_sync
)
atexit
.
register
(
_full_sync
)
atexit
.
register
(
release_trace_apply_func
)
del
release_trace_apply_func
del
_set_fork_exec_path_for_timed_func
del
_set_fork_exec_path_for_timed_func
del
_persistent_cache_impl_ins
del
_persistent_cache_impl_ins
...
...
imperative/python/src/tensor.cpp
浏览文件 @
09af925f
...
@@ -34,22 +34,15 @@ namespace mgb::imperative::python {
...
@@ -34,22 +34,15 @@ namespace mgb::imperative::python {
std
::
unique_ptr
<
interpreter
::
Interpreter
::
Channel
>
interpreter_for_py
;
std
::
unique_ptr
<
interpreter
::
Interpreter
::
Channel
>
interpreter_for_py
;
py
::
object
cpp_apply_with_tracing
,
cpp_apply_const_with_tracing
,
PyObject
*
cpp_apply_with_tracing
,
*
cpp_apply_const_with_tracing
,
cpp_apply_compiled_mode
,
cpp_apply_const_compiled_mode
;
*
cpp_apply_compiled_mode
,
*
cpp_apply_const_compiled_mode
;
py
::
object
cpp_apply_backward_varnode
;
PyObject
*
cpp_apply_backward_varnode
;
void
release_trace_apply_func
(){
cpp_apply_with_tracing
.
release
();
cpp_apply_const_with_tracing
.
release
();
cpp_apply_compiled_mode
.
release
();
cpp_apply_const_compiled_mode
.
release
();
cpp_apply_backward_varnode
.
release
();
}
#define REGISTE_APPLY_FUNC(mode) \
#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \
void set_##mode(py::object pyf) { \
mode = py
bind11::reinterpret_steal<py::object>(pyf);
\
mode = py
f.ptr();
\
}
}
REGISTE_APPLY_FUNC
(
cpp_apply_with_tracing
)
REGISTE_APPLY_FUNC
(
cpp_apply_with_tracing
)
...
@@ -242,14 +235,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
...
@@ -242,14 +235,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
// const op
// const op
if
(
is_const
&&
is_tracing
)
{
if
(
is_const
&&
is_tracing
)
{
py
::
object
pyf
;
PyObject
*
pyf
;
if
(
is_compiled
)
{
if
(
is_compiled
)
{
pyf
=
cpp_apply_const_compiled_mode
;
pyf
=
cpp_apply_const_compiled_mode
;
}
else
{
}
else
{
pyf
=
cpp_apply_const_with_tracing
;
pyf
=
cpp_apply_const_with_tracing
;
}
}
auto
ret
=
pyf
(
*
tup
);
auto
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
pyf
,
tup
.
ptr
(),
nullptr
));
auto
py_ret
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
auto
py_ret
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
if
(
auto
*
t
=
try_cast
(
py_ret
[
0
].
ptr
()))
{
if
(
auto
*
t
=
try_cast
(
py_ret
[
0
].
ptr
()))
{
m_tensor
=
t
->
m_tensor
;
m_tensor
=
t
->
m_tensor
;
...
@@ -744,8 +738,6 @@ void init_tensor(py::module m) {
...
@@ -744,8 +738,6 @@ void init_tensor(py::module m) {
},
},
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"release_trace_apply_func"
,
&
release_trace_apply_func
);
py
::
handle
grad_key_type
=
GradKeyWrapper
::
wrap_t
::
type
()
py
::
handle
grad_key_type
=
GradKeyWrapper
::
wrap_t
::
type
()
.
def
<&
GradKeyWrapper
::
attach
>
(
"attach"
)
.
def
<&
GradKeyWrapper
::
attach
>
(
"attach"
)
.
def
<&
GradKeyWrapper
::
is_attached_to
>
(
"is_attached_to"
)
.
def
<&
GradKeyWrapper
::
is_attached_to
>
(
"is_attached_to"
)
...
...
imperative/python/src/tensor.h
浏览文件 @
09af925f
...
@@ -253,8 +253,8 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
...
@@ -253,8 +253,8 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
void
init_tensor
(
pybind11
::
module
);
void
init_tensor
(
pybind11
::
module
);
extern
pybind11
::
object
cpp_apply_with_tracing
,
cpp_apply_compiled_mode
;
extern
PyObject
*
cpp_apply_with_tracing
,
*
cpp_apply_compiled_mode
;
extern
pybind11
::
object
cpp_apply_backward_varnode
;
extern
PyObject
*
cpp_apply_backward_varnode
;
}
// namespace mgb::imperative::python
}
// namespace mgb::imperative::python
...
...
imperative/python/src/trace.cpp
浏览文件 @
09af925f
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "./trace.h"
#include "./trace.h"
...
@@ -23,12 +24,13 @@ apply_result_t apply_trace(ApplyContext& ctx) {
...
@@ -23,12 +24,13 @@ apply_result_t apply_trace(ApplyContext& ctx) {
if
(
ctx
.
backward
)
{
if
(
ctx
.
backward
)
{
// reach here when symbolic=True or compiled=True
// reach here when symbolic=True or compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args)
// call megbrain_graph.py apply(BackwardGraph, *args)
auto
args
=
py
::
tuple
(
ctx
.
nargs
);
auto
args
=
py
::
tuple
(
ctx
.
nargs
+
1
);
args
[
0
]
=
py
::
cast
(
ctx
.
op
);
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
args
[
i
]
=
py
::
cast
(
ctx
.
args
[
i
]
->
m_var
);
args
[
i
+
1
]
=
py
::
cast
(
ctx
.
args
[
i
]
->
m_var
);
}
}
py
::
object
ret
=
cpp_apply_backward_varnode
(
py
::
cast
(
ctx
.
op
),
*
args
);
py
::
object
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
cpp_apply_backward_varnode
,
args
.
ptr
(),
nullptr
));
if
(
!
ret
)
{
if
(
!
ret
)
{
throw
py
::
value_error
(
"invalid py object call"
);
throw
py
::
value_error
(
"invalid py object call"
);
}
}
...
@@ -36,13 +38,13 @@ apply_result_t apply_trace(ApplyContext& ctx) {
...
@@ -36,13 +38,13 @@ apply_result_t apply_trace(ApplyContext& ctx) {
// assumption: python function always returns PyList
// assumption: python function always returns PyList
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
for
(
auto
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
for
(
auto
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
auto
pitem
=
tup
[
i
].
cast
<
cg
::
VarNode
*>
();
auto
pitem
=
tup
[
i
].
cast
<
cg
::
VarNode
*>
();
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
pitem
));
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
pitem
));
}
}
return
outputs
;
return
outputs
;
}
}
py
::
object
pyf
;
PyObject
*
pyf
;
if
(
is_compiled
)
{
if
(
is_compiled
)
{
// run apply in compiled mode, step 2, 3, etc
// run apply in compiled mode, step 2, 3, etc
pyf
=
cpp_apply_compiled_mode
;
pyf
=
cpp_apply_compiled_mode
;
...
@@ -51,11 +53,15 @@ apply_result_t apply_trace(ApplyContext& ctx) {
...
@@ -51,11 +53,15 @@ apply_result_t apply_trace(ApplyContext& ctx) {
pyf
=
cpp_apply_with_tracing
;
pyf
=
cpp_apply_with_tracing
;
}
}
auto
args
=
py
::
tuple
(
ctx
.
nargs
);
auto
args
=
py
::
tuple
(
ctx
.
nargs
+
1
);
args
[
0
]
=
py
::
cast
(
ctx
.
op
);
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
args
[
i
]
=
TensorWrapper
::
make
(
std
::
move
(
std
::
shared_ptr
<
Tensor
>
(
ctx
.
args
[
i
]))).
release
();
args
[
i
+
1
]
=
TensorWrapper
::
make
(
std
::
move
(
std
::
shared_ptr
<
Tensor
>
(
ctx
.
args
[
i
])))
.
release
();
}
}
auto
ret
=
pyf
(
py
::
cast
(
ctx
.
op
),
*
args
);
auto
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
pyf
,
args
.
ptr
(),
nullptr
));
// assumption: python function always returns PyList
// assumption: python function always returns PyList
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录