未验证 提交 8ca0a8a8 编写于 作者: Z Zhou Wei 提交者: GitHub

fix tensor detach to zero copy (#27921)

* fix tensor detach to zero copy

* fix tensor detach to zero copy
上级 13828db3
...@@ -643,44 +643,82 @@ void BindImperative(py::module *m_ptr) { ...@@ -643,44 +643,82 @@ void BindImperative(py::module *m_ptr) {
return TensorToPyArray(tensor, true); return TensorToPyArray(tensor, true);
}, },
R"DOC( R"DOC(
**Notes**:
**This API is ONLY available in Dygraph mode**
Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`
Returns a numpy array shows the value of current Tensor.
Returns: Returns:
ndarray: The numpy value of current Variable. ndarray: The numpy value of current Tensor.
Returns type: Returns type:
ndarray: dtype is same as current Variable ndarray: dtype is same as current Tensor
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import Linear
import numpy as np import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32') data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with fluid.dygraph.guard(): linear = paddle.nn.Linear(32, 64)
linear = Linear(32, 64) data = paddle.to_tensor(data)
data = to_variable(data) x = linear(data)
x = linear(data) print(x.numpy())
print(x.numpy())
)DOC") )DOC")
.def("detach", .def("detach",
[](const imperative::VarBase &self) { [](const imperative::VarBase
const auto &tensor = self.Var().Get<framework::LoDTensor>(); &self) -> std::shared_ptr<imperative::VarBase> {
PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( self.Var().IsInitialized(), true,
"%s has not been initialized", self.Name())); platform::errors::InvalidArgument(
return self.NewVarBase(tensor.place(), false); "Tensor %s has not been initialized!", self.Name()));
PADDLE_ENFORCE_EQ(
self.Var().IsType<framework::LoDTensor>() ||
self.Var().IsType<framework::SelectedRows>(),
true,
platform::errors::InvalidArgument(
"Type of Tensor[%s] must be LoDTensor or SelectedRows!",
self.Name()));
auto detach_var = std::make_shared<imperative::VarBase>(
true, "detach_" + self.Name());
detach_var->SetPersistable(self.Persistable());
detach_var->SetType(self.Type());
detach_var->SetDataType(self.DataType());
if (self.Var().IsType<framework::LoDTensor>()) {
const auto &origin_tensor =
self.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
origin_tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_tensor =
detach_var->MutableVar()->GetMutable<framework::LoDTensor>();
detach_tensor->ShareDataWith(origin_tensor);
} else {
const auto &origin_selected_rows =
self.Var().Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(
origin_selected_rows.value().IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_selected_rows =
detach_var->MutableVar()
->GetMutable<framework::SelectedRows>();
detach_selected_rows->set_height(origin_selected_rows.height());
detach_selected_rows->set_rows(origin_selected_rows.rows());
detach_selected_rows->mutable_value()->ShareDataWith(
origin_selected_rows.value());
}
VLOG(3) << "The detached Tensor(" << detach_var->Name()
<< ") share data with " << self.Name();
return detach_var;
}, },
py::return_value_policy::copy, R"DOC( py::return_value_policy::take_ownership, R"DOC(
Returns a new Tensor, detached from the current graph. Returns a new Tensor, detached from the current graph.
It will share data with origin Tensor and always doesn't have a Tensor copy.
In addition, the detached Tensor doesn't provide gradient propagation.
Returns: The detached Tensor. Returns: The detached Tensor.
...@@ -688,10 +726,31 @@ void BindImperative(py::module *m_ptr) { ...@@ -688,10 +726,31 @@ void BindImperative(py::module *m_ptr) {
.. code-block:: python .. code-block:: python
import paddle import paddle
linear = Linear(32, 64)
data = paddle.uniform(shape=[30, 10, 32], -1, 1) x = paddle.to_tensor(1.0, stop_gradient=False)
x = linear(data) detach_x = x.detach()
y = x.detach() detach_x[:] = 10.0
print(x) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
# [10.])
y = x**2
y.backward()
print(x.grad) # [20.0]
print(detach_x.grad) # None, 'stop_gradient=True' by default
detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
z = detach_x**3
z.backward()
print(x.grad) # [20.0], detach_x is detached from x's graph, not affect each other
print(detach_x.grad) # [300.0], detach_x has its own graph
# Due to sharing of data with origin Tensor, There are some unsafe operations:
y = 2 * x
detach_x[:] = 5.0
y.backward()
# It will raise Error:
# one of the variables needed for gradient computation has been modified by an inplace operation.
)DOC") )DOC")
.def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC( .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(
......
...@@ -200,6 +200,31 @@ class TestVarBase(unittest.TestCase): ...@@ -200,6 +200,31 @@ class TestVarBase(unittest.TestCase):
var = fluid.dygraph.to_variable(t) var = fluid.dygraph.to_variable(t)
self.assertTrue(np.array_equal(t, var.numpy())) self.assertTrue(np.array_equal(t, var.numpy()))
def test_detach(self):
with fluid.dygraph.guard():
x = paddle.to_tensor(1.0, dtype="float64", stop_gradient=False)
detach_x = x.detach()
self.assertTrue(detach_x.stop_gradient, True)
detach_x[:] = 10.0
self.assertTrue(np.array_equal(x.numpy(), [10.0]))
y = x**2
y.backward()
self.assertTrue(np.array_equal(x.grad, [20.0]))
self.assertEqual(detach_x.grad, None)
detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
z = 3 * detach_x**2
z.backward()
self.assertTrue(np.array_equal(x.grad, [20.0]))
self.assertTrue(np.array_equal(detach_x.grad, [60.0]))
# Due to sharing of data with origin Tensor, There are some unsafe operations:
# with self.assertRaises(RuntimeError):
# y = 2 * x
# detach_x[:] = 5.0
# y.backward()
def test_write_property(self): def test_write_property(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册