提交 adc49de8 编写于 作者: M Megvii Engine Team

feat(mge/imperative): expose c++ tensor reference count

GitOrigin-RevId: 1940881adc243bd0ec01bf104d50b01f31639bd6
上级 fe1680b3
...@@ -541,6 +541,7 @@ struct TensorWeakRef { ...@@ -541,6 +541,7 @@ struct TensorWeakRef {
} }
return py::none(); return py::none();
} }
int _use_cnt() { return wptr.use_count(); }
}; };
/* ============== convert inputs ============== */ /* ============== convert inputs ============== */
...@@ -774,6 +775,7 @@ void init_tensor(py::module m) { ...@@ -774,6 +775,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_swap_in>("_swap_in") .def<&TensorWrapper::_swap_in>("_swap_in")
.def<&TensorWrapper::_drop>("_drop") .def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::reset_varnode>("_reset_varnode") .def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def<&TensorWrapper::_use_cnt>("_use_cnt")
.def_getset<&TensorWrapper::varnode>("_varnode") .def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied") .def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle") .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
...@@ -787,7 +789,8 @@ void init_tensor(py::module m) { ...@@ -787,7 +789,8 @@ void init_tensor(py::module m) {
py::class_<TensorWeakRef>(m, "TensorWeakRef") py::class_<TensorWeakRef>(m, "TensorWeakRef")
.def(py::init<const TensorWrapper&>()) .def(py::init<const TensorWrapper&>())
.def("__call__", &TensorWeakRef::operator()); .def("__call__", &TensorWeakRef::operator())
.def("_use_cnt", &TensorWeakRef::_use_cnt);
static PyMethodDef method_defs[] = { static PyMethodDef method_defs[] = {
MGE_PY_INTERFACE(apply, py_apply), MGE_PY_INTERFACE(apply, py_apply),
......
...@@ -170,6 +170,7 @@ struct TensorWrapper { ...@@ -170,6 +170,7 @@ struct TensorWrapper {
void set_compiled_info(PyObject *); void set_compiled_info(PyObject *);
PyObject* trace_mixin_info(); PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject *); void set_trace_mixin_info(PyObject *);
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册