From 76cab7519efe4649077c210a3a79271871932a3f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 8 Nov 2021 17:09:57 +0800 Subject: [PATCH] setitem support passing stop_gradient from value to tensor (#37028) att,Fix issue:36902 --- paddle/fluid/pybind/imperative.cc | 6 ++++++ .../fluid/tests/unittests/test_set_value_op.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 8b01f02ee2c..4403eb46972 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -985,6 +985,12 @@ void BindImperative(py::module *m_ptr) { auto value_tensor = value_obj.cast>(); ins.insert({"ValueTensor", {value_tensor}}); + + // pass the stop_gradient from value to tensor + if (!value_tensor->OverridedStopGradient() && + self->OverridedStopGradient()) { + self->SetOverridedStopGradient(false); + } } else if (py::isinstance(value_obj)) { auto value_tensor = std::shared_ptr( new imperative::VarBase(false, diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 21f506d03ce..e9809318cb3 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -1154,6 +1154,18 @@ class TestGradientTruncated(unittest.TestCase): msg="The gradient of input should be \n{},\n but reveived {}". format(value_grad, value.grad.numpy())) + # case 6: pass stop_gradient from value to x + x = paddle.zeros([8, 8], dtype='float32') + value = paddle.to_tensor([10], dtype='float32', stop_gradient=False) + + self.assertTrue(x.stop_gradient) + self.assertTrue(x.is_leaf) + + x[0, :] = value + + self.assertTrue(~x.stop_gradient) + self.assertTrue(~x.is_leaf) + def test_static_graph(self): paddle.enable_static() -- GitLab