未验证 提交 8c1b2fa6 编写于 作者: F FlyingQianMM 提交者: GitHub

Reduce the number of threads per block of deformable_psroi_pooling to solve...

Reduce the number of threads per block of deformable_psroi_pooling to solve the bug where too many resources requested for launch (#42531)
上级 6ea2f049
...@@ -39,10 +39,10 @@ namespace operators { ...@@ -39,10 +39,10 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
const int CUDA_NUM_THREADS = 1024;
static inline int GET_BLOCKS(const int N) { static inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
} }
template <typename T> template <typename T>
...@@ -252,8 +252,8 @@ class DeformablePSROIPoolCUDAKernel : public framework::OpKernel<T> { ...@@ -252,8 +252,8 @@ class DeformablePSROIPoolCUDAKernel : public framework::OpKernel<T> {
T* top_data = out->mutable_data<T>(ctx.GetPlace()); T* top_data = out->mutable_data<T>(ctx.GetPlace());
T* top_count_data = top_count->mutable_data<T>(ctx.GetPlace()); T* top_count_data = top_count->mutable_data<T>(ctx.GetPlace());
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, DeformablePSROIPoolForwardKernel<<<
dev_ctx.stream()>>>( GET_BLOCKS(count), PADDLE_CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
count, bottom_data, (T)spatial_scale, channels, height, width, count, bottom_data, (T)spatial_scale, channels, height, width,
pooled_height, pooled_width, bottom_rois, bottom_trans, no_trans, pooled_height, pooled_width, bottom_rois, bottom_trans, no_trans,
(T)trans_std, sample_per_part, output_dim, group_height, group_width, (T)trans_std, sample_per_part, output_dim, group_height, group_width,
...@@ -344,6 +344,19 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( ...@@ -344,6 +344,19 @@ __global__ void DeformablePSROIPoolBackwardAccKernel(
gw = min(max(gw, 0), group_width - 1); gw = min(max(gw, 0), group_width - 1);
gh = min(max(gh, 0), group_height - 1); gh = min(max(gh, 0), group_height - 1);
int c = (ctop * group_height + gh) * group_width + gw;
int bottom_index_base = c * height * width;
int bottom_index =
roi_batch_ind * channels * height * width + bottom_index_base;
int trans_index_x =
(((n * num_classes + class_id) * 2) * part_height + part_h) *
part_width +
part_w;
int trans_index_y =
(((n * num_classes + class_id) * 2 + 1) * part_height + part_h) *
part_width +
part_w;
// sampling in each bin // sampling in each bin
for (int ih = 0; ih < sample_per_part; ih++) { for (int ih = 0; ih < sample_per_part; ih++) {
for (int iw = 0; iw < sample_per_part; iw++) { for (int iw = 0; iw < sample_per_part; iw++) {
...@@ -354,7 +367,6 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( ...@@ -354,7 +367,6 @@ __global__ void DeformablePSROIPoolBackwardAccKernel(
} }
w = min(max(w, 0.), width - 1.); w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.); h = min(max(h, 0.), height - 1.);
int c = (ctop * group_height + gh) * group_width + gw;
int x0 = floor(w); int x0 = floor(w);
int x1 = ceil(w); int x1 = ceil(w);
int y0 = floor(h); int y0 = floor(h);
...@@ -366,25 +378,20 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( ...@@ -366,25 +378,20 @@ __global__ void DeformablePSROIPoolBackwardAccKernel(
T q01 = (1 - dist_x) * dist_y; T q01 = (1 - dist_x) * dist_y;
T q10 = dist_x * (1 - dist_y); T q10 = dist_x * (1 - dist_y);
T q11 = dist_x * dist_y; T q11 = dist_x * dist_y;
int bottom_index_base = c * height * width;
// compute gradient of input // compute gradient of input
if (bottom_data_diff) { if (bottom_data_diff) {
platform::CudaAtomicAdd( platform::CudaAtomicAdd(
bottom_data_diff + roi_batch_ind * channels * height * width + bottom_data_diff + bottom_index + y0 * width + x0,
bottom_index_base + y0 * width + x0,
q00 * diff_val); q00 * diff_val);
platform::CudaAtomicAdd( platform::CudaAtomicAdd(
bottom_data_diff + roi_batch_ind * channels * height * width + bottom_data_diff + bottom_index + y1 * width + x0,
bottom_index_base + y1 * width + x0,
q01 * diff_val); q01 * diff_val);
platform::CudaAtomicAdd( platform::CudaAtomicAdd(
bottom_data_diff + roi_batch_ind * channels * height * width + bottom_data_diff + bottom_index + y0 * width + x1,
bottom_index_base + y0 * width + x1,
q10 * diff_val); q10 * diff_val);
platform::CudaAtomicAdd( platform::CudaAtomicAdd(
bottom_data_diff + roi_batch_ind * channels * height * width + bottom_data_diff + bottom_index + y1 * width + x1,
bottom_index_base + y1 * width + x1,
q11 * diff_val); q11 * diff_val);
} }
...@@ -405,19 +412,8 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( ...@@ -405,19 +412,8 @@ __global__ void DeformablePSROIPoolBackwardAccKernel(
u00 * (1 - dist_x)) * u00 * (1 - dist_x)) *
trans_std * diff_val; trans_std * diff_val;
diff_y *= roi_height; diff_y *= roi_height;
platform::CudaAtomicAdd( platform::CudaAtomicAdd(bottom_trans_diff + trans_index_x, diff_x);
bottom_trans_diff + platform::CudaAtomicAdd(bottom_trans_diff + trans_index_y, diff_y);
(((n * num_classes + class_id) * 2) * part_height + part_h) *
part_width +
part_w,
diff_x);
platform::CudaAtomicAdd(
bottom_trans_diff +
(((n * num_classes + class_id) * 2 + 1) * part_height +
part_h) *
part_width +
part_w,
diff_y);
} }
} }
} }
...@@ -520,8 +516,8 @@ class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> { ...@@ -520,8 +516,8 @@ class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> {
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
dev_ctx.stream()); dev_ctx.stream());
DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, DeformablePSROIPoolBackwardAccKernel<<<
0, dev_ctx.stream()>>>( GET_BLOCKS(count), PADDLE_CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
count, top_diff, top_count_data, num_rois, (T)spatial_scale, channels, count, top_diff, top_count_data, num_rois, (T)spatial_scale, channels,
height, width, pooled_height, pooled_width, output_dim, height, width, pooled_height, pooled_width, output_dim,
bottom_data_diff, bottom_trans_diff, bottom_data, bottom_rois, bottom_data_diff, bottom_trans_diff, bottom_data, bottom_rois,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册