diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 88a2c740e08f7c2c6ac831c65a0ef992064b3a61..f58131e75bebc5633ae01299a1f2084b63e0c8b9 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -118,9 +118,10 @@ paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) -paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR')) +paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None)) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) -paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) +paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None)) +paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) diff --git a/paddle/fluid/operators/bilinear_interp_op.cu b/paddle/fluid/operators/bilinear_interp_op.cu deleted file mode 100644 index 4c1971538495c6f111e9db18f4014786f6f0dd58..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/bilinear_interp_op.cu +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/bilinear_interp_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -template -__global__ void KeBilinearInterpFw( - const T* in, const size_t in_img_h, const size_t in_img_w, - const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratioW) { - int nthreads = output_h * output_w; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { - int out_id_h = tid / output_w; - int out_id_w = tid % output_w; - int in_img_size = input_w / num_channels; - int out_img_size = output_w / num_channels; - int channel_id = out_id_w / out_img_size; - - int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = ratio_h * out_img_idy; - int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = ratio_h * out_img_idy - in_img_idy; - T h2lambda = 1.f - h1lambda; - - int out_img_idx = tid % out_img_w; - int in_img_idx = ratioW * out_img_idx; - int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = ratioW * out_img_idx - in_img_idx; - T w2lambda = 1.f - w1lambda; - - const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; - - // bilinear interpolation - out[out_id_h * output_w + out_id_w] = - h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + - h1lambda * (w2lambda * in_pos[h_id * in_img_w] + - w1lambda * in_pos[h_id * in_img_w + w_id]); - } -} - -template -__global__ void KeBilinearInterpBw( - T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, - const size_t input_w, const T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratioW) { - int nthreads = output_h * output_w; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { - int out_id_h = tid / output_w; - int out_id_w = tid % output_w; - int in_img_size = input_w / num_channels; - int out_img_size = output_w / num_channels; - int channel_id = out_id_w / out_img_size; - - int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = ratio_h * out_img_idy; - int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = ratio_h * out_img_idy - in_img_idy; - T h2lambda = 1.f - h1lambda; - - int out_img_idx = tid % out_img_w; - int in_img_idx = ratioW * out_img_idx; - int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = ratioW * out_img_idx - in_img_idx; - T w2lambda = 1.f - w1lambda; - - T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; - const T* out_pos = &out[out_id_h * output_w + out_id_w]; - atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); - atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); - atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]); - atomicAdd(&in_pos[h_id * in_img_w + w_id], - h1lambda * w1lambda * out_pos[0]); - } -} - -template -class BilinearInterpOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); - auto* input_t = ctx.Input("X"); // float tensor - auto* output_t = ctx.Output("Out"); // float tensor - auto* input = input_t->data(); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - auto out_dims = output_t->dims(); - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { - Tensor sizes; - framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); - auto size_data = sizes.data(); - out_h = size_data[0]; - out_w = size_data[1]; - } - auto* output = output_t->mutable_data( - {out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace()); - - int batch_size = input_t->dims()[0]; - int channels = input_t->dims()[1]; - int in_h = input_t->dims()[2]; - int in_w = input_t->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; - - T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(output, input, input_t->numel() * sizeof(T)); - } else { - int threadNum = batch_size * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; - - KeBilinearInterpFw< - T><<>>( - input, in_h, in_w, batch_size, in_chw, output, out_h, out_w, - batch_size, out_chw, channels, ratio_h, ratio_w); - } - } -}; - -template -class BilinearInterpGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_input_t = ctx.Output(framework::GradVarName("X")); - auto* d_output_t = ctx.Input(framework::GradVarName("Out")); - auto* d_output = d_output_t->data(); - auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); - - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, d_input_t, static_cast(0.0)); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { - Tensor sizes; - framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); - auto size_data = sizes.data(); - out_h = size_data[0]; - out_w = size_data[1]; - } - - int batch_size = d_input_t->dims()[0]; - int channels = d_input_t->dims()[1]; - int in_h = d_input_t->dims()[2]; - int in_w = d_input_t->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; - - T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); - } else { - int threadNum = batch_size * out_chw; - int blocks = (threadNum + 1024 - 1) / 1024; - - KeBilinearInterpBw< - T><<>>( - d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w, - batch_size, out_chw, channels, ratio_h, ratio_w); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(bilinear_interp, - ops::BilinearInterpOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad, - ops::BilinearInterpGradOpCUDAKernel); diff --git a/paddle/fluid/operators/bilinear_interp_op.h b/paddle/fluid/operators/bilinear_interp_op.h deleted file mode 100644 index 70847cb8c1abe2e94bc844ab8117d1f23fea533b..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/bilinear_interp_op.h +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class BilinearInterpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input_t = ctx.Input("X"); // float tensor - auto* output_t = ctx.Output("Out"); // float tensor - auto out_dims = output_t->dims(); - auto* input = input_t->data(); - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { - auto out_size_data = out_size_t->data(); - out_h = out_size_data[0]; - out_w = out_size_data[1]; - } - auto* output = output_t->mutable_data( - {out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace()); - int batch_size = input_t->dims()[0]; - int channels = input_t->dims()[1]; - int in_h = input_t->dims()[2]; - int in_w = input_t->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; - - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(output, input, input_t->numel() * sizeof(T)); - } else { - for (int k = 0; k < batch_size; ++k) { // loop for batches - for (int i = 0; i < out_h; ++i) { // loop for images - int h = ratio_h * i; - int hid = (h < in_h - 1) ? 1 : 0; - float h1lambda = ratio_h * i - h; - float h2lambda = 1.f - h1lambda; - - for (int j = 0; j < out_w; ++j) { - int w = ratio_w * j; - int wid = (w < in_w - 1) ? 1 : 0; - float w1lambda = ratio_w * j - w; - float w2lambda = 1.f - w1lambda; - // calculate four position for bilinear interpolation - const T* in_pos = &input[k * in_chw + h * in_w + w]; - T* out_pos = &output[k * out_chw + i * out_w + j]; - - for (int c = 0; c < channels; ++c) { // loop for channels - // bilinear interpolation - out_pos[0] = static_cast( - h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) + - h1lambda * (w2lambda * in_pos[hid * in_w] + - w1lambda * in_pos[hid * in_w + wid])); - in_pos += in_hw; - out_pos += out_hw; - } - } - } - } - } - } -}; - -template -class BilinearInterpGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_input_t = ctx.Output(framework::GradVarName("X")); - auto* d_output_t = ctx.Input(framework::GradVarName("Out")); - auto* d_output = d_output_t->data(); - auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, d_input_t, static_cast(0.0)); - - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - - auto out_size_t = ctx.Input("OutSize"); - if (out_size_t != nullptr) { - auto out_size_data = out_size_t->data(); - out_h = out_size_data[0]; - out_w = out_size_data[1]; - } - - int batch_size = d_input_t->dims()[0]; - int channels = d_input_t->dims()[1]; - int in_h = d_input_t->dims()[2]; - int in_w = d_input_t->dims()[3]; - - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = channels * in_hw; - int out_chw = channels * out_hw; - - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; - - if (in_h == out_h && in_w == out_w) { - memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); - } else { - for (int k = 0; k < batch_size; ++k) { // loop for batches - for (int i = 0; i < out_h; ++i) { // loop for images - int h = ratio_h * i; - int hid = (h < in_h - 1) ? 1 : 0; - float h1lambda = ratio_h * i - h; - float h2lambda = 1 - h1lambda; - - for (int j = 0; j < out_w; ++j) { - int w = ratio_w * j; - int wid = (w < in_w - 1) ? 1 : 0; - float w1lambda = ratio_w * j - w; - float w2lambda = 1 - w1lambda; - T* in_pos = &d_input[k * in_chw + h * in_w + w]; - const T* out_pos = &d_output[k * out_chw + i * out_w + j]; - - for (int c = 0; c < channels; ++c) { // loop for channels - in_pos[0] += static_cast(h2lambda * w2lambda * out_pos[0]); - in_pos[wid] += static_cast(h2lambda * w1lambda * out_pos[0]); - in_pos[hid * in_w] += - static_cast(h1lambda * w2lambda * out_pos[0]); - in_pos[hid * in_w + wid] += - static_cast(h1lambda * w1lambda * out_pos[0]); - in_pos += in_hw; - out_pos += out_hw; - } - } - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/bilinear_interp_op.cc b/paddle/fluid/operators/interpolate_op.cc similarity index 52% rename from paddle/fluid/operators/bilinear_interp_op.cc rename to paddle/fluid/operators/interpolate_op.cc index 2dc3399da183fbcf7664066f6f7ce12db3dc6d5e..8f979e05d31e5a85bc86784943f4588ab650f668 100644 --- a/paddle/fluid/operators/bilinear_interp_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -9,7 +9,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/bilinear_interp_op.h" +#include "paddle/fluid/operators/interpolate_op.h" +#include #include #include "paddle/fluid/framework/op_registry.h" @@ -18,27 +19,34 @@ namespace operators { using framework::Tensor; -class BilinearInterpOp : public framework::OperatorWithKernel { +class InterpolateOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of BilinearInterOp should not be null."); + "Input(X) of InterpolateOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of BilinearInterOp should not be null."); + "Output(Out) of InterpolationOp should not be null."); + + auto interp_method = ctx->Attrs().Get("interp_method"); + PADDLE_ENFORCE( + "bilinear" == interp_method || "nearest" == interp_method, + "Interpolation method can only be \"bilinear\" or \"nearest\"."); auto dim_x = ctx->GetInputDim("X"); // NCHW format int out_h = ctx->Attrs().Get("out_h"); int out_w = ctx->Attrs().Get("out_w"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); - if (ctx->HasInput("OutSize")) { + if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { auto out_size_dim = ctx->GetInputDim("OutSize"); PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, "OutSize's dimension size must be 1"); PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2"); + ctx->ShareLoD("X", "Out"); + return; } std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); @@ -52,35 +60,53 @@ class BilinearInterpOp : public framework::OperatorWithKernel { } }; -class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { +class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "The input tensor of bilinear interpolation, " - "This is a 4-D tensor with shape of (N x C x h x w)"); + "The input tensor of interpolate operator, " + "This is a 4-D tensor with shape of [N, C, H, w]."); AddInput("OutSize", - "This is a 1-D tensor with two number. " + "This is a 1-D tensor with two numbers to specify output size. " "The first number is height and the second number is width.") .AsDispensable(); - AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)"); + AddOutput("Out", + "The output tensor of interpolate operator, " + "This is a 4-D tensor with shape of [N, C, H, W]."); - AddAttr("out_h", "output height of bilinear interpolation op."); - AddAttr("out_w", "output width of bilinear interpolation op."); + AddAttr("out_h", "output height of interpolate op."); + AddAttr("out_w", "output width of interpolate op."); + AddAttr( + "interp_method", + "(string), interpolation method, can be \"bilinear\" for " + "bilinear interpolation and \"nearest\" for nearest " + "neighbor interpolation."); AddComment(R"DOC( + This operator samples input X to given output shape by using specified + interpolation method, the interpolation methods can be \"nearest\" + for nearest neighbor interpolation and \"bilinear\" for bilinear + interpolation. + + Nearest neighbor interpolation is to perform nearest neighbor interpolation + in both the 3rd dimention(in height direction) and the 4th dimention(in width + direction) on input tensor. + Bilinear interpolation is an extension of linear interpolation for interpolating functions of two variables (e.g. H-direction and - W-direction in this op) on a rectilinear 2D grid. - - The key idea is to perform linear interpolation first in one - direction, and then again in the other direction. - - For details, please refer to Wikipedia: + W-direction in this op) on a rectilinear 2D grid. The key idea is + to perform linear interpolation first in one direction, and then + again in the other direction. + + For details of nearest neighbor interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation + + For details of bilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation )DOC"); } }; -class BilinearInterpOpGrad : public framework::OperatorWithKernel { +class InterpolateOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -106,11 +132,11 @@ class BilinearInterpOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp, - ops::BilinearInterpOpMaker, +REGISTER_OPERATOR(interpolate, ops::InterpolateOp, ops::InterpolateOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad); -REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel, - ops::BilinearInterpKernel); -REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, - ops::BilinearInterpGradKernel); +REGISTER_OPERATOR(interpolate_grad, ops::InterpolateOpGrad); +REGISTER_OP_CPU_KERNEL(interpolate, ops::InterpolateKernel, + ops::InterpolateKernel, + ops::InterpolateKernel); +REGISTER_OP_CPU_KERNEL(interpolate_grad, ops::InterpolateGradKernel, + ops::InterpolateGradKernel); diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..190afbdac431f863c32e2a4a4b3ad83848e550fc --- /dev/null +++ b/paddle/fluid/operators/interpolate_op.cu @@ -0,0 +1,292 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/fluid/operators/interpolate_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +__global__ void KeNearestNeighborInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = static_cast(ratio_h * out_img_idy + 0.5); + + int out_img_idx = tid % out_img_w; + int in_img_idx = static_cast(ratio_w * out_img_idx + 0.5); + + out[tid] = in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + } +} + +template +__global__ void KeNearestNeighborInterpBw( + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, + const size_t input_w, const T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = static_cast(ratio_h * out_img_idy + 0.5); + + int out_img_idx = tid % out_img_w; + int in_img_idx = static_cast(ratio_w * out_img_idx + 0.5); + + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + const T out_pos = out[out_id_h * output_w + out_id_w]; + platform::CudaAtomicAdd(in_pos, out_pos); + } +} + +template +__global__ void KeBilinearInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = ratio_h * out_img_idy; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = ratio_w * out_img_idx; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T w1lambda = ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + + w1lambda * in_pos[h_id * in_img_w + w_id]); + } +} + +template +__global__ void KeBilinearInterpBw( + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, + const size_t input_w, const T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const T ratio_h, const T ratio_w) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = ratio_h * out_img_idy; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = ratio_w * out_img_idx; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T w1lambda = ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + const T* out_pos = &out[out_id_h * output_w + out_id_w]; + platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos[h_id * in_img_w], + h1lambda * w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id], + h1lambda * w1lambda * out_pos[0]); + } +} + +template +class InterpolateOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* input_data = input->data(); + + auto interp_method = ctx.Attr("interp_method"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + + int n = input->dims()[0]; + int c = input->dims()[1]; + int in_h = input->dims()[2]; + int in_w = input->dims()[3]; + + auto* output_data = + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*input, ctx.GetPlace(), output); + return; + } + + int pixelNum = n * out_chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + if ("nearest" == interp_method) { + KeNearestNeighborInterpFw< + T><<>>( + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w); + } else if ("bilinear" == interp_method) { + KeBilinearInterpFw< + T><<>>( + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w); + } + } +}; + +template +class InterpolateGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + auto* output_grad_data = output_grad->data(); + auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + auto interp_method = ctx.Attr("interp_method"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + + int n = input_grad->dims()[0]; + int c = input_grad->dims()[1]; + int in_h = input_grad->dims()[2]; + int in_w = input_grad->dims()[3]; + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); + return; + } + + int pixelNum = n * out_chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + if ("nearest" == interp_method) { + KeNearestNeighborInterpBw< + T><<>>( + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, + out_w, n, out_chw, c, ratio_h, ratio_w); + } else if ("bilinear" == interp_method) { + KeBilinearInterpBw< + T><<>>( + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, + out_w, n, out_chw, c, ratio_h, ratio_w); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(interpolate, ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(interpolate_grad, + ops::InterpolateGradOpCUDAKernel, + ops::InterpolateGradOpCUDAKernel); diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7fdb3e1f5a2ff82284d89dd0759e357978e1d873 --- /dev/null +++ b/paddle/fluid/operators/interpolate_op.h @@ -0,0 +1,236 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +using EigenTensor = framework::EigenTensor; +using Tensor = framework::Tensor; + +template +static void NearestNeighborInterpolate(const Tensor& input, Tensor* output, + const float ratio_h, const float ratio_w, + const int n, const int c, + const int out_h, const int out_w) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = static_cast(ratio_h * k + 0.5); + + for (int l = 0; l < out_w; l++) { + int in_l = static_cast(ratio_w * l + 0.5); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + output_t(i, j, k, l) = input_t(i, j, in_k, in_l); + } + } + } + } +} + +template +static void BilinearInterpolation(const Tensor& input, Tensor* output, + const float ratio_h, const float ratio_w, + const int in_h, const int in_w, const int n, + const int c, const int out_h, + const int out_w) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + for (int k = 0; k < out_h; k++) { // loop for images + int y_n = static_cast(ratio_h * k); + int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); + float d_n = ratio_h * k - y_n; + float d_s = 1.f - d_n; + + for (int l = 0; l < out_w; l++) { + int x_w = static_cast(ratio_w * l); + int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); + float d_w = ratio_w * l - x_w; + float d_e = 1.f - d_w; + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + // bilinear interpolation + output_t(i, j, k, l) = input_t(i, j, y_n, x_w) * d_s * d_e + + input_t(i, j, y_s, x_w) * d_n * d_e + + input_t(i, j, y_n, x_e) * d_s * d_w + + input_t(i, j, y_s, x_e) * d_n * d_w; + } + } + } + } +} + +template +static void NearestNeighborInterpolateGrad(const Tensor& output_grad, + Tensor* input_grad, + const float ratio_h, + const float ratio_w, const int n, + const int c, const int out_h, + const int out_w) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = static_cast(ratio_h * k + 0.5); + + for (int l = 0; l < out_w; l++) { + int in_l = static_cast(ratio_w * l + 0.5); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l); + } + } + } + } +} + +template +static void BilinearInterpolationGrad(const Tensor& output_grad, + Tensor* input_grad, const float ratio_h, + const float ratio_w, const int in_h, + const int in_w, const int n, const int c, + const int out_h, const int out_w) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + for (int k = 0; k < out_h; k++) { // loop for images + int y_n = static_cast(ratio_h * k); + int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); + float d_n = ratio_h * k - y_n; + float d_s = 1.f - d_n; + + for (int l = 0; l < out_w; l++) { + int x_w = static_cast(ratio_w * l); + int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); + float d_w = ratio_w * l - x_w; + float d_e = 1.f - d_w; + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + // bilinear interpolation grad + const T grad = output_grad_t(i, j, k, l); + input_grad_t(i, j, y_n, x_w) += static_cast(grad * d_s * d_e); + input_grad_t(i, j, y_s, x_w) += static_cast(grad * d_n * d_e); + input_grad_t(i, j, y_n, x_e) += static_cast(grad * d_s * d_w); + input_grad_t(i, j, y_s, x_e) += static_cast(grad * d_n * d_w); + } + } + } + } +} + +template +class InterpolateKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + std::string interp_method = ctx.Attr("interp_method"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, output, static_cast(0.0)); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*input, ctx.GetPlace(), output); + return; + } + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if ("bilinear" == interp_method) { + BilinearInterpolation(*input, output, ratio_h, ratio_w, in_h, in_w, n, + c, out_h, out_w); + } else if ("nearest" == interp_method) { + NearestNeighborInterpolate(*input, output, ratio_h, ratio_w, n, c, + out_h, out_w); + } + } +}; + +template +class InterpolateGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + + std::string interp_method = ctx.Attr("interp_method"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); + return; + } + + float ratio_h = + (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + float ratio_w = + (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if ("bilinear" == interp_method) { + BilinearInterpolationGrad(*output_grad, input_grad, ratio_h, ratio_w, + in_h, in_w, n, c, out_h, out_w); + } else if ("nearest" == interp_method) { + NearestNeighborInterpolateGrad(*output_grad, input_grad, ratio_h, + ratio_w, n, c, out_h, out_w); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b0a8efd5edcf4de7053438d63ba0315997d4e280..595537ab1e572cf50cd7812c7d2d551124ae7560 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -101,6 +101,7 @@ __all__ = [ 'image_resize', 'image_resize_short', 'resize_bilinear', + 'resize_nearest', 'gather', 'scatter', 'sequence_scatter', @@ -5640,7 +5641,8 @@ def image_resize(input, out_shape=None, scale=None, name=None, - resample='BILINEAR'): + resample='BILINEAR', + actual_shape=None): """ **Resize a Batch of Images** @@ -5650,6 +5652,7 @@ def image_resize(input, Supporting resample methods: 'BILINEAR' : Bilinear interpolation + 'NEAREST' : Nearest neighbor interpolation Args: input (Variable): The input tensor of image resize layer, @@ -5664,25 +5667,51 @@ def image_resize(input, Default: None name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. - resample(str): The resample method. It can only be 'BILINEAR' currently. + resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST' + currently. Default: 'BILINEAR' + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph + constructing stage. + Default: None Returns: Variable: The output is a 4-D tensor of the shape (num_batches, channls, out_h, out_w). + Raises: + TypeError: out_shape should be a list or tuple or Variable. + TypeError: actual_shape should either be Variable or None. + ValueError: The 'resample' of image_resize can only be 'BILINEAR' + or 'NEAREST' currently. + ValueError: One of out_shape and scale must not be None. + ValueError: out_shape length should be 2. + Examples: .. code-block:: python out = fluid.layers.image_resize(input, out_shape=[12, 12]) """ - resample_methods = {'BILINEAR': 'bilinear_interp'} + resample_methods = { + 'BILINEAR': 'bilinear', + 'NEAREST': 'nearest', + } if resample not in resample_methods: raise ValueError( - "The 'resample' of image_resize can only be 'BILINEAR' currently.") + "The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently." + ) if out_shape is None and scale is None: - raise ValueError("One of out_shape and scale must not be None") - helper = LayerHelper('bilinear_interp', **locals()) + raise ValueError("One of out_shape and scale must not be None.") + helper = LayerHelper('interpolate', **locals()) dtype = helper.input_dtype() def _is_list_or_turple_(data): @@ -5692,33 +5721,106 @@ def image_resize(input, out_w = 0 inputs = {"X": input} if out_shape is not None: - if not (_is_list_or_turple_(out_shape) and - len(out_shape) == 2) and not isinstance(out_shape, Variable): - raise ValueError('out_shape should be a list or tuple or variable') - if _is_list_or_turple_(out_shape): - out_shape = list(map(int, out_shape)) - out_h = out_shape[0] - out_w = out_shape[1] - else: + if isinstance(out_shape, Variable): + warnings.warn("out_shape as Variable type is deprecated, \ + it is recommended to use actual_shape instead of \ + out_shape to specify output shape dynamically.") inputs['OutSize'] = out_shape + elif not (_is_list_or_turple_(out_shape)): + raise TypeError("out_shape should be a list or tuple or Variable.") + elif len(out_shape) != 2: + raise ValueError("out_shape length should be 2.") + + out_shape = list(map(int, out_shape)) + out_h = out_shape[0] + out_w = out_shape[1] else: out_h = int(input.shape[2] * scale) out_w = int(input.shape[3] * scale) + if isinstance(actual_shape, Variable): + inputs["OutSize"] = actual_shape + elif actual_shape is not None: + raise TypeError("actual_shape should either be Variable or None.") + out = helper.create_variable_for_type_inference(dtype) helper.append_op( - type=resample_methods[resample], + type='interpolate', inputs=inputs, outputs={"Out": out}, - attrs={"out_h": out_h, - "out_w": out_w}) + attrs={ + "out_h": out_h, + "out_w": out_w, + "interp_method": resample_methods[resample] + }) return out -@templatedoc(op_type="bilinear_interp") -def resize_bilinear(input, out_shape=None, scale=None, name=None): +@templatedoc(op_type="interpolate") +def resize_bilinear(input, + out_shape=None, + scale=None, + name=None, + actual_shape=None): """ - ${comment} + Resize input by performing bilinear interpolation based on given + output shape which specified by actual_shape, out_shape and scale + in priority order. + + Bilinear interpolation is an extension of linear interpolation for + interpolating functions of two variables (e.g. H-direction and + W-direction in this op) on a rectilinear 2D grid. The key idea is + to perform linear interpolation first in one direction, and then + again in the other direction. + + For details of bilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bilinear_interpolation + + Args: + input(${x_type}): ${x_comment}. + + out_shape(${out_size_type}): ${out_size_comment}. + + scale(float|None): The multiplier for the input height or width. At + least one of out_shape or scale must be set. And out_shape has + a higher priority than scale. Default: None. + + name(str|None): The output variable name. + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph + constructing stage. + Default: None + + Returns: + ${out_comment}. + """ + + return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape) + + +@templatedoc(op_type="interpolate") +def resize_nearest(input, + out_shape=None, + scale=None, + name=None, + actual_shape=None): + """ + Resize input by performing nearest neighbor interpolation in both the + 3rd dimention(in height direction) and the 4th dimention(in width + direction) based on given output shape which specified by actual_shape, + out_shape and scale in priority order. + + For details of nearest neighbor interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation Args: input(${x_type}): ${x_comment}. @@ -5730,12 +5832,25 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): a higher priority than scale. Default: None. name(str|None): The output variable name. + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph + constructing stage. + Default: None Returns: ${out_comment}. """ - return image_resize(input, out_shape, scale, name, 'BILINEAR') + return image_resize(input, out_shape, scale, name, 'NEAREST', actual_shape) def image_resize_short(input, out_short_len, resample='BILINEAR'): diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py deleted file mode 100644 index bed847c3c168c906a89c32631b2a8f0ba2e6e7be..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -from op_test import OpTest -import paddle.fluid.core as core - - -def bilinear_interp_np(input, out_h, out_w, out_size): - if out_size is not None: - out_h = out_size[0] - out_w = out_size[1] - batch_size, channel, in_h, in_w = input.shape - if out_h > 1: - ratio_h = (in_h - 1.0) / (out_h - 1.0) - else: - ratio_h = 0.0 - if out_w > 1: - ratio_w = (in_w - 1.0) / (out_w - 1.0) - else: - ratio_w = 0.0 - - out = np.zeros((batch_size, channel, out_h, out_w)) - for i in range(out_h): - h = int(ratio_h * i) - hid = 1 if h < in_h - 1 else 0 - h1lambda = ratio_h * i - h - h2lambda = 1.0 - h1lambda - for j in range(out_w): - w = int(ratio_w * j) - wid = 1 if w < in_w - 1 else 0 - w1lambda = ratio_w * j - w - w2lambda = 1.0 - w1lambda - - out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] + - w1lambda*input[:, :, h, w+wid]) + \ - h1lambda*(w2lambda*input[:, :, h+hid, w] + - w1lambda*input[:, :, h+hid, w+wid]) - return out.astype(input.dtype) - - -class TestBilinearInterpOp(OpTest): - def setUp(self): - self.out_size = None - self.init_test_case() - self.op_type = "bilinear_interp" - input_np = np.random.random(self.input_shape).astype("float32") - output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, - self.out_size) - self.inputs = {'X': input_np} - if self.out_size is not None: - self.inputs['OutSize'] = self.out_size - self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} - self.outputs = {'Out': output_np} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) - - def init_test_case(self): - self.input_shape = [2, 3, 4, 4] - self.out_h = 2 - self.out_w = 2 - self.out_size = np.array([3, 3]).astype("int32") - - -class TestCase1(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 1 - self.out_w = 1 - - -class TestCase2(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [3, 3, 9, 6] - self.out_h = 12 - self.out_w = 12 - - -class TestCase3(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [1, 1, 128, 64] - self.out_h = 64 - self.out_w = 128 - - -class TestCase4(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 1 - self.out_w = 1 - self.out_size = np.array([2, 2]).astype("int32") - - -class TestCase5(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [3, 3, 9, 6] - self.out_h = 12 - self.out_w = 12 - self.out_size = np.array([11, 11]).astype("int32") - - -class TestCase6(TestBilinearInterpOp): - def init_test_case(self): - self.input_shape = [1, 1, 128, 64] - self.out_h = 64 - self.out_w = 128 - self.out_size = np.array([65, 129]).astype("int32") - - -class TestBilinearInterpOpUint8(OpTest): - def setUp(self): - self.out_size = None - self.init_test_case() - self.op_type = "bilinear_interp" - input_np = np.random.randint( - low=0, high=256, size=self.input_shape).astype("uint8") - output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, - self.out_size) - self.inputs = {'X': input_np} - if self.out_size is not None: - self.inputs['OutSize'] = self.out_size - self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} - self.outputs = {'Out': output_np} - - def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) - - def init_test_case(self): - self.input_shape = [1, 3, 9, 6] - self.out_h = 10 - self.out_w = 9 - - -class TestCase1Uint8(TestBilinearInterpOpUint8): - def init_test_case(self): - self.input_shape = [2, 3, 128, 64] - self.out_h = 120 - self.out_w = 50 - - -class TestCase2Uint8(TestBilinearInterpOpUint8): - def init_test_case(self): - self.input_shape = [4, 1, 7, 8] - self.out_h = 5 - self.out_w = 13 - self.out_size = np.array([6, 15]).astype("int32") - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_interpolate_op.py b/python/paddle/fluid/tests/unittests/test_interpolate_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9748d094cda6ee9dc649d95d1ca7f1c4b55d1031 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_interpolate_op.py @@ -0,0 +1,335 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + + +def nearest_neighbor_interp_np(X, + out_h, + out_w, + out_size=None, + actual_shape=None): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + n, c, in_h, in_w = X.shape + + ratio_h = ratio_w = 0.0 + if out_h > 1: + ratio_h = (in_h - 1.0) / (out_h - 1.0) + if out_w > 1: + ratio_w = (in_w - 1.0) / (out_w - 1.0) + + out = np.zeros((n, c, out_h, out_w)) + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, :, i, j] = X[:, :, in_i, in_j] + + return out.astype(X.dtype) + + +def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None): + """bilinear interpolation implement in shape [N, C, H, W]""" + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + batch_size, channel, in_h, in_w = input.shape + if out_h > 1: + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 0.0 + if out_w > 1: + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 0.0 + + out = np.zeros((batch_size, channel, out_h, out_w)) + for i in range(out_h): + h = int(ratio_h * i) + hid = 1 if h < in_h - 1 else 0 + h1lambda = ratio_h * i - h + h2lambda = 1.0 - h1lambda + for j in range(out_w): + w = int(ratio_w * j) + wid = 1 if w < in_w - 1 else 0 + w1lambda = ratio_w * j - w + w2lambda = 1.0 - w1lambda + + out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] + + w1lambda*input[:, :, h, w+wid]) + \ + h1lambda*(w2lambda*input[:, :, h+hid, w] + + w1lambda*input[:, :, h+hid, w+wid]) + return out.astype(input.dtype) + + +INTERPOLATE_FUNCS = { + 'bilinear': bilinear_interp_np, + 'nearest': nearest_neighbor_interp_np, +} + + +class TestInterpolateOp(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "interpolate" + input_np = np.random.random(self.input_shape).astype("float32") + + output_np = INTERPOLATE_FUNCS[self.interp_method]( + input_np, self.out_h, self.out_w, self.out_size, self.actual_shape) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 4, 4] + self.out_h = 2 + self.out_w = 2 + self.out_size = np.array([3, 3]).astype("int32") + + +class TestBilinearInterpCase1(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + + +class TestBilinearInterpCase2(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + + +class TestBilinearInterpCase3(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + + +class TestBilinearInterpCase4(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.out_size = np.array([2, 2]).astype("int32") + + +class TestBilinearInterpCase5(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.out_size = np.array([11, 11]).astype("int32") + + +class TestBilinearInterpCase6(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + self.out_size = np.array([65, 129]).astype("int32") + + +class TestBilinearInterpActualShape(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.out_size = np.array([66, 40]).astype("int32") + + +class TestBilinearInterpBigScale(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 4, 64, 32] + self.out_h = 100 + self.out_w = 50 + self.out_size = np.array([101, 51]).astype('int32') + + +class TestInterpolateOpUint8(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "interpolate" + input_np = np.random.randint( + low=0, high=256, size=self.input_shape).astype("uint8") + output_np = INTERPOLATE_FUNCS[self.interp_method]( + input_np, self.out_h, self.out_w, self.out_size, self.actual_shape) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output_with_place(place=core.CPUPlace(), atol=1) + + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 3, 9, 6] + self.out_h = 10 + self.out_w = 9 + + +class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 128, 64] + self.out_h = 120 + self.out_w = 50 + + +class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 5 + self.out_w = 13 + self.out_size = np.array([6, 15]).astype("int32") + + +class TestNearestNeighborInterpCase1(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + + +class TestNearestNeighborInterpCase2(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + + +class TestNearestNeighborInterpCase3(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + + +class TestNearestNeighborInterpCase4(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.out_size = np.array([2, 2]).astype("int32") + + +class TestNearestNeighborInterpCase5(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.out_size = np.array([11, 11]).astype("int32") + + +class TestNearestNeighborInterpCase6(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + self.out_size = np.array([65, 129]).astype("int32") + + +class TestNearestNeighborInterpActualShape(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.out_size = np.array([66, 40]).astype("int32") + + +class TestNearestNeighborInterpBigScale(TestInterpolateOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 4, 64, 32] + self.out_h = 100 + self.out_w = 50 + self.out_size = np.array([101, 51]).astype('int32') + + +class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 128, 64] + self.out_h = 120 + self.out_w = 50 + + +class TestNearestNeighborInterpCase2Uint8(TestInterpolateOpUint8): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 5 + self.out_w = 13 + self.out_size = np.array([6, 15]).astype("int32") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 49ba41e6fc908e9713414120bbeb45ca715042c3..692f2375f2119c3def1751d92087613cd965bf25 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -496,6 +496,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) + def test_resize_nearest(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[3, 9, 6], dtype="float32") + output = layers.resize_nearest(x, out_shape=[12, 12]) + self.assertIsNotNone(output) + output = layers.resize_nearest(x, scale=3) + self.assertIsNotNone(output) + print(str(program)) + def test_polygon_box_transform(self): program = Program() with program_guard(program):