提交 34c705fc 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): move detach into C++

GitOrigin-RevId: 8c0d86cbbfdde275885b50417c80d28045871493
上级 147cef52
...@@ -118,13 +118,6 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -118,13 +118,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
def __setstate__(self, state): def __setstate__(self, state):
self.q_dict = state.pop("qdict") self.q_dict = state.pop("qdict")
def detach(self):
r"""
Returns a new tensor sharing the same data memory, which is treated as a constant
during backward gradient calcuation, i.e. its gradient is zero.
"""
Wrapper = type(self)
return Wrapper(self)
tensor = Tensor tensor = Tensor
......
...@@ -68,9 +68,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ...@@ -68,9 +68,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
return nullptr; return nullptr;
} }
auto* op = args[0]; auto* op = args[0];
if (!strcmp(op->ob_type->tp_base->tp_name,"PodOpVisitor") || !strcmp(op->ob_type->tp_base->tp_name,"IndexingOpBase")){
op = PyObject_CallMethod(op,"to_c","");
}
PyTypeObject* pytype = args[1]->ob_type; PyTypeObject* pytype = args[1]->ob_type;
++args; ++args;
...@@ -195,6 +192,15 @@ void TensorWrapper::reset(PyObject* tensor) { ...@@ -195,6 +192,15 @@ void TensorWrapper::reset(PyObject* tensor) {
m_tensor = t->m_tensor; m_tensor = t->m_tensor;
} }
PyObject* TensorWrapper::detach() {
PyObject* self = wrap_t::pycast(this);
PyTypeObject* pytype = self->ob_type;
auto new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr();
}
PyObject* TensorWrapper::isscalar() { PyObject* TensorWrapper::isscalar() {
if(m_tensor->m_flags & Tensor::Flags::SCALAR) { if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
...@@ -233,6 +239,7 @@ void init_tensor(py::module m) { ...@@ -233,6 +239,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("isscalar") .def<&TensorWrapper::isscalar>("isscalar")
.def<&TensorWrapper::setscalar>("setscalar") .def<&TensorWrapper::setscalar>("setscalar")
.def<&TensorWrapper::detach>("detach")
.finalize(); .finalize();
if (!tensor_type) throw py::error_already_set(); if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type); py::setattr(m, "Tensor", tensor_type);
......
...@@ -128,6 +128,7 @@ struct TensorWrapper { ...@@ -128,6 +128,7 @@ struct TensorWrapper {
PyObject* device(); PyObject* device();
PyObject* numpy(); PyObject* numpy();
void reset(PyObject*); void reset(PyObject*);
PyObject* detach();
PyObject* isscalar(); PyObject* isscalar();
void setscalar(); void setscalar();
}; };
......
...@@ -166,7 +166,7 @@ def test_interpolate(): ...@@ -166,7 +166,7 @@ def test_interpolate():
def _save_to(self, name="grad"): def _save_to(self, name="grad"):
def callback(tensor, grad): def callback(grad):
setattr(self, name, grad) setattr(self, name, grad)
return callback return callback
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册