From 7fedf26b8778c5e1da1facfb32bd17ac9ca9f0a0 Mon Sep 17 00:00:00 2001 From: FDInSky <48318485+FDInSky@users.noreply.github.com> Date: Tue, 12 May 2020 17:01:35 +0800 Subject: [PATCH] add linear interpolate operator (#23357) * test=develop add linear interpolate operator --- paddle/fluid/operators/interpolate_op.cc | 173 ++++++- paddle/fluid/operators/interpolate_op.cu | 323 +++++++++++- paddle/fluid/operators/interpolate_op.h | 277 ++++++++++- python/paddle/fluid/layers/nn.py | 186 ++++++- .../tests/unittests/test_linear_interp_op.py | 462 ++++++++++++++++++ python/paddle/nn/__init__.py | 4 +- python/paddle/nn/functional/common.py | 21 +- python/paddle/nn/layer/__init__.py | 2 +- python/paddle/nn/layer/common.py | 235 ++++++++- 9 files changed, 1597 insertions(+), 86 deletions(-) create mode 100755 python/paddle/fluid/tests/unittests/test_linear_interp_op.py diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 2a7059c1d8..49da719880 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -21,6 +21,85 @@ namespace operators { using framework::Tensor; using DataLayout = framework::DataLayout; +static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) { + auto dim_x = ctx->GetInputDim("X"); + auto interp_method = ctx->Attrs().Get("interp_method"); + + PADDLE_ENFORCE_EQ("linear", interp_method, + platform::errors::InvalidArgument( + "Interpolation method can only be \"linear\" when" + "Input(X) dimension is 3, but got method = %s .", + interp_method)); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + if (ctx->HasInputs("SizeTensor")) { + // top prority size + auto inputs_name = ctx->Inputs("SizeTensor"); + PADDLE_ENFORCE_EQ( + inputs_name.size(), 1, + platform::errors::InvalidArgument( + "Input(SizeTensor)'size of Op(interpolate) must be 1. " + "Attr(out_shape)'s length must be 1 for 3-D input tensor, but got " + "size = %d .", + inputs_name.size())); + int out_w = ctx->Attrs().Get("out_w"); + framework::DDim dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {dim_x[0], dim_x[1], out_w}; + } else { + dim_out = {dim_x[0], out_w, dim_x[2]}; + } + ctx->SetOutputDim("Out", dim_out); + + return; + } + + int out_w; + if (ctx->HasInput("Scale")) { + auto scale_tensor = ctx->GetInputDim("Scale"); + PADDLE_ENFORCE_EQ( + scale_tensor.size(), 1, + platform::errors::InvalidArgument( + "Scale's dimension size must be 1, but got dimension = %d .", + scale_tensor.size())); + out_w = -1; + } else { + float scale = ctx->Attrs().Get("scale"); + if (scale > 0) { + // round down + out_w = (data_layout == DataLayout::kNCHW + ? static_cast(dim_x[2] * scale) + : static_cast(dim_x[1] * scale)); + // protect when input shape is -1 + out_w = out_w > 0 ? out_w : -1; + } else { + out_w = ctx->Attrs().Get("out_w"); + } + } + + if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { + auto out_size_dim = ctx->GetInputDim("OutSize"); + PADDLE_ENFORCE_EQ( + out_size_dim.size(), 1, + platform::errors::InvalidArgument( + "OutSize's dimension size must be 1, but got dimention = %d .", + out_size_dim.size())); + PADDLE_ENFORCE_EQ(out_size_dim[0], 1, platform::errors::InvalidArgument( + "OutSize's dim[0] must be 1")); + ctx->ShareLoD("X", "Out"); + return; + } + + framework::DDim dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {dim_x[0], dim_x[1], out_w}; + } else { + dim_out = {dim_x[0], out_w, dim_x[2]}; + } + ctx->SetOutputDim("Out", dim_out); +} + static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) { auto dim_x = ctx->GetInputDim("X"); auto interp_method = ctx->Attrs().Get("interp_method"); @@ -29,7 +108,8 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) { "bilinear" == interp_method || "nearest" == interp_method || "bicubic" == interp_method, "Interpolation method can only be \"bilinear\" or \"nearest\" when " - "Input(X) dimension is 4"); + "Input(X) dimension is 4, but got method = %s .", + interp_method); const DataLayout data_layout = framework::StringToDataLayout( ctx->Attrs().Get("data_layout")); @@ -38,8 +118,11 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) { auto inputs_name = ctx->Inputs("SizeTensor"); PADDLE_ENFORCE_EQ( inputs_name.size(), 2, - "Input(SizeTensor)'size of Op(interpolate) must be 2. " - "Attr(out_shape)'s length must be 2 for 4-D input tensor."); + platform::errors::InvalidArgument( + "Input(SizeTensor)'size of Op(interpolate) must be 2. " + "Attr(out_shape)'s length must be 2 for 4-D input " + "tensor, but got size = %d .", + inputs_name.size())); int out_h = ctx->Attrs().Get("out_h"); int out_w = ctx->Attrs().Get("out_w"); framework::DDim dim_out; @@ -56,8 +139,11 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) { int out_h, out_w; if (ctx->HasInput("Scale")) { auto scale_tensor = ctx->GetInputDim("Scale"); - PADDLE_ENFORCE_EQ(scale_tensor.size(), 1, - "Scale's dimension size must be 1."); + PADDLE_ENFORCE_EQ( + scale_tensor.size(), 1, + platform::errors::InvalidArgument( + "Scale's dimension size must be 1, but got dimension = %d .", + scale_tensor.size())); out_h = -1; out_w = -1; } else { @@ -81,9 +167,16 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) { if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { auto out_size_dim = ctx->GetInputDim("OutSize"); - PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, - "OutSize's dimension size must be 1"); - PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2"); + PADDLE_ENFORCE_EQ( + out_size_dim.size(), 1, + platform::errors::InvalidArgument( + "OutSize's dimension size must be 1, but got dimension = %d .", + out_size_dim.size())); + PADDLE_ENFORCE_EQ( + out_size_dim[0], 2, + platform::errors::InvalidArgument( + "OutSize's dim[0] must be 2, but got dimention = %d .", + out_size_dim[0])); ctx->ShareLoD("X", "Out"); return; } @@ -101,9 +194,12 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { auto dim_x = ctx->GetInputDim("X"); auto interp_method = ctx->Attrs().Get("interp_method"); - PADDLE_ENFORCE("trilinear" == interp_method, - "Interpolation method can only be \"trilinear\" when Input(X) " - "dimension is 5"); + PADDLE_ENFORCE_EQ( + "trilinear", interp_method, + platform::errors::InvalidArgument( + "Interpolation method can only be \"trilinear\" when Input(X) " + "dimension is 5, but got method = %s .", + interp_method)); const DataLayout data_layout = framework::StringToDataLayout( ctx->Attrs().Get("data_layout")); @@ -112,8 +208,11 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { auto inputs_name = ctx->Inputs("SizeTensor"); PADDLE_ENFORCE_EQ( inputs_name.size(), 3, - "Input(SizeTensor)'s size of Op(interpolate) must be 3. " - "Attr(out_shape)'s length must be 3 for 5-D input tensor."); + platform::errors::InvalidArgument( + "Input(SizeTensor)'s size of Op(interpolate) must be 3. " + "Attr(out_shape)'s length must be 3 for 5-D input " + "tensor, but got size = %d .", + inputs_name.size())); int out_d = ctx->Attrs().Get("out_d"); int out_h = ctx->Attrs().Get("out_h"); int out_w = ctx->Attrs().Get("out_w"); @@ -131,8 +230,11 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { int out_d, out_h, out_w; if (ctx->HasInput("Scale")) { auto scale_tensor = ctx->GetInputDim("Scale"); - PADDLE_ENFORCE_EQ(scale_tensor.size(), 1, - "Scale's dimension size must be 1"); + PADDLE_ENFORCE_EQ( + scale_tensor.size(), 1, + platform::errors::InvalidArgument( + "Scale's dimension size must be 1, but got size = %d .", + scale_tensor.size())); out_d = -1; out_h = -1; out_w = -1; @@ -163,8 +265,11 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { auto out_size_dim = ctx->GetInputDim("OutSize"); PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, - "OutSize's dimension size must be 1"); - PADDLE_ENFORCE_EQ(out_size_dim[0], 3, "OutSize's dim[0] must be 3"); + "OutSize's dimension size must be 1, but got size =%d .", + out_size_dim.size()); + PADDLE_ENFORCE_EQ(out_size_dim[0], 3, + "OutSize's dim[0] must be 3, but got size = %d .", + out_size_dim[0]); ctx->ShareLoD("X", "Out"); return; } @@ -190,10 +295,16 @@ class InterpolateOp : public framework::OperatorWithKernel { "Output(Out) of InterpolationOp should not be null."); auto dim_x = ctx->GetInputDim("X"); // NCHW format - PADDLE_ENFORCE(dim_x.size() == 4 || dim_x.size() == 5, - "Input(X) dimension must be 4 or 5"); - - if (dim_x.size() == 4) { + PADDLE_ENFORCE( + dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5, + platform::errors::Unimplemented( + "Input(X) dimension must be 3, 4 or 5, but got dimension = %d .", + dim_x.size())); + + if (dim_x.size() == 3) { + // shape check for 1D interpolate for input tensor shape NCHW + Interpolate1DInferShapeCheck(ctx); + } else if (dim_x.size() == 4) { // shape check for 2D interpolate for input tensor shape NCHW Interpolate2DInferShapeCheck(ctx); } else { // dim_x.size() == 5 @@ -262,7 +373,8 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("scale", "scale factor of interpolate op.").SetDefault(0.); AddAttr("interp_method", "(string, default \"bilinear\"), interpolation " - "method, can be \"bilinear\" for " + "method, can be \"linear\" for linear interpolation" + ",\"bilinear\" for " "bilinear interpolation, \"trilinear\" for trilinear " "interpolation and \"nearest\" for nearest " "neighbor interpolation, and \"bicubic\" for bicubic" @@ -284,12 +396,15 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { This operator samples input X to given output shape by using specified interpolation method, the interpolation methods can be \"nearest\" for nearest neighbor interpolation and \"bilinear\" for bilinear - interpolation. + interpolation and \"linear\" for linear interpolation.. Nearest neighbor interpolation is to perform nearest neighbor interpolation in both the 3rd dimension(in height direction) and the 4th dimension(in width direction) on input tensor. - + + Linear interpolation is the method of using a line connecting two known quantities + to determine the value of an unknown quantity between the two known quantities. + Bilinear interpolation is an extension of linear interpolation for interpolating functions of two variables (e.g. H-direction and W-direction in this op) on a rectilinear 2D grid. The key idea is @@ -512,6 +627,16 @@ REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel, ops::InterpolateKernel); REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel, ops::InterpolateGradKernel); +REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker, + ops::InterpolateGradMaker, + ops::InterpolateGradMaker); +REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad, + ops::InterpolateGradNoNeedBufferVarsInference); +REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel, + ops::InterpolateKernel, + ops::InterpolateKernel); +REGISTER_OP_CPU_KERNEL(linear_interp_grad, ops::InterpolateGradKernel, + ops::InterpolateGradKernel); REGISTER_OP_CPU_KERNEL(bicubic_interp, ops::InterpolateKernel, ops::InterpolateKernel); REGISTER_OP_CPU_KERNEL(bicubic_interp_grad, ops::InterpolateGradKernel, diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 5beaa64664..5f17f29605 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -112,6 +112,119 @@ __global__ void KeNearestNeighborInterpBw( } } +template +__global__ void KeLinearInterpFw(const T* in, const size_t in_img_w, + const size_t input_w, T* out, + const size_t out_img_w, const size_t output_h, + const size_t output_w, + const size_t num_channels, const float ratio_w, + const bool align_corners, const int align_mode, + const DataLayout data_layout) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id, out_img_idy, out_img_idx; + if (data_layout == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idx = tid % out_img_w; + } else { + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + int in_img_idx = align_flag + ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) + : static_cast(ratio_w * out_img_idx); + in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id + + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + T w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + if (data_layout == DataLayout::kNCHW) { + const T* in_pos = + &in[out_id_h * out_id_w + channel_id * in_img_size + in_img_idx]; + // linear interpolation + out[out_id_h * output_w + out_id_w] = + w2lambda * in_pos[0] + w1lambda * in_pos[w_id]; + + } else { + const T* in_pos = + &in[out_id_h * input_w + in_img_idx * num_channels + channel_id]; + // linear interpolation + out[out_id_h * output_w + out_id_w] = + w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]; + } + } +} + +template +__global__ void KeLinearInterpBw(T* in, const size_t in_img_w, + const size_t input_w, const T* out, + const size_t out_img_w, const size_t output_h, + const size_t output_w, + const size_t num_channels, const T ratio_w, + const bool align_corners, const int align_mode, + const DataLayout data_layout) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id, out_img_idx; + if (data_layout == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idx = tid % out_img_w; + } else { + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 + : ratio_w * out_img_idx; + in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id + + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + T w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + T* in_pos; + if (data_layout == DataLayout::kNCHW) { + in_pos = &in[out_id_h * input_w + channel_id * in_img_size + in_img_idx]; + } else { + in_pos = &in[out_id_h * input_w + in_img_idx * num_channels + channel_id]; + } + const T* out_pos = &out[out_id_w]; + + if (data_layout == DataLayout::kNCHW) { + platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos[w_id], w1lambda * out_pos[0]); + } else { + platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos[w_id * num_channels], + w1lambda * out_pos[0]); + } + } +} + template __global__ void KeBilinearInterpFw( const T* in, const size_t in_img_h, const size_t in_img_w, @@ -706,6 +819,84 @@ __global__ void KeBicubicInterpBw( } } +template +static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx, + const Tensor& input, Tensor* output) { + auto* input_data = input.data(); + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_w = ctx.Attr("out_w"); + + auto list_new_shape_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_shape_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape(list_new_shape_tensor); + out_w = new_size[0]; + } else { + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + if (scale > 0) { + out_w = static_cast(in_w * scale); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_w = size_data[0]; + } + } + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + framework::DDim dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {n, c, out_w}; + } else { + dim_out = {n, out_w, c}; + } + auto output_data = output->mutable_data(dim_out, ctx.GetPlace()); + + if (in_w == out_w) { + framework::TensorCopy(input, ctx.GetPlace(), output); + return; + } + + float ratio_w = 0.f; + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1.0) / (out_w - 1.0) + : static_cast(in_w) / out_w; + } + + int in_cw = c * in_w; + int out_cw = c * out_w; + int pixelNum = n * out_cw; + + platform::GpuLaunchConfig config = + platform::getGpuLaunchConfig(pixelNum, ctx); + + if ("linear" == interp_method) { + KeLinearInterpFw<<>>( + input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w, + align_corners, align_mode, data_layout); + } +} + template static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, const Tensor& input, Tensor* output) { @@ -751,12 +942,12 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, out_w = size_data[1]; } } - PADDLE_ENFORCE_GT( - out_h, 0, - "out_h in Attr(out_shape) of Op(interpolate) should be greater than 0."); - PADDLE_ENFORCE_GT( - out_w, 0, - "out_w in Attr(out_shape) of Op(interpolate) should be greater than 0."); + PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( + "out_h in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); framework::DDim dim_out; if (data_layout == DataLayout::kNCHW) { @@ -859,15 +1050,15 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, out_w = size_data[2]; } } - PADDLE_ENFORCE_GT( - out_d, 0, - "out_d in Attr(out_shape) of Op(interpolate) should be greater than 0."); - PADDLE_ENFORCE_GT( - out_h, 0, - "out_h in Attr(out_shape) of Op(interpolate) should be greater than 0."); - PADDLE_ENFORCE_GT( - out_w, 0, - "out_w in Attr(out_shape) of Op(interpolate) should be greater than 0."); + PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument( + "out_d in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( + "out_h in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); framework::DDim dim_out; if (data_layout == DataLayout::kNCHW) { @@ -917,6 +1108,84 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, } } +template +static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx, + Tensor* input_grad, const Tensor output_grad) { + auto* input = ctx.Input("X"); + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_w = ctx.Attr("out_w"); + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + if (scale > 0) { + out_w = static_cast(in_w * scale); + } + + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_w = size_data[0]; + } + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape(list_new_size_tensor); + out_w = new_size[0]; + } + + auto* output_grad_data = output_grad.data(); + framework::DDim dim_grad; + if (data_layout == DataLayout::kNCHW) { + dim_grad = {n, c, in_w}; + } else { + dim_grad = {n, in_w, c}; + } + input_grad->mutable_data(dim_grad, ctx.GetPlace()); + auto* input_grad_data = input_grad->mutable_data(dim_grad, ctx.GetPlace()); + auto& device_ctx = ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + if (in_w == out_w) { + framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad); + return; + } + + float ratio_w = 0.f; + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + int in_cw = c * in_w; + int out_cw = c * out_w; + int pixelNum = n * out_cw; + + platform::GpuLaunchConfig config = + platform::getGpuLaunchConfig(pixelNum, ctx); + + if ("linear" == interp_method) { + KeLinearInterpBw<<>>( + input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c, + ratio_w, align_corners, align_mode, data_layout); + } +} + template static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, Tensor* input_grad, const Tensor output_grad) { @@ -1124,13 +1393,16 @@ template class InterpolateOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::NotFound("This kernel only runs on GPU device.")); auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto input_dims = input->dims(); - if (input_dims.size() == 4) { // 2D interpolation + if (input_dims.size() == 3) { // 1D interpolation + Interpolate1DCUDAFwd(ctx, *input, output); + } else if (input_dims.size() == 4) { // 2D interpolation Interpolate2DCUDAFwd(ctx, *input, output); } else if (input_dims.size() == 5) { // 3D interpolation Interpolate3DCUDAFwd(ctx, *input, output); @@ -1142,13 +1414,16 @@ template class InterpolateGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::NotFound("This kernel only runs on GPU device.")); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Out")); auto output_grad_dims = output_grad->dims(); - if (output_grad_dims.size() == 4) { // 2D interpolation + if (output_grad_dims.size() == 3) { // 1D interpolation + Interpolate1DCUDABwd(ctx, input_grad, *output_grad); + } else if (output_grad_dims.size() == 4) { // 2D interpolation Interpolate2DCUDABwd(ctx, input_grad, *output_grad); } else if (output_grad_dims.size() == 5) { // 3D interpolation Interpolate3DCUDABwd(ctx, input_grad, *output_grad); @@ -1178,6 +1453,12 @@ REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel, REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad, ops::InterpolateGradOpCUDAKernel, ops::InterpolateGradOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(linear_interp, ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(linear_interp_grad, + ops::InterpolateGradOpCUDAKernel, + ops::InterpolateGradOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(bicubic_interp, ops::InterpolateOpCUDAKernel, ops::InterpolateOpCUDAKernel, ops::InterpolateOpCUDAKernel); diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index 8acd54200e..5fbde701fc 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -32,12 +32,12 @@ inline std::vector get_new_shape( std::vector vec_new_shape; for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { auto tensor = list_new_shape_tensor[i]; - PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}), - "shape of dim tensor should be [1]"); + PADDLE_ENFORCE_EQ( + tensor->dims(), framework::make_ddim({1}), + platform::errors::InvalidArgument("shape of dim tensor should be [1]")); if (platform::is_gpu_place(tensor->place())) { framework::Tensor temp; TensorCopySync(*tensor, platform::CPUPlace(), &temp); - vec_new_shape.push_back(static_cast(*temp.data())); } else { vec_new_shape.push_back(static_cast(*tensor->data())); @@ -64,7 +64,13 @@ inline void ExtractNCDWH(const framework::DDim& dims, const DataLayout& data_layout, int* N, int* C, int* D, int* H, int* W) { *N = dims[0]; - if (dims.size() == 4) { + + if (dims.size() == 3) { + *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[2]; + *D = 1; + *H = 1; + *W = data_layout == DataLayout::kNCHW ? dims[2] : dims[1]; + } else if (dims.size() == 4) { *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3]; *D = 1; *H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1]; @@ -107,6 +113,103 @@ static void NearestNeighborInterpolate(const Tensor& input, Tensor* output, } } +template +static void LinearInterpolation(const Tensor& input, Tensor* output, + const float ratio_w, const int in_w, + const int n, const int c, const int out_w, + const bool align_corners, const bool align_mode, + const DataLayout data_layout) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + bool align_flag = (align_mode == 0 && !align_corners); + + std::vector vx_w, vx_e; + std::vector vd_w, vd_e; + vx_w.reserve(out_w); + vx_e.reserve(out_w); + vd_w.reserve(out_w); + vd_e.reserve(out_w); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int l = 0; l < out_w; l++) { + int x_w = align_flag ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); + x_w = (x_w > 0) ? x_w : 0; // w + int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w; // w_id + + float idx_src_x = ratio_w * (l + 0.5) - 0.5; + idx_src_x = (idx_src_x > 0) ? idx_src_x : 0; + float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; // w1lambda + float d_e = 1.f - d_w; // w2lambda + { + vx_w[l] = x_w; + vx_e[l] = x_e; + vd_w[l] = d_w; + vd_e[l] = d_e; + } + } + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + for (int l = 0; l < out_w; l++) { + // linear interpolation + T out_t; + if (data_layout == DataLayout::kNCHW) { + out_t = input_t(i, j, vx_w[l]) * vd_e[l] + + input_t(i, j, vx_e[l]) * vd_w[l]; + output_t(i, j, l) = out_t; + } else { + out_t = input_t(i, vx_w[l], j) * vd_e[l] + + input_t(i, vx_e[l], j) * vd_w[l]; + output_t(i, l, j) = out_t; + } + } + } + } +} + +template +static void LinearInterpolationGrad(const Tensor& output_grad, + Tensor* input_grad, const float ratio_w, + const int in_w, const int n, const int c, + const int out_w, const bool align_corners, + const int align_mode, + const DataLayout data_layout) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + bool align_flag = (align_mode == 0 && !align_corners); + for (int l = 0; l < out_w; l++) { + int x_w = align_flag ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); + x_w = (x_w > 0) ? x_w : 0; // w + int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w; // w_id + + float idx_src_x = ratio_w * (l + 0.5) - 0.5; + idx_src_x = (idx_src_x > 0) ? idx_src_x : 0; + float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; // w1lambda + float d_e = 1.f - d_w; // w2lambda + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + // linear interpolation grad + if (data_layout == DataLayout::kNCHW) { + const T grad = output_grad_t(i, j, l); + input_grad_t(i, j, x_w) += static_cast(grad * d_e); + input_grad_t(i, j, x_e) += static_cast(grad * d_w); + } else { + const T grad = output_grad_t(i, l, j); + input_grad_t(i, x_w, j) += static_cast(grad * d_e); + input_grad_t(i, x_e, j) += static_cast(grad * d_w); + } + } + } + } +} + template static void BilinearInterpolation(const Tensor& input, Tensor* output, const float ratio_h, const float ratio_w, @@ -666,6 +769,69 @@ static void BicubicInterpolationGrad(const Tensor& output_grad, } } +template +static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx, + const Tensor& input, Tensor* output) { + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_w = ctx.Attr("out_w"); + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape(list_new_size_tensor); + out_w = new_size[0]; + } else { + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + if (scale > 0) { + out_w = static_cast(in_w * scale); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_w = out_size_data[0]; + } + } + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + framework::DDim dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {n, c, out_w}; + } else { + dim_out = {n, out_w, c}; + } + output->mutable_data(dim_out, ctx.GetPlace()); + + if (in_w == out_w) { + framework::TensorCopy(input, ctx.GetPlace(), output); + return; + } + + float ratio_w = 0.f; + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + if ("linear" == interp_method) { + LinearInterpolation(input, output, ratio_w, in_w, n, c, out_w, + align_corners, align_mode, data_layout); + } +} + template static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, const Tensor& input, Tensor* output) { @@ -707,12 +873,12 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, out_w = out_size_data[1]; } } - PADDLE_ENFORCE_GT( - out_h, 0, - "out_h in Attr(out_shape) of Op(interpolate) should be greater than 0."); - PADDLE_ENFORCE_GT( - out_w, 0, - "out_w in Attr(out_shape) of Op(interpolate) should be greater than 0."); + PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( + "out_h in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); framework::DDim dim_out; if (data_layout == DataLayout::kNCHW) { dim_out = {n, c, out_h, out_w}; @@ -795,15 +961,15 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, out_w = out_size_data[2]; } } - PADDLE_ENFORCE_GT( - out_d, 0, - "out_d in Attr(out_shape) of Op(interpolate) should be greater than 0."); - PADDLE_ENFORCE_GT( - out_h, 0, - "out_h in Attr(out_shape) of Op(interpolate) should be greater than 0."); - PADDLE_ENFORCE_GT( - out_w, 0, - "out_w in Attr(out_shape) of Op(interpolate) should be greater than 0."); + PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument( + "out_d in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( + "out_h in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); framework::DDim dim_out; if (data_layout == DataLayout::kNCHW) { @@ -842,6 +1008,71 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, } } +template +static void Interpolate1DCPUBwd(const framework::ExecutionContext& ctx, + Tensor* input_grad, const Tensor& output_grad) { + auto* input = ctx.Input("X"); + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_w = ctx.Attr("out_w"); + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + if (scale > 0) { + out_w = static_cast(in_w * scale); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_w = out_size_data[0]; + } + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape(list_new_size_tensor); + out_w = new_size[0]; + } + + framework::DDim dim_grad; + if (data_layout == DataLayout::kNCHW) { + dim_grad = {n, c, in_w}; + } else { + dim_grad = {n, in_w, c}; + } + input_grad->mutable_data(dim_grad, ctx.GetPlace()); + + auto& device_ctx = ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + if (in_w == out_w) { + framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad); + return; + } + + float ratio_w = 0.f; + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + if ("linear" == interp_method) { + LinearInterpolationGrad(output_grad, input_grad, ratio_w, in_w, n, c, + out_w, align_corners, align_mode, data_layout); + } +} + template static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx, Tensor* input_grad, const Tensor& output_grad) { @@ -1018,7 +1249,9 @@ class InterpolateKernel : public framework::OpKernel { auto* output = ctx.Output("Out"); auto input_dims = input->dims(); - if (input_dims.size() == 4) { // 2D interpolation + if (input_dims.size() == 3) { // 1D interpolation + Interpolate1DCPUFwd(ctx, *input, output); + } else if (input_dims.size() == 4) { // 2D interpolation Interpolate2DCPUFwd(ctx, *input, output); } else if (input_dims.size() == 5) { // 3D interpolation Interpolate3DCPUFwd(ctx, *input, output); @@ -1034,7 +1267,9 @@ class InterpolateGradKernel : public framework::OpKernel { auto* output_grad = ctx.Input(framework::GradVarName("Out")); auto output_grad_dims = output_grad->dims(); - if (output_grad_dims.size() == 4) { // 2D interpolation grad + if (output_grad_dims.size() == 3) { // 1D interpolation grad + Interpolate1DCPUBwd(ctx, input_grad, *output_grad); + } else if (output_grad_dims.size() == 4) { // 2D interpolation grad Interpolate2DCPUBwd(ctx, input_grad, *output_grad); } else if (output_grad_dims.size() == 5) { // 3D interpolation grad Interpolate3DCPUBwd(ctx, input_grad, *output_grad); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 90eb51e792..7e3f458531 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -95,6 +95,7 @@ __all__ = [ 'dice_loss', 'image_resize', 'image_resize_short', + 'resize_linear', 'resize_bilinear', 'resize_trilinear', 'resize_nearest', @@ -6889,7 +6890,8 @@ def image_resize(input, """ This op resizes a batch of images. - The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w) + The input must be a 3-D Tensor of the shape (num_batches, channels, in_w) + or a 4-D Tensor of the shape (num_batches, channels, in_h, in_w) or (num_batches, in_h, in_w, channels), or a 5-D Tensor of the shape (num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels), and the resizing only applies on the three dimensions(depth, height and width). @@ -6898,13 +6900,17 @@ def image_resize(input, future and only use :attr:`out_shape` instead. Supporting resample methods: + 'LINEAR' : Linear interpolation 'BILINEAR' : Bilinear interpolation 'TRILINEAR' : Trilinear interpolation 'NEAREST' : Nearest neighbor interpolation - + + Linear interpolation is the method of using a line connecting two known quantities + to determine the value of an unknown quantity between the two known quantities. + Nearest neighbor interpolation is to perform nearest neighbor interpolation in both the 3rd dimension(in height direction) and the 4th dimension(in width direction) on input tensor. @@ -6958,6 +6964,23 @@ def image_resize(input, H_out = round(H_{in} * scale_{factor}) W_out = round(W_{in} * scale_{factor}) + linear interpolation: + + if: + align_corners = False , align_mode = 0 + + input : (N,C,W_in) + output: (N,C,W_out) where: + + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + else: + + input : (N,C,W_in) + output: (N,C,H_out,W_out) where: + + W_out = W_{in} * scale_{factor} + Bilinear interpolation: if: @@ -7061,15 +7084,17 @@ def image_resize(input, TypeError: actual_shape should either be Variable or None. ValueError: The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR' or 'NEAREST' currently. + ValueError: 'LINEAR' only support 3-D tensor. ValueError: 'BILINEAR' and 'NEAREST' only support 4-D tensor. ValueError: 'TRILINEAR' only support 5-D tensor. ValueError: One of out_shape and scale must not be None. + ValueError: out_shape length should be 1 for input 3-D tensor. ValueError: out_shape length should be 2 for input 4-D tensor. ValueError: out_shape length should be 3 for input 5-D tensor. ValueError: scale should be greater than zero. TypeError: align_corners should be a bool value ValueError: align_mode can only be '0' or '1' - ValueError: data_format can only be 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'. + ValueError: data_format can only be 'NCW', 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'. Examples: .. code-block:: python @@ -7134,19 +7159,24 @@ def image_resize(input, """ resample_methods = { + 'LINEAR': 'linear', 'BILINEAR': 'bilinear', 'TRILINEAR': 'trilinear', 'NEAREST': 'nearest', + 'LINEAR': 'linear', } + resample = resample.upper() if resample not in resample_methods: raise ValueError( - "The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR' " + "The 'resample' of image_resize can only be 'LINEAR', 'BILINEAR', 'TRILINEAR' " "or 'NEAREST' currently.") resample_type = resample_methods[resample] - if resample in ['BILINEAR', 'NEAREST'] and len(input.shape) != 4: + if resample == 'LINEAR' and len(input.shape) != 3: + raise ValueError("'LINER only support 3-D tensor.") + elif resample in ['BILINEAR', 'NEAREST'] and len(input.shape) != 4: raise ValueError("'BILINEAR' and 'NEAREST' only support 4-D tensor.") - if resample == 'TRILINEAR' and len(input.shape) != 5: + elif resample == 'TRILINEAR' and len(input.shape) != 5: raise ValueError("'TRILINEAR'only support 5-D tensor.") if not isinstance(align_corners, bool): @@ -7159,7 +7189,11 @@ def image_resize(input, helper = LayerHelper('{}_interp'.format(resample_type), **locals()) dtype = helper.input_dtype() - if len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']: + if len(input.shape) == 3 and data_format not in ['NCHW', 'NHWC']: + raise ValueError( + "Got wrong value for param `data_format`: " + data_format + + " received but only `NCHW` or `NHWC` supported for 3-D input.") + elif len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']: raise ValueError( "Got wrong value for param `data_format`: " + data_format + " received but only `NCHW` or `NHWC` supported for 4-D input.") @@ -7223,7 +7257,16 @@ def image_resize(input, size_list.append(dim) inputs['SizeTensor'] = new_size_tensor - if len(input.shape) == 4: + if len(input.shape) == 3: + if len(out_shape) != 1: + raise ValueError("out_shape length should be 1 for " + "input 3-D tensor.") + if contain_var: + attrs['out_w'] = size_list[0] + else: + out_shape = list(map(int, out_shape)) + attrs['out_w'] = out_shape[0] + elif len(input.shape) == 4: if len(out_shape) != 2: raise ValueError("out_shape length should be 2 for " "input 4-D tensor.") @@ -7269,7 +7312,6 @@ def image_resize(input, inputs["OutSize"] = actual_shape elif actual_shape is not None: raise TypeError("actual_shape should either be Variable or None.") - out = helper.create_variable_for_type_inference(dtype) helper.append_op( type='{}_interp'.format(resample_type), @@ -7279,6 +7321,132 @@ def image_resize(input, return out +@templatedoc(op_type="linear_interp") +def resize_linear(input, + out_shape=None, + scale=None, + name=None, + actual_shape=None, + align_corners=True, + align_mode=1, + data_format='NCHW'): + """ + This op resizes the input by performing linear interpolation based on given + output shape which specified by actual_shape, out_shape and scale + in priority order. + + **Warning:** the parameter :attr:`actual_shape` will be deprecated in + the future and only use :attr:`out_shape` instead. + + Align_corners and align_mode are optional parameters,the calculation + method of interpolation can be selected by them. + + Example: + + .. code-block:: text + + For scale: + + if align_corners = True && out_size > 1 : + + scale_factor = (in_size-1.0)/(out_size-1.0) + + else: + + scale_factor = float(in_size/out_size) + + Linear interpolation: + + if: + align_corners = False , align_mode = 0 + + input : (N,C,W_in) + output: (N,C,W_out) where: + + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + else: + + input : (N,C,W_in) + output: (N,C,W_out) where: + W_out = W_{in} * scale_{factor} + + Parameters: + input(Variable): 3-D Tensor(NCHW), its data type is float32, float64, or uint8, + its data format is specified by :attr:`data_format`. + out_shape(list|tuple|Variable|None): Output shape of resize linear + layer, the shape is (out_w,). Default: None. If a list, each + element can be an integer or a Tensor Variable with shape: [1]. If a + Tensor Variable, its dimension size should be 1. + scale(float|Variable|None): The multiplier for the input height or width. At + least one of :attr:`out_shape` or :attr:`scale` must be set. + And :attr:`out_shape` has a higher priority than :attr:`scale`. + Default: None. + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + :attr:`out_shape` if you want to specify output + shape dynamically, because :attr:`actual_shape` + will be deprecated. When using actual_shape to + specify output shape, one of :attr:`out_shape` + and :attr:`scale` should also be set, otherwise + errors would be occurred in graph constructing stage. + Default: None + align_corners(bool): ${align_corners_comment} + align_mode(bool): ${align_mode_comment} + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Variable: 3-D tensor(NCHW or NHWC). + + Examples: + .. code-block:: python + + #declarative mode + import paddle.fluid as fluid + import numpy as np + input = fluid.data(name="input", shape=[None,3,100]) + + output = fluid.layers.resize_linear(input=input,out_shape=[50,]) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + input_data = np.random.rand(1,3,100).astype("float32") + + output_data = exe.run(fluid.default_main_program(), + feed={"input":input_data}, + fetch_list=[output], + return_numpy=True) + + print(output_data[0].shape) + + # (1, 3, 50) + + #imperative mode + import paddle.fluid.dygraph as dg + + with dg.guard(place) as g: + input = dg.to_variable(input_data) + output = fluid.layers.resize_linear(input=input, out_shape=[50,]) + print(output.shape) + + # [1L, 3L, 50L] + + """ + + return image_resize(input, out_shape, scale, name, 'LINEAR', actual_shape, + align_corners, align_mode, data_format) + + @templatedoc(op_type="bilinear_interp") def resize_bilinear(input, out_shape=None, diff --git a/python/paddle/fluid/tests/unittests/test_linear_interp_op.py b/python/paddle/fluid/tests/unittests/test_linear_interp_op.py new file mode 100755 index 0000000000..aea87de235 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_linear_interp_op.py @@ -0,0 +1,462 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import platform +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from paddle.nn.functional import * + + +def linear_interp_np(input, + out_w, + out_size=None, + actual_shape=None, + align_corners=True, + align_mode=0, + data_layout='NCHW'): + if data_layout == "NHWC": + input = np.transpose(input, (0, 2, 1)) # NHWC => NCHW + if out_size is not None: + out_w = out_size[0] + if actual_shape is not None: + out_w = actual_shape[0] + batch_size, channel, in_w = input.shape + + ratio_w = 0.0 + if out_w > 1: + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w + + out = np.zeros((batch_size, channel, out_w)) + + for j in range(out_w): + if (align_mode == 0 and not align_corners): + w = int(ratio_w * (j + 0.5) - 0.5) + else: + w = int(ratio_w * j) + w = max(0, w) + wid = 1 if w < in_w - 1 else 0 + + if (align_mode == 0 and not align_corners): + idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0) + w1lambda = idx_src_w - w + else: + w1lambda = ratio_w * j - w + w2lambda = 1.0 - w1lambda + + out[:, :, j] = w2lambda * input[:, :, w] + w1lambda * input[:, :, w + + wid] + + if data_layout == "NHWC": + out = np.transpose(out, (0, 2, 1)) # NCHW => NHWC + + return out.astype(input.dtype) + + +class TestLinearInterpOp(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.data_layout = 'NCHW' + self.init_test_case() + self.op_type = "linear_interp" + input_np = np.random.random(self.input_shape).astype("float64") + + if self.data_layout == "NCHW": + in_w = self.input_shape[2] + else: + in_w = self.input_shape[1] + + if self.scale > 0: + out_w = int(in_w * self.scale) + else: + out_w = self.out_w + + output_np = linear_interp_np(input_np, out_w, self.out_size, + self.actual_shape, self.align_corners, + self.align_mode, self.data_layout) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + + self.attrs = { + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode, + 'data_layout': self.data_layout + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + if platform.system() == "Linux": + self.check_output(atol=1e-7) + else: + self.check_output(atol=1e-5) + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'linear' + self.input_shape = [1, 3, 100] + self.out_w = 50 + self.scale = 0. + self.out_size = np.array([50, ]).astype("int32") + self.align_corners = False + self.align_mode = 1 + + +class TestLinearInterpOpDataLayout(TestLinearInterpOp): + def init_test_case(self): + self.interp_method = 'linear' + self.input_shape = [1, 3, 100] + self.out_w = 50 + self.scale = 0. + self.out_size = np.array([50, ]).astype("int32") + self.align_corners = False + self.align_mode = 1 + self.data_layout = 'NHWC' + + +class TestLinearInterpOpAlignMode(TestLinearInterpOp): + def init_test_case(self): + self.interp_method = 'linear' + self.input_shape = [1, 3, 100] + self.out_w = 50 + self.scale = 0. + self.out_size = np.array([50, ]).astype("int32") + self.align_corners = False + self.align_mode = 0 + + +class TestLinearInterpOpScale(TestLinearInterpOp): + def init_test_case(self): + self.interp_method = 'linear' + self.input_shape = [1, 3, 100] + self.out_w = 50 + self.scale = 0.5 + self.out_size = np.array([50, ]).astype("int32") + self.align_corners = False + self.align_mode = 0 + + +class TestLinearInterpOpSizeTensor(TestLinearInterpOp): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.data_layout = 'NCHW' + self.init_test_case() + self.op_type = "linear_interp" + input_np = np.random.random(self.input_shape).astype("float64") + self.shape_by_1Dtensor = False + self.scale_by_1Dtensor = False + + if self.data_layout == "NCHW": + in_w = self.input_shape[2] + else: + in_w = self.input_shape[1] + + if self.scale > 0: + out_w = int(in_w * self.scale) + else: + out_w = self.out_w + + output_np = linear_interp_np(input_np, out_w, self.out_size, + self.actual_shape, self.align_corners, + self.align_mode, self.data_layout) + + self.inputs = {'X': input_np} + if self.out_size is not None and self.shape_by_1Dtensor: + self.inputs['OutSize'] = self.out_size + elif self.actual_shape is not None and self.shape_by_1Dtensor: + self.inputs['OutSize'] = self.actual_shape + else: + size_tensor = [] + for index, ele in enumerate(self.out_size): + size_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.inputs['SizeTensor'] = size_tensor + + self.attrs = { + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode, + 'data_layout': self.data_layout + } + self.outputs = {'Out': output_np} + + +class TestLinearInterpOpAPI(unittest.TestCase): + def test_case(self): + x = fluid.data(name="x", shape=[1, 3, 128], dtype="float32") + shape_tensor = fluid.data(name="shape_tensor", shape=[1], dtype="int32") + scale_tensor = fluid.data( + name="scale_tensor", shape=[1], dtype="float32") + dim = fluid.data(name="dim", shape=[1], dtype="int32") + actual_size = fluid.data(name='actual_size', shape=[1], dtype='int32') + + out1 = fluid.layers.resize_linear( + x, out_shape=[256, ], align_mode=1, align_corners=False) + out2 = fluid.layers.resize_linear( + x, out_shape=shape_tensor, align_mode=1, align_corners=False) + out3 = fluid.layers.resize_linear( + x, scale=scale_tensor, align_mode=1, align_corners=False) + out4 = fluid.layers.resize_linear( + x, out_shape=[dim, ], align_mode=1, align_corners=False) + out5 = fluid.layers.resize_linear( + x, + out_shape=[256, ], + actual_shape=actual_size, + align_mode=1, + align_corners=False) + + x_data = np.random.random((1, 3, 128)).astype("float32") + shape_data = np.array([256, ]).astype("int32") + scale_data = np.array([2.0, ]).astype("float32") + dim_data = np.array([256, ]).astype("int32") + actual_size_data = np.array([256, ]).astype("int32") + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "shape_tensor": shape_data, + "scale_tensor": scale_data, + "dim": dim_data, + 'actual_size': actual_size_data, + }, + fetch_list=[out1, out2, out3, out4, out5], + return_numpy=True) + + expect_res = linear_interp_np( + x_data, out_w=256, align_mode=1, align_corners=False) + + for res in results: + self.assertTrue(np.allclose(res, expect_res)) + + +class TestLinearInterpOpAPI2_Func(unittest.TestCase): + def test_case(self): + x = fluid.data(name="x", shape=[1, 3, 128], dtype="float32") + shape_tensor = fluid.data(name="shape_tensor", shape=[1], dtype="int32") + scale_tensor = fluid.data( + name="scale_tensor", shape=[1], dtype="float32") + dim = fluid.data(name="dim", shape=[1], dtype="int32") + actual_size = fluid.data(name='actual_size', shape=[1], dtype='int32') + + out1 = interpolate( + x, + out_shape=[256, ], + align_mode=1, + align_corners=False, + resample='LINEAR') + out2 = interpolate( + x, + out_shape=shape_tensor, + align_mode=1, + align_corners=False, + resample='LINEAR') + out3 = interpolate( + x, + scale=scale_tensor, + align_mode=1, + align_corners=False, + resample='LINEAR') + out4 = interpolate( + x, + out_shape=[dim, ], + align_mode=1, + align_corners=False, + resample='LINEAR') + out5 = interpolate( + x, + out_shape=[256, ], + actual_shape=actual_size, + align_mode=1, + align_corners=False, + resample='LINEAR') + + x_data = np.random.random((1, 3, 128)).astype("float32") + shape_data = np.array([256, ]).astype("int32") + scale_data = np.array([2.0, ]).astype("float32") + dim_data = np.array([256, ]).astype("int32") + actual_size_data = np.array([256, ]).astype("int32") + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "shape_tensor": shape_data, + "scale_tensor": scale_data, + "dim": dim_data, + 'actual_size': actual_size_data, + }, + fetch_list=[out1, out2, out3, out4, out5], + return_numpy=True) + + expect_res = linear_interp_np( + x_data, out_w=256, align_mode=1, align_corners=False) + + for res in results: + self.assertTrue(np.allclose(res, expect_res)) + + +class TestLinearInterpOpAPI2_0(unittest.TestCase): + def test_case(self): + + # dygraph + x_data = np.random.random((1, 3, 128)).astype("float32") + us_1 = paddle.nn.UpSample( + out_shape=[64, ], + resample='LINEAR', + align_mode=1, + align_corners=False) + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(x_data) + interp = us_1(x) + + expect = linear_interp_np( + x_data, out_w=64, align_mode=1, align_corners=False) + + self.assertTrue(np.allclose(interp.numpy(), expect)) + + +class TestLinearInterpOpUint8(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "linear_interp" + input_np = np.random.random(self.input_shape).astype("uint8") + + if self.scale > 0: + out_w = int(self.input_shape[3] * self.scale) + else: + out_w = self.out_w + + output_np = linear_interp_np(input_np, out_w, self.out_size, + self.actual_shape, self.align_corners, + self.align_mode) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + + self.attrs = { + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + if platform.system() == "Linux": + self.check_output_with_place(place=core.CPUPlace(), atol=1e-7) + else: + self.check_output_with_place(place=core.CPUPlace(), atol=1e-5) + + def init_test_case(self): + self.interp_method = 'linear' + self.input_shape = [2, 3, 100] + self.out_w = 50 + self.scale = 0. + self.out_size = np.array([50, ]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +class TestLinearInterpOpException(unittest.TestCase): + def test_exception(self): + def input_shape_error(): + x1 = fluid.data(name="x1", shape=[1], dtype="float32") + out = fluid.layers.resize_linear( + x1, out_shape=[256, ], data_format='NCW') + + def data_format_error(): + x2 = fluid.data(name="x2", shape=[1, 3, 128], dtype="float32") + out = fluid.layers.resize_linear( + x2, out_shape=[256, ], data_format='NHWCD') + + def out_shape_error(): + x3 = fluid.data(name="x3", shape=[1, 3, 128], dtype="float32") + out = fluid.layers.resize_linear( + x3, out_shape=[ + 256, + 256, + ], data_format='NHWC') + + self.assertRaises(ValueError, input_shape_error) + self.assertRaises(ValueError, data_format_error) + self.assertRaises(ValueError, out_shape_error) + + +class TestLinearInterpOpError(unittest.TestCase): + def test_error(self): + with program_guard(Program(), Program()): + + def input_shape_error(): + x1 = fluid.data(name="x1", shape=[1], dtype="float32") + out1 = paddle.nn.UpSample( + out_shape=[256, ], data_format='NCW', resample='LINEAR') + out1_res = out1(x1) + + def data_format_error(): + x2 = fluid.data(name="x2", shape=[1, 3, 128], dtype="float32") + out2 = paddle.nn.UpSample( + out_shape=[256, ], data_format='NHWCD', resample='LINEAR') + out2_res = out2(x2) + + def out_shape_error(): + x3 = fluid.data(name="x3", shape=[1, 3, 128], dtype="float32") + out3 = paddle.nn.UpSample( + out_shape=[ + 256, + 256, + ], + data_format='NHWC', + resample='LINEAR') + out3_res = out3(x3) + + self.assertRaises(ValueError, input_shape_error) + self.assertRaises(ValueError, data_format_error) + self.assertRaises(ValueError, out_shape_error) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index c164bb7829..1c4d353bf8 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -17,10 +17,12 @@ from .layer import norm from .functional import extension +from .layer import common __all__ = [] __all__ += norm.__all__ __all__ += extension.__all__ +__all__ += common.__all__ # TODO: define alias in nn directory # from .clip import ErrorClipByValue #DEFINE_ALIAS @@ -64,7 +66,7 @@ from .layer.common import BilinearTensorProduct #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS -# from .layer.common import UpSample #DEFINE_ALIAS +from .layer.common import UpSample #DEFINE_ALIAS from .layer.conv import Conv2D #DEFINE_ALIAS from .layer.conv import Conv2DTranspose #DEFINE_ALIAS from .layer.conv import Conv3D #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index f33c317681..154ca7a875 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -284,6 +284,7 @@ def interpolate(input, # [2L, 3L, 12L, 12L] """ resample_methods = { + 'LINEAR': 'linear', 'BILINEAR': 'bilinear', 'TRILINEAR': 'trilinear', 'NEAREST': 'nearest', @@ -291,10 +292,13 @@ def interpolate(input, } if resample not in resample_methods: raise ValueError( - "The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR', " + "The 'resample' of image_resize can only be 'LINEAR', 'BILINEAR', 'TRILINEAR', " " 'BICUBIC' or 'NEAREST' currently.") resample_type = resample_methods[resample] + if resample in ['LINEAR'] and len(input.shape) != 3: + raise ValueError("'LINEAR' only support 3-D tensor.") + if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(input.shape) != 4: raise ValueError( "'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.") @@ -311,7 +315,11 @@ def interpolate(input, helper = LayerHelper('{}_interp'.format(resample_type), **locals()) dtype = helper.input_dtype() - if len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']: + if len(input.shape) == 3 and data_format not in ['NCHW', 'NHWC']: + raise ValueError( + "Got wrong value for param `data_format`: " + data_format + + " received but only `NCHW` or `NHWC` supported for 3-D input.") + elif len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']: raise ValueError( "Got wrong value for param `data_format`: " + data_format + " received but only `NCHW` or `NHWC` supported for 4-D input.") @@ -375,6 +383,15 @@ def interpolate(input, size_list.append(dim) inputs['SizeTensor'] = new_size_tensor + if len(input.shape) == 3: + if len(out_shape) != 1: + raise ValueError( + "out_shape length should be 2 for input 3-D tensor") + if contain_var: + attrs['out_w'] = size_list[0] + else: + out_shape = list(map(int, out_shape)) + attrs['out_w'] = out_shape[0] if len(input.shape) == 4: if len(out_shape) != 2: raise ValueError("out_shape length should be 2 for " diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 90bb6d0ada..cac6afd615 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -37,7 +37,7 @@ from .common import BilinearTensorProduct #DEFINE_ALIAS from .common import Pool2D #DEFINE_ALIAS from .common import Embedding #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS -# from .common import UpSample #DEFINE_ALIAS +from .common import UpSample #DEFINE_ALIAS from .conv import Conv2D #DEFINE_ALIAS from .conv import Conv2DTranspose #DEFINE_ALIAS from .conv import Conv3D #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 02abf891ff..3fb4f43ab4 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -17,11 +17,232 @@ from ...fluid.dygraph import BilinearTensorProduct #DEFINE_ALIAS from ...fluid.dygraph import Pool2D #DEFINE_ALIAS from ...fluid.dygraph import Embedding #DEFINE_ALIAS from ...fluid.dygraph import Linear #DEFINE_ALIAS +from ...fluid.dygraph import layers +from .. import functional as F -__all__ = [ - 'BilinearTensorProduct', - 'Pool2D', - 'Embedding', - 'Linear', - # 'UpSample' -] +__all__ = ['BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample'] + + +class UpSample(layers.Layer): + """ + This op resizes a batch of images. + The input must be a 3-D Tensor of the shape (num_batches, channels, in_w) + or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape + (num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels), + and the resizing only applies on the three dimensions(depth, height and width). + **Warning:** the parameter :attr:`actual_shape` will be deprecated in the + future and only use :attr:`out_shape` instead. + Supporting resample methods: + 'LINEAR' : linear interpolation + 'BILINEAR' : Bilinear interpolation + 'TRILINEAR' : Trilinear interpolation + 'NEAREST' : Nearest neighbor interpolation + 'BICUBIC' : Bicubic interpolation + + Linear interpolation is the method of using a line connecting two known quantities + to determine the value of an unknown quantity between the two known quantities. + + Nearest neighbor interpolation is to perform nearest neighbor interpolation + in both the 3rd dimension(in height direction) and the 4th dimension(in width + direction) on input tensor. + + Bilinear interpolation is an extension of linear interpolation for + interpolating functions of two variables (e.g. H-direction and + W-direction in this op) on a rectilinear 2D grid. The key idea is + to perform linear interpolation first in one direction, and then + again in the other direction. + + Trilinear interpolation is an extension of linear interpolation for + interpolating functions of three variables (e.g. D-direction, + H-direction and W-direction in this op) on a rectilinear 3D grid. + The linear interpolation is performed on three directions. + Align_corners and align_mode are optional parameters,the calculation method + of interpolation can be selected by them. + + Bicubic interpolation is an extension of cubic interpolation for interpolating + data points on a two-dimensional regular grid. The interpolated surface is + smoother than corresponding surfaces obtained by bilinear interpolation or + nearest-neighbor interpolation. + + Example: + + .. code-block:: text + + For scale: + + if align_corners = True && out_size > 1 : + scale_factor = (in_size-1.0)/(out_size-1.0) + + else: + + scale_factor = float(in_size/out_size) + + + Nearest neighbor interpolation: + if: + align_corners = False + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = floor (H_{in} * scale_{factor}) + W_out = floor (W_{in} * scale_{factor}) + else: + align_corners = True + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = round(H_{in} * scale_{factor}) + W_out = round(W_{in} * scale_{factor}) + + Linear interpolation: + if: + align_corners = False , align_mode = 0 + input : (N,C,W_in) + output: (N,C,W_out) where: + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + else: + input : (N,C,W_in) + output: (N,C,W_out) where: + W_out = W_{in} * scale_{factor} + + Bilinear interpolation: + if: + align_corners = False , align_mode = 0 + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + else: + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + Bicubic interpolation: + if: + align_corners = False + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + else: + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + Trilinear interpolation: + if: + align_corners = False , align_mode = 0 + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + D_out = (D_{in}+0.5) * scale_{factor} - 0.5 + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + else: + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + D_out = D_{in} * scale_{factor} + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + For details of nearest neighbor interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. + For details of bilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Linear_interpolation. + For details of bilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bilinear_interpolation. + For details of trilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Trilinear_interpolation. + For details of bicubic interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bicubic_interpolation + + Parameters: + input (Variable): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8, + its data format is specified by :attr:`data_format`. + out_shape(list|tuple|Variable|None): Output shape of image resize + layer, the shape is (out_w, ) when input is 3-D Tensor , + the shape is (out_h, out_w) when input is a 4-D Tensor and is + (out_d, out_h, out_w) when input is a 5-D Tensor. Default: None. If + a list, each element can be an integer or a Tensor Variable of shape: [1]. + If a Tensor Variable, its dimensions size should be a 1. + scale(float|Variable|None): The multiplier for the input height or width. At + least one of :attr:`out_shape` or :attr:`scale` must be set. + And :attr:`out_shape` has a higher priority than :attr:`scale`. + Default: None. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + resample(str): The resample method. It supports 'LINEAR', 'BILINEAR', 'TRILINEAR' , + 'BICUBIC' and 'NEAREST' currently. Default: 'BILINEAR' + align_corners(bool) : An optional bool, If True, the centers of the 4 corner pixels of the + input and output tensors are aligned, preserving the values at the + corner pixels. + Default: True + align_mode(int) : An optional for bilinear interpolation. can be \'0\' + for src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for + src_idx = scale*dst_index. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from:'NCW', `"NCHW"`, `"NHWC"`, `"NCDHW"`, + `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored + in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`. + + Returns: + A 3-D Tensor of the shape (num_batches, channels, out_w) or (num_batches, out_w, channels), + A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels), + or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels). + + Raises: + TypeError: out_shape should be a list or tuple or Variable. + TypeError: actual_shape should either be Variable or None. + ValueError: The 'resample' of image_resize can only be 'BILINEAR', + 'TRILINEAR', 'BICUBIC', or 'NEAREST' currently. + ValueError: 'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor. + ValueError: 'TRILINEAR' only support 5-D tensor. + ValueError: One of out_shape and scale must not be None. + ValueError: out_shape length should be 2 for input 4-D tensor. + ValueError: out_shape length should be 3 for input 5-D tensor. + ValueError: scale should be greater than zero. + TypeError: align_corners should be a bool value + ValueError: align_mode can only be '0' or '1' + ValueError: data_format can only be 'NCW', 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'. + + Examples: + .. code-block:: python + import paddle + import numpy as np + import paddle.fluid.dygraph as dg + upsample_op = paddle.nn.UpSample(out_shape=[12,12]) + input_data = np.random.rand(2,3,6,10).astype("float32") + place = paddle.fluid.CPUPlace() + with dg.guard(place) as g: + input = dg.to_variable(input_data) + output = upsample_op(input=input) + print(output.shape) + # [2L, 3L, 12L, 12L] + """ + + def __init__(self, + out_shape=None, + scale=None, + resample='BILINEAR', + align_corners=True, + align_mode=1, + data_format='NCHW'): + super(UpSample, self).__init__() + self.out_shape = out_shape + self.scale = scale + self.resample = resample + self.align_corners = align_corners + self.align_mode = align_mode + self.data_format = data_format + + def forward(self, input): + out = F.interpolate( + input, + out_shape=self.out_shape, + scale=self.scale, + resample=self.resample, + align_corners=self.align_corners, + align_mode=self.align_mode, + data_format=self.data_format) + + return out -- GitLab