提交 09af925f 编写于 作者: M Megvii Engine Team

fix(mge): fix cpp trace function release

GitOrigin-RevId: 924f945c211bc17596710410e616ab4b1e2e612e
上级 3975a54a
......@@ -72,7 +72,6 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.core2 import full_sync as _full_sync
from .core._imperative_rt.core2 import release_trace_apply_func
from .core._imperative_rt.core2 import sync as _sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import *
......@@ -92,9 +91,7 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()
atexit.register(_full_sync)
atexit.register(release_trace_apply_func)
del release_trace_apply_func
del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins
......
......@@ -34,22 +34,15 @@ namespace mgb::imperative::python {
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing,
cpp_apply_compiled_mode, cpp_apply_const_compiled_mode;
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing,
*cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode;
py::object cpp_apply_backward_varnode;
PyObject *cpp_apply_backward_varnode;
void release_trace_apply_func(){
cpp_apply_with_tracing.release();
cpp_apply_const_with_tracing.release();
cpp_apply_compiled_mode.release();
cpp_apply_const_compiled_mode.release();
cpp_apply_backward_varnode.release();
}
#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \
mode = pybind11::reinterpret_steal<py::object>(pyf); \
mode = pyf.ptr(); \
}
REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
......@@ -242,14 +235,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
// const op
if (is_const && is_tracing) {
py::object pyf;
PyObject *pyf;
if (is_compiled) {
pyf = cpp_apply_const_compiled_mode;
} else {
pyf = cpp_apply_const_with_tracing;
}
auto ret = pyf(*tup);
auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, tup.ptr(), nullptr));
auto py_ret = py::reinterpret_borrow<py::list>(ret);
if (auto* t = try_cast(py_ret[0].ptr())) {
m_tensor = t->m_tensor;
......@@ -744,8 +738,6 @@ void init_tensor(py::module m) {
},
py::call_guard<py::gil_scoped_release>());
m.def("release_trace_apply_func", &release_trace_apply_func);
py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")
.def<&GradKeyWrapper::is_attached_to>("is_attached_to")
......
......@@ -253,8 +253,8 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
void init_tensor(pybind11::module);
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode;
extern pybind11::object cpp_apply_backward_varnode;
extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode;
extern PyObject *cpp_apply_backward_varnode;
} // namespace mgb::imperative::python
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./trace.h"
......@@ -23,12 +24,13 @@ apply_result_t apply_trace(ApplyContext& ctx) {
if (ctx.backward) {
// reach here when symbolic=True or compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args)
auto args = py::tuple(ctx.nargs);
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i] = py::cast(ctx.args[i]->m_var);
args[i + 1] = py::cast(ctx.args[i]->m_var);
}
py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args);
py::object ret = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
if (!ret) {
throw py::value_error("invalid py object call");
}
......@@ -36,13 +38,13 @@ apply_result_t apply_trace(ApplyContext& ctx) {
// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode *>();
auto pitem = tup[i].cast<cg::VarNode*>();
outputs.emplace_back(std::make_shared<Tensor>(pitem));
}
return outputs;
}
py::object pyf;
PyObject* pyf;
if (is_compiled) {
// run apply in compiled mode, step 2, 3, etc
pyf = cpp_apply_compiled_mode;
......@@ -51,11 +53,15 @@ apply_result_t apply_trace(ApplyContext& ctx) {
pyf = cpp_apply_with_tracing;
}
auto args = py::tuple(ctx.nargs);
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i] = TensorWrapper::make(std::move(std::shared_ptr<Tensor>(ctx.args[i]))).release();
args[i + 1] = TensorWrapper::make(
std::move(std::shared_ptr<Tensor>(ctx.args[i])))
.release();
}
auto ret = pyf(py::cast(ctx.op), *args);
auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, args.ptr(), nullptr));
// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册