diff --git a/paddle/fluid/pybind/eager_py_layer.cc b/paddle/fluid/pybind/eager_py_layer.cc index 9e9231415fb4a2a8b515c420f33c4c32d688441a..59f21a1e1face6f6a95312c4ef0b75af705e017b 100644 --- a/paddle/fluid/pybind/eager_py_layer.cc +++ b/paddle/fluid/pybind/eager_py_layer.cc @@ -43,8 +43,7 @@ namespace py = ::pybind11; PyTypeObject* p_pylayer_type; extern PyTypeObject* p_tensor_type; -std::set GetNonDifferentiableNames( - PyObject* obj) { +std::set GetTensorsFromPyObject(PyObject* obj) { std::set result; if (obj == nullptr) { return result; @@ -298,8 +297,7 @@ PyObject* pylayer_method_apply(PyObject* cls, PyObject* args, VLOG(6) << "PyLayer forward function finish..."; if (require_any_grad && trace_backward) { - auto non_differentiable = - GetNonDifferentiableNames(ctx->non_differentiable); + auto non_differentiable = GetTensorsFromPyObject(ctx->non_differentiable); for (size_t i = 0; i < outputs_autograd_meta.size(); i++) { for (size_t j = 0; j < outputs_autograd_meta[i].size(); j++) { if (non_differentiable.find(outputs_tensor[i][j]) != @@ -311,7 +309,22 @@ PyObject* pylayer_method_apply(PyObject* cls, PyObject* args, } } - // TODO(pangyoki) add inplace, inplaced tensor is ctx->dirty_tensors + // add inplace strategy, inplaced tensor is ctx->dirty_tensors + auto dirty_tensors = GetTensorsFromPyObject(ctx->dirty_tensors); + for (auto it = dirty_tensors.begin(); it != dirty_tensors.end(); ++it) { + auto dirty_tensor = *it; + auto dirty_tensor_autograd_meta = + egr::EagerUtils::autograd_meta(dirty_tensor); + PADDLE_ENFORCE_EQ(!dirty_tensor_autograd_meta->StopGradient() && + egr::egr_utils_api::IsLeafTensor(*dirty_tensor), + false, paddle::platform::errors::InvalidArgument( + "Leaf Var (%s) that doesn't stop gradient " + "can't use inplace strategy.", + dirty_tensor->name())); + dirty_tensor->bump_inplace_version(); + VLOG(3) << "Tensor(" << dirty_tensor->name() + << ") uses Inplace Strategy."; + } auto grad_node = std::make_shared( reinterpret_cast(ctx), outputs_autograd_meta.size(), diff --git a/python/paddle/fluid/tests/unittests/test_pylayer_op.py b/python/paddle/fluid/tests/unittests/test_pylayer_op.py index 786f4cb7a74d48bb3192668bc09f8caee249dd30..91e7b5d00e1a74023d6fd06c2f314af996ef3d05 100644 --- a/python/paddle/fluid/tests/unittests/test_pylayer_op.py +++ b/python/paddle/fluid/tests/unittests/test_pylayer_op.py @@ -424,7 +424,7 @@ class TestPyLayer(unittest.TestCase): self.func_test_pylayer_bk_return_none() self.func_test_pylayer_bk_return_none() - def test_pylayer_inplace(self): + def func_test_pylayer_inplace(self): class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): @@ -452,10 +452,115 @@ class TestPyLayer(unittest.TestCase): z.backward() self.assertTrue(data.grad is not None) - def test_pylayer_inplace_and_leaf_exception(self): + def test_pylayer_inplace(self): + with _test_eager_guard(): + self.func_test_pylayer_inplace() + self.func_test_pylayer_inplace() + + def test_pylayer_inplace_backward_error(self): + with _test_eager_guard(): + + class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): + @staticmethod + def forward(ctx, x): + ctx.mark_dirty(x) + return x + + @staticmethod + def backward(ctx, dy): + return dy + + class Layer(paddle.nn.Layer): + def __init__(self): + super(Layer, self).__init__() + + def forward(self, data): + var_b = data**2 + var_c = var_b**2 + z = cus_tanh.apply(var_b) + loss = paddle.nn.functional.relu(var_c) + return loss + + data = paddle.ones([2, 3], dtype="float64") + data.stop_gradient = False + layer = Layer() + z = layer(data) + with self.assertRaisesRegexp( + RuntimeError, + "received current_inplace_version:{} != inplace_version_snapshot_:{}". + format(1, 0)): + z.backward() + + def test_pylayer_inplace_backward_success_1(self): + with _test_eager_guard(): + + class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): + @staticmethod + def forward(ctx, x): + ctx.mark_dirty(x) + return x + + @staticmethod + def backward(ctx, dy): + return dy + + class Layer(paddle.nn.Layer): + def __init__(self): + super(Layer, self).__init__() + + def forward(self, data): + var_b = data**2 + var_c = cus_tanh.apply(var_b) + var_d = var_c**2 + loss = var_d.sum() + return loss + + for i in range(2): + data = paddle.ones([2, 3], dtype="float64") / (i + 1) + data.stop_gradient = False + layer = Layer() + z = layer(data) + z.backward() + self.assertTrue(data.grad is not None) + + def test_pylayer_inplace_backward_success_2(self): + with _test_eager_guard(): + + class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): + @staticmethod + def forward(ctx, x): + ctx.mark_dirty(x) + return x + + @staticmethod + def backward(ctx, dy): + return dy + + class Layer(paddle.nn.Layer): + def __init__(self): + super(Layer, self).__init__() + + def forward(self, data): + var_b = data**2 + var_c = cus_tanh.apply(var_b) + var_d = var_c + var_c + loss = var_d.sum() + return loss + + for i in range(2): + data = paddle.ones([2, 3], dtype="float64") / (i + 1) + data.stop_gradient = False + layer = Layer() + z = layer(data) + z.backward() + self.assertTrue(data.grad is not None) + + def func_test_pylayer_inplace_and_leaf_exception(self): class cus_pylayer_op(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): + if in_dygraph_mode(): + ctx.mark_dirty(x) return x @staticmethod @@ -478,6 +583,11 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z = layer(data) + def test_pylayer_inplace_and_leaf_exception(self): + with _test_eager_guard(): + self.func_test_pylayer_inplace_and_leaf_exception() + self.func_test_pylayer_inplace_and_leaf_exception() + def func_test_backward_in_backward(self): class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod