提交 34c705fc 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): move detach into C++

GitOrigin-RevId: 8c0d86cbbfdde275885b50417c80d28045871493
上级 147cef52
......@@ -118,13 +118,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
def __setstate__(self, state):
self.q_dict = state.pop("qdict")
def detach(self):
r"""
Returns a new tensor sharing the same data memory, which is treated as a constant
during backward gradient calcuation, i.e. its gradient is zero.
"""
Wrapper = type(self)
return Wrapper(self)
tensor = Tensor
......
......@@ -68,9 +68,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
return nullptr;
}
auto* op = args[0];
if (!strcmp(op->ob_type->tp_base->tp_name,"PodOpVisitor") || !strcmp(op->ob_type->tp_base->tp_name,"IndexingOpBase")){
op = PyObject_CallMethod(op,"to_c","");
}
PyTypeObject* pytype = args[1]->ob_type;
++args;
......@@ -195,6 +192,15 @@ void TensorWrapper::reset(PyObject* tensor) {
m_tensor = t->m_tensor;
}
PyObject* TensorWrapper::detach() {
PyObject* self = wrap_t::pycast(this);
PyTypeObject* pytype = self->ob_type;
auto new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr();
}
PyObject* TensorWrapper::isscalar() {
if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
Py_RETURN_TRUE;
......@@ -233,6 +239,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("isscalar")
.def<&TensorWrapper::setscalar>("setscalar")
.def<&TensorWrapper::detach>("detach")
.finalize();
if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type);
......
......@@ -128,6 +128,7 @@ struct TensorWrapper {
PyObject* device();
PyObject* numpy();
void reset(PyObject*);
PyObject* detach();
PyObject* isscalar();
void setscalar();
};
......
......@@ -166,7 +166,7 @@ def test_interpolate():
def _save_to(self, name="grad"):
def callback(tensor, grad):
def callback(grad):
setattr(self, name, grad)
return callback
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册