diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 465df08392d91094d3586965dbcaa2a7661d3167..924a5f59d505567fe2fc50fce1c9cd0342d3a7f8 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1092,8 +1092,7 @@ args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false) output : Tensor(x_grad), Tensor(value_grad) infer_meta : - func : GeneralBinaryGradInferMeta - param : [x, value] + func : IndexPutGradInferMeta kernel : func : index_put_grad data_type : out_grad diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index b028fd15b1b9396940f0660f88ddd25986345095..d1078e2d176bc0383aa3fd6ad9ec9577affe6036 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1202,6 +1202,21 @@ void IndexAddGradInferMeta(const MetaTensor& index, } } +void IndexPutGradInferMeta(const MetaTensor& x, + const std::vector& indices, + const MetaTensor& value, + const MetaTensor& out_grad, + bool accumulate, + MetaTensor* x_grad, + MetaTensor* value_grad) { + if (x_grad) { + x_grad->share_meta(x); + } + if (value_grad) { + value_grad->share_meta(value); + } +} + void FusedRopeGradInferMeta(const MetaTensor& dout_q, const MetaTensor& dout_k, const MetaTensor& dout_v, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index cb923e16446af203a023e04240515ce6bfb5078d..c73e5ab7a4d9ecec63beb04c0ff4929df9253dae 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -467,4 +467,11 @@ void IndexAddGradInferMeta(const MetaTensor& index, MetaTensor* x_grad, MetaTensor* add_tensor_grad); +void IndexPutGradInferMeta(const MetaTensor& x, + const std::vector& indices, + const MetaTensor& value, + const MetaTensor& out_grad, + bool accumulate, + MetaTensor* x_grad, + MetaTensor* value_grad); } // namespace phi diff --git a/test/legacy_test/test_index_put_op.py b/test/legacy_test/test_index_put_op.py index c4bf5d6f0fd401f608eefb5495bee4e4f776f7a6..f21f7b084bde41dbf01853d19d633529b357c894 100644 --- a/test/legacy_test/test_index_put_op.py +++ b/test/legacy_test/test_index_put_op.py @@ -854,6 +854,39 @@ class TestIndexPutAPIBackward(unittest.TestCase): atol=1e-7, ) + def test_backward_in_static(self): + paddle.enable_static() + exe = paddle.static.Executor() + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(train_program, startup_program): + x = paddle.zeros((4, 2, 5)) + x.stop_gradient = False + + y = x + 1 + index = paddle.to_tensor([0, 1, 3]) + + value = paddle.ones((5,)) + value.stop_gradient = False + + z = paddle.index_put(y, (index,), value) + l = z.sum() + paddle.static.append_backward(l) + res = exe.run(fetch_list=[z, x.grad_name, value.grad_name]) + + expected_z = np.ones((4, 2, 5)) + expected_z[[0, 1, 3]] = np.ones((5,)) + + expected_x_grad = np.ones((4, 2, 5)) + expected_x_grad[[0, 1, 3]] = 0 + + expected_v_grad = np.ones((5,)) * 3 * 2 + + np.testing.assert_allclose(expected_z, res[0]) + np.testing.assert_allclose(expected_x_grad, res[1]) + np.testing.assert_allclose(expected_v_grad, res[2]) + paddle.disable_static() + class TestIndexPutAPIMixedIndices(TestIndexPutAPIBase): def init_dtype_type(self):