提交 8023bc76 编写于 作者: B baiyfbupt

fix index

上级 4d2a2e75
...@@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward( ...@@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward(
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) { for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width; int pw = i % pooled_width;
int ph = (index / pooled_width) % pooled_height; int ph = (i / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels; int c = (i / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels; int n = i / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize; const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n]; int roi_batch_ind = roi_batch_id_data[n];
...@@ -65,7 +65,6 @@ __global__ void GPUROIPoolForward( ...@@ -65,7 +65,6 @@ __global__ void GPUROIPoolForward(
int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) * int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
static_cast<double>(roi_width) / static_cast<double>(roi_width) /
static_cast<double>(pooled_width))); static_cast<double>(pooled_width)));
hstart = min(max(hstart + roi_start_h, 0), height); hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height); hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width); wstart = min(max(wstart + roi_start_w, 0), width);
...@@ -85,9 +84,9 @@ __global__ void GPUROIPoolForward( ...@@ -85,9 +84,9 @@ __global__ void GPUROIPoolForward(
} }
} }
} }
output_data[index] = maxval; output_data[i] = maxval;
if (argmax_data) { if (argmax_data) {
argmax_data[index] = maxidx; argmax_data[i] = maxidx;
} }
} }
} }
...@@ -144,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -144,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
int width = in_dims[3]; int width = in_dims[3];
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
if (rois_num == 0) return; if (rois_num == 0) return;
int output_size = out->numel(); int output_size = out->numel();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册