未验证 提交 aef8bf2a 编写于 作者: Z zyfncg 提交者: GitHub

setitem support passing stop_gradient from value to tensor (#37023)

上级 ac1d3571
......@@ -985,6 +985,12 @@ void BindImperative(py::module *m_ptr) {
auto value_tensor =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
ins.insert({"ValueTensor", {value_tensor}});
// pass the stop_gradient from value to tensor
if (!value_tensor->OverridedStopGradient() &&
self->OverridedStopGradient()) {
self->SetOverridedStopGradient(false);
}
} else if (py::isinstance<py::array>(value_obj)) {
auto value_tensor = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(false,
......
......@@ -1154,6 +1154,18 @@ class TestGradientTruncated(unittest.TestCase):
msg="The gradient of input should be \n{},\n but reveived {}".
format(value_grad, value.grad.numpy()))
# case 6: pass stop_gradient from value to x
x = paddle.zeros([8, 8], dtype='float32')
value = paddle.to_tensor([10], dtype='float32', stop_gradient=False)
self.assertTrue(x.stop_gradient)
self.assertTrue(x.is_leaf)
x[0, :] = value
self.assertTrue(~x.stop_gradient)
self.assertTrue(~x.is_leaf)
def test_static_graph(self):
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册