diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index f905d690f984a20622c5fbcbcc813d888dfb19d9..50450b62f7b1c0b2b5abf01a43581a0e2d2cd01e 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward( int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (size_t i = index; i < nthreads; i += offset) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; const int64_t* offset_input_rois = input_rois + n * kROISize; int roi_batch_ind = roi_batch_id_data[n]; @@ -52,14 +52,19 @@ __global__ void GPUROIPoolForward( int roi_width = max(roi_end_w - roi_start_w + 1, 1); int roi_height = max(roi_end_h - roi_start_h + 1, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + int hstart = static_cast(floor(static_cast(ph) * + static_cast(roi_height) / + static_cast(pooled_height))); + int wstart = static_cast(floor(static_cast(pw) * + static_cast(roi_width) / + static_cast(pooled_width))); + int hend = static_cast(ceil(static_cast(ph + 1) * + static_cast(roi_height) / + static_cast(pooled_height))); + int wend = static_cast(ceil(static_cast(pw + 1) * + static_cast(roi_width) / + static_cast(pooled_width))); hstart = min(max(hstart + roi_start_h, 0), height); hend = min(max(hend + roi_start_h, 0), height); wstart = min(max(wstart + roi_start_w, 0), width); @@ -79,9 +84,9 @@ __global__ void GPUROIPoolForward( } } } - output_data[index] = maxval; + output_data[i] = maxval; if (argmax_data) { - argmax_data[index] = maxidx; + argmax_data[i] = maxidx; } } } @@ -96,10 +101,10 @@ __global__ void GPUROIPoolBackward( int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (int i = index; i < nthreads; i += offset) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; int roi_batch_ind = roi_batch_id_data[n]; int input_offset = (roi_batch_ind * channels + c) * height * width; @@ -138,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel { int width = in_dims[3]; int rois_num = rois->dims()[0]; + if (rois_num == 0) return; int output_size = out->numel();