提交 fe93cce3 编写于 作者: B baiyfbupt

fix roi_pool op bug

上级 ed748dae
...@@ -52,13 +52,19 @@ __global__ void GPUROIPoolForward( ...@@ -52,13 +52,19 @@ __global__ void GPUROIPoolForward(
int roi_width = max(roi_end_w - roi_start_w + 1, 1); int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1); int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h)); int hstart =
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w)); static_cast<int>(floor(static_cast<T>(ph) * static_cast<T>(roi_height) /
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h)); static_cast<T>(pooled_height)));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w)); int wstart =
static_cast<int>(floor(static_cast<T>(pw) * static_cast<T>(roi_width) /
static_cast<T>(pooled_width)));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) *
static_cast<T>(roi_height) /
static_cast<T>(pooled_height)));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) *
static_cast<T>(roi_width) /
static_cast<T>(pooled_width)));
hstart = min(max(hstart + roi_start_h, 0), height); hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height); hend = min(max(hend + roi_start_h, 0), height);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册