未验证 提交 b10ecd9d 编写于 作者: L liym27 提交者: GitHub

[inplace] Add ShareHolderWith for class Variable and SharePlaceholderWith in...

[inplace] Add ShareHolderWith for class Variable and SharePlaceholderWith in VarBase.detach() to share the same Tensor/SelectedRows (#29267)
上级 9ad800eb
...@@ -69,6 +69,16 @@ class Variable { ...@@ -69,6 +69,16 @@ class Variable {
return holder_->Type(); return holder_->Type();
} }
/**
* The internal of two Variables share the same Placeholder whose type can be
* Tensor, LoDTensor, SelectedRows, LoDTensorArray, etc.
*
* NOTE(liym27): In dynamic mode, sharing the same Placeholder also means
* share the same TensorInplaceVersion, which is very important for inplace
* operations.
*/
void SharePlaceholderWith(const Variable& var);
private: private:
// This method hides type T, so it doesn't appear as a template parameter of // This method hides type T, so it doesn't appear as a template parameter of
// Variable. // Variable.
...@@ -113,6 +123,14 @@ class Variable { ...@@ -113,6 +123,14 @@ class Variable {
std::shared_ptr<Placeholder> holder_; std::shared_ptr<Placeholder> holder_;
}; };
inline void Variable::SharePlaceholderWith(const Variable& var) {
PADDLE_ENFORCE_EQ(var.IsInitialized(), true,
platform::errors::PreconditionNotMet(
"Variable holds no memory. "
"Call Variable::GetMutable() firstly."));
holder_ = var.holder_;
}
inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() { inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
framework::TensorInplaceVersion* version_counter_ptr(nullptr); framework::TensorInplaceVersion* version_counter_ptr(nullptr);
if (IsType<framework::LoDTensor>()) { if (IsType<framework::LoDTensor>()) {
......
...@@ -689,57 +689,44 @@ void BindImperative(py::module *m_ptr) { ...@@ -689,57 +689,44 @@ void BindImperative(py::module *m_ptr) {
x = linear(data) x = linear(data)
print(x.numpy()) print(x.numpy())
)DOC") )DOC")
.def("detach", .def(
[](const imperative::VarBase "detach",
&self) -> std::shared_ptr<imperative::VarBase> { [](const imperative::VarBase &self)
PADDLE_ENFORCE_EQ( -> std::shared_ptr<imperative::VarBase> {
self.Var().IsInitialized(), true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( self.Var().IsInitialized(), true,
"Tensor %s has not been initialized!", self.Name())); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ( "Tensor %s has not been initialized!", self.Name()));
self.Var().IsType<framework::LoDTensor>() ||
self.Var().IsType<framework::SelectedRows>(), PADDLE_ENFORCE_EQ(
true, self.Var().IsType<framework::LoDTensor>() ||
platform::errors::InvalidArgument( self.Var().IsType<framework::SelectedRows>(),
"Type of Tensor[%s] must be LoDTensor or SelectedRows!", true,
self.Name())); platform::errors::InvalidArgument(
auto detach_var = std::make_shared<imperative::VarBase>( "Type of Tensor[%s] must be LoDTensor or SelectedRows!",
true, "detach_" + self.Name()); self.Name()));
detach_var->SetPersistable(self.Persistable());
detach_var->SetType(self.Type()); auto detach_var = std::make_shared<imperative::VarBase>(
detach_var->SetDataType(self.DataType()); true, "detach_" + self.Name());
if (self.Var().IsType<framework::LoDTensor>()) {
const auto &origin_tensor = detach_var->SetPersistable(self.Persistable());
self.Var().Get<framework::LoDTensor>(); detach_var->SetType(self.Type());
PADDLE_ENFORCE_EQ( detach_var->SetDataType(self.DataType());
origin_tensor.IsInitialized(), true,
platform::errors::InvalidArgument( // NOTE(liym27):
"Tensor %s has not been initialized!", self.Name())); // Call Variable::SharePlaceholderWith but not
// Tensor::ShareDataWith or Tensor::ShareBufferWith, because
auto *detach_tensor = // `detach_var` should share the same TensorInplaceVersion with
detach_var->MutableVar()->GetMutable<framework::LoDTensor>(); // `self`, and only SharePlaceholderWith can also share the same
detach_tensor->ShareDataWith(origin_tensor); // TensorInplaceVersion, which is used to check whether inplace
} else { // operations are correct.
const auto &origin_selected_rows = detach_var->MutableVar()->SharePlaceholderWith(self.Var());
self.Var().Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ( VLOG(3) << "The detached Tensor(" << detach_var->Name()
origin_selected_rows.value().IsInitialized(), true, << ") share data with " << self.Name();
platform::errors::InvalidArgument( return detach_var;
"Tensor %s has not been initialized!", self.Name())); },
py::return_value_policy::take_ownership, R"DOC(
auto *detach_selected_rows =
detach_var->MutableVar()
->GetMutable<framework::SelectedRows>();
detach_selected_rows->set_height(origin_selected_rows.height());
detach_selected_rows->set_rows(origin_selected_rows.rows());
detach_selected_rows->mutable_value()->ShareDataWith(
origin_selected_rows.value());
}
VLOG(3) << "The detached Tensor(" << detach_var->Name()
<< ") share data with " << self.Name();
return detach_var;
},
py::return_value_policy::take_ownership, R"DOC(
Returns a new Tensor, detached from the current graph. Returns a new Tensor, detached from the current graph.
It will share data with origin Tensor and always doesn't have a Tensor copy. It will share data with origin Tensor and always doesn't have a Tensor copy.
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
...@@ -161,5 +162,47 @@ class Test_Detach(unittest.TestCase): ...@@ -161,5 +162,47 @@ class Test_Detach(unittest.TestCase):
) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode" ) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
class TestInplace(unittest.TestCase):
def test_forward_version(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.assertEqual(var.inplace_version, 0)
detach_var_1 = var.detach()
self.assertEqual(detach_var_1.inplace_version, 0)
var[0] = 1.1
self.assertEqual(var.inplace_version, 1)
detach_var_2 = var.detach()
self.assertEqual(detach_var_2.inplace_version, 1)
var[0] = 3
self.assertEqual(detach_var_1.inplace_version, 2)
self.assertEqual(detach_var_2.inplace_version, 2)
def test_backward_error(self):
# It raises an error because the inplace operator will result
# in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
var_a = paddle.ones(shape=[4, 2, 3], dtype="float32")
var_a.stop_gradient = False
var_b = var_a**2
# Here, the gradient computation will use the value of var_b
var_c = var_b**2
detach_var_b = var_b.detach()
detach_var_b[1:2] = 3.3 # var_b is modified inplace
var_d = var_b**2
loss = paddle.nn.functional.relu(var_c + var_d)
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -243,11 +243,12 @@ class TestVarBase(unittest.TestCase): ...@@ -243,11 +243,12 @@ class TestVarBase(unittest.TestCase):
z.backward() z.backward()
self.assertTrue(np.array_equal(x.grad, [20.0])) self.assertTrue(np.array_equal(x.grad, [20.0]))
self.assertTrue(np.array_equal(detach_x.grad, [60.0])) self.assertTrue(np.array_equal(detach_x.grad, [60.0]))
# Due to sharing of data with origin Tensor, There are some unsafe operations: # Due to sharing of data with origin Tensor, There are some unsafe operations:
# with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
# y = 2 * x y = 2**x
# detach_x[:] = 5.0 detach_x[:] = 5.0
# y.backward() y.backward()
def test_write_property(self): def test_write_property(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册