未验证 提交 76530a2a 编写于 作者: J JYChen 提交者: GitHub

add IndexPutGradInfermeta to fix backward error in static-mode (#55602)

* add IndexPutGradInfermeta to fix backward error in static-mode

* codestyle
上级 38fbbe6b
......@@ -1092,8 +1092,7 @@
args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false)
output : Tensor(x_grad), Tensor(value_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, value]
func : IndexPutGradInferMeta
kernel :
func : index_put_grad
data_type : out_grad
......
......@@ -1202,6 +1202,21 @@ void IndexAddGradInferMeta(const MetaTensor& index,
}
}
void IndexPutGradInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& indices,
const MetaTensor& value,
const MetaTensor& out_grad,
bool accumulate,
MetaTensor* x_grad,
MetaTensor* value_grad) {
if (x_grad) {
x_grad->share_meta(x);
}
if (value_grad) {
value_grad->share_meta(value);
}
}
void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
......
......@@ -467,4 +467,11 @@ void IndexAddGradInferMeta(const MetaTensor& index,
MetaTensor* x_grad,
MetaTensor* add_tensor_grad);
void IndexPutGradInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& indices,
const MetaTensor& value,
const MetaTensor& out_grad,
bool accumulate,
MetaTensor* x_grad,
MetaTensor* value_grad);
} // namespace phi
......@@ -854,6 +854,39 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7,
)
def test_backward_in_static(self):
paddle.enable_static()
exe = paddle.static.Executor()
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
x = paddle.zeros((4, 2, 5))
x.stop_gradient = False
y = x + 1
index = paddle.to_tensor([0, 1, 3])
value = paddle.ones((5,))
value.stop_gradient = False
z = paddle.index_put(y, (index,), value)
l = z.sum()
paddle.static.append_backward(l)
res = exe.run(fetch_list=[z, x.grad_name, value.grad_name])
expected_z = np.ones((4, 2, 5))
expected_z[[0, 1, 3]] = np.ones((5,))
expected_x_grad = np.ones((4, 2, 5))
expected_x_grad[[0, 1, 3]] = 0
expected_v_grad = np.ones((5,)) * 3 * 2
np.testing.assert_allclose(expected_z, res[0])
np.testing.assert_allclose(expected_x_grad, res[1])
np.testing.assert_allclose(expected_v_grad, res[2])
paddle.disable_static()
class TestIndexPutAPIMixedIndices(TestIndexPutAPIBase):
def init_dtype_type(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册