From 2c02a580bb5f13cc6c8f7dc85458b380f0b2a53a Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Wed, 8 Dec 2021 12:56:41 +0800 Subject: [PATCH] add check whether tensor is inplace and leaf when calcute gradient (#37931) --- paddle/fluid/imperative/py_layer_fwd.h | 9 +++++++ .../fluid/tests/unittests/test_pylayer_op.py | 26 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/paddle/fluid/imperative/py_layer_fwd.h b/paddle/fluid/imperative/py_layer_fwd.h index 79251d7bf7a..159371970dc 100644 --- a/paddle/fluid/imperative/py_layer_fwd.h +++ b/paddle/fluid/imperative/py_layer_fwd.h @@ -226,6 +226,15 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, } } if (if_inplace) { + // when pylayer forward is inplace strategy, check whether tensor is leaf + for (auto& t : input_vars) { + PADDLE_ENFORCE_EQ(t->IsLeaf() && !t->OverridedStopGradient(), false, + platform::errors::InvalidArgument( + "Leaf Var (%s) that doesn't stop gradient can't " + "use inplace strategy.", + t->Name())); + } + inplace_map["X"] = "Out"; } diff --git a/python/paddle/fluid/tests/unittests/test_pylayer_op.py b/python/paddle/fluid/tests/unittests/test_pylayer_op.py index a852b4c9042..200273c6066 100644 --- a/python/paddle/fluid/tests/unittests/test_pylayer_op.py +++ b/python/paddle/fluid/tests/unittests/test_pylayer_op.py @@ -406,6 +406,32 @@ class TestPyLayer(unittest.TestCase): z.backward() self.assertTrue(data.grad is not None) + def test_pylayer_inplace_and_leaf_exception(self): + class cus_pylayer_op(PyLayer): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, dy): + return dy + + class Layer(paddle.nn.Layer): + def __init__(self): + super(Layer, self).__init__() + + def forward(self, data): + z = cus_pylayer_op.apply(data) + return z.mean() + + for i in range(2): + data = paddle.ones([2, 3], dtype="float64") / (i + 1) + data.stop_gradient = False + layer = Layer() + + with self.assertRaises(ValueError): + z = layer(data) + def test_backward_in_backward(self): class cus_tanh(PyLayer): @staticmethod -- GitLab