提交 a8292704 编写于 作者: M Megvii Engine Team

fix(mge/tracing): replace detach as fast path copy

GitOrigin-RevId: d765725d5ab0e41e7b7acb4825e01d087e4460f4
上级 8485eff1
......@@ -541,15 +541,9 @@ PyObject* TensorWrapper::detach() {
PyObject* self = wrap_t::pycast(this);
PyTypeObject* pytype = self->ob_type;
std::shared_ptr<Tensor> new_tensor;
if (m_tensor->m_handle.get()) {
new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
} else {
new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
}
new_tensor->m_trace_info = m_tensor->m_trace_info;
new_tensor->m_flags = m_tensor->m_flags;
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy());
auto new_tensor = python::apply(op, m_tensor)[0];
new_tensor->m_grad_info_dict = {};
auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr();
}
......
......@@ -95,6 +95,19 @@ def test_output_copy_trace():
np.testing.assert_equal(ys[False][i], ys[True][i])
@pytest.mark.parametrize("trace_mode", [False, True])
def test_tensor_detach(trace_mode):
@trace(symbolic=True)
def f(x):
y = x.detach() ** 2
z = y.detach() + 1
return z.detach()
x = tensor([1, 2, 3, 4])
for _ in range(3):
f(x).numpy()
@pytest.mark.parametrize("trace_mode", [False, True])
def test_exclude_from_trace(trace_mode):
@trace(symbolic=trace_mode)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册