From 8023bc766b8d1c39f280295abaeac7a374face72 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Thu, 17 May 2018 14:58:44 +0000 Subject: [PATCH] fix index --- paddle/fluid/operators/roi_pool_op.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index a699d21542..972dc36c11 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward( int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (size_t i = index; i < nthreads; i += offset) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; const int64_t* offset_input_rois = input_rois + n * kROISize; int roi_batch_ind = roi_batch_id_data[n]; @@ -65,7 +65,6 @@ __global__ void GPUROIPoolForward( int wend = static_cast(ceil(static_cast(pw + 1) * static_cast(roi_width) / static_cast(pooled_width))); - hstart = min(max(hstart + roi_start_h, 0), height); hend = min(max(hend + roi_start_h, 0), height); wstart = min(max(wstart + roi_start_w, 0), width); @@ -85,9 +84,9 @@ __global__ void GPUROIPoolForward( } } } - output_data[index] = maxval; + output_data[i] = maxval; if (argmax_data) { - argmax_data[index] = maxidx; + argmax_data[i] = maxidx; } } } @@ -144,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel { int width = in_dims[3]; int rois_num = rois->dims()[0]; + if (rois_num == 0) return; int output_size = out->numel(); -- GitLab