未验证 提交 b7560a59 编写于 作者: W wawltor 提交者: GitHub

fix the broadcast for the large second input (#30818)

fix the broadcast for the large second input 
上级 6e1e036a
...@@ -708,10 +708,10 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel( ...@@ -708,10 +708,10 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel(
int x_offset = b_i * post + b_j; int x_offset = b_i * post + b_j;
if (dy) { if (dy) {
dy[y_offset] = dy[y_offset] =
dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
} }
if (dx) { if (dx) {
val += dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
} }
} }
if (dx) { if (dx) {
...@@ -1674,7 +1674,6 @@ void CommonElementwiseBroadcastBackward( ...@@ -1674,7 +1674,6 @@ void CommonElementwiseBroadcastBackward(
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim, y_dims_array.data(), out_dims_array.data(), max_dim,
axis); axis);
// for inplace strategy. memset will make dx and dout clear and get wrong // for inplace strategy. memset will make dx and dout clear and get wrong
// result. // result.
if (dx && dx->IsSharedBufferWith(dout)) { if (dx && dx->IsSharedBufferWith(dout)) {
...@@ -1762,7 +1761,6 @@ void ElemwiseGradComputeWithBroadcast( ...@@ -1762,7 +1761,6 @@ void ElemwiseGradComputeWithBroadcast(
get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post, get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
} }
// special case for common backward implementation. // special case for common backward implementation.
if (is_run_common_broadcast) { if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>( CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>(
......
...@@ -381,6 +381,16 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp): ...@@ -381,6 +381,16 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp):
self.axis = 2 self.axis = 2
class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 1, 12).astype(self.dtype)
self.y = np.random.rand(10, 3, 12).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 0
class TestElementwiseAddOpError(unittest.TestCase): class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册