提交 23b9a98f 编写于 作者: M Megvii Engine Team

fix(mge): fix sublnear cuda and mem leak

GitOrigin-RevId: 82091ec9a6d9a00ecffad9505b59d54e3127c783
上级 c70a49ed
......@@ -350,12 +350,16 @@ class trace:
self._lazy_eval_links = ()
def _take_escaped_tensors(self):
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values()))
escaped_tensors = tuple(
filter(lambda x: x() is not None, self._active_tensors.values())
)
self._active_tensors.clear()
return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values()))
lazy_eval_tensors = list(
filter(lambda x: x() is not None, lazy_eval_tensors.values())
)
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph)
# FIXME
......@@ -443,6 +447,7 @@ class trace:
x()._reset_varnode()
x().mixin_handle = -1
x().recording = False
x()._trace_mixin_info = None
try:
do_enter()
......
......@@ -294,8 +294,13 @@ PyObject* TensorWrapper::copied() {
return m_tensor->m_trace_info.member; \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
Py_INCREF(dest); \
m_tensor->m_trace_info.member = dest; \
if (dest == Py_None) { \
Py_XDECREF(m_tensor->m_trace_info.member); \
m_tensor->m_trace_info.member = nullptr; \
} else { \
Py_INCREF(dest); \
m_tensor->m_trace_info.member = dest; \
} \
}
REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
......@@ -463,6 +468,8 @@ PyObject* TensorWrapper::_dev_tensor(){
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh));
Py_DECREF(m_tensor->m_trace_info.compiled_info);
m_tensor->m_trace_info.compiled_info = nullptr;
return dev_tensor;
}
......
......@@ -55,10 +55,9 @@ apply_result_t apply_trace(ApplyContext& ctx) {
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
py::tuple args(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(
std::move(std::shared_ptr<Tensor>(ctx.args[i])))
.release();
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, args.ptr(), nullptr));
......
......@@ -28,10 +28,10 @@ struct TraceInfo {
mixin_handle = that.mixin_handle;
recording = that.recording;
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);
trace_mixin_info = that.trace_mixin_info;
Py_XINCREF(trace_mixin_info);
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);
copied = true;
return *this;
......@@ -39,7 +39,7 @@ struct TraceInfo {
~TraceInfo() {
Py_XDECREF(trace_mixin_info);
// Py_XDECREF(compiled_info);
Py_XDECREF(compiled_info);
}
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册