From b7560a59ab7cd5a17037d35acefcc6f3f05ed56f Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 3 Feb 2021 19:50:33 +0800 Subject: [PATCH] fix the broadcast for the large second input (#30818) fix the broadcast for the large second input --- .../operators/elementwise/elementwise_op_function.h | 6 ++---- .../fluid/tests/unittests/test_elementwise_add_op.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index bce22ca9a7c..46b477afeb5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -708,10 +708,10 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel( int x_offset = b_i * post + b_j; if (dy) { dy[y_offset] = - dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]); } if (dx) { - val += dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]); } } if (dx) { @@ -1674,7 +1674,6 @@ void CommonElementwiseBroadcastBackward( GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, axis); - // for inplace strategy. memset will make dx and dout clear and get wrong // result. if (dx && dx->IsSharedBufferWith(dout)) { @@ -1762,7 +1761,6 @@ void ElemwiseGradComputeWithBroadcast( get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post, &is_run_common_broadcast); } - // special case for common backward implementation. if (is_run_common_broadcast) { CommonElementwiseBroadcastBackward( diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 6abc97fd583..fde7ea4b238 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -381,6 +381,16 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp): self.axis = 2 +class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 1, 12).astype(self.dtype) + self.y = np.random.rand(10, 3, 12).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = 0 + + class TestElementwiseAddOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): -- GitLab