未验证 提交 934934d8 编写于 作者: W wawltor 提交者: GitHub

fix the backward bug of cumsum (#50997)

上级 753fa844
......@@ -30,9 +30,20 @@ void CumsumGradKernel(const Context& dev_ctx,
bool exclusive,
bool reverse,
DenseTensor* x_grad) {
x_grad->Resize(x.dims());
auto x_dims = x.dims();
// If the attribute of flatten is `True`, the cumsum kernel is compose of the
// operation of flatten and cumsum, need to flatten the tensor of input
// gradient, and last step need to unflatten the tensor
if (flatten) {
x_grad->Resize(out_grad.dims());
} else {
x_grad->Resize(x_dims);
}
CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
if (flatten) {
x_grad->Resize(x_dims);
}
}
} // namespace phi
......
......@@ -44,9 +44,20 @@ void CumsumGradKernel(const Context& dev_ctx,
bool exclusive,
bool reverse,
DenseTensor* x_grad) {
x_grad->Resize(x.dims());
auto x_dims = x.dims();
// If the attribute of flatten is `True`, the cumsum kernel is compose of the
// operation of flatten and cumsum, need to flatten the tensor of input
// gradient, and last step need to unflatten the tensor
if (flatten) {
x_grad->Resize(out_grad.dims());
} else {
x_grad->Resize(x_dims);
}
CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
if (flatten) {
x_grad->Resize(x_dims);
}
}
} // namespace phi
......
......@@ -27,9 +27,20 @@ void CumsumGradKernel(const Context& dev_ctx,
bool exclusive,
bool reverse,
DenseTensor* x_grad) {
x_grad->Resize(x.dims());
auto x_dims = x.dims();
// If the attribute of flatten is `True`, the cumsum kernel is compose of the
// operation of flatten and cumsum, need to flatten the tensor of input
// gradient, and last step need to unflatten the tensor
if (flatten) {
x_grad->Resize(out_grad.dims());
} else {
x_grad->Resize(x_dims);
}
CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
if (flatten) {
x_grad->Resize(x_dims);
}
}
} // namespace phi
......
......@@ -200,6 +200,20 @@ class TestSumOp5(OpTest):
self.check_grad(['X'], 'Out')
class TestSumOp6(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.attrs = {'axis': -1, 'flatten': True}
self.inputs = {'X': np.random.random((5, 6, 5)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum()}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOp7(OpTest):
def setUp(self):
self.op_type = "cumsum"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册