diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 8b01f02ee2c3a6f734c85133755cfcdb54bb6cd6..4403eb469723a5afdbc602d3f6ce19966536079c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -985,6 +985,12 @@ void BindImperative(py::module *m_ptr) { auto value_tensor = value_obj.cast>(); ins.insert({"ValueTensor", {value_tensor}}); + + // pass the stop_gradient from value to tensor + if (!value_tensor->OverridedStopGradient() && + self->OverridedStopGradient()) { + self->SetOverridedStopGradient(false); + } } else if (py::isinstance(value_obj)) { auto value_tensor = std::shared_ptr( new imperative::VarBase(false, diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 21f506d03ce68e7eb47d185c06aeab5f4ba4cabd..e9809318cb393163d1c9c154ea478bd54c942c83 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -1154,6 +1154,18 @@ class TestGradientTruncated(unittest.TestCase): msg="The gradient of input should be \n{},\n but reveived {}". format(value_grad, value.grad.numpy())) + # case 6: pass stop_gradient from value to x + x = paddle.zeros([8, 8], dtype='float32') + value = paddle.to_tensor([10], dtype='float32', stop_gradient=False) + + self.assertTrue(x.stop_gradient) + self.assertTrue(x.is_leaf) + + x[0, :] = value + + self.assertTrue(~x.stop_gradient) + self.assertTrue(~x.is_leaf) + def test_static_graph(self): paddle.enable_static()