未验证 提交 2de5075a 编写于 作者: W wawltor 提交者: GitHub

Fix the elementwise ops in broadcast in the process of backward (#24319)

* Remove the error in the elementwise op, use the backup mode to calculate
上级 5a29919a
...@@ -1040,22 +1040,21 @@ void CommonGradBroadcastCUDA( ...@@ -1040,22 +1040,21 @@ void CommonGradBroadcastCUDA(
// fallback // fallback
// to old fast path. // to old fast path.
// 2. if both x and y need broadcast, then do it one by one. // 2. if both x and y need broadcast, then do it one by one.
bool fast_broadcast = false;
if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
can_split_y = SplitDims(y_broadcast_pos, max_dim); can_split_y = SplitDims(y_broadcast_pos, max_dim);
if (can_split_y) { if (can_split_y) {
// only y need to do broadcast on h // only y need to do broadcast on h
if (y_broadcast_pos[0] == 0) { if (y_broadcast_pos[0] == 0) {
FastBroadCastHeightCUDAF(y_broadcast_pos, true); FastBroadCastHeightCUDAF(y_broadcast_pos, true);
} else { fast_broadcast = true;
LOG(ERROR) << "Error, broadcast should not into w broadcast";
} }
return;
} else if (y_broadcast_pos.size() == 1 || } else if (y_broadcast_pos.size() == 1 ||
CheckContiguousDims(y_broadcast_pos)) { // for only one dim and CheckContiguousDims(y_broadcast_pos)) { // for only one dim and
// contiguous broadcast. // contiguous broadcast.
// If cannot split, which means input has 3 parts // If cannot split, which means input has 3 parts
FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true); FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
return; fast_broadcast = true;
} }
} else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) { } else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) {
// only x need broadcast // only x need broadcast
...@@ -1063,49 +1062,53 @@ void CommonGradBroadcastCUDA( ...@@ -1063,49 +1062,53 @@ void CommonGradBroadcastCUDA(
if (can_split_x) { if (can_split_x) {
if (x_broadcast_pos[0] == 0) { if (x_broadcast_pos[0] == 0) {
FastBroadCastHeightCUDAF(x_broadcast_pos, false); FastBroadCastHeightCUDAF(x_broadcast_pos, false);
} else { fast_broadcast = true;
// x need to do broadcast on w
LOG(ERROR) << "Error, broadcast should not into w broadcast";
} }
return;
} else if (x_broadcast_pos.size() == 1 || } else if (x_broadcast_pos.size() == 1 ||
CheckContiguousDims(x_broadcast_pos)) { CheckContiguousDims(x_broadcast_pos)) {
FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false); FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
return; fast_broadcast = true;
} }
} else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { } else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
// do x and y broadcast each. // do x and y broadcast each.
can_split_y = SplitDims(y_broadcast_pos, max_dim); can_split_y = SplitDims(y_broadcast_pos, max_dim);
bool fast_broadcast_x = false;
bool fast_broadcast_y = false;
if (can_split_y) { if (can_split_y) {
// begin at start. // begin at start.
if (y_broadcast_pos[0] == 0) { if (y_broadcast_pos[0] == 0) {
FastCommonCUDAF(y_broadcast_pos, true); FastCommonCUDAF(y_broadcast_pos, true);
} else { fast_broadcast_y = true;
// finish at end
LOG(ERROR) << "Error, broadcast should not into w broadcast";
} }
} else if (y_broadcast_pos.size() == 1) { } else if (y_broadcast_pos.size() == 1) {
FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false); FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
can_split_y = true; can_split_y = true;
fast_broadcast_y = true;
} }
can_split_x = SplitDims(x_broadcast_pos, max_dim); can_split_x = SplitDims(x_broadcast_pos, max_dim);
if (can_split_x) { if (can_split_x) {
if (x_broadcast_pos[0] == 0) { if (x_broadcast_pos[0] == 0) {
FastCommonCUDAF(x_broadcast_pos, false); FastCommonCUDAF(x_broadcast_pos, false);
} else { fast_broadcast_x = true;
LOG(ERROR) << "Error, broadcast should not into w broadcast";
} }
} else if (x_broadcast_pos.size() == 1) { } else if (x_broadcast_pos.size() == 1) {
FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true); FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
can_split_x = true; can_split_x = true;
fast_broadcast_x = true;
} }
VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
<< " can_split_x:" << can_split_x; << " can_split_x:" << can_split_x;
// if both x and y into fast path then return // if both x and y into fast path then return
if (can_split_y && can_split_x) return; if (fast_broadcast_x && fast_broadcast_y) {
fast_broadcast = true;
}
if (can_split_y && can_split_x && fast_broadcast) return;
} }
// Should remove memory copy, use reg instead. // Should remove memory copy, use reg instead.
if (fast_broadcast) {
return;
}
int x_blocks = 0; int x_blocks = 0;
int x_threads = 0; int x_threads = 0;
ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks, ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks,
...@@ -1136,7 +1139,7 @@ void CommonGradBroadcastCUDA( ...@@ -1136,7 +1139,7 @@ void CommonGradBroadcastCUDA(
1, std::multiplies<int>()); 1, std::multiplies<int>());
int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads); int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads); int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
if (dx && !can_split_x) { if (dx) {
auto x_strides_order_tmp = memory::Alloc(ctx, bytes); auto x_strides_order_tmp = memory::Alloc(ctx, bytes);
int *x_strides_order_gpu = int *x_strides_order_gpu =
reinterpret_cast<int *>(x_strides_order_tmp->ptr()); reinterpret_cast<int *>(x_strides_order_tmp->ptr());
...@@ -1153,7 +1156,7 @@ void CommonGradBroadcastCUDA( ...@@ -1153,7 +1156,7 @@ void CommonGradBroadcastCUDA(
x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data, x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
dout_data, dx_data, out_size, max_dim, x_threads, dx_op); dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
} }
if (dy && !can_split_y) { if (dy) {
auto y_strides_order_tmp = memory::Alloc(ctx, bytes); auto y_strides_order_tmp = memory::Alloc(ctx, bytes);
int *y_strides_order_gpu = int *y_strides_order_gpu =
reinterpret_cast<int *>(y_strides_order_tmp->ptr()); reinterpret_cast<int *>(y_strides_order_tmp->ptr());
......
...@@ -263,6 +263,13 @@ class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp): ...@@ -263,6 +263,13 @@ class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp):
self.out = self.x + self.y self.out = self.x + self.y
class TestElementwiseAddOp_broadcast_7(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(1, 1, 20, 5).astype(self.dtype)
self.y = np.random.rand(20, 5, 1, 1).astype(self.dtype)
self.out = self.x + self.y
class TestFP16ElementwiseAddOp_broadcast_6(TestFP16ElementwiseAddOp): class TestFP16ElementwiseAddOp_broadcast_6(TestFP16ElementwiseAddOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype) self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册