diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index cebc213fb61f1687a9c6b2699f87aee608392ad5..2a7059c1d86a0567c6dfc66400fc5aef94a321cf 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -26,7 +26,8 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) { auto interp_method = ctx->Attrs().Get("interp_method"); PADDLE_ENFORCE( - "bilinear" == interp_method || "nearest" == interp_method, + "bilinear" == interp_method || "nearest" == interp_method || + "bicubic" == interp_method, "Interpolation method can only be \"bilinear\" or \"nearest\" when " "Input(X) dimension is 4"); const DataLayout data_layout = framework::StringToDataLayout( @@ -264,7 +265,8 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { "method, can be \"bilinear\" for " "bilinear interpolation, \"trilinear\" for trilinear " "interpolation and \"nearest\" for nearest " - "neighbor interpolation.") + "neighbor interpolation, and \"bicubic\" for bicubic" + "interpolation.") .SetDefault("bilinear"); AddAttr( "align_corners", @@ -299,6 +301,11 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { H-direction and W-direction in this op) on a rectilinear 3D grid. The linear interpolation is performed on three directions. + Bicubic interpolation is an extension of cubic interpolation for interpolating + data points on a two-dimensional regular grid. The interpolated surface is + smoother than corresponding surfaces obtained by bilinear interpolation or + nearest-neighbor interpolation. + Align_corners and align_mode are optional parameters,the calculation method of interpolation can be selected by them. @@ -376,7 +383,20 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { D_out = D_{in} * scale_{factor} H_out = H_{in} * scale_{factor} W_out = W_{in} * scale_{factor} - + + Bicubic interpolation: + + if: + align_corners = False + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + else: + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} For details of nearest neighbor interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation @@ -386,6 +406,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { For details of trilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Trilinear_interpolation + + For details of bicubic interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bicubic_interpolation )DOC"); } }; @@ -469,6 +492,11 @@ REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker, ops::InterpolateGradMaker); REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad, ops::InterpolateGradNoNeedBufferVarsInference); +REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker, + ops::InterpolateGradMaker, + ops::InterpolateGradMaker); +REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad, + ops::InterpolateGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel, ops::InterpolateKernel, ops::InterpolateKernel); @@ -484,3 +512,7 @@ REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel, ops::InterpolateKernel); REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel, ops::InterpolateGradKernel); +REGISTER_OP_CPU_KERNEL(bicubic_interp, ops::InterpolateKernel, + ops::InterpolateKernel); +REGISTER_OP_CPU_KERNEL(bicubic_interp_grad, ops::InterpolateGradKernel, + ops::InterpolateGradKernel); diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index d12a25f360eece7bda47411c5bb4cd8625cdf729..5beaa64664f69f1caa053b2f166ea8aced687290 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -506,6 +506,206 @@ __global__ void KeTrilinearInterpBw( } } +template +__device__ __forceinline__ static T Kecubic_interp(const T x0, const T x1, + const T x2, const T x3, + T t) { + T coeffs[4]; + T a = -0.75; + T x_1 = t; + T x_2 = 1.0 - t; + coeffs[0] = cubic_convolution2(x_1 + 1.0, a); + coeffs[1] = cubic_convolution1(x_1, a); + coeffs[2] = cubic_convolution1(x_2, a); + coeffs[3] = cubic_convolution2(x_2 + 1.0, a); + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +template +__global__ void KeBicubicInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners, const DataLayout data_layout) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id, out_img_idy, out_img_idx; + + if (data_layout == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idx = tid % out_img_w; + } else { + out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + T in_img_idy = align_corners + ? static_cast(ratio_h * out_img_idy) + : static_cast(ratio_h * (out_img_idy + 0.5) - 0.5); + int input_y = floorf(in_img_idy); + const T y_t = in_img_idy - input_y; + + T in_img_idx = align_corners + ? static_cast(ratio_w * out_img_idx) + : static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); + int input_x = floorf(in_img_idx); + const T x_t = in_img_idx - input_x; + + T coefficients[4]; + const T* in_pos_0; + const T* in_pos_1; + const T* in_pos_2; + const T* in_pos_3; + int access_x_0; + if (data_layout == DataLayout::kNCHW) { + for (int k = 0; k < 4; k++) { + int access_y = + max(min(input_y - 1 + k, static_cast(in_img_h - 1)), 0); + access_x_0 = max(min(input_x - 1, static_cast(in_img_w - 1)), 0); + int access_x_1 = + max(min(input_x + 0, static_cast(in_img_w - 1)), 0); + int access_x_2 = + max(min(input_x + 1, static_cast(in_img_w - 1)), 0); + int access_x_3 = + max(min(input_x + 2, static_cast(in_img_w - 1)), 0); + + in_pos_0 = &in[out_id_h * input_w + channel_id * in_img_size + + access_y * in_img_w + access_x_0]; + in_pos_1 = &in[out_id_h * input_w + channel_id * in_img_size + + access_y * in_img_w + access_x_1]; + in_pos_2 = &in[out_id_h * input_w + channel_id * in_img_size + + access_y * in_img_w + access_x_2]; + in_pos_3 = &in[out_id_h * input_w + channel_id * in_img_size + + access_y * in_img_w + access_x_3]; + + coefficients[k] = Kecubic_interp(in_pos_0[0], in_pos_1[0], + in_pos_2[0], in_pos_3[0], x_t); + } + + out[out_id_h * output_w + out_id_w] = + Kecubic_interp(coefficients[0], coefficients[1], coefficients[2], + coefficients[3], y_t); + + } else { + for (int k = 0; k < 4; k++) { + int access_y = + max(min(input_y - 1 + k, static_cast((in_img_h - 1))), 0); + int access_x_0 = + max(min(input_x - 1, static_cast((in_img_w - 1))), 0); + int access_x_1 = + max(min(input_x + 0, static_cast((in_img_w - 1))), 0); + int access_x_2 = + max(min(input_x + 1, static_cast((in_img_w - 1))), 0); + int access_x_3 = + max(min(input_x + 2, static_cast((in_img_w - 1))), 0); + + const T* in_pos_0 = + &in[out_id_h * input_w + access_y * in_img_w * num_channels + + access_x_0 * num_channels + channel_id]; + const T* in_pos_1 = + &in[out_id_h * input_w + access_y * in_img_w * num_channels + + access_x_1 * num_channels + channel_id]; + const T* in_pos_2 = + &in[out_id_h * input_w + access_y * in_img_w * num_channels + + access_x_2 * num_channels + channel_id]; + const T* in_pos_3 = + &in[out_id_h * input_w + access_y * in_img_w * num_channels + + access_x_3 * num_channels + channel_id]; + + coefficients[k] = Kecubic_interp(in_pos_0[0], in_pos_1[0], in_pos_2[0], + in_pos_3[0], x_t); + } + + out[out_id_h * output_w + out_id_w] = + static_cast(Kecubic_interp(coefficients[0], coefficients[1], + coefficients[2], coefficients[3], y_t)); + } + } +} + +template +__global__ void KeBicubicInterpBw( + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, + const size_t input_w, const T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners, const DataLayout data_layout) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id, out_img_idy, out_img_idx; + if (data_layout == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idx = tid % out_img_w; + } else { + out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + T in_img_idy = align_corners + ? static_cast(ratio_h * out_img_idy) + : static_cast(ratio_h * (out_img_idy + 0.5) - 0.5); + int input_y = floorf(in_img_idy); + const T y_t = in_img_idy - input_y; + + T in_img_idx = align_corners + ? static_cast(ratio_w * out_img_idx) + : static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); + int input_x = floorf(in_img_idx); + + const T x_t = in_img_idx - input_x; + + T x_coeffs[4]; + T y_coeffs[4]; + + get_cubic_upsample_coefficients(x_coeffs, x_t); + get_cubic_upsample_coefficients(y_coeffs, y_t); + + const T* out_pos = &out[out_id_h * output_w + out_id_w]; + T* in_pos; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + int access_y = max(min(static_cast(input_y - 1 + j), + static_cast(in_img_h - 1)), + 0); + int access_x = max(min(static_cast(input_x - 1 + i), + static_cast(in_img_w - 1)), + 0); + if (data_layout == DataLayout::kNCHW) { + in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + access_y * in_img_w + access_x]; + } else { + in_pos = &in[out_id_h * input_w + access_y * in_img_w * num_channels + + access_x * num_channels + channel_id]; + } + platform::CudaAtomicAdd(&in_pos[0], + (out_pos[0] * y_coeffs[j] * x_coeffs[i])); + } + } + } +} + template static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, const Tensor& input, Tensor* output) { @@ -602,6 +802,11 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, ctx.cuda_device_context().stream()>>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); + } else if ("bicubic" == interp_method) { + KeBicubicInterpFw< + T><<>>( + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } } @@ -806,6 +1011,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); + } else if ("bicubic" == interp_method) { + KeBicubicInterpBw< + T><<>>( + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, + n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } } @@ -968,3 +1178,9 @@ REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel, REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad, ops::InterpolateGradOpCUDAKernel, ops::InterpolateGradOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(bicubic_interp, ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(bicubic_interp_grad, + ops::InterpolateGradOpCUDAKernel, + ops::InterpolateGradOpCUDAKernel); diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index b107d1e6656bec8d2b2214a94dc7a44246771bbb..8acd54200ecbb999be76c26f7ef999a2a7120ff9 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -10,10 +10,12 @@ limitations under the License. */ #pragma once +#include #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/hostdevice.h" namespace paddle { namespace operators { @@ -342,6 +344,106 @@ static void TrilinearInterpolation( } } +template +HOSTDEVICE inline T cubic_convolution1(T x, T A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +HOSTDEVICE inline T cubic_convolution2(T x, T A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) { + T A = -0.75; + + T x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + T x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) { + T coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +template +static void BicubicInterpolation(const Tensor& input, Tensor* output, + const float ratio_h, const float ratio_w, + const int in_h, const int in_w, const int n, + const int c, const int out_h, const int out_w, + const bool align_corners, + const DataLayout data_layout) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + + for (int k = 0; k < out_h; k++) { // loop for images + T y_n = align_corners ? static_cast(ratio_h * k) + : static_cast(ratio_h * (k + 0.5) - 0.5); + int input_y = floorf(y_n); + const T y_t = y_n - input_y; + + for (int l = 0; l < out_w; l++) { + T x_n = align_corners ? static_cast(ratio_w * l) + : static_cast(ratio_w * (l + 0.5) - 0.5); + int input_x = floorf(x_n); + const T x_t = x_n - input_x; + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + T coefficients[4]; + // interp 4 times in x direction + for (int ii = 0; ii < 4; ii++) { + int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1), + static_cast(0)); + int access_x_0 = + std::max(std::min(input_x - 1, in_w - 1), static_cast(0)); + int access_x_1 = + std::max(std::min(input_x + 0, in_w - 1), static_cast(0)); + int access_x_2 = + std::max(std::min(input_x + 1, in_w - 1), static_cast(0)); + int access_x_3 = + std::max(std::min(input_x + 2, in_w - 1), static_cast(0)); + if (data_layout == DataLayout::kNCHW) { + coefficients[ii] = + cubic_interp(input_t(i, j, access_y, access_x_0), + input_t(i, j, access_y, access_x_1), + input_t(i, j, access_y, access_x_2), + input_t(i, j, access_y, access_x_3), x_t); + } else { + coefficients[ii] = + cubic_interp(input_t(i, access_y, access_x_0, j), + input_t(i, access_y, access_x_1, j), + input_t(i, access_y, access_x_2, j), + input_t(i, access_y, access_x_3, j), x_t); + } + } + + // interp y direction + if (data_layout == DataLayout::kNCHW) { + output_t(i, j, k, l) = + cubic_interp(coefficients[0], coefficients[1], + coefficients[2], coefficients[3], y_t); + } else { + output_t(i, k, l, j) = + cubic_interp(coefficients[0], coefficients[1], + coefficients[2], coefficients[3], y_t); + } + } + } + } + } +} + template static void NearestNeighborInterpolateGrad( const Tensor& output_grad, Tensor* input_grad, const float ratio_h, @@ -509,6 +611,61 @@ static void TrilinearInterpolationGrad( } } +template +static void BicubicInterpolationGrad(const Tensor& output_grad, + Tensor* input_grad, const float ratio_h, + const float ratio_w, const int in_h, + const int in_w, const int n, const int c, + const int out_h, const int out_w, + const bool align_corners, + const DataLayout data_layout) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + + for (int k = 0; k < out_h; k++) { // loop for images + T y_n = align_corners ? static_cast(ratio_h * k) + : static_cast(ratio_h * (k + 0.5) - 0.5); + int input_y = floorf(y_n); + T y_t = y_n - input_y; + + for (int l = 0; l < out_w; l++) { + T x_n = align_corners ? static_cast(ratio_w * l) + : static_cast(ratio_w * (l + 0.5) - 0.5); + int input_x = floorf(x_n); + T x_t = x_n - input_x; + + T x_coeffs[4]; + T y_coeffs[4]; + + get_cubic_upsample_coefficients(x_coeffs, x_t); + get_cubic_upsample_coefficients(y_coeffs, y_t); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + // bicubic interpolation grad + for (int ii = 0; ii < 4; ii++) { + for (int jj = 0; jj < 4; jj++) { + int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1), + static_cast(0)); + int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1), + static_cast(0)); + if (data_layout == DataLayout::kNCHW) { + T grad = output_grad_t(i, j, k, l); + input_grad_t(i, j, access_y, access_x) += + grad * y_coeffs[jj] * x_coeffs[ii]; + } else { + T grad = output_grad_t(i, k, l, j); + input_grad_t(i, access_y, access_x, j) += + grad * y_coeffs[jj] * x_coeffs[ii]; + } + } + } + } + } + } + } +} + template static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, const Tensor& input, Tensor* output) { @@ -587,6 +744,9 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, } else if ("nearest" == interp_method) { NearestNeighborInterpolate(input, output, ratio_h, ratio_w, n, c, out_h, out_w, align_corners, data_layout); + } else if ("bicubic" == interp_method) { + BicubicInterpolation(input, output, ratio_h, ratio_w, in_h, in_w, n, c, + out_h, out_w, align_corners, data_layout); } } @@ -759,6 +919,10 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx, NearestNeighborInterpolateGrad(output_grad, input_grad, ratio_h, ratio_w, n, c, out_h, out_w, align_corners, data_layout); + } else if ("bicubic" == interp_method) { + BicubicInterpolationGrad(output_grad, input_grad, ratio_h, ratio_w, in_h, + in_w, n, c, out_h, out_w, align_corners, + data_layout); } } diff --git a/python/paddle/fluid/tests/unittests/test_bicubic_interp_op.py b/python/paddle/fluid/tests/unittests/test_bicubic_interp_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a044e04f502cc98c8e8815abf36a172159763029 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bicubic_interp_op.py @@ -0,0 +1,459 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle +from paddle.fluid import Program, program_guard +from paddle.nn.functional import * + + +def cubic_1(x, a): + return ((a + 2) * x - (a + 3)) * x * x + 1 + + +def cubic_2(x, a): + return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a + + +def cubic_interp1d(x0, x1, x2, x3, t): + param = [0, 0, 0, 0] + a = -0.75 + x_1 = t + x_2 = 1.0 - t + param[0] = cubic_2(x_1 + 1.0, a) + param[1] = cubic_1(x_1, a) + param[2] = cubic_1(x_2, a) + param[3] = cubic_2(x_2 + 1.0, a) + return x0 * param[0] + x1 * param[1] + x2 * param[2] + x3 * param[3] + + +def value_bound(input, w, h, x, y): + access_x = int(max(min(x, w - 1), 0)) + access_y = int(max(min(y, h - 1), 0)) + return input[:, :, access_y, access_x] + + +def bicubic_interp_np(input, + out_h, + out_w, + out_size=None, + actual_shape=None, + align_corners=True, + data_layout='kNCHW'): + """trilinear interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + batch_size, channel, in_h, in_w = input.shape + + ratio_h = ratio_w = 0.0 + if out_h > 1: + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + + if out_w > 1: + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w + + out = np.zeros((batch_size, channel, out_h, out_w)) + + for k in range(out_h): + if (align_corners): + h = ratio_h * k + else: + h = ratio_h * (k + 0.5) - 0.5 + input_y = np.floor(h) + y_t = h - input_y + for l in range(out_w): + if (align_corners): + w = ratio_w * l + else: + w = ratio_w * (l + 0.5) - 0.5 + input_x = np.floor(w) + x_t = w - input_x + for i in range(batch_size): + for j in range(channel): + coefficients = [0, 0, 0, 0] + for ii in range(4): + access_x_0 = int(max(min(input_x - 1, in_w - 1), 0)) + access_x_1 = int(max(min(input_x + 0, in_w - 1), 0)) + access_x_2 = int(max(min(input_x + 1, in_w - 1), 0)) + access_x_3 = int(max(min(input_x + 2, in_w - 1), 0)) + access_y = int(max(min(input_y - 1 + ii, in_h - 1), 0)) + + coefficients[ii] = cubic_interp1d( + input[i, j, access_y, access_x_0], + input[i, j, access_y, access_x_1], + input[i, j, access_y, access_x_2], + input[i, j, access_y, access_x_3], x_t) + out[i, j, k, l] = cubic_interp1d( + coefficients[0], coefficients[1], coefficients[2], + coefficients[3], y_t) + if data_layout == "NHWC": + out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC + return out.astype(input.dtype) + + +class TestBicubicInterpOp(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.data_layout = 'NCHW' + self.init_test_case() + self.op_type = "bicubic_interp" + input_np = np.random.random(self.input_shape).astype("float64") + + if self.data_layout == "NCHW": + in_h = self.input_shape[2] + in_w = self.input_shape[3] + else: + in_h = self.input_shape[1] + in_w = self.input_shape[2] + + if self.scale > 0: + out_h = int(in_h * self.scale) + out_w = int(in_w * self.scale) + else: + out_h = self.out_h + out_w = self.out_w + + output_np = bicubic_interp_np(input_np, out_h, out_w, self.out_size, + self.actual_shape, self.align_corners, + self.data_layout) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'data_layout': self.data_layout + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [2, 3, 5, 5] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = True + + +class TestBicubicInterpCase1(TestBicubicInterpOp): + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.align_corners = True + + +class TestBicubicInterpCase2(TestBicubicInterpOp): + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [3, 3, 9, 6] + self.out_h = 10 + self.out_w = 8 + self.scale = 0. + self.align_corners = True + + +class TestBicubicInterpCase3(TestBicubicInterpOp): + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.align_corners = False + + +class TestBicubicInterpCase4(TestBicubicInterpOp): + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.out_size = np.array([2, 2]).astype("int32") + self.align_corners = True + + +class TestBicubicInterpCase5(TestBicubicInterpOp): + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [3, 3, 9, 6] + self.out_h = 11 + self.out_w = 11 + self.scale = 0. + self.out_size = np.array([6, 4]).astype("int32") + self.align_corners = False + + +class TestBicubicInterpCase6(TestBicubicInterpOp): + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0 + self.out_size = np.array([64, 32]).astype("int32") + self.align_corners = False + + +class TestBicubicInterpSame(TestBicubicInterpOp): + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [2, 3, 32, 64] + self.out_h = 32 + self.out_w = 64 + self.scale = 0. + self.align_corners = True + + +class TestBicubicInterpDataLayout(TestBicubicInterpOp): + def init_test_case(self): + self.interp_method = 'bicubic' + self.input_shape = [2, 5, 5, 3] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = True + self.data_layout = "NHWC" + + +class TestBicubicInterpOpAPI(unittest.TestCase): + def test_case(self): + + x_data = np.random.random((2, 3, 6, 6)).astype("float32") + dim_data = np.array([12]).astype("int32") + shape_data = np.array([12, 12]).astype("int32") + actual_size_data = np.array([12, 12]).astype("int32") + scale_data = np.array([2.0]).astype("float32") + + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + with fluid.program_guard(prog, startup_prog): + + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + + dim = fluid.data(name="dim", shape=[1], dtype="int32") + shape_tensor = fluid.data( + name="shape_tensor", shape=[2], dtype="int32") + actual_size = fluid.data( + name="actual_size", shape=[2], dtype="int32") + scale_tensor = fluid.data( + name="scale_tensor", shape=[1], dtype="float32") + + out1 = interpolate( + x, out_shape=[12, 12], resample='BICUBIC', align_corners=False) + out2 = interpolate( + x, out_shape=[12, dim], resample='BICUBIC', align_corners=False) + out3 = interpolate( + x, + out_shape=shape_tensor, + resample='BICUBIC', + align_corners=False) + out4 = interpolate( + x, + out_shape=[4, 4], + actual_shape=actual_size, + resample='BICUBIC', + align_corners=False) + out5 = interpolate( + x, scale=scale_tensor, resample='BICUBIC', align_corners=False) + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "dim": dim_data, + "shape_tensor": shape_data, + "actual_size": actual_size_data, + "scale_tensor": scale_data + }, + fetch_list=[out1, out2, out3, out4, out5], + return_numpy=True) + + expect_res = bicubic_interp_np( + x_data, out_h=12, out_w=12, align_corners=False) + for res in results: + self.assertTrue(np.allclose(res, expect_res)) + + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(x_data) + interp = interpolate( + x, out_shape=[12, 12], resample='BICUBIC', align_corners=False) + dy_result = interp.numpy() + expect = bicubic_interp_np( + x_data, out_h=12, out_w=12, align_corners=False) + self.assertTrue(np.allclose(dy_result, expect)) + + +class TestBicubicOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # the input of interpoalte must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + self.assertRaises(TypeError, interpolate, x1) + + def test_mode_type(): + # mode must be "BILINEAR" "TRILINEAR" "NEAREST" "BICUBIC" + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + + out = interpolate( + x, + out_shape=[12, 12], + resample='UNKONWN', + align_corners=False) + + def test_input_shape(): + x = fluid.data(name="x", shape=[2], dtype="float32") + out = interpolate( + x, + out_shape=[12, 12], + resample='BICUBIC', + align_corners=False) + + def test_align_corcers(): + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + interpolate( + x, out_shape=[12, 12], resample='BICUBIC', align_corners=3) + + def test_out_shape(): + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + out = interpolate( + x, out_shape=[12], resample='BICUBIC', align_corners=False) + + def test_attr_data_format(): + # for 5-D input, data_format only can be NCDHW or NDHWC + input = fluid.data( + name="input", shape=[2, 3, 6, 9, 4], dtype="float32") + out = interpolate( + input, + out_shape=[4, 8, 4, 5], + resample='TRILINEAR', + data_format='NHWC') + + def test_actual_shape(): + # the actual_shape must be Variable. + x = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + out = interpolate( + x, + out_shape=[12, 12], + resample='BICUBIC', + align_corners=False) + + def test_scale_value(): + # the scale must be greater than zero. + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + out = interpolate( + x, + out_shape=None, + resample='BICUBIC', + align_corners=False, + scale=-2.0) + + def test_attr_5D_input(): + # for 5-D input, data_format only can be NCDHW or NDHWC + input = fluid.data( + name="input", shape=[2, 3, 6, 9, 4], dtype="float32") + out = interpolate( + input, + out_shape=[4, 8, 4, 5], + resample='TRILINEAR', + data_format='NDHWC') + + def test_scale_type(): + # the scale must be greater than zero. + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + scale = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + out = interpolate( + x, + out_shape=None, + resample='BICUBIC', + align_corners=False, + scale=scale) + + def test_align_mode(): + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + out = interpolate( + x, + out_shape=None, + resample='NEAREST', + align_corners=False, + align_mode=2, + scale=1.0) + + def test_outshape_and_scale(): + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + out = interpolate( + x, + out_shape=None, + resample='BICUBIC', + align_corners=False, + scale=None) + + self.assertRaises(ValueError, test_mode_type) + self.assertRaises(ValueError, test_input_shape) + self.assertRaises(TypeError, test_align_corcers) + self.assertRaises(ValueError, test_attr_data_format) + self.assertRaises(TypeError, test_actual_shape) + self.assertRaises(ValueError, test_scale_value) + self.assertRaises(ValueError, test_out_shape) + self.assertRaises(ValueError, test_attr_5D_input) + self.assertRaises(TypeError, test_scale_type) + self.assertRaises(ValueError, test_align_mode) + self.assertRaises(ValueError, test_outshape_and_scale) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py index e8e025ce6a8c6642736102297589c1f304aeea27..e27f5d3a2706354dc2e08acc531bc768200a6bd6 100755 --- a/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py @@ -19,6 +19,7 @@ import numpy as np from op_test import OpTest import paddle.fluid.core as core import paddle.fluid as fluid +from paddle.nn.functional import * def trilinear_interp_np(input, @@ -586,6 +587,15 @@ class TestTrilinearInterpAPI(unittest.TestCase): out4 = fluid.layers.resize_trilinear( x, out_shape=[4, 4, 8], actual_shape=actual_size) out5 = fluid.layers.resize_trilinear(x, scale=scale_tensor) + out6 = interpolate( + x, scale=scale_tensor, resample='TRILINEAR', data_format="NCDHW") + out7 = interpolate( + x, out_shape=[4, 4, 8], resample='TRILINEAR', data_format="NCDHW") + out8 = interpolate( + x, + out_shape=shape_tensor, + resample='TRILINEAR', + data_format="NCDHW") x_data = np.random.random((2, 3, 6, 9, 4)).astype("float32") dim_data = np.array([18]).astype("int32") diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 7de9b7db0a81ad5d8b2338ff7bec6030a36f0386..b516946aa74dba206d0eee78a29d827baa789dd1 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -187,4 +187,4 @@ from .extension import row_conv #DEFINE_ALIAS # from .common import unfold #DEFINE_ALIAS # from .common import bilinear_tensor_product #DEFINE_ALIAS # from .common import assign #DEFINE_ALIAS -# from .common import interpolate #DEFINE_ALIAS +from .common import interpolate #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 17fc3142814d6d7166f207766a00c91399922953..cb4473d22f1fa1e6f1bfb6d29a6c390d98db3064 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.layers.tensor import Variable, fill_constant + # TODO: define the common functions to build a neural network # __all__ = ['dropout', # 'embedding', @@ -25,3 +29,395 @@ # 'bilinear_tensor_product', # 'assign', # 'interpolate'] + +__all__ = ['interpolate'] + + +def interpolate(input, + out_shape=None, + scale=None, + name=None, + resample='BILINEAR', + actual_shape=None, + align_corners=True, + align_mode=1, + data_format='NCHW'): + """ + This op resizes a batch of images. + The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w) + or (num_batches, in_h, in_w, channels), or a 5-D Tensor of the shape + (num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels), + and the resizing only applies on the three dimensions(depth, height and width). + **Warning:** the parameter :attr:`actual_shape` will be deprecated in the + future and only use :attr:`out_shape` instead. + Supporting resample methods: + 'BILINEAR' : Bilinear interpolation + 'TRILINEAR' : Trilinear interpolation + 'NEAREST' : Nearest neighbor interpolation + 'BICUBIC' : Bicubic interpolation + + Nearest neighbor interpolation is to perform nearest neighbor interpolation + in both the 3rd dimension(in height direction) and the 4th dimension(in width + direction) on input tensor. + + Bilinear interpolation is an extension of linear interpolation for + interpolating functions of two variables (e.g. H-direction and + W-direction in this op) on a rectilinear 2D grid. The key idea is + to perform linear interpolation first in one direction, and then + again in the other direction. + + Trilinear interpolation is an extension of linear interpolation for + interpolating functions of three variables (e.g. D-direction, + H-direction and W-direction in this op) on a rectilinear 3D grid. + The linear interpolation is performed on three directions. + Align_corners and align_mode are optional parameters,the calculation method + of interpolation can be selected by them. + + Bicubic interpolation is an extension of cubic interpolation for interpolating + data points on a two-dimensional regular grid. The interpolated surface is + smoother than corresponding surfaces obtained by bilinear interpolation or + nearest-neighbor interpolation. + + Example: + + .. code-block:: text + + For scale: + + if align_corners = True && out_size > 1 : + scale_factor = (in_size-1.0)/(out_size-1.0) + + else: + + scale_factor = float(in_size/out_size) + + + Nearest neighbor interpolation: + + if: + align_corners = False + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = floor (H_{in} * scale_{factor}) + W_out = floor (W_{in} * scale_{factor}) + else: + align_corners = True + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = round(H_{in} * scale_{factor}) + W_out = round(W_{in} * scale_{factor}) + Bilinear interpolation: + if: + align_corners = False , align_mode = 0 + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + else: + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + Bicubic interpolation: + + if: + align_corners = False + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + else: + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + Trilinear interpolation: + if: + align_corners = False , align_mode = 0 + + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + + D_out = (D_{in}+0.5) * scale_{factor} - 0.5 + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + else: + + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + D_out = D_{in} * scale_{factor} + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + For details of nearest neighbor interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. + For details of bilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bilinear_interpolation. + For details of trilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Trilinear_interpolation. + For details of bicubic interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bicubic_interpolation + Parameters: + input (Variable): 4-D or 5-D Tensor, its data type is float32, float64, or uint8, + its data format is specified by :attr:`data_format`. + out_shape(list|tuple|Variable|None): Output shape of image resize + layer, the shape is (out_h, out_w) when input is a 4-D Tensor and is + (out_d, out_h, out_w) when input is a 5-D Tensor. Default: None. If + a list, each element can be an integer or a Tensor Variable of shape: [1]. + If a Tensor Variable, its dimensions size should be a 1. + scale(float|Variable|None): The multiplier for the input height or width. At + least one of :attr:`out_shape` or :attr:`scale` must be set. + And :attr:`out_shape` has a higher priority than :attr:`scale`. + Default: None. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + resample(str): The resample method. It supports 'BILINEAR', 'TRILINEAR' , + 'BICUBIC' and 'NEAREST' currently. Default: 'BILINEAR' + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + :attr:`out_shape` if you want to specify output + shape dynamically, because :attr:`actual_shape` + will be deprecated. When using actual_shape to + specify output shape, one of :attr:`out_shape` + and :attr:`scale` should also be set, otherwise + errors would be occurred in graph constructing stage. + Default: None + align_corners(bool) : An optional bool, If True, the centers of the 4 corner pixels of the + input and output tensors are aligned, preserving the values at the + corner pixels. + Default: True + align_mode(int) : An optional for bilinear interpolation. can be \'0\' + for src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for + src_idx = scale*dst_index. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`, `"NCDHW"`, + `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored + in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`. + Returns: + A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels), + or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels). + Raises: + TypeError: out_shape should be a list or tuple or Variable. + TypeError: actual_shape should either be Variable or None. + ValueError: The 'resample' of image_resize can only be 'BILINEAR', + 'TRILINEAR', 'BICUBIC', or 'NEAREST' currently. + ValueError: 'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor. + ValueError: 'TRILINEAR' only support 5-D tensor. + ValueError: One of out_shape and scale must not be None. + ValueError: out_shape length should be 2 for input 4-D tensor. + ValueError: out_shape length should be 3 for input 5-D tensor. + ValueError: scale should be greater than zero. + TypeError: align_corners should be a bool value + ValueError: align_mode can only be '0' or '1' + ValueError: data_format can only be 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'. + Examples: + .. code-block:: python + + #declarative mode + import paddle + import numpy as np + input = fluid.data(name="input", shape=[None,3,6,10]) + #1 + output = paddle.nn.functional.interpolate(input=input,out_shape=[12,12]) + #2 + #x = np.array([2]).astype("int32") + #dim1 = fluid.data(name="dim1", shape=[1], dtype="int32") + #fluid.layers.assign(input=x, output=dim1) + #output = paddle.nn.functional.interpolate(input=input,out_shape=[12,dim1]) + #3 + #x = np.array([3,12]).astype("int32") + #shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32") + #fluid.layers.assign(input=x, output=shape_tensor) + #output = paddle.nn.functional.interpolate(input=input,out_shape=shape_tensor) + #4 + #x = np.array([0.5]).astype("float32") + #scale_tensor = fluid.data(name="scale", shape=[1], dtype="float32") + #fluid.layers.assign(x,scale_tensor) + #output = paddle.nn.functional.interpolate(input=input,scale=scale_tensor) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + input_data = np.random.rand(2,3,6,10).astype("float32") + output_data = exe.run(fluid.default_main_program(), + feed={"input":input_data}, + fetch_list=[output], + return_numpy=True) + + print(output_data[0].shape) + #1 + # (2, 3, 12, 12) + #2 + # (2, 3, 12, 2) + #3 + # (2, 3, 3, 12) + #4 + # (2, 3, 3, 5) + #imperative mode + import paddle.fluid.dygraph as dg + with dg.guard(place) as g: + input = dg.to_variable(input_data) + output = paddle.nn.functional.interpolate(input=input, out_shape=[12,12]) + print(output.shape) + # [2L, 3L, 12L, 12L] + """ + resample_methods = { + 'BILINEAR': 'bilinear', + 'TRILINEAR': 'trilinear', + 'NEAREST': 'nearest', + 'BICUBIC': 'bicubic', + } + if resample not in resample_methods: + raise ValueError( + "The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR', " + " 'BICUBIC' or 'NEAREST' currently.") + resample_type = resample_methods[resample] + + if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(input.shape) != 4: + raise ValueError( + "'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.") + if resample == 'TRILINEAR' and len(input.shape) != 5: + raise ValueError("'TRILINEAR'only support 5-D tensor.") + + if not isinstance(align_corners, bool): + raise TypeError("Attr align_corners should be a bool value") + if align_mode != 0 and align_mode != 1: + raise ValueError("align_mode can only be 0 or 1") + + if out_shape is None and scale is None: + raise ValueError("One of out_shape and scale must not be None.") + helper = LayerHelper('{}_interp'.format(resample_type), **locals()) + dtype = helper.input_dtype() + + if len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']: + raise ValueError( + "Got wrong value for param `data_format`: " + data_format + + " received but only `NCHW` or `NHWC` supported for 4-D input.") + elif len(input.shape) == 5 and data_format not in ['NCDHW', 'NDHWC']: + raise ValueError( + "Got wrong value for param `data_format`: " + data_format + + " received but only `NCDHW` or `NDHWC` supported for 5-D input.") + + def _is_list_or_turple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if data_format == 'NCHW' or data_format == 'NCDHW': + data_layout = 'NCHW' + if data_format == 'NHWC' or data_format == 'NDHWC': + data_layout = 'NHWC' + + inputs = {"X": input} + attrs = { + "out_d": -1, + "out_h": -1, + "out_w": -1, + "interp_method": resample_type, + "align_corners": align_corners, + "align_mode": align_mode, + "data_layout": data_layout + } + + if out_shape is not None: + if isinstance(out_shape, Variable): + out_shape.stop_gradient = True + inputs['OutSize'] = out_shape + else: + if not (_is_list_or_turple_(out_shape)): + raise TypeError( + "out_shape should be a list or tuple or Variable.") + # Validate the shape + contain_var = False + for dim_idx, dim_size in enumerate(out_shape): + if isinstance(dim_size, Variable): + contain_var = True + continue + assert dim_size > 0, ( + "Each dimension size given in out_shape must be greater than 0." + ) + + if contain_var: + new_size_tensor = [] + size_list = [] + for dim in out_shape: + if isinstance(dim, Variable): + dim.stop_gradient = True + new_size_tensor.append(dim) + size_list.append(-1) + else: + assert (isinstance(dim, int)) + temp_out = helper.create_variable_for_type_inference( + 'int32') + fill_constant( + [1], 'int32', dim, force_cpu=True, out=temp_out) + new_size_tensor.append(temp_out) + size_list.append(dim) + inputs['SizeTensor'] = new_size_tensor + + if len(input.shape) == 4: + if len(out_shape) != 2: + raise ValueError("out_shape length should be 2 for " + "input 4-D tensor.") + if contain_var: + attrs['out_h'] = size_list[0] + attrs['out_w'] = size_list[1] + else: + out_shape = list(map(int, out_shape)) + attrs['out_h'] = out_shape[0] + attrs['out_w'] = out_shape[1] + if len(input.shape) == 5: + if len(out_shape) != 3: + raise ValueError("out_shape length should be 3 for " + "input 5-D tensor.") + if contain_var: + attrs['out_d'] = size_list[0] + attrs['out_h'] = size_list[1] + attrs['out_w'] = size_list[2] + else: + out_shape = list(map(int, out_shape)) + attrs['out_d'] = out_shape[0] + attrs['out_h'] = out_shape[1] + attrs['out_w'] = out_shape[2] + + else: + if isinstance(scale, Variable): + scale.stop_gradient = True + inputs["Scale"] = scale + elif isinstance(scale, float) or isinstance(scale, int): + if scale <= 0: + raise ValueError("Attr(scale) should be greater than zero.") + attrs['scale'] = float(scale) + else: + raise TypeError( + "Attr(scale)'s type should be float, int or Variable.") + + if isinstance(actual_shape, Variable): + warnings.warn( + "actual_shape will be deprecated, it is recommended to use " + "out_shape instead of actual_shape to specify output shape dynamically." + ) + actual_shape.stop_gradient = True + inputs["OutSize"] = actual_shape + elif actual_shape is not None: + raise TypeError("actual_shape should either be Variable or None.") + + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='{}_interp'.format(resample_type), + inputs=inputs, + outputs={"Out": out}, + attrs=attrs) + return out