diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 357832223c01f6a0c1836fd6b6e9460fe1c74ce5..de91ba6270ac2ed22c8380878c0a0037fb1629c0 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -110,7 +110,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { to perform linear interpolation first in one direction, and then again in the other direction. - Align_corners and align_mode are optinal parameters,The calculation method + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. Example: diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 7595511cf57442811bff4e44e5204f668c439dcb..1dfd4947c6054c3da9a8e1c79542169b1727e9ad 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -94,6 +94,7 @@ __global__ void KeBilinearInterpFw( int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; @@ -102,25 +103,23 @@ __global__ void KeBilinearInterpFw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = (align_mode == 0 && !align_corners) + int in_img_idy = align_flag ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) : static_cast(ratio_h * out_img_idy); in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = (align_mode == 0 && !align_corners) - ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy - : ratio_h * out_img_idy - in_img_idy; + T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy + : ratio_h * out_img_idy - in_img_idy; T h2lambda = 1.f - h1lambda; int out_img_idx = tid % out_img_w; - int in_img_idx = (align_mode == 0 && !align_corners) + int in_img_idx = align_flag ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) : static_cast(ratio_w * out_img_idx); in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = (align_mode == 0 && !align_corners) - ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx - : ratio_w * out_img_idx - in_img_idx; + T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx + : ratio_w * out_img_idx - in_img_idx; T w2lambda = 1.f - w1lambda; const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + @@ -144,6 +143,7 @@ __global__ void KeBilinearInterpBw( int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; @@ -152,26 +152,22 @@ __global__ void KeBilinearInterpBw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = (align_mode == 0 && !align_corners) - ? ratio_h * (out_img_idy + 0.5) - 0.5 - : ratio_h * out_img_idy; + int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 + : ratio_h * out_img_idy; in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = (align_mode == 0 && !align_corners) - ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy - : ratio_h * out_img_idy - in_img_idy; + T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy + : ratio_h * out_img_idy - in_img_idy; T h2lambda = 1.f - h1lambda; int out_img_idx = tid % out_img_w; - int in_img_idx = (align_mode == 0 && !align_corners) - ? ratio_w * (out_img_idx + 0.5) - 0.5 - : ratio_w * out_img_idx; + int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 + : ratio_w * out_img_idx; in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = (align_mode == 0 && !align_corners) - ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx - : ratio_w * out_img_idx - in_img_idx; + T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx + : ratio_w * out_img_idx - in_img_idx; T w2lambda = 1.f - w1lambda; T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index ab41ff781a5d5fa9791639984a2ff9babe593911..1ec0cb5025b343117b803dc4a0f8b03be57b31ac 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -56,15 +56,14 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, const bool align_mode) { auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); + bool align_flag = (align_mode == 0 && !align_corners); for (int k = 0; k < out_h; k++) { // loop for images - int y_n = (align_mode == 0 && !align_corners) - ? static_cast(ratio_h * (k + 0.5) - 0.5) - : static_cast(ratio_h * k); + int y_n = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); y_n = (y_n > 0) ? y_n : 0; int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); - float d_n = (align_mode == 0 && !align_corners) - ? ratio_h * (k + 0.5) - 0.5 - y_n - : ratio_h * k - y_n; + float d_n = + align_flag ? ratio_h * (k + 0.5) - 0.5 - y_n : ratio_h * k - y_n; float d_s = 1.f - d_n; for (int l = 0; l < out_w; l++) { @@ -73,9 +72,8 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, : static_cast(ratio_w * l); x_w = (x_w > 0) ? x_w : 0; int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); - float d_w = (align_mode == 0 && !align_corners) - ? ratio_w * (l + 0.5) - 0.5 - x_w - : ratio_w * l - x_w; + float d_w = + align_flag ? ratio_w * (l + 0.5) - 0.5 - x_w : ratio_w * l - x_w; float d_e = 1.f - d_w; for (int i = 0; i < n; i++) { // loop for batches @@ -126,26 +124,23 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, const int align_mode) { auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); + bool align_flag = (align_mode == 0 && !align_corners); for (int k = 0; k < out_h; k++) { // loop for images - int y_n = (align_mode == 0 && !align_corners) - ? static_cast(ratio_h * (k + 0.5) - 0.5) - : static_cast(ratio_h * k); + int y_n = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); y_n = (y_n > 0) ? y_n : 0; int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); - float d_n = (align_mode == 0 && !align_corners) - ? ratio_h * (k + 0.5) - 0.5 - y_n - : ratio_h * k - y_n; + float d_n = + align_flag ? ratio_h * (k + 0.5) - 0.5 - y_n : ratio_h * k - y_n; float d_s = 1.f - d_n; for (int l = 0; l < out_w; l++) { - int x_w = (align_mode == 0 && !align_corners) - ? static_cast(ratio_w * (l + 0.5) - 0.5) - : static_cast(ratio_w * l); + int x_w = align_flag ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); x_w = (x_w > 0) ? x_w : 0; int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); - float d_w = (align_mode == 0 && !align_corners) - ? ratio_w * (l + 0.5) - 0.5 - x_w - : ratio_w * l - x_w; + float d_w = + align_flag ? ratio_w * (l + 0.5) - 0.5 - x_w : ratio_w * l - x_w; float d_e = 1.f - d_w; for (int i = 0; i < n; i++) { // loop for batches diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a5a3aa2f3a5905cac8a876a16a64a9b2c1cfe688..b398f5d20630265251ccaec112a4d1db0df9687c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6552,7 +6552,7 @@ def image_resize(input, to perform linear interpolation first in one direction, and then again in the other direction. - Align_corners and align_mode are optinal parameters,The calculation method + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. Example: @@ -6758,11 +6758,11 @@ def resize_bilinear(input, For details of bilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation - Align_corners and align_mode are optinal parameters,The calculation + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. - Align_corners and align_mode are optinal parameters,The calculation method + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. Example: