From 34c705fcbfc1e7c462bd1cb51dc82e5f58c6f1f4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 15 Dec 2020 18:24:27 +0800 Subject: [PATCH] refactor(mge/imperative): move detach into C++ GitOrigin-RevId: 8c0d86cbbfdde275885b50417c80d28045871493 --- imperative/python/megengine/tensor.py | 7 ------- imperative/python/src/tensor.cpp | 13 ++++++++++--- imperative/python/src/tensor.h | 1 + .../python/test/unit/functional/test_functional.py | 2 +- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 6c13d9df6..ae66437d3 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -118,13 +118,6 @@ class Tensor(_Tensor, ArrayMethodMixin): def __setstate__(self, state): self.q_dict = state.pop("qdict") - def detach(self): - r""" - Returns a new tensor sharing the same data memory, which is treated as a constant - during backward gradient calcuation, i.e. its gradient is zero. - """ - Wrapper = type(self) - return Wrapper(self) tensor = Tensor diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 2eefbc0dc..b123f73e6 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -68,9 +68,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje return nullptr; } auto* op = args[0]; - if (!strcmp(op->ob_type->tp_base->tp_name,"PodOpVisitor") || !strcmp(op->ob_type->tp_base->tp_name,"IndexingOpBase")){ - op = PyObject_CallMethod(op,"to_c",""); - } PyTypeObject* pytype = args[1]->ob_type; ++args; @@ -195,6 +192,15 @@ void TensorWrapper::reset(PyObject* tensor) { m_tensor = t->m_tensor; } +PyObject* TensorWrapper::detach() { + PyObject* self = wrap_t::pycast(this); + PyTypeObject* pytype = self->ob_type; + auto new_tensor = std::make_shared(m_tensor->m_handle); + auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); + return ret.release().ptr(); + +} + PyObject* TensorWrapper::isscalar() { if(m_tensor->m_flags & Tensor::Flags::SCALAR) { Py_RETURN_TRUE; @@ -233,6 +239,7 @@ void init_tensor(py::module m) { .def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::isscalar>("isscalar") .def<&TensorWrapper::setscalar>("setscalar") + .def<&TensorWrapper::detach>("detach") .finalize(); if (!tensor_type) throw py::error_already_set(); py::setattr(m, "Tensor", tensor_type); diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 73a436c46..26054a76f 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -128,6 +128,7 @@ struct TensorWrapper { PyObject* device(); PyObject* numpy(); void reset(PyObject*); + PyObject* detach(); PyObject* isscalar(); void setscalar(); }; diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index da894936a..49e3e8174 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -166,7 +166,7 @@ def test_interpolate(): def _save_to(self, name="grad"): - def callback(tensor, grad): + def callback(grad): setattr(self, name, grad) return callback -- GitLab