未验证 提交 b10ecd9d 编写于 作者: L liym27 提交者: GitHub

[inplace] Add ShareHolderWith for class Variable and SharePlaceholderWith in...

[inplace] Add ShareHolderWith for class Variable and SharePlaceholderWith in VarBase.detach() to share the same Tensor/SelectedRows (#29267)
上级 9ad800eb
...@@ -69,6 +69,16 @@ class Variable { ...@@ -69,6 +69,16 @@ class Variable {
return holder_->Type(); return holder_->Type();
} }
/**
* The internal of two Variables share the same Placeholder whose type can be
* Tensor, LoDTensor, SelectedRows, LoDTensorArray, etc.
*
* NOTE(liym27): In dynamic mode, sharing the same Placeholder also means
* share the same TensorInplaceVersion, which is very important for inplace
* operations.
*/
void SharePlaceholderWith(const Variable& var);
private: private:
// This method hides type T, so it doesn't appear as a template parameter of // This method hides type T, so it doesn't appear as a template parameter of
// Variable. // Variable.
...@@ -113,6 +123,14 @@ class Variable { ...@@ -113,6 +123,14 @@ class Variable {
std::shared_ptr<Placeholder> holder_; std::shared_ptr<Placeholder> holder_;
}; };
inline void Variable::SharePlaceholderWith(const Variable& var) {
PADDLE_ENFORCE_EQ(var.IsInitialized(), true,
platform::errors::PreconditionNotMet(
"Variable holds no memory. "
"Call Variable::GetMutable() firstly."));
holder_ = var.holder_;
}
inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() { inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
framework::TensorInplaceVersion* version_counter_ptr(nullptr); framework::TensorInplaceVersion* version_counter_ptr(nullptr);
if (IsType<framework::LoDTensor>()) { if (IsType<framework::LoDTensor>()) {
......
...@@ -689,13 +689,15 @@ void BindImperative(py::module *m_ptr) { ...@@ -689,13 +689,15 @@ void BindImperative(py::module *m_ptr) {
x = linear(data) x = linear(data)
print(x.numpy()) print(x.numpy())
)DOC") )DOC")
.def("detach", .def(
[](const imperative::VarBase "detach",
&self) -> std::shared_ptr<imperative::VarBase> { [](const imperative::VarBase &self)
-> std::shared_ptr<imperative::VarBase> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
self.Var().IsInitialized(), true, self.Var().IsInitialized(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name())); "Tensor %s has not been initialized!", self.Name()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
self.Var().IsType<framework::LoDTensor>() || self.Var().IsType<framework::LoDTensor>() ||
self.Var().IsType<framework::SelectedRows>(), self.Var().IsType<framework::SelectedRows>(),
...@@ -703,38 +705,23 @@ void BindImperative(py::module *m_ptr) { ...@@ -703,38 +705,23 @@ void BindImperative(py::module *m_ptr) {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Type of Tensor[%s] must be LoDTensor or SelectedRows!", "Type of Tensor[%s] must be LoDTensor or SelectedRows!",
self.Name())); self.Name()));
auto detach_var = std::make_shared<imperative::VarBase>( auto detach_var = std::make_shared<imperative::VarBase>(
true, "detach_" + self.Name()); true, "detach_" + self.Name());
detach_var->SetPersistable(self.Persistable()); detach_var->SetPersistable(self.Persistable());
detach_var->SetType(self.Type()); detach_var->SetType(self.Type());
detach_var->SetDataType(self.DataType()); detach_var->SetDataType(self.DataType());
if (self.Var().IsType<framework::LoDTensor>()) {
const auto &origin_tensor =
self.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
origin_tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_tensor = // NOTE(liym27):
detach_var->MutableVar()->GetMutable<framework::LoDTensor>(); // Call Variable::SharePlaceholderWith but not
detach_tensor->ShareDataWith(origin_tensor); // Tensor::ShareDataWith or Tensor::ShareBufferWith, because
} else { // `detach_var` should share the same TensorInplaceVersion with
const auto &origin_selected_rows = // `self`, and only SharePlaceholderWith can also share the same
self.Var().Get<framework::SelectedRows>(); // TensorInplaceVersion, which is used to check whether inplace
PADDLE_ENFORCE_EQ( // operations are correct.
origin_selected_rows.value().IsInitialized(), true, detach_var->MutableVar()->SharePlaceholderWith(self.Var());
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_selected_rows =
detach_var->MutableVar()
->GetMutable<framework::SelectedRows>();
detach_selected_rows->set_height(origin_selected_rows.height());
detach_selected_rows->set_rows(origin_selected_rows.rows());
detach_selected_rows->mutable_value()->ShareDataWith(
origin_selected_rows.value());
}
VLOG(3) << "The detached Tensor(" << detach_var->Name() VLOG(3) << "The detached Tensor(" << detach_var->Name()
<< ") share data with " << self.Name(); << ") share data with " << self.Name();
return detach_var; return detach_var;
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
...@@ -161,5 +162,47 @@ class Test_Detach(unittest.TestCase): ...@@ -161,5 +162,47 @@ class Test_Detach(unittest.TestCase):
) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode" ) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
class TestInplace(unittest.TestCase):
def test_forward_version(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.assertEqual(var.inplace_version, 0)
detach_var_1 = var.detach()
self.assertEqual(detach_var_1.inplace_version, 0)
var[0] = 1.1
self.assertEqual(var.inplace_version, 1)
detach_var_2 = var.detach()
self.assertEqual(detach_var_2.inplace_version, 1)
var[0] = 3
self.assertEqual(detach_var_1.inplace_version, 2)
self.assertEqual(detach_var_2.inplace_version, 2)
def test_backward_error(self):
# It raises an error because the inplace operator will result
# in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
var_a = paddle.ones(shape=[4, 2, 3], dtype="float32")
var_a.stop_gradient = False
var_b = var_a**2
# Here, the gradient computation will use the value of var_b
var_c = var_b**2
detach_var_b = var_b.detach()
detach_var_b[1:2] = 3.3 # var_b is modified inplace
var_d = var_b**2
loss = paddle.nn.functional.relu(var_c + var_d)
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -243,11 +243,12 @@ class TestVarBase(unittest.TestCase): ...@@ -243,11 +243,12 @@ class TestVarBase(unittest.TestCase):
z.backward() z.backward()
self.assertTrue(np.array_equal(x.grad, [20.0])) self.assertTrue(np.array_equal(x.grad, [20.0]))
self.assertTrue(np.array_equal(detach_x.grad, [60.0])) self.assertTrue(np.array_equal(detach_x.grad, [60.0]))
# Due to sharing of data with origin Tensor, There are some unsafe operations: # Due to sharing of data with origin Tensor, There are some unsafe operations:
# with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
# y = 2 * x y = 2**x
# detach_x[:] = 5.0 detach_x[:] = 5.0
# y.backward() y.backward()
def test_write_property(self): def test_write_property(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册