diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index cbffc2fa630fa17f5567e8dc4140787c83c77ed0..1cdda4cfe90c459b74fe9436654c88206e498b50 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -108,8 +108,10 @@ __global__ void KeBilinearInterpFw( : static_cast(ratio_h * out_img_idy); in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy - : ratio_h * out_img_idy - in_img_idy; + T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; + src_h = (src_h > 0) ? src_h : 0; + T h1lambda = + align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; T h2lambda = 1.f - h1lambda; int out_img_idx = tid % out_img_w; @@ -118,8 +120,10 @@ __global__ void KeBilinearInterpFw( : static_cast(ratio_w * out_img_idx); in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx - : ratio_w * out_img_idx - in_img_idx; + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + T w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; T w2lambda = 1.f - w1lambda; const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + @@ -156,8 +160,10 @@ __global__ void KeBilinearInterpBw( : ratio_h * out_img_idy; in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy - : ratio_h * out_img_idy - in_img_idy; + T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; + src_h = (src_h > 0) ? src_h : 0; + T h1lambda = + align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; T h2lambda = 1.f - h1lambda; @@ -166,8 +172,10 @@ __global__ void KeBilinearInterpBw( : ratio_w * out_img_idx; in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx - : ratio_w * out_img_idx - in_img_idx; + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + T w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; T w2lambda = 1.f - w1lambda; T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index 652aec9a5385335ea2649c208a7b174aad744a71..bd33abb98f2f1a6ad75b64e37ca14b411a4a168e 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -73,8 +73,9 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, : static_cast(ratio_h * k); y_n = (y_n > 0) ? y_n : 0; int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); - float d_n = - align_flag ? ratio_h * (k + 0.5) - 0.5 - y_n : ratio_h * k - y_n; + float idx_src_y = ratio_h * (k + 0.5) - 0.5; + idx_src_y = (idx_src_y > 0) ? idx_src_y : 0; + float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n; float d_s = 1.f - d_n; { vy_n[k] = y_n; @@ -99,8 +100,9 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, : static_cast(ratio_w * l); x_w = (x_w > 0) ? x_w : 0; int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); - float d_w = - align_flag ? ratio_w * (l + 0.5) - 0.5 - x_w : ratio_w * l - x_w; + float idx_src_x = ratio_w * (l + 0.5) - 0.5; + idx_src_x = (idx_src_x > 0) ? idx_src_x : 0; + float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; float d_e = 1.f - d_w; { vx_w[l] = x_w; @@ -170,8 +172,9 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, : static_cast(ratio_h * k); y_n = (y_n > 0) ? y_n : 0; int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); - float d_n = - align_flag ? ratio_h * (k + 0.5) - 0.5 - y_n : ratio_h * k - y_n; + float idx_src_y = ratio_h * (k + 0.5) - 0.5; + idx_src_y = (idx_src_y > 0) ? idx_src_y : 0; + float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n; float d_s = 1.f - d_n; for (int l = 0; l < out_w; l++) { @@ -179,8 +182,9 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, : static_cast(ratio_w * l); x_w = (x_w > 0) ? x_w : 0; int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); - float d_w = - align_flag ? ratio_w * (l + 0.5) - 0.5 - x_w : ratio_w * l - x_w; + float idx_src_x = ratio_w * (l + 0.5) - 0.5; + idx_src_x = (idx_src_x > 0) ? idx_src_x : 0; + float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; float d_e = 1.f - d_w; for (int i = 0; i < n; i++) { // loop for batches diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index e06c87d969f4c1146f1636a6402059f82341cff3..7e577229777e256b15a02232229f4127b0f877f5 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -59,7 +59,8 @@ def bilinear_interp_np(input, h = max(0, h) hid = 1 if h < in_h - 1 else 0 if (align_mode == 0 and not align_corners): - h1lambda = ratio_h * (i + 0.5) - 0.5 - h + idx_src_h = max(ratio_h * (i + 0.5) - 0.5, 0) + h1lambda = idx_src_h - h else: h1lambda = ratio_h * i - h h2lambda = 1.0 - h1lambda @@ -71,7 +72,8 @@ def bilinear_interp_np(input, w = max(0, w) wid = 1 if w < in_w - 1 else 0 if (align_mode == 0 and not align_corners): - w1lambda = ratio_w * (j + 0.5) - 0.5 - w + idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0) + w1lambda = idx_src_w - w else: w1lambda = ratio_w * j - w w2lambda = 1.0 - w1lambda @@ -335,5 +337,16 @@ class TestBilinearInterpScale3(TestBilinearInterpOp): self.align_mode = 1 +class TestBilinearInterpZero(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 0.2 + self.align_corners = False + self.align_mode = 0 + + if __name__ == "__main__": unittest.main()