提交 adc49de8 编写于 作者: M Megvii Engine Team

feat(mge/imperative): expose c++ tensor reference count

GitOrigin-RevId: 1940881adc243bd0ec01bf104d50b01f31639bd6
上级 fe1680b3
......@@ -541,6 +541,7 @@ struct TensorWeakRef {
}
return py::none();
}
int _use_cnt() { return wptr.use_count(); }
};
/* ============== convert inputs ============== */
......@@ -774,6 +775,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_swap_in>("_swap_in")
.def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def<&TensorWrapper::_use_cnt>("_use_cnt")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
......@@ -787,7 +789,8 @@ void init_tensor(py::module m) {
py::class_<TensorWeakRef>(m, "TensorWeakRef")
.def(py::init<const TensorWrapper&>())
.def("__call__", &TensorWeakRef::operator());
.def("__call__", &TensorWeakRef::operator())
.def("_use_cnt", &TensorWeakRef::_use_cnt);
static PyMethodDef method_defs[] = {
MGE_PY_INTERFACE(apply, py_apply),
......
......@@ -170,6 +170,7 @@ struct TensorWrapper {
void set_compiled_info(PyObject *);
PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject *);
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册