diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 5b51b509ebc70819990514860c719826b04ece73..166bc1ee7fc22eedcb44f642dd8a3e83e2dea1f7 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -541,15 +541,9 @@ PyObject* TensorWrapper::detach() { PyObject* self = wrap_t::pycast(this); PyTypeObject* pytype = self->ob_type; - std::shared_ptr new_tensor; - if (m_tensor->m_handle.get()) { - new_tensor = std::make_shared(m_tensor->m_handle); - } else { - new_tensor = std::make_shared(m_tensor->m_var); - } - new_tensor->m_trace_info = m_tensor->m_trace_info; - - new_tensor->m_flags = m_tensor->m_flags; + static std::shared_ptr op = std::shared_ptr(new FastpathCopy()); + auto new_tensor = python::apply(op, m_tensor)[0]; + new_tensor->m_grad_info_dict = {}; auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); return ret.release().ptr(); } diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 2f8331399b323f3dc5adf7963c55335225b90fc7..6e443e6c99f771f1e7c19f982edbcecffa96cf54 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -95,6 +95,19 @@ def test_output_copy_trace(): np.testing.assert_equal(ys[False][i], ys[True][i]) +@pytest.mark.parametrize("trace_mode", [False, True]) +def test_tensor_detach(trace_mode): + @trace(symbolic=True) + def f(x): + y = x.detach() ** 2 + z = y.detach() + 1 + return z.detach() + + x = tensor([1, 2, 3, 4]) + for _ in range(3): + f(x).numpy() + + @pytest.mark.parametrize("trace_mode", [False, True]) def test_exclude_from_trace(trace_mode): @trace(symbolic=trace_mode)