未验证 提交 50af0c2c 编写于 作者: W wangguanzhong 提交者: GitHub

fix roi_align, test=develop (#31479)

上级 e03e4673
...@@ -124,11 +124,9 @@ __global__ void GPUROIAlignForward( ...@@ -124,11 +124,9 @@ __global__ void GPUROIAlignForward(
T roi_width = roi_xmax - roi_xmin; T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin; T roi_height = roi_ymax - roi_ymin;
roi_width = max(roi_width, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
if (!continuous_coordinate) {
roi_width = max(roi_width, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
...@@ -182,10 +180,9 @@ __global__ void GPUROIAlignBackward( ...@@ -182,10 +180,9 @@ __global__ void GPUROIAlignBackward(
T roi_width = roi_xmax - roi_xmin; T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin; T roi_height = roi_ymax - roi_ymin;
if (!continuous_coordinate) { roi_width = max(roi_width, static_cast<T>(1.));
roi_width = max(roi_width, static_cast<T>(1.)); roi_height = max(roi_height, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
......
...@@ -226,10 +226,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -226,10 +226,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
T roi_width = roi_xmax - roi_xmin; T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin; T roi_height = roi_ymax - roi_ymin;
if (!aligned) { roi_width = std::max(roi_width, static_cast<T>(1.));
roi_width = std::max(roi_width, static_cast<T>(1.)); roi_height = std::max(roi_height, static_cast<T>(1.));
roi_height = std::max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
...@@ -362,11 +360,9 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -362,11 +360,9 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
T roi_width = roi_xmax - roi_xmin; T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin; T roi_height = roi_ymax - roi_ymin;
roi_width = std::max(roi_width, static_cast<T>(1.));
roi_height = std::max(roi_height, static_cast<T>(1.));
if (!aligned) {
roi_width = std::max(roi_width, static_cast<T>(1.));
roi_height = std::max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
......
...@@ -129,9 +129,9 @@ class TestROIAlignOp(OpTest): ...@@ -129,9 +129,9 @@ class TestROIAlignOp(OpTest):
roi_width = roi_xmax - roi_xmin roi_width = roi_xmax - roi_xmin
roi_height = roi_ymax - roi_ymin roi_height = roi_ymax - roi_ymin
if not self.aligned: roi_width = max(roi_width, 1)
roi_width = max(roi_width, 1) roi_height = max(roi_height, 1)
roi_height = max(roi_height, 1)
bin_size_h = float(roi_height) / float(self.pooled_height) bin_size_h = float(roi_height) / float(self.pooled_height)
bin_size_w = float(roi_width) / float(self.pooled_width) bin_size_w = float(roi_width) / float(self.pooled_width)
roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \ roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册