未验证 提交 b7560a59 编写于 作者: W wawltor 提交者: GitHub

fix the broadcast for the large second input (#30818)

fix the broadcast for the large second input 
上级 6e1e036a
......@@ -708,10 +708,10 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel(
int x_offset = b_i * post + b_j;
if (dy) {
dy[y_offset] =
dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx) {
val += dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
}
}
if (dx) {
......@@ -1674,7 +1674,6 @@ void CommonElementwiseBroadcastBackward(
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if (dx && dx->IsSharedBufferWith(dout)) {
......@@ -1762,7 +1761,6 @@ void ElemwiseGradComputeWithBroadcast(
get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
}
// special case for common backward implementation.
if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>(
......
......@@ -381,6 +381,16 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp):
self.axis = 2
class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 1, 12).astype(self.dtype)
self.y = np.random.rand(10, 3, 12).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 0
class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册