From 1d7f91e867b598bc4cde7dd5f0ef0410ca569bd5 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Fri, 18 May 2018 02:14:09 +0000 Subject: [PATCH] fix roi_pool gpu_backward kernel --- paddle/fluid/operators/roi_pool_op.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index 972dc36c11..50450b62f7 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -101,10 +101,10 @@ __global__ void GPUROIPoolBackward( int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (int i = index; i < nthreads; i += offset) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; int roi_batch_ind = roi_batch_id_data[n]; int input_offset = (roi_batch_ind * channels + c) * height * width; -- GitLab