未验证 提交 2de5075a 编写于 作者: W wawltor 提交者: GitHub

Fix the elementwise ops in broadcast in the process of backward (#24319)

* Remove the error in the elementwise op, use the backup mode to calculate
上级 5a29919a
......@@ -1040,22 +1040,21 @@ void CommonGradBroadcastCUDA(
// fallback
// to old fast path.
// 2. if both x and y need broadcast, then do it one by one.
bool fast_broadcast = false;
if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
can_split_y = SplitDims(y_broadcast_pos, max_dim);
if (can_split_y) {
// only y need to do broadcast on h
if (y_broadcast_pos[0] == 0) {
FastBroadCastHeightCUDAF(y_broadcast_pos, true);
} else {
LOG(ERROR) << "Error, broadcast should not into w broadcast";
fast_broadcast = true;
}
return;
} else if (y_broadcast_pos.size() == 1 ||
CheckContiguousDims(y_broadcast_pos)) { // for only one dim and
// contiguous broadcast.
// If cannot split, which means input has 3 parts
FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
return;
fast_broadcast = true;
}
} else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) {
// only x need broadcast
......@@ -1063,49 +1062,53 @@ void CommonGradBroadcastCUDA(
if (can_split_x) {
if (x_broadcast_pos[0] == 0) {
FastBroadCastHeightCUDAF(x_broadcast_pos, false);
} else {
// x need to do broadcast on w
LOG(ERROR) << "Error, broadcast should not into w broadcast";
fast_broadcast = true;
}
return;
} else if (x_broadcast_pos.size() == 1 ||
CheckContiguousDims(x_broadcast_pos)) {
FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
return;
fast_broadcast = true;
}
} else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
// do x and y broadcast each.
can_split_y = SplitDims(y_broadcast_pos, max_dim);
bool fast_broadcast_x = false;
bool fast_broadcast_y = false;
if (can_split_y) {
// begin at start.
if (y_broadcast_pos[0] == 0) {
FastCommonCUDAF(y_broadcast_pos, true);
} else {
// finish at end
LOG(ERROR) << "Error, broadcast should not into w broadcast";
fast_broadcast_y = true;
}
} else if (y_broadcast_pos.size() == 1) {
FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
can_split_y = true;
fast_broadcast_y = true;
}
can_split_x = SplitDims(x_broadcast_pos, max_dim);
if (can_split_x) {
if (x_broadcast_pos[0] == 0) {
FastCommonCUDAF(x_broadcast_pos, false);
} else {
LOG(ERROR) << "Error, broadcast should not into w broadcast";
fast_broadcast_x = true;
}
} else if (x_broadcast_pos.size() == 1) {
FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
can_split_x = true;
fast_broadcast_x = true;
}
VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
<< " can_split_x:" << can_split_x;
// if both x and y into fast path then return
if (can_split_y && can_split_x) return;
if (fast_broadcast_x && fast_broadcast_y) {
fast_broadcast = true;
}
if (can_split_y && can_split_x && fast_broadcast) return;
}
// Should remove memory copy, use reg instead.
if (fast_broadcast) {
return;
}
int x_blocks = 0;
int x_threads = 0;
ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks,
......@@ -1136,7 +1139,7 @@ void CommonGradBroadcastCUDA(
1, std::multiplies<int>());
int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
if (dx && !can_split_x) {
if (dx) {
auto x_strides_order_tmp = memory::Alloc(ctx, bytes);
int *x_strides_order_gpu =
reinterpret_cast<int *>(x_strides_order_tmp->ptr());
......@@ -1153,7 +1156,7 @@ void CommonGradBroadcastCUDA(
x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
}
if (dy && !can_split_y) {
if (dy) {
auto y_strides_order_tmp = memory::Alloc(ctx, bytes);
int *y_strides_order_gpu =
reinterpret_cast<int *>(y_strides_order_tmp->ptr());
......
......@@ -263,6 +263,13 @@ class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp):
self.out = self.x + self.y
class TestElementwiseAddOp_broadcast_7(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(1, 1, 20, 5).astype(self.dtype)
self.y = np.random.rand(20, 5, 1, 1).astype(self.dtype)
self.out = self.x + self.y
class TestFP16ElementwiseAddOp_broadcast_6(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册