未验证 提交 8ca0a8a8 编写于 作者: Z Zhou Wei 提交者: GitHub

fix tensor detach to zero copy (#27921)

* fix tensor detach to zero copy

* fix tensor detach to zero copy
上级 13828db3
......@@ -643,44 +643,82 @@ void BindImperative(py::module *m_ptr) {
return TensorToPyArray(tensor, true);
},
R"DOC(
**Notes**:
**This API is ONLY available in Dygraph mode**
Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`
Returns a numpy array shows the value of current Tensor.
Returns:
ndarray: The numpy value of current Variable.
ndarray: The numpy value of current Tensor.
Returns type:
ndarray: dtype is same as current Variable
ndarray: dtype is same as current Tensor
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import Linear
import paddle
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with fluid.dygraph.guard():
linear = Linear(32, 64)
data = to_variable(data)
linear = paddle.nn.Linear(32, 64)
data = paddle.to_tensor(data)
x = linear(data)
print(x.numpy())
)DOC")
.def("detach",
[](const imperative::VarBase &self) {
const auto &tensor = self.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
[](const imperative::VarBase
&self) -> std::shared_ptr<imperative::VarBase> {
PADDLE_ENFORCE_EQ(
self.Var().IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
PADDLE_ENFORCE_EQ(
self.Var().IsType<framework::LoDTensor>() ||
self.Var().IsType<framework::SelectedRows>(),
true,
platform::errors::InvalidArgument(
"%s has not been initialized", self.Name()));
return self.NewVarBase(tensor.place(), false);
"Type of Tensor[%s] must be LoDTensor or SelectedRows!",
self.Name()));
auto detach_var = std::make_shared<imperative::VarBase>(
true, "detach_" + self.Name());
detach_var->SetPersistable(self.Persistable());
detach_var->SetType(self.Type());
detach_var->SetDataType(self.DataType());
if (self.Var().IsType<framework::LoDTensor>()) {
const auto &origin_tensor =
self.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
origin_tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_tensor =
detach_var->MutableVar()->GetMutable<framework::LoDTensor>();
detach_tensor->ShareDataWith(origin_tensor);
} else {
const auto &origin_selected_rows =
self.Var().Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(
origin_selected_rows.value().IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_selected_rows =
detach_var->MutableVar()
->GetMutable<framework::SelectedRows>();
detach_selected_rows->set_height(origin_selected_rows.height());
detach_selected_rows->set_rows(origin_selected_rows.rows());
detach_selected_rows->mutable_value()->ShareDataWith(
origin_selected_rows.value());
}
VLOG(3) << "The detached Tensor(" << detach_var->Name()
<< ") share data with " << self.Name();
return detach_var;
},
py::return_value_policy::copy, R"DOC(
py::return_value_policy::take_ownership, R"DOC(
Returns a new Tensor, detached from the current graph.
It will share data with origin Tensor and always doesn't have a Tensor copy.
In addition, the detached Tensor doesn't provide gradient propagation.
Returns: The detached Tensor.
......@@ -688,10 +726,31 @@ void BindImperative(py::module *m_ptr) {
.. code-block:: python
import paddle
linear = Linear(32, 64)
data = paddle.uniform(shape=[30, 10, 32], -1, 1)
x = linear(data)
y = x.detach()
x = paddle.to_tensor(1.0, stop_gradient=False)
detach_x = x.detach()
detach_x[:] = 10.0
print(x) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
# [10.])
y = x**2
y.backward()
print(x.grad) # [20.0]
print(detach_x.grad) # None, 'stop_gradient=True' by default
detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
z = detach_x**3
z.backward()
print(x.grad) # [20.0], detach_x is detached from x's graph, not affect each other
print(detach_x.grad) # [300.0], detach_x has its own graph
# Due to sharing of data with origin Tensor, There are some unsafe operations:
y = 2 * x
detach_x[:] = 5.0
y.backward()
# It will raise Error:
# one of the variables needed for gradient computation has been modified by an inplace operation.
)DOC")
.def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(
......
......@@ -200,6 +200,31 @@ class TestVarBase(unittest.TestCase):
var = fluid.dygraph.to_variable(t)
self.assertTrue(np.array_equal(t, var.numpy()))
def test_detach(self):
with fluid.dygraph.guard():
x = paddle.to_tensor(1.0, dtype="float64", stop_gradient=False)
detach_x = x.detach()
self.assertTrue(detach_x.stop_gradient, True)
detach_x[:] = 10.0
self.assertTrue(np.array_equal(x.numpy(), [10.0]))
y = x**2
y.backward()
self.assertTrue(np.array_equal(x.grad, [20.0]))
self.assertEqual(detach_x.grad, None)
detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
z = 3 * detach_x**2
z.backward()
self.assertTrue(np.array_equal(x.grad, [20.0]))
self.assertTrue(np.array_equal(detach_x.grad, [60.0]))
# Due to sharing of data with origin Tensor, There are some unsafe operations:
# with self.assertRaises(RuntimeError):
# y = 2 * x
# detach_x[:] = 5.0
# y.backward()
def test_write_property(self):
with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册