From fe93cce387561a0b997276be5626d23bdea4a270 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Wed, 16 May 2018 08:58:04 +0000 Subject: [PATCH] fix roi_pool op bug --- paddle/fluid/operators/roi_pool_op.cu | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index f905d690f98..265dbb3a801 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -52,13 +52,19 @@ __global__ void GPUROIPoolForward( int roi_width = max(roi_end_w - roi_start_w + 1, 1); int roi_height = max(roi_end_h - roi_start_h + 1, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + int hstart = + static_cast(floor(static_cast(ph) * static_cast(roi_height) / + static_cast(pooled_height))); + int wstart = + static_cast(floor(static_cast(pw) * static_cast(roi_width) / + static_cast(pooled_width))); + int hend = static_cast(ceil(static_cast(ph + 1) * + static_cast(roi_height) / + static_cast(pooled_height))); + int wend = static_cast(ceil(static_cast(pw + 1) * + static_cast(roi_width) / + static_cast(pooled_width))); hstart = min(max(hstart + roi_start_h, 0), height); hend = min(max(hend + roi_start_h, 0), height); -- GitLab