提交 23b9a98f 编写于 作者: M Megvii Engine Team

fix(mge): fix sublnear cuda and mem leak

GitOrigin-RevId: 82091ec9a6d9a00ecffad9505b59d54e3127c783
上级 c70a49ed
...@@ -350,12 +350,16 @@ class trace: ...@@ -350,12 +350,16 @@ class trace:
self._lazy_eval_links = () self._lazy_eval_links = ()
def _take_escaped_tensors(self): def _take_escaped_tensors(self):
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values())) escaped_tensors = tuple(
filter(lambda x: x() is not None, self._active_tensors.values())
)
self._active_tensors.clear() self._active_tensors.clear()
return escaped_tensors return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values())) lazy_eval_tensors = list(
filter(lambda x: x() is not None, lazy_eval_tensors.values())
)
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph) self._apply_graph_options(lazy_eval_graph)
# FIXME # FIXME
...@@ -443,6 +447,7 @@ class trace: ...@@ -443,6 +447,7 @@ class trace:
x()._reset_varnode() x()._reset_varnode()
x().mixin_handle = -1 x().mixin_handle = -1
x().recording = False x().recording = False
x()._trace_mixin_info = None
try: try:
do_enter() do_enter()
......
...@@ -294,8 +294,13 @@ PyObject* TensorWrapper::copied() { ...@@ -294,8 +294,13 @@ PyObject* TensorWrapper::copied() {
return m_tensor->m_trace_info.member; \ return m_tensor->m_trace_info.member; \
} \ } \
void TensorWrapper::set_##member(PyObject* dest) { \ void TensorWrapper::set_##member(PyObject* dest) { \
Py_INCREF(dest); \ if (dest == Py_None) { \
m_tensor->m_trace_info.member = dest; \ Py_XDECREF(m_tensor->m_trace_info.member); \
m_tensor->m_trace_info.member = nullptr; \
} else { \
Py_INCREF(dest); \
m_tensor->m_trace_info.member = dest; \
} \
} }
REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
...@@ -463,6 +468,8 @@ PyObject* TensorWrapper::_dev_tensor(){ ...@@ -463,6 +468,8 @@ PyObject* TensorWrapper::_dev_tensor(){
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh)); m_tensor->m_handle = std::move(SharedHandle(sh));
Py_DECREF(m_tensor->m_trace_info.compiled_info);
m_tensor->m_trace_info.compiled_info = nullptr;
return dev_tensor; return dev_tensor;
} }
......
...@@ -55,10 +55,9 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -55,10 +55,9 @@ apply_result_t apply_trace(ApplyContext& ctx) {
auto args = py::tuple(ctx.nargs + 1); auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op); args[0] = py::cast(ctx.op);
py::tuple args(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) { for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make( args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
std::move(std::shared_ptr<Tensor>(ctx.args[i])))
.release();
} }
auto ret = py::reinterpret_steal<py::object>( auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, args.ptr(), nullptr)); PyObject_Call(pyf, args.ptr(), nullptr));
......
...@@ -28,10 +28,10 @@ struct TraceInfo { ...@@ -28,10 +28,10 @@ struct TraceInfo {
mixin_handle = that.mixin_handle; mixin_handle = that.mixin_handle;
recording = that.recording; recording = that.recording;
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);
trace_mixin_info = that.trace_mixin_info; trace_mixin_info = that.trace_mixin_info;
Py_XINCREF(trace_mixin_info); Py_XINCREF(trace_mixin_info);
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);
copied = true; copied = true;
return *this; return *this;
...@@ -39,7 +39,7 @@ struct TraceInfo { ...@@ -39,7 +39,7 @@ struct TraceInfo {
~TraceInfo() { ~TraceInfo() {
Py_XDECREF(trace_mixin_info); Py_XDECREF(trace_mixin_info);
// Py_XDECREF(compiled_info); Py_XDECREF(compiled_info);
} }
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册