未验证 提交 e19736d7 编写于 作者: W wangguanzhong 提交者: GitHub

fix aligned in roi_align (#33444)

上级 dffc331f
...@@ -124,8 +124,10 @@ __global__ void GPUROIAlignForward( ...@@ -124,8 +124,10 @@ __global__ void GPUROIAlignForward(
T roi_width = roi_xmax - roi_xmin; T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin; T roi_height = roi_ymax - roi_ymin;
roi_width = max(roi_width, static_cast<T>(1.)); if (!continuous_coordinate) {
roi_height = max(roi_height, static_cast<T>(1.)); roi_width = max(roi_width, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
...@@ -138,7 +140,7 @@ __global__ void GPUROIAlignForward( ...@@ -138,7 +140,7 @@ __global__ void GPUROIAlignForward(
: ceil(roi_height / pooled_height); : ceil(roi_height / pooled_height);
int roi_bin_grid_w = int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w; const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
T output_val = 0; T output_val = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_ymin + ph * bin_size_h + const T y = roi_ymin + ph * bin_size_h +
...@@ -180,9 +182,10 @@ __global__ void GPUROIAlignBackward( ...@@ -180,9 +182,10 @@ __global__ void GPUROIAlignBackward(
T roi_width = roi_xmax - roi_xmin; T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin; T roi_height = roi_ymax - roi_ymin;
roi_width = max(roi_width, static_cast<T>(1.)); if (!continuous_coordinate) {
roi_height = max(roi_height, static_cast<T>(1.)); roi_width = max(roi_width, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
......
...@@ -226,8 +226,10 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -226,8 +226,10 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
T roi_width = roi_xmax - roi_xmin; T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin; T roi_height = roi_ymax - roi_ymin;
roi_width = std::max(roi_width, static_cast<T>(1.)); if (!aligned) {
roi_height = std::max(roi_height, static_cast<T>(1.)); roi_width = std::max(roi_width, static_cast<T>(1.));
roi_height = std::max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
...@@ -239,7 +241,7 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -239,7 +241,7 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int roi_bin_grid_w = (sampling_ratio > 0) int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: ceil(roi_width / pooled_width); : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w; const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1);
Tensor pre_pos; Tensor pre_pos;
Tensor pre_w; Tensor pre_w;
int pre_size = count * out_stride[1]; int pre_size = count * out_stride[1];
...@@ -362,6 +364,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -362,6 +364,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
T roi_height = roi_ymax - roi_ymin; T roi_height = roi_ymax - roi_ymin;
roi_width = std::max(roi_width, static_cast<T>(1.)); roi_width = std::max(roi_width, static_cast<T>(1.));
roi_height = std::max(roi_height, static_cast<T>(1.)); roi_height = std::max(roi_height, static_cast<T>(1.));
if (!aligned) {
roi_width = std::max(roi_width, static_cast<T>(1.));
roi_height = std::max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
......
...@@ -129,8 +129,9 @@ class TestROIAlignOp(OpTest): ...@@ -129,8 +129,9 @@ class TestROIAlignOp(OpTest):
roi_width = roi_xmax - roi_xmin roi_width = roi_xmax - roi_xmin
roi_height = roi_ymax - roi_ymin roi_height = roi_ymax - roi_ymin
roi_width = max(roi_width, 1) if not self.aligned:
roi_height = max(roi_height, 1) roi_width = max(roi_width, 1)
roi_height = max(roi_height, 1)
bin_size_h = float(roi_height) / float(self.pooled_height) bin_size_h = float(roi_height) / float(self.pooled_height)
bin_size_w = float(roi_width) / float(self.pooled_width) bin_size_w = float(roi_width) / float(self.pooled_width)
...@@ -138,7 +139,7 @@ class TestROIAlignOp(OpTest): ...@@ -138,7 +139,7 @@ class TestROIAlignOp(OpTest):
math.ceil(roi_height / self.pooled_height) math.ceil(roi_height / self.pooled_height)
roi_bin_grid_w = self.sampling_ratio if self.sampling_ratio > 0 else \ roi_bin_grid_w = self.sampling_ratio if self.sampling_ratio > 0 else \
math.ceil(roi_width / self.pooled_width) math.ceil(roi_width / self.pooled_width)
count = int(roi_bin_grid_h * roi_bin_grid_w) count = max(int(roi_bin_grid_h * roi_bin_grid_w), 1)
pre_size = count * self.pooled_width * self.pooled_height pre_size = count * self.pooled_width * self.pooled_height
bilinear_pos, bilinear_w = self.pre_calc(x_i, roi_xmin, roi_ymin, bilinear_pos, bilinear_w = self.pre_calc(x_i, roi_xmin, roi_ymin,
int(roi_bin_grid_h), int(roi_bin_grid_h),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册