diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index bce22ca9a7c20ed0fdeeeae4a45b98a20cca03d4..46b477afeb535fb53ef632b2e381f6a8eb5ae228 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -708,10 +708,10 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel( int x_offset = b_i * post + b_j; if (dy) { dy[y_offset] = - dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]); } if (dx) { - val += dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]); } } if (dx) { @@ -1674,7 +1674,6 @@ void CommonElementwiseBroadcastBackward( GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, axis); - // for inplace strategy. memset will make dx and dout clear and get wrong // result. if (dx && dx->IsSharedBufferWith(dout)) { @@ -1762,7 +1761,6 @@ void ElemwiseGradComputeWithBroadcast( get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post, &is_run_common_broadcast); } - // special case for common backward implementation. if (is_run_common_broadcast) { CommonElementwiseBroadcastBackward( diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 6abc97fd583fb354877f6d86711e81da237201e9..fde7ea4b23801ed8b07ea72e078ed7646ec02aa7 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -381,6 +381,16 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp): self.axis = 2 +class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 1, 12).astype(self.dtype) + self.y = np.random.rand(10, 3, 12).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = 0 + + class TestElementwiseAddOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()):