未验证 提交 67b8a300 编写于 作者: B baiyf 提交者: GitHub

Merge pull request #10700 from baiyfbupt/develop

fix roi_pool op bug
...@@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward( ...@@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward(
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) { for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width; int pw = i % pooled_width;
int ph = (index / pooled_width) % pooled_height; int ph = (i / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels; int c = (i / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels; int n = i / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize; const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n]; int roi_batch_ind = roi_batch_id_data[n];
...@@ -52,14 +52,19 @@ __global__ void GPUROIPoolForward( ...@@ -52,14 +52,19 @@ __global__ void GPUROIPoolForward(
int roi_width = max(roi_end_w - roi_start_w + 1, 1); int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1); int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
int hstart = static_cast<int>(floor(static_cast<double>(ph) *
static_cast<double>(roi_height) /
static_cast<double>(pooled_height)));
int wstart = static_cast<int>(floor(static_cast<double>(pw) *
static_cast<double>(roi_width) /
static_cast<double>(pooled_width)));
int hend = static_cast<int>(ceil(static_cast<double>(ph + 1) *
static_cast<double>(roi_height) /
static_cast<double>(pooled_height)));
int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
static_cast<double>(roi_width) /
static_cast<double>(pooled_width)));
hstart = min(max(hstart + roi_start_h, 0), height); hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height); hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width); wstart = min(max(wstart + roi_start_w, 0), width);
...@@ -79,9 +84,9 @@ __global__ void GPUROIPoolForward( ...@@ -79,9 +84,9 @@ __global__ void GPUROIPoolForward(
} }
} }
} }
output_data[index] = maxval; output_data[i] = maxval;
if (argmax_data) { if (argmax_data) {
argmax_data[index] = maxidx; argmax_data[i] = maxidx;
} }
} }
} }
...@@ -96,10 +101,10 @@ __global__ void GPUROIPoolBackward( ...@@ -96,10 +101,10 @@ __global__ void GPUROIPoolBackward(
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) { for (int i = index; i < nthreads; i += offset) {
int pw = index % pooled_width; int pw = i % pooled_width;
int ph = (index / pooled_width) % pooled_height; int ph = (i / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels; int c = (i / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels; int n = i / pooled_width / pooled_height / channels;
int roi_batch_ind = roi_batch_id_data[n]; int roi_batch_ind = roi_batch_id_data[n];
int input_offset = (roi_batch_ind * channels + c) * height * width; int input_offset = (roi_batch_ind * channels + c) * height * width;
...@@ -138,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -138,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
int width = in_dims[3]; int width = in_dims[3];
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
if (rois_num == 0) return; if (rois_num == 0) return;
int output_size = out->numel(); int output_size = out->numel();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册