未验证 提交 934934d8 编写于 作者: W wawltor 提交者: GitHub

fix the backward bug of cumsum (#50997)

上级 753fa844
...@@ -30,9 +30,20 @@ void CumsumGradKernel(const Context& dev_ctx, ...@@ -30,9 +30,20 @@ void CumsumGradKernel(const Context& dev_ctx,
bool exclusive, bool exclusive,
bool reverse, bool reverse,
DenseTensor* x_grad) { DenseTensor* x_grad) {
x_grad->Resize(x.dims()); auto x_dims = x.dims();
// If the attribute of flatten is `True`, the cumsum kernel is compose of the
// operation of flatten and cumsum, need to flatten the tensor of input
// gradient, and last step need to unflatten the tensor
if (flatten) {
x_grad->Resize(out_grad.dims());
} else {
x_grad->Resize(x_dims);
}
CumsumKernel<T, Context>( CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad); dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
if (flatten) {
x_grad->Resize(x_dims);
}
} }
} // namespace phi } // namespace phi
......
...@@ -44,9 +44,20 @@ void CumsumGradKernel(const Context& dev_ctx, ...@@ -44,9 +44,20 @@ void CumsumGradKernel(const Context& dev_ctx,
bool exclusive, bool exclusive,
bool reverse, bool reverse,
DenseTensor* x_grad) { DenseTensor* x_grad) {
x_grad->Resize(x.dims()); auto x_dims = x.dims();
// If the attribute of flatten is `True`, the cumsum kernel is compose of the
// operation of flatten and cumsum, need to flatten the tensor of input
// gradient, and last step need to unflatten the tensor
if (flatten) {
x_grad->Resize(out_grad.dims());
} else {
x_grad->Resize(x_dims);
}
CumsumKernel<T, Context>( CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad); dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
if (flatten) {
x_grad->Resize(x_dims);
}
} }
} // namespace phi } // namespace phi
......
...@@ -27,9 +27,20 @@ void CumsumGradKernel(const Context& dev_ctx, ...@@ -27,9 +27,20 @@ void CumsumGradKernel(const Context& dev_ctx,
bool exclusive, bool exclusive,
bool reverse, bool reverse,
DenseTensor* x_grad) { DenseTensor* x_grad) {
x_grad->Resize(x.dims()); auto x_dims = x.dims();
// If the attribute of flatten is `True`, the cumsum kernel is compose of the
// operation of flatten and cumsum, need to flatten the tensor of input
// gradient, and last step need to unflatten the tensor
if (flatten) {
x_grad->Resize(out_grad.dims());
} else {
x_grad->Resize(x_dims);
}
CumsumKernel<T, Context>( CumsumKernel<T, Context>(
dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad); dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad);
if (flatten) {
x_grad->Resize(x_dims);
}
} }
} // namespace phi } // namespace phi
......
...@@ -200,6 +200,20 @@ class TestSumOp5(OpTest): ...@@ -200,6 +200,20 @@ class TestSumOp5(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestSumOp6(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.attrs = {'axis': -1, 'flatten': True}
self.inputs = {'X': np.random.random((5, 6, 5)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum()}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOp7(OpTest): class TestSumOp7(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cumsum" self.op_type = "cumsum"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册