From c8e49be2f14ecee3c7a9f2f37ef099181ea9db0b Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 28 Oct 2019 11:40:37 +0800 Subject: [PATCH] Fix roi_perspective_transform op (#20764) --- .../detection/roi_perspective_transform_op.cc | 32 +++++++++---------- .../detection/roi_perspective_transform_op.cu | 23 ++++++------- .../test_roi_perspective_transform_op.py | 12 +++---- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index 74756a2a22a..20f0012276a 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -187,17 +187,17 @@ void bilinear_interpolate(const T* in_data, const int channels, const int width, const int height, int in_n, int in_c, T in_w, T in_h, T* val) { // Deal with cases that source coords are out of feature map boundary - if (GT(-0.5, in_w) || GT(in_w, width - 0.5) || GT(-0.5, in_h) || - GT(in_h, height - 0.5)) { + if (GT_E(-0.5, in_w) || GT_E(in_w, width - 0.5) || + GT_E(-0.5, in_h) || GT_E(in_h, height - 0.5)) { // empty val[0] = 0.0; return; } - if (GT(0, in_w)) { + if (GT_E(0, in_w)) { in_w = 0; } - if (GT(0, in_h)) { + if (GT_E(0, in_h)) { in_h = 0; } @@ -301,10 +301,10 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { T in_w, in_h; get_source_coords(matrix, out_w, out_h, &in_w, &in_h); if (in_quad(in_w, in_h, roi_x, roi_y)) { - if (GT(-0.5, in_w) || - GT(in_w, static_cast(in_width - 0.5)) || - GT(-0.5, in_h) || - GT(in_h, static_cast(in_height - 0.5))) { + if (GT_E(-0.5, in_w) || + GT_E(in_w, static_cast(in_width - 0.5)) || + GT_E(-0.5, in_h) || + GT_E(in_h, static_cast(in_height - 0.5))) { output_data[out_index] = 0.0; mask_data[(n * transformed_height + out_h) * transformed_width + out_w] = 0; @@ -330,15 +330,15 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { template T get_feature_gradient(T xs, T ys, int w, int h, const int width, const int height) { - if (GT(-0.5, xs) || GT(xs, width - 0.5) || GT(-0.5, ys) || - GT(ys, height - 0.5)) { + if (GT_E(-0.5, xs) || GT_E(xs, width - 0.5) || GT_E(-0.5, ys) || + GT_E(ys, height - 0.5)) { return 0; } - if (GT(0, xs)) { + if (GT_E(0, xs)) { xs = 0; } - if (GT(0, ys)) { + if (GT_E(0, ys)) { ys = 0; } @@ -441,10 +441,10 @@ class CPUROIPerspectiveTransformGradOpKernel : public framework::OpKernel { T src_h; get_source_coords(matrix, out_w, out_h, &src_w, &src_h); if (in_quad(src_w, src_h, roi_x, roi_y)) { - if (GT(-0.5, src_w) || - GT(src_w, static_cast(in_width - 0.5)) || - GT(-0.5, src_h) || - GT(src_h, static_cast(in_height - 0.5))) { + if (GT_E(-0.5, src_w) || + GT_E(src_w, static_cast(in_width - 0.5)) || + GT_E(-0.5, src_h) || + GT_E(src_h, static_cast(in_height - 0.5))) { continue; } T weight = get_feature_gradient(src_w, src_h, in_w, in_h, diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu index 8c9ca9462c3..fe65162353e 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu @@ -120,16 +120,16 @@ __device__ void bilinear_interpolate(const T* in_data, const int channels, int out_idx, int* out2in_idx, T* out2in_w) { // Deal with cases that source coords are out of feature map boundary - if (GT(-0.5, in_w) || GT(in_w, width - 0.5) || GT(-0.5, in_h) || - GT(in_h, height - 0.5)) { + if (GT_E(-0.5, in_w) || GT_E(in_w, width - 0.5) || + GT_E(-0.5, in_h) || GT_E(in_h, height - 0.5)) { val[0] = 0.0; return; } - if (GT(0, in_w)) { + if (GT_E(0, in_w)) { in_w = 0; } - if (GT(0, in_h)) { + if (GT_E(0, in_h)) { in_h = 0; } @@ -284,7 +284,6 @@ __global__ void RoiTransformKernel(const float* input_data, int* mask, T* transform_matrix) { int output_size = num_rois * transformed_height * transformed_width * channels; - CUDA_1D_KERNEL_LOOP(index, output_size) { // (n, c, out_h, out_w) is an element in the transformed output int out_w = idx4_4(index, num_rois, channels, transformed_height, @@ -318,8 +317,10 @@ __global__ void RoiTransformKernel(const float* input_data, get_source_coords(matrix, out_w, out_h, &in_w, &in_h); if (in_quad(in_w, in_h, roi_x, roi_y)) { - if (GT(-0.5, in_w) || GT(in_w, static_cast(in_width - 0.5)) || - GT(-0.5, in_h) || GT(in_h, static_cast(in_height - 0.5))) { + if (GT_E(-0.5, in_w) || + GT_E(in_w, static_cast(in_width - 0.5)) || + GT_E(-0.5, in_h) || + GT_E(in_h, static_cast(in_height - 0.5))) { // Skip if source coords is not in input image output_data[index] = 0.0; mask[(n * transformed_height + out_h) * transformed_width + out_w] = 0; @@ -409,15 +410,15 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel { template __device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width, const int height) { - if (GT(-0.5, xs) || GT(xs, width - 0.5) || GT(-0.5, ys) || - GT(ys, height - 0.5)) { + if (GT_E(-0.5, xs) || GT_E(xs, width - 0.5) || GT_E(-0.5, ys) || + GT_E(ys, height - 0.5)) { return 0; } - if (GT(0, xs)) { + if (GT_E(0, xs)) { xs = 0; } - if (GT(0, ys)) { + if (GT_E(0, ys)) { ys = 0; } diff --git a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py index e742993c2bf..0a302f5efca 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py @@ -135,13 +135,13 @@ def bilinear_interpolate(in_data, in_n, in_c, in_w, in_h): height = in_data.shape[2] width = in_data.shape[3] - if gt(-0.5, in_w) or gt(in_w, width - 0.5) or gt(-0.5, in_h) or gt( + if gt_e(-0.5, in_w) or gt_e(in_w, width - 0.5) or gt_e(-0.5, in_h) or gt_e( in_h, height - 0.5): return 0.0 - if gt(0, in_w): + if gt_e(0, in_w): in_w = 0 - if gt(0, in_h): + if gt_e(0, in_h): in_h = 0 in_w_floor = floor(in_w) @@ -216,9 +216,9 @@ def roi_transform(in_data, rois, rois_lod, transformed_height, for out_w in range(transformed_width): in_w, in_h = get_source_coords(transform_matrix, out_w, out_h) - if in_quad(in_w, in_h, roi_x, roi_y) and gt_e( - in_w, -0.5) and lt_e(in_w, in_width - 0.5) and gt_e( - in_h, -0.5) and lt_e(in_h, in_height - 0.5): + if in_quad(in_w, in_h, roi_x, roi_y) and gt( + in_w, -0.5) and gt(in_width - 0.5, in_w) and gt( + in_h, -0.5) and gt(in_height - 0.5, in_h): out[n][c][out_h][out_w] = bilinear_interpolate( in_data, image_id, c, in_w, in_h) mask[n][0][out_h][out_w] = 1 -- GitLab