diff --git a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc index 3bf34fc685ee8af39b66f444c35d606c4b5d8ffb..93f9e108723fbd56e0d3bf5d439614c2c20bb393 100644 --- a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc +++ b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc @@ -41,13 +41,14 @@ class CUDNNGridSampleOpKernel : public framework::OpKernel { int n = input->dims()[0]; int c = input->dims()[1]; - int h = input->dims()[2]; - int w = input->dims()[3]; - const int size[4] = {n, c, h, w}; + int out_h = grid->dims()[1]; + int out_w = grid->dims()[2]; + const int size[4] = {n, c, out_h, out_w}; const T* input_data = input->data(); const T* grid_data = grid->data(); - T* output_data = output->mutable_data({n, c, h, w}, ctx.GetPlace()); + T* output_data = + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); ScopedSpatialTransformerDescriptor st_desc; cudnnSpatialTransformerDescriptor_t cudnn_st_desc = @@ -97,7 +98,7 @@ class CUDNNGridSampleGradOpKernel : public framework::OpKernel { const T* grid_data = grid->data(); const T* output_grad_data = output_grad->data(); T* input_grad_data = - input_grad->mutable_data(output_grad_dims, ctx.GetPlace()); + input_grad->mutable_data(input->dims(), ctx.GetPlace()); T* grid_grad_data = grid_grad->mutable_data({n, h, w, 2}, ctx.GetPlace()); diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 5be490379642e8761a6821fa0dc0d332ca5b41ef..deb71b807128e5c0b173b517e60832894ced41e5 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/grid_sampler_op.h" #include +#include #include "paddle/fluid/framework/op_registry.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" @@ -58,21 +59,10 @@ class GridSampleOp : public framework::OperatorWithKernel { "Input(X) and Input(Grid) dimension[0] should be equal, but " "received X dimension[0](%d) != Grid dimension[0](%d)", x_dims[0], grid_dims[0])); - PADDLE_ENFORCE_EQ( - grid_dims[1], x_dims[2], - platform::errors::InvalidArgument( - "Input(X) dims[2] and Input(Grid) dims[1] should be equal, but " - "received X dimension[2](%d) != Grid dimension[1](%d)", - x_dims[2], grid_dims[1])); - PADDLE_ENFORCE_EQ( - grid_dims[2], x_dims[3], - platform::errors::InvalidArgument( - "Input(X) dims[3] and Input(Grid) dims[2] should be equal, but " - "received X dimension[3](%d) != Grid dimension[2](%d)", - x_dims[3], grid_dims[2])); } - ctx->SetOutputDim("Output", x_dims); + ctx->SetOutputDim("Output", + {x_dims[0], x_dims[1], grid_dims[1], grid_dims[2]}); ctx->ShareLoD("X", "Output"); } @@ -108,15 +98,37 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default true) Only used in cudnn kernel, need install cudnn") .SetDefault(true); + AddAttr( + "align_corners", + "(bool, default true) If align_corners is true, it will project" + "-1 and 1 to the centers of the corner pixels. Otherwise, it will " + "project" + "-1 and 1 to the image edges.") + .SetDefault(true); + + AddAttr( + "mode", + "(bool, default true) The interpolation method which can be 'bilinear'" + " or 'nearest'.") + .SetDefault("bilinear"); + + AddAttr( + "padding_mode", + "(bool, default true) The padding method used when source" + "index is out of input images. It can be 'zeros', 'reflect' and " + "'border'.") + .SetDefault("zeros"); + AddComment(R"DOC( - This operation samples input X by using bilinear interpolation based on + This operation samples input X by using bilinear or nearest interpolation based on flow field grid, which is usually generated by affine_grid. The grid of shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates with shape [N, H, W] each, where grid_x is indexing the 4th dimension (in width dimension) of input data x and grid_y is indexing the 3rd dimension (in height dimension), finally results is the bilinear - interpolation value of 4 nearest corner points. + interpolation value or nearest value of 4 nearest corner points. + For bilinear interpolation mode: Step 1: Get (x, y) grid coordinates and scale to [0, H-1/W-1]. diff --git a/paddle/fluid/operators/grid_sampler_op.cu b/paddle/fluid/operators/grid_sampler_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..dc6258e507e9594beeae5534a896080fde53c435 --- /dev/null +++ b/paddle/fluid/operators/grid_sampler_op.cu @@ -0,0 +1,489 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/grid_sampler_op.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +static __forceinline__ __device__ bool in_bounds(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +template +static __forceinline__ __device__ void atomic_add(T* data, int h, int w, int sH, + int sW, int H, int W, + T delta) { + if (in_bounds(h, w, H, W)) { + atomicAdd(data + h * sH + w * sW, delta); + } +} + +template +static __forceinline__ __device__ T _unnormalize(T coord, int size, + bool align_corners) { + if (align_corners) { + return ((coord + 1.f) / 2) * (size - 1); + } else { + return ((coord + 1.f) * size - 1) / 2; + } +} + +template +static __forceinline__ __device__ T clip_indexes(T in, int max_value) { + return min(static_cast(max_value), max(in, static_cast(0))); +} + +template +static __forceinline__ __device__ T reflect_indexes(T in, int twice_low, + int twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = fabs(in - min); + T extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +template +static __forceinline__ __device__ T compute_positions(T coord, int size, + PaddingMode padding_mode, + bool align_corners) { + coord = _unnormalize(coord, size, align_corners); + if (padding_mode == PaddingMode::border) { + coord = clip_indexes(coord, size - 1); + } else if (padding_mode == PaddingMode::reflect) { + if (align_corners) { + coord = reflect_indexes(coord, 0, 2 * (size - 1)); + } else { + coord = reflect_indexes(coord, -1, 2 * size - 1); + } + coord = clip_indexes(coord, size - 1); + } + return coord; +} + +template +static __forceinline__ __device__ T _unnormalize_with_mask(T coord, int size, + bool align_corners, + T* grad_in) { + if (align_corners) { + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + *grad_in = static_cast(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +template +static __forceinline__ __device__ T clip_indexes_with_mask(T in, int clip_limit, + T* grad_in) { + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + T max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +template +static __forceinline__ __device__ T +reflect_indexes_with_mask(T in, int twice_low, int twice_high, T* grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_; + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + T extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +template +static __forceinline__ __device__ T +compute_positions_with_mask(T coord, int size, PaddingMode padding_mode, + bool align_corners, T* grad_in) { + T grad_clip, grad_refl; + coord = _unnormalize_with_mask(coord, size, align_corners, grad_in); + if (padding_mode == PaddingMode::border) { + coord = clip_indexes_with_mask(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == PaddingMode::reflect) { + if (align_corners) { + coord = reflect_indexes_with_mask(coord, 0, 2 * (size - 1), &grad_refl); + } else { + coord = reflect_indexes_with_mask(coord, -1, 2 * size - 1, &grad_refl); + } + coord = clip_indexes_with_mask(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + + return coord; +} + +template +__global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c, + int out_h, int out_w, int in_h, + int in_w, const T* input, const T* grid, + T* output, const Mode mode, + const PaddingMode padding_mode, + bool align_corners) { + int inp_sN = out_c * in_h * in_w; + + int inp_sC = in_h * in_w; + int inp_sH = in_w; + int inp_sW = 1; + int grid_sN = out_h * out_w * 2; + int grid_sH = out_w * 2; + int grid_sW = 2; + int grid_sCoor = 1; + int out_sN = out_c * out_h * out_w; + int out_sC = out_h * out_w; + int out_sH = out_w; + int out_sW = 1; + + CUDA_KERNEL_LOOP(index, nthreads) { + const int w = index % out_w; + const int h = (index / out_w) % out_h; + const int n = index / (out_h * out_w); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + + ix = compute_positions(ix, in_w, padding_mode, align_corners); + iy = compute_positions(iy, in_h, padding_mode, align_corners); + + if (mode == Mode::bilinear) { + int ix_nw = static_cast(floor(ix)); + int iy_nw = static_cast(floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); + + auto inp_offset_NC = n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < out_c; + ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { + *out_ptr_NCHW = static_cast(0); + if (in_bounds(iy_nw, ix_nw, in_h, in_w)) { + *out_ptr_NCHW += + input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (in_bounds(iy_ne, ix_ne, in_h, in_w)) { + *out_ptr_NCHW += + input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (in_bounds(iy_sw, ix_sw, in_h, in_w)) { + *out_ptr_NCHW += + input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (in_bounds(iy_se, ix_se, in_h, in_w)) { + *out_ptr_NCHW += + input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se; + } + } + } else if (mode == Mode::nearest) { + int ix_nearest = static_cast(round(ix)); + int iy_nearest = static_cast(round(iy)); + + auto inp_offset_NC = n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < out_c; + ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { + if (in_bounds(iy_nearest, ix_nearest, in_h, in_w)) { + *out_ptr_NCHW = + input[inp_offset_NC + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCHW = static_cast(0); + } + } + } + } +} + +template +class GridSampleOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.cuda_device_context(); + auto align_corners = ctx.Attr("align_corners"); + auto padding_mode_s = ctx.Attr("padding_mode"); + auto mode_s = ctx.Attr("mode"); + PaddingMode padding_mode; + Mode mode; + if (padding_mode_s == "border") { + padding_mode = PaddingMode::border; + } else if (padding_mode_s == "reflect") { + padding_mode = PaddingMode::reflect; + } else { + padding_mode = PaddingMode::zeros; + } + + if (mode_s == "nearest") { + mode = Mode::nearest; + } else { + mode = Mode::bilinear; + } + + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + const int n = grid->dims()[0]; + const int out_h = grid->dims()[1]; + const int out_w = grid->dims()[2]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h + << "; out_w: " << out_w; + auto* output = ctx.Output("Output"); + auto* output_data = output->mutable_data(ctx.GetPlace()); + + VLOG(3) << "set constant"; + math::SetConstant()( + dev_ctx, output, static_cast(0)); + int count = static_cast(n * out_h * out_w); + + auto cu_stream = dev_ctx.stream(); + + int block = 512; + int grid_size = (count + block - 1) / block; + grid_sample_cuda_kernel<<>>( + count, n, c, out_h, out_w, in_h, in_w, input->data(), + grid->data(), output_data, mode, padding_mode, align_corners); + } +}; + +template +__global__ void grid_sampler_cuda_backward_kernel( + const int nthreads, const T* grad_output, const T* input, const T* grid, + int n, int out_c, int out_h, int out_w, int in_h, int in_w, T* grad_input, + T* grad_grid, const Mode mode, const PaddingMode padding_mode, + bool align_corners) { + int inp_sN = out_c * in_h * in_w; + int inp_sC = in_h * in_w; + int inp_sH = in_w; + int inp_sW = 1; + int grid_sN = out_h * out_w * 2; + int grid_sH = out_w * 2; + int grid_sW = 2; + int grid_sCoor = 1; + + int gOut_sN = out_c * out_h * out_w; + int gOut_sC = out_h * out_w; + int gOut_sH = out_w; + int gOut_sW = 1; + + CUDA_KERNEL_LOOP(index, nthreads) { + const int w = index % out_w; + const int h = (index / out_w) % out_h; + const int n = index / (out_h * out_w); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + + T gix_mult, giy_mult; + ix = compute_positions_with_mask(ix, in_w, padding_mode, align_corners, + &gix_mult); + iy = compute_positions_with_mask(iy, in_h, padding_mode, align_corners, + &giy_mult); + + if (mode == Mode::bilinear) { + int ix_nw = static_cast(floor(ix)); + int iy_nw = static_cast(floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); + + T gix = static_cast(0), giy = static_cast(0); + int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW; + T* gInp_ptr_NC = grad_input + n * inp_sN; + int inp_offset_NC = n * inp_sN; + for (int c = 0; c < out_c; ++c, inp_offset_NC += inp_sC, + gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) { + T gOut = grad_output[gOut_offset]; + + atomic_add(gInp_ptr_NC, iy_nw, ix_nw, inp_sH, inp_sW, in_h, in_w, + nw * gOut); + atomic_add(gInp_ptr_NC, iy_ne, ix_ne, inp_sH, inp_sW, in_h, in_w, + ne * gOut); + atomic_add(gInp_ptr_NC, iy_sw, ix_sw, inp_sH, inp_sW, in_h, in_w, + sw * gOut); + atomic_add(gInp_ptr_NC, iy_se, ix_se, inp_sH, inp_sW, in_h, in_w, + se * gOut); + + if (in_bounds(iy_nw, ix_nw, in_h, in_w)) { + T nw_val = input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW]; + gix -= nw_val * (iy_se - iy) * gOut; + giy -= nw_val * (ix_se - ix) * gOut; + } + if (in_bounds(iy_ne, ix_ne, in_h, in_w)) { + T ne_val = input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW]; + gix += ne_val * (iy_sw - iy) * gOut; + giy -= ne_val * (ix - ix_sw) * gOut; + } + if (in_bounds(iy_sw, ix_sw, in_h, in_w)) { + T sw_val = input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW]; + gix -= sw_val * (iy - iy_ne) * gOut; + giy += sw_val * (ix_ne - ix) * gOut; + } + if (in_bounds(iy_se, ix_se, in_h, in_w)) { + T se_val = input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW]; + gix += se_val * (iy - iy_nw) * gOut; + giy += se_val * (ix - ix_nw) * gOut; + } + } + + T* gGrid_ptr_NHW = grad_grid + index * grid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } else if (mode == Mode::nearest) { + int ix_nearest = static_cast(::round(ix)); + int iy_nearest = static_cast(::round(iy)); + + int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW; + T* gInp_ptr_NC = grad_input + n * inp_sN; + for (int c = 0; c < out_c; + ++c, gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) { + atomic_add(gInp_ptr_NC, iy_nearest, ix_nearest, inp_sH, inp_sW, in_h, + in_w, grad_output[gOut_offset]); + } + + T* gGrid_ptr_NHW = grad_grid + index * grid_sW; + gGrid_ptr_NHW[0] = static_cast(0); + gGrid_ptr_NHW[1] = static_cast(0); + } + } +} + +template +class GridSampleGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.cuda_device_context(); + auto align_corners = ctx.Attr("align_corners"); + auto padding_mode_s = ctx.Attr("padding_mode"); + auto mode_s = ctx.Attr("mode"); + + PaddingMode padding_mode; + Mode mode; + if (padding_mode_s == "border") { + padding_mode = PaddingMode::border; + } else if (padding_mode_s == "reflect") { + padding_mode = PaddingMode::reflect; + } else { + padding_mode = PaddingMode::zeros; + } + + if (mode_s == "nearest") { + mode = Mode::nearest; + } else { + mode = Mode::bilinear; + } + + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + auto* output_grad = ctx.Input(framework::GradVarName("Output")); + + const int n = grid->dims()[0]; + const int out_h = grid->dims()[1]; + const int out_w = grid->dims()[2]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + auto* input_grad = ctx.Output(framework::GradVarName("X")); + input_grad->mutable_data(ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), + input_grad, static_cast(0)); + auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); + grid_grad->mutable_data(ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), + grid_grad, static_cast(0)); + + int count = static_cast(n * out_h * out_w); + auto cu_stream = dev_ctx.stream(); + int block = 512; + int grid_size = (count + block - 1) / block; + grid_sampler_cuda_backward_kernel<<>>( + count, output_grad->data(), input->data(), grid->data(), n, c, + out_h, out_w, in_h, in_w, input_grad->data(), grid_grad->data(), + mode, padding_mode, align_corners); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(grid_sampler, ops::GridSampleOpCUDAKernel, + ops::GridSampleOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(grid_sampler_grad, + ops::GridSampleGradOpCUDAKernel, + ops::GridSampleGradOpCUDAKernel); diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h index 08a6043eb07a6e44d46428ee195f6cb28c2ee77c..eda800e78faf5da2bb379b8101e4823c5bc2d2f8 100644 --- a/paddle/fluid/operators/grid_sampler_op.h +++ b/paddle/fluid/operators/grid_sampler_op.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/gather.h" @@ -22,6 +25,13 @@ limitations under the License. */ namespace paddle { namespace operators { +enum class Mode { + bilinear, + nearest, +}; + +enum class PaddingMode { zeros, border, reflect }; + using Tensor = framework::Tensor; template @@ -39,64 +49,229 @@ static inline bool isInBound(T x, T y, T x_max, T y_max) { } template -static void CalcGridLocations(const platform::CPUDeviceContext& ctx, - const Tensor& grid, Tensor* x_w, Tensor* x_e, - Tensor* y_n, Tensor* y_s, Tensor* d_w, - Tensor* d_e, Tensor* d_n, Tensor* d_s) { +static inline void unnormalize(const platform::CPUDeviceContext& ctx, + Tensor* grid_slice, + const int max_val, // height-1 or width-1 + bool align_corners) { auto& place = *ctx.eigen_device(); + auto grid_slice_t = EigenTensor::From(*grid_slice); + + if (!align_corners) { + auto factor = static_cast((max_val + 1) * 0.5); + grid_slice_t.device(place) = + (grid_slice_t + static_cast(1)) * factor - static_cast(0.5); + } else { + auto factor = static_cast(max_val * 0.5); + grid_slice_t.device(place) = (grid_slice_t + static_cast(1)) * factor; + } +} + +template +static inline void clip(const platform::CPUDeviceContext& ctx, + Tensor* grid_slice, + const int max_val, // height-1 or width-1 + bool align_corners, std::string padding_mode) { + auto& place = *ctx.eigen_device(); + auto grid_slice_t = EigenTensor::From(*grid_slice); + if (padding_mode == "border") { + grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(max_val)); + } else if (padding_mode == "reflect") { + if (align_corners) { + auto double_range = static_cast(max_val * 2); + auto grid_abs = grid_slice_t.abs(); + auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; + grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + } else { + auto double_range = static_cast((max_val + 1) * 2); + auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); + auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; + grid_slice_t.device(place) = + extra.cwiseMin(double_range - extra) - static_cast(0.5); + grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(max_val)); + } + } +} + +template +static inline void clipWithMask(const platform::CPUDeviceContext& ctx, + const int max_val, // height-1 or width-1 + bool align_corners, std::string padding_mode, + Tensor* grid_slice, Tensor* grid_scale) { + auto& place = *ctx.eigen_device(); + grid_scale->mutable_data(grid_slice->dims(), ctx.GetPlace()); + + auto grid_slice_t = EigenTensor::From(*grid_slice); + auto factor = static_cast(max_val * 0.5); + if (!align_corners) { + factor = static_cast((max_val + 1) * 0.5); + } + auto grid_scale_t = EigenTensor::From(*grid_scale).setConstant(factor); + + if (padding_mode == "border") { + // auto bounded_lo = grid_slice_t.cwiseMax(static_cast(0)); + auto res = grid_slice_t.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(max_val)); + + auto in_bound = (res == grid_slice_t); + grid_scale_t.device(place) = grid_scale_t * in_bound.template cast(); + grid_slice_t.device(place) = res; + } else if (padding_mode == "reflect") { + if (align_corners) { + auto double_range = static_cast(max_val * 2); + auto is_neg = (grid_slice_t < static_cast(0)); + auto grid_abs = grid_slice_t.abs(); + auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; + auto one_more_flip = (extra > (double_range - extra)); + grid_scale_t.device(place) = + grid_scale_t * ((is_neg == one_more_flip).template cast() - + (is_neg != one_more_flip).template cast()); + grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + } else { + auto double_range = static_cast((max_val + 1) * 2); + auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); + auto is_neg = ((grid_slice_t + static_cast(0.5)) < static_cast(0)); + auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; + auto one_more_flip = (extra > (double_range - extra)); + auto reflected = + extra.cwiseMin(double_range - extra) - static_cast(0.5); + auto clipped = reflected.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(max_val)); + auto in_bound = (clipped == reflected).template cast(); + grid_scale_t.device(place) = + grid_scale_t * ((is_neg == one_more_flip).template cast() - + (is_neg != one_more_flip).template cast()) * + in_bound; + grid_slice_t.device(place) = clipped; + } + } +} + +template +static void calcGridLocations(const platform::CPUDeviceContext& ctx, + const Tensor& grid, const int in_h, + const int in_w, bool align_corners, + std::string padding_mode, Tensor* grid_x, + Tensor* grid_y) { const int n = grid.dims()[0]; - const int h = grid.dims()[1]; - const int w = grid.dims()[2]; - const T x_max = static_cast(w - 1); - const T y_max = static_cast(h - 1); + const int out_h = grid.dims()[1]; + const int out_w = grid.dims()[2]; // split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim - Tensor grid_x, grid_y; - T* grid_x_data = grid_x.mutable_data({n, h, w}, ctx.GetPlace()); - T* grid_y_data = grid_y.mutable_data({n, h, w}, ctx.GetPlace()); + T* grid_x_data = grid_x->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + T* grid_y_data = grid_y->mutable_data({n, out_h, out_w}, ctx.GetPlace()); const T* grid_data = grid.data(); - for (int i = 0; i < n * h * w; i++) { + for (int i = 0; i < n * out_h * out_w; i++) { grid_x_data[i] = grid_data[2 * i]; grid_y_data[i] = grid_data[(2 * i) + 1]; } - Tensor ones; - ones.mutable_data({n, h, w}, ctx.GetPlace()); - auto ones_t = EigenTensor::From(ones).setConstant(1.0); - Tensor half_xmax; - Tensor half_ymax; - half_xmax.mutable_data({n, h, w}, ctx.GetPlace()); - auto half_xmax_t = - EigenTensor::From(half_xmax).setConstant(0.5 * x_max); - half_ymax.mutable_data({n, h, w}, ctx.GetPlace()); - auto half_ymax_t = - EigenTensor::From(half_ymax).setConstant(0.5 * y_max); - - // scale grid to [0, h-1/w-1] - auto grid_x_t = EigenTensor::From(grid_x); - auto grid_y_t = EigenTensor::From(grid_y); - grid_x_t.device(place) = (grid_x_t + ones_t) * half_xmax_t; - grid_y_t.device(place) = (grid_y_t + ones_t) * half_ymax_t; + unnormalize(ctx, grid_x, in_w - 1, align_corners); + unnormalize(ctx, grid_y, in_h - 1, align_corners); + + clip(ctx, grid_x, in_w - 1, align_corners, padding_mode); + clip(ctx, grid_y, in_h - 1, align_corners, padding_mode); +} + +template +static void calcGridLocationsWithGrad(const platform::CPUDeviceContext& ctx, + const Tensor& grid, const int in_h, + const int in_w, bool align_corners, + std::string padding_mode, Tensor* grid_x, + Tensor* grid_y, Tensor* grid_x_scale, + Tensor* grid_y_scale) { + const int n = grid.dims()[0]; + const int out_h = grid.dims()[1]; + const int out_w = grid.dims()[2]; + + // split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim + T* grid_x_data = grid_x->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + T* grid_y_data = grid_y->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + + const T* grid_data = grid.data(); + for (int i = 0; i < n * out_h * out_w; i++) { + grid_x_data[i] = grid_data[2 * i]; + grid_y_data[i] = grid_data[(2 * i) + 1]; + } + unnormalize(ctx, grid_x, in_w - 1, align_corners); + unnormalize(ctx, grid_y, in_h - 1, align_corners); + + clipWithMask(ctx, in_w - 1, align_corners, padding_mode, grid_x, + grid_x_scale); + clipWithMask(ctx, in_h - 1, align_corners, padding_mode, grid_y, + grid_y_scale); +} + +template +static void getGridPointValue(const Tensor& input, Tensor* output, + const Tensor& x, const Tensor& y) { + const int n = input.dims()[0]; + const int c = input.dims()[1]; + const int in_h = input.dims()[2]; + const int in_w = input.dims()[3]; + const int out_h = x.dims()[1]; + const int out_w = x.dims()[2]; + auto x_t = EigenTensor::From(x); + auto y_t = EigenTensor::From(y); + auto output_t = EigenTensor::From(*output).setConstant((T)0); + auto input_t = EigenTensor::From(input); + + for (int i = 0; i < n; i++) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), + (T)(in_h - 1))) { + for (int j = 0; j < c; j++) { + output_t(i, j, k, l) = + input_t(i, j, static_cast(round(y_t(i, k, l))), + static_cast(round(x_t(i, k, l)))); + } + } + } + } + } +} + +template +static void allNeigbors(const platform::CPUDeviceContext& ctx, + const Tensor& input, Tensor* grid_x, Tensor* grid_y, + Tensor* x_w, Tensor* x_e, Tensor* y_n, + Tensor* y_s, // positions + Tensor* d_w, Tensor* d_e, Tensor* d_n, + Tensor* d_s, // distance + Tensor* v_wn, Tensor* v_en, Tensor* v_ws, + Tensor* v_es) { // values + auto& place = *ctx.eigen_device(); + + const int c = input.dims()[1]; + const int n = grid_x->dims()[0]; + const int out_h = grid_x->dims()[1]; + const int out_w = grid_x->dims()[2]; // calculate coords of 4 corner points - x_w->mutable_data({n, h, w}, ctx.GetPlace()); - x_e->mutable_data({n, h, w}, ctx.GetPlace()); - y_n->mutable_data({n, h, w}, ctx.GetPlace()); - y_s->mutable_data({n, h, w}, ctx.GetPlace()); + x_w->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + x_e->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + y_n->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + y_s->mutable_data({n, out_h, out_w}, ctx.GetPlace()); auto x_w_t = EigenTensor::From(*x_w); auto x_e_t = EigenTensor::From(*x_e); auto y_n_t = EigenTensor::From(*y_n); auto y_s_t = EigenTensor::From(*y_s); + + auto grid_x_t = EigenTensor::From(*grid_x); + auto grid_y_t = EigenTensor::From(*grid_y); + x_w_t.device(place) = grid_x_t.floor(); - x_e_t.device(place) = x_w_t + ones_t; + x_e_t.device(place) = x_w_t + static_cast(1); y_n_t.device(place) = grid_y_t.floor(); - y_s_t.device(place) = y_n_t + ones_t; + y_s_t.device(place) = y_n_t + static_cast(1); // calculate distances to 4 sides - d_w->mutable_data({n, h, w}, ctx.GetPlace()); - d_e->mutable_data({n, h, w}, ctx.GetPlace()); - d_n->mutable_data({n, h, w}, ctx.GetPlace()); - d_s->mutable_data({n, h, w}, ctx.GetPlace()); + d_w->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + d_e->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + d_n->mutable_data({n, out_h, out_w}, ctx.GetPlace()); + d_s->mutable_data({n, out_h, out_w}, ctx.GetPlace()); auto d_w_t = EigenTensor::From(*d_w); auto d_e_t = EigenTensor::From(*d_e); auto d_n_t = EigenTensor::From(*d_n); @@ -105,28 +280,100 @@ static void CalcGridLocations(const platform::CPUDeviceContext& ctx, d_e_t.device(place) = x_e_t - grid_x_t; d_n_t.device(place) = grid_y_t - y_n_t; d_s_t.device(place) = y_s_t - grid_y_t; + + // calc 4 corner points value + v_wn->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + v_en->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + v_ws->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + v_es->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + getGridPointValue(input, v_wn, *x_w, *y_n); + getGridPointValue(input, v_en, *x_e, *y_n); + getGridPointValue(input, v_ws, *x_w, *y_s); + getGridPointValue(input, v_es, *x_e, *y_s); } template -static void GetGridPointValue(const Tensor& input, Tensor* output, - const Tensor& x, const Tensor& y) { - const int n = input.dims()[0]; +static void bilinearInter(const platform::CPUDeviceContext& ctx, + const Tensor& input, Tensor* grid_x, Tensor* grid_y, + Tensor* out) { + auto& place = *ctx.eigen_device(); + const int n = grid_x->dims()[0]; + const int out_h = grid_x->dims()[1]; + const int out_w = grid_x->dims()[2]; const int c = input.dims()[1]; - const int h = input.dims()[2]; - const int w = input.dims()[3]; + + Tensor x_w, x_e, y_n, y_s; + Tensor d_w, d_e, d_n, d_s; + Tensor v_wn, v_en, v_ws, v_es; + + allNeigbors(ctx, input, grid_x, grid_y, &x_w, &x_e, &y_n, &y_s, &d_w, &d_e, + &d_n, &d_s, &v_wn, &v_en, &v_ws, &v_es); + + auto d_w_t = EigenTensor::From(d_w); + auto d_e_t = EigenTensor::From(d_e); + auto d_n_t = EigenTensor::From(d_n); + auto d_s_t = EigenTensor::From(d_s); + + auto d_w_scaled_t = + d_w_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); + auto d_e_scaled_t = + d_e_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); + auto d_n_scaled_t = + d_n_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); + auto d_s_scaled_t = + d_s_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); + auto v_wn_t = EigenTensor::From(v_wn); + auto v_en_t = EigenTensor::From(v_en); + auto v_ws_t = EigenTensor::From(v_ws); + auto v_es_t = EigenTensor::From(v_es); + auto output_t = EigenTensor::From(*out); + // bilinear interpolaetion by 4 corner points + output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t + + v_en_t * d_w_scaled_t * d_s_scaled_t + + v_ws_t * d_e_scaled_t * d_n_scaled_t + + v_es_t * d_w_scaled_t * d_n_scaled_t; +} + +template +static void nearestInter(const platform::CPUDeviceContext& ctx, + const Tensor& input, Tensor* grid_x, Tensor* grid_y, + Tensor* out) { + auto& place = *ctx.eigen_device(); + + auto grid_x_t = EigenTensor::From(*grid_x); + auto grid_y_t = EigenTensor::From(*grid_y); + grid_x_t = grid_x_t.round(); + grid_y_t = grid_y_t.round(); + getGridPointValue(input, out, *grid_x, *grid_y); +} + +template +static void gatherOutputGradToInputGrad(const Tensor& output_grad, + Tensor* input_grad, const Tensor& x, + const Tensor& y, const Tensor& d1, + const Tensor& d2) { + const int n = output_grad.dims()[0]; + const int c = output_grad.dims()[1]; + const int out_h = output_grad.dims()[2]; + const int out_w = output_grad.dims()[3]; + const int in_h = input_grad->dims()[2]; + const int in_w = input_grad->dims()[3]; auto x_t = EigenTensor::From(x); auto y_t = EigenTensor::From(y); - auto output_t = EigenTensor::From(*output).setConstant((T)0); - auto input_t = EigenTensor::From(input); + auto d1_t = EigenTensor::From(d1); + auto d2_t = EigenTensor::From(d2); + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); for (int i = 0; i < n; i++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), + (T)(in_h - 1))) { for (int j = 0; j < c; j++) { - output_t(i, j, k, l) = - input_t(i, j, static_cast(round(y_t(i, k, l))), - static_cast(round(x_t(i, k, l)))); + input_grad_t(i, j, static_cast(round(y_t(i, k, l))), + static_cast(round(x_t(i, k, l)))) += + output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l); } } } @@ -135,29 +382,28 @@ static void GetGridPointValue(const Tensor& input, Tensor* output, } template -static void GatherOutputGradToInputGrad(const Tensor& output_grad, +static void gatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input_grad, const Tensor& x, - const Tensor& y, const Tensor& d1, - const Tensor& d2) { + const Tensor& y) { const int n = output_grad.dims()[0]; const int c = output_grad.dims()[1]; - const int h = output_grad.dims()[2]; - const int w = output_grad.dims()[3]; + const int out_h = output_grad.dims()[2]; + const int out_w = output_grad.dims()[3]; + const int in_h = input_grad->dims()[2]; + const int in_w = input_grad->dims()[3]; auto x_t = EigenTensor::From(x); auto y_t = EigenTensor::From(y); - auto d1_t = EigenTensor::From(d1); - auto d2_t = EigenTensor::From(d2); auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); - for (int i = 0; i < n; i++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), + (T)(in_h - 1))) { for (int j = 0; j < c; j++) { input_grad_t(i, j, static_cast(round(y_t(i, k, l))), static_cast(round(x_t(i, k, l)))) += - output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l); + output_grad_t(i, j, k, l); } } } @@ -165,65 +411,126 @@ static void GatherOutputGradToInputGrad(const Tensor& output_grad, } } +template +static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx, + const Tensor& input, const Tensor& output_grad, + Tensor* grid_x, Tensor* grid_y, + Tensor* grid_x_scale, Tensor* grid_y_scale, + Tensor* input_grad, Tensor* grid_grad) { + const int n = grid_x->dims()[0]; + const int out_h = grid_x->dims()[1]; + const int out_w = grid_x->dims()[2]; + const int c = input.dims()[1]; + + Tensor x_w, x_e, y_n, y_s; + Tensor d_w, d_e, d_n, d_s; + Tensor v_wn, v_en, v_ws, v_es; + + allNeigbors(ctx, input, + grid_x, // grid_x + grid_y, // grid_y + &x_w, &x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s, &v_wn, &v_en, + &v_ws, &v_es); + + // gather output grad value to input grad by corner point coords and weight + gatherOutputGradToInputGrad(output_grad, input_grad, x_w, y_n, d_e, d_s); + gatherOutputGradToInputGrad(output_grad, input_grad, x_w, y_s, d_e, d_n); + gatherOutputGradToInputGrad(output_grad, input_grad, x_e, y_n, d_w, d_s); + gatherOutputGradToInputGrad(output_grad, input_grad, x_e, y_s, d_w, d_n); + + auto v_wn_t = EigenTensor::From(v_wn); + auto v_en_t = EigenTensor::From(v_en); + auto v_ws_t = EigenTensor::From(v_ws); + auto v_es_t = EigenTensor::From(v_es); + + auto d_w_t = EigenTensor::From(d_w); + auto d_e_t = EigenTensor::From(d_e); + auto d_n_t = EigenTensor::From(d_n); + auto d_s_t = EigenTensor::From(d_s); + + auto output_grad_t = EigenTensor::From(output_grad); + + Tensor grid_grad_x, grid_grad_y; + grid_grad_x.mutable_data({n, out_h, out_w}, ctx.GetPlace()); + grid_grad_y.mutable_data({n, out_h, out_w}, ctx.GetPlace()); + auto grid_grad_x_t = + EigenTensor::From(grid_grad_x).setConstant(static_cast(0.0)); + auto grid_grad_y_t = + EigenTensor::From(grid_grad_y).setConstant(static_cast(0.0)); + for (int i = 0; i < n; i++) { + for (int j = 0; j < c; j++) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + grid_grad_x_t(i, k, l) += + ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + + (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) * + output_grad_t(i, j, k, l); + grid_grad_y_t(i, k, l) += + ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + + (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) * + output_grad_t(i, j, k, l); + } + } + } + } + + // const T x_max = static_cast(in_w - 1); + // const T y_max = static_cast(in_h - 1); + + auto grid_x_scale_t = EigenTensor::From(*grid_x_scale); + auto grid_y_scale_t = EigenTensor::From(*grid_y_scale); + grid_grad_x_t = grid_grad_x_t * grid_x_scale_t; + grid_grad_y_t = grid_grad_y_t * grid_y_scale_t; + + // gather grid_grad [x, y] in 3rd Dim + T* grid_grad_data = grid_grad->data(); + T* grid_grad_x_data = grid_grad_x.data(); + T* grid_grad_y_data = grid_grad_y.data(); + for (int i = 0; i < n * out_h * out_w; i++) { + grid_grad_data[2 * i] = grid_grad_x_data[i]; + grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; + } +} + template class GridSampleOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& place = *ctx.template device_context().eigen_device(); + auto align_corners = ctx.Attr("align_corners"); + auto padding_mode = ctx.Attr("padding_mode"); + auto mode = ctx.Attr("mode"); + auto* input = ctx.Input("X"); auto* grid = ctx.Input("Grid"); - const int n = input->dims()[0]; + const int n = grid->dims()[0]; + const int out_h = grid->dims()[1]; + const int out_w = grid->dims()[2]; const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - - // calc locations and distances of 4 corner points - Tensor x_w, x_e, y_n, y_s; - Tensor d_w, d_e, d_n, d_s; - CalcGridLocations( - ctx.template device_context(), *grid, &x_w, - &x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s); + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; auto* output = ctx.Output("Output"); - output->mutable_data({n, c, h, w}, ctx.GetPlace()); + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); math::SetConstant()( ctx.template device_context(), output, static_cast(0)); - // calc 4 corner points value - Tensor v_wn, v_en, v_ws, v_es; - v_wn.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_en.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_ws.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_es.mutable_data({n, c, h, w}, ctx.GetPlace()); - GetGridPointValue(*input, &v_wn, x_w, y_n); - GetGridPointValue(*input, &v_en, x_e, y_n); - GetGridPointValue(*input, &v_ws, x_w, y_s); - GetGridPointValue(*input, &v_es, x_e, y_s); - - auto d_w_t = EigenTensor::From(d_w); - auto d_e_t = EigenTensor::From(d_e); - auto d_n_t = EigenTensor::From(d_n); - auto d_s_t = EigenTensor::From(d_s); - auto d_w_scaled_t = - d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); - auto d_e_scaled_t = - d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); - auto d_n_scaled_t = - d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); - auto d_s_scaled_t = - d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); - auto v_wn_t = EigenTensor::From(v_wn); - auto v_en_t = EigenTensor::From(v_en); - auto v_ws_t = EigenTensor::From(v_ws); - auto v_es_t = EigenTensor::From(v_es); - auto output_t = EigenTensor::From(*output); - // bilinear interpolaetion by 4 corner points - output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t + - v_en_t * d_w_scaled_t * d_s_scaled_t + - v_ws_t * d_e_scaled_t * d_n_scaled_t + - v_es_t * d_w_scaled_t * d_n_scaled_t; + Tensor grid_x, grid_y; + calcGridLocations( + ctx.template device_context(), *grid, in_h, + in_w, align_corners, padding_mode, &grid_x, &grid_y); + if (mode == "bilinear") { + bilinearInter( + ctx.template device_context(), *input, + &grid_x, &grid_y, output); + } else if (mode == "nearest") { + auto grid_x_t = EigenTensor::From(grid_x); + auto grid_y_t = EigenTensor::From(grid_y); + grid_x_t = grid_x_t.round(); + grid_y_t = grid_y_t.round(); + getGridPointValue(*input, output, grid_x, grid_y); + } } }; @@ -231,97 +538,48 @@ template class GridSampleGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto align_corners = ctx.Attr("align_corners"); + auto padding_mode = ctx.Attr("padding_mode"); + auto mode = ctx.Attr("mode"); + auto* input = ctx.Input("X"); auto* grid = ctx.Input("Grid"); auto* output_grad = ctx.Input(framework::GradVarName("Output")); - const int n = input->dims()[0]; + const int n = grid->dims()[0]; + const int out_h = grid->dims()[1]; + const int out_w = grid->dims()[2]; const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; auto* input_grad = ctx.Output(framework::GradVarName("X")); - input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); math::SetConstant()( ctx.template device_context(), input_grad, static_cast(0)); auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); - grid_grad->mutable_data({n, h, w, 2}, ctx.GetPlace()); + grid_grad->mutable_data({n, out_h, out_w, 2}, ctx.GetPlace()); math::SetConstant()( ctx.template device_context(), grid_grad, static_cast(0)); - - Tensor x_w, x_e, y_n, y_s; - Tensor d_w, d_e, d_n, d_s; - CalcGridLocations( - ctx.template device_context(), *grid, &x_w, - &x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s); - - // gather output grad value to input grad by corner point coords and weight - GatherOutputGradToInputGrad(*output_grad, input_grad, x_w, y_n, d_e, - d_s); - GatherOutputGradToInputGrad(*output_grad, input_grad, x_w, y_s, d_e, - d_n); - GatherOutputGradToInputGrad(*output_grad, input_grad, x_e, y_n, d_w, - d_s); - GatherOutputGradToInputGrad(*output_grad, input_grad, x_e, y_s, d_w, - d_n); - - // calc 4 corner points value - Tensor v_wn, v_en, v_ws, v_es; - v_wn.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_en.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_ws.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_es.mutable_data({n, c, h, w}, ctx.GetPlace()); - GetGridPointValue(*input, &v_wn, x_w, y_n); - GetGridPointValue(*input, &v_en, x_e, y_n); - GetGridPointValue(*input, &v_ws, x_w, y_s); - GetGridPointValue(*input, &v_es, x_e, y_s); - auto v_wn_t = EigenTensor::From(v_wn); - auto v_en_t = EigenTensor::From(v_en); - auto v_ws_t = EigenTensor::From(v_ws); - auto v_es_t = EigenTensor::From(v_es); - - auto d_w_t = EigenTensor::From(d_w); - auto d_e_t = EigenTensor::From(d_e); - auto d_n_t = EigenTensor::From(d_n); - auto d_s_t = EigenTensor::From(d_s); - - auto output_grad_t = EigenTensor::From(*output_grad); - - Tensor grid_grad_x, grid_grad_y; - grid_grad_x.mutable_data({n, h, w}, ctx.GetPlace()); - grid_grad_y.mutable_data({n, h, w}, ctx.GetPlace()); - auto grid_grad_x_t = EigenTensor::From(grid_grad_x).setConstant(0.0); - auto grid_grad_y_t = EigenTensor::From(grid_grad_y).setConstant(0.0); - for (int i = 0; i < n; i++) { - for (int j = 0; j < c; j++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - grid_grad_x_t(i, k, l) += - ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + - (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) * - output_grad_t(i, j, k, l); - grid_grad_y_t(i, k, l) += - ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + - (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) * - output_grad_t(i, j, k, l); - } - } - } - } - const T x_max = static_cast(w - 1); - const T y_max = static_cast(h - 1); - grid_grad_x_t = grid_grad_x_t * (x_max / (T)2); - grid_grad_y_t = grid_grad_y_t * (y_max / (T)2); - - // gather grid_grad [x, y] in 3rd Dim - T* grid_grad_data = grid_grad->data(); - T* grid_grad_x_data = grid_grad_x.data(); - T* grid_grad_y_data = grid_grad_y.data(); - for (int i = 0; i < n * h * w; i++) { - grid_grad_data[2 * i] = grid_grad_x_data[i]; - grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; + Tensor grid_x, grid_y; + Tensor grid_x_scale, grid_y_scale; + calcGridLocationsWithGrad( + ctx.template device_context(), *grid, in_h, + in_w, align_corners, padding_mode, &grid_x, &grid_y, &grid_x_scale, + &grid_y_scale); + if (mode == "bilinear") { + gatherBilinearGrad(ctx.template device_context(), + *input, *output_grad, &grid_x, &grid_y, + &grid_x_scale, &grid_y_scale, input_grad, + grid_grad); + } else { + auto grid_x_t = EigenTensor::From(grid_x); + auto grid_y_t = EigenTensor::From(grid_y); + grid_x_t = grid_x_t.round(); + grid_y_t = grid_y_t.round(); + gatherOutputGradToInputGrad(*output_grad, input_grad, grid_x, grid_y); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_grid_sample_function.py b/python/paddle/fluid/tests/unittests/test_grid_sample_function.py new file mode 100644 index 0000000000000000000000000000000000000000..4a33f32a0b6977716d8065419f8e0f88d6c4f44a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_grid_sample_function.py @@ -0,0 +1,131 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle import fluid, nn +import paddle.fluid.dygraph as dg +import paddle.nn.functional as F +import unittest + + +class GridSampleTestCase(unittest.TestCase): + def __init__(self, + methodName='runTest', + x_shape=[2, 2, 3, 3], + grid_shape=[2, 3, 3, 2], + mode="bilinear", + padding_mode="zeros", + align_corners=False): + super(GridSampleTestCase, self).__init__(methodName) + self.padding_mode = padding_mode + self.x_shape = x_shape + self.grid_shape = grid_shape + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = "float64" + + def setUp(self): + self.x = np.random.randn(*(self.x_shape)).astype(self.dtype) + self.grid = np.random.uniform(-1, 1, self.grid_shape).astype(self.dtype) + + def static_functional(self, place): + main = fluid.Program() + start = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, start): + x = fluid.data("x", self.x_shape, dtype=self.dtype) + grid = fluid.data("grid", self.grid_shape, dtype=self.dtype) + y_var = F.grid_sample( + x, + grid, + mode=self.mode, + padding_mode=self.padding_mode, + align_corners=self.align_corners) + feed_dict = {"x": self.x, "grid": self.grid} + exe = fluid.Executor(place) + exe.run(start) + y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var]) + return y_np + + def dynamic_functional(self): + x_t = paddle.to_tensor(self.x) + grid_t = paddle.to_tensor(self.grid) + y_t = F.grid_sample( + x_t, + grid_t, + mode=self.mode, + padding_mode=self.padding_mode, + align_corners=self.align_corners) + y_np = y_t.numpy() + return y_np + + def _test_equivalence(self, place): + result1 = self.static_functional(place) + with dg.guard(place): + result2 = self.dynamic_functional() + np.testing.assert_array_almost_equal(result1, result2) + + def runTest(self): + place = fluid.CPUPlace() + self._test_equivalence(place) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self._test_equivalence(place) + + +class GridSampleErrorTestCase(GridSampleTestCase): + def runTest(self): + place = fluid.CPUPlace() + with self.assertRaises(ValueError): + self.static_functional(place) + + +def add_cases(suite): + suite.addTest(GridSampleTestCase(methodName='runTest')) + suite.addTest( + GridSampleTestCase( + methodName='runTest', + mode='bilinear', + padding_mode='reflect', + align_corners=True)) + suite.addTest( + GridSampleTestCase( + methodName='runTest', + mode='bilinear', + padding_mode='zeros', + align_corners=True)) + + +def add_error_cases(suite): + suite.addTest( + GridSampleErrorTestCase( + methodName='runTest', padding_mode="VALID")) + suite.addTest( + GridSampleErrorTestCase( + methodName='runTest', align_corners="VALID")) + suite.addTest(GridSampleErrorTestCase(methodName='runTest', mode="VALID")) + + +def load_tests(loader, standard_tests, pattern): + suite = unittest.TestSuite() + add_cases(suite) + add_error_cases(suite) + return suite + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py b/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py index bd5a07769e30de5110566f630de2d480e3426c77..4d1ed5aeb96ebbe064e35c1bee9d5775812440f7 100644 --- a/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py +++ b/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py @@ -17,17 +17,17 @@ import numpy as np from op_test import OpTest -def AffineGrid(theta, size): - n = size[0] - h = size[2] - w = size[3] +def AffineGrid(theta, grid_shape): + n = grid_shape[0] + h = grid_shape[1] + w = grid_shape[2] h_idx = np.repeat( np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis] w_idx = np.repeat( np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis] grid = np.concatenate( [w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3 - grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3 + grid = np.repeat(grid[np.newaxis, :], n, axis=0) # n * h * w *3 ret = np.zeros([n, h * w, 2]) theta = theta.transpose([0, 2, 1]) @@ -40,15 +40,19 @@ def AffineGrid(theta, size): def getGridPointValue(data, x, y): data_shape = data.shape N = data_shape[0] - H = data_shape[2] - W = data_shape[3] - - out = np.zeros(data_shape, dtype='float64') + C = data_shape[1] + in_H = data_shape[2] + in_W = data_shape[3] + out_H = x.shape[1] + out_W = x.shape[2] + + #out = np.zeros(data_shape, dtype='float64') + out = np.zeros([N, C, out_H, out_W], dtype='float64') for i in range(N): - for j in range(H): - for k in range(W): - if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[ - i, j, k] > W - 1: + for j in range(out_H): + for k in range(out_W): + if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[ + i, j, k] < 0 or x[i, j, k] > in_W - 1: out[i, :, j, k] = 0 else: out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]] @@ -56,44 +60,89 @@ def getGridPointValue(data, x, y): return out -def GridSampler(data, grid): - dims = data.shape - N = dims[0] - C = dims[1] - H = dims[2] - W = dims[3] +def clip(x, min_n, max_n): + return np.maximum(np.minimum(x, max_n), min_n) - x = grid[:, :, :, 0] - y = grid[:, :, :, 1] - y_max = H - 1 - x_max = W - 1 - x = 0.5 * ((x.astype('float64') + 1.0) * x_max) - y = 0.5 * ((y.astype('float64') + 1.0) * y_max) +def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode): + if align_corners: + grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) * max_val) + else: + grid_slice = 0.5 * ( + (grid_slice.astype('float64') + 1.0) * (max_val + 1)) - 0.5 + + if padding_mode == "border": + grid_slice = clip(grid_slice, 0, max_val) + elif padding_mode == "reflect": + double_range = 2 * max_val if align_corners else (max_val + 1) * 2 + grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice + + 0.5) + extra = grid_abs - np.floor(grid_abs / double_range) * double_range + grid_slice = np.minimum(extra, double_range - extra) + grid_slice = grid_slice if align_corners else clip(grid_slice - 0.5, 0, + max_val) + return grid_slice - x0 = np.floor(x).astype('int32') - x1 = x0 + 1 - y0 = np.floor(y).astype('int32') - y1 = y0 + 1 - wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1)) - wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1)) - wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1)) - wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1)) +def GridSampler(data, + grid, + align_corners=True, + mode="bilinear", + padding_mode="zeros"): + dims = data.shape + N = dims[0] + in_C = dims[1] + in_H = dims[2] + in_W = dims[3] - va = getGridPointValue(data, x0, y0) - vb = getGridPointValue(data, x0, y1) - vc = getGridPointValue(data, x1, y0) - vd = getGridPointValue(data, x1, y1) + out_H = grid.shape[1] + out_W = grid.shape[2] - out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float64') + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + y_max = in_H - 1 + x_max = in_W - 1 + + x = unnormalizeAndClip(x, x_max, align_corners, padding_mode) + y = unnormalizeAndClip(y, y_max, align_corners, padding_mode) + + if mode == "bilinear": + x0 = np.floor(x).astype('int32') + x1 = x0 + 1 + y0 = np.floor(y).astype('int32') + y1 = y0 + 1 + + wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + + va = getGridPointValue(data, x0, y0) + vb = getGridPointValue(data, x0, y1) + vc = getGridPointValue(data, x1, y0) + vd = getGridPointValue(data, x1, y1) + + out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float64') + elif mode == "nearest": + x = np.round(x).astype('int32') + y = np.round(y).astype('int32') + out = getGridPointValue(data, x, y) return out class TestGridSamplerOp(OpTest): def setUp(self): - self.initTestCase() + self.use_cudnn = False + self.numeric_grad_delta = 0.0001 self.op_type = 'grid_sampler' + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + self.initTestCase() x = np.random.randint(0, 255, self.x_shape).astype('float64') theta = np.zeros(self.theta_shape).astype('float64') @@ -101,22 +150,90 @@ class TestGridSamplerOp(OpTest): for j in range(2): for k in range(3): theta[i, j, k] = np.random.rand(1)[0] - grid = AffineGrid(theta, self.x_shape) + grid = AffineGrid(theta, self.grid_shape) self.inputs = {'X': x, 'Grid': grid} - self.attrs = {'use_cudnn': True} - self.outputs = {'Output': GridSampler(x, grid)} + self.attrs = { + 'use_cudnn': self.use_cudnn, + "align_corners": self.align_corners, + "padding_mode": self.padding_mode, + "mode": self.mode + } + # print("X: {}".format(x)) + self.outputs = { + 'Output': GridSampler(x, grid, self.align_corners, self.mode, + self.padding_mode) + } def test_check_output(self): self.check_output() def test_check_grad_normal(self): - self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61) + self.check_grad( + ['X', 'Grid'], + 'Output', + max_relative_error=0.01, + numeric_grad_delta=self.numeric_grad_delta) + + def initTestCase(self): + self.x_shape = (2, 3, 8, 8) + self.grid_shape = (2, 7, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + self.use_cudnn = True + + +class Case1(TestGridSamplerOp): + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "zeros" + self.mode = "bilinear" + + +class Case1(TestGridSamplerOp): + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "border" + self.mode = "bilinear" + + +class Case2(TestGridSamplerOp): + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "reflect" + self.mode = "bilinear" + + +class Case3(TestGridSamplerOp): + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "reflect" + self.mode = "bilinear" + +class Case4(TestGridSamplerOp): def initTestCase(self): - self.x_shape = (2, 5, 7, 3) - self.grid_shape = (2, 7, 3, 2) + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "reflect" + self.mode = "nearest" + self.numeric_grad_delta = 0.0001 if __name__ == "__main__": diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index ba3d80d40b0e079748e6b3e6a7ee0d5030c02e2b..a952cd587be839dda450610d87361b4729376313 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -192,7 +192,7 @@ from .vision import fsp_matrix #DEFINE_ALIAS from .vision import generate_mask_labels #DEFINE_ALIAS from .vision import generate_proposal_labels #DEFINE_ALIAS from .vision import generate_proposals #DEFINE_ALIAS -from .vision import grid_sampler #DEFINE_ALIAS +from .vision import grid_sample #DEFINE_ALIAS from .vision import image_resize #DEFINE_ALIAS from .vision import image_resize_short #DEFINE_ALIAS # from .vision import multi_box_head #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index a2cc8fde5ad7147b7af4765de834508f1f3cc825..23e45725a78299e7e67308400a8b9c1adbfebed7 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -28,7 +28,6 @@ from ...fluid.layers import distribute_fpn_proposals #DEFINE_ALIAS from ...fluid.layers import generate_mask_labels #DEFINE_ALIAS from ...fluid.layers import generate_proposal_labels #DEFINE_ALIAS from ...fluid.layers import generate_proposals #DEFINE_ALIAS -from ...fluid.layers import grid_sampler #DEFINE_ALIAS from ...fluid.layers import image_resize #DEFINE_ALIAS from ...fluid.layers import prior_box #DEFINE_ALIAS from ...fluid.layers import prroi_pool #DEFINE_ALIAS @@ -68,7 +67,7 @@ __all__ = [ 'generate_mask_labels', 'generate_proposal_labels', 'generate_proposals', - 'grid_sampler', + 'grid_sample', 'image_resize', 'image_resize_short', # 'multi_box_head', @@ -89,3 +88,187 @@ __all__ = [ 'yolo_box', 'yolov3_loss' ] + +from ...fluid.layer_helper import LayerHelper +from ...fluid.data_feeder import check_variable_and_dtype +from ...fluid import core, dygraph_utils +from ...fluid.framework import Variable, in_dygraph_mode +from ...device import get_cudnn_version +import numpy as np + + +def grid_sample(x, + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=True, + name=None): + """ + This operation samples input X by using bilinear interpolation or + nearest interpolation based on flow field grid, which is usually + generated by :code:`affine_grid` . The grid of shape [N, H, W, 2] + is the concatenation of (x, y) coordinates with shape [N, H, W] each, + where x is indexing the 4th dimension (in width dimension) of input + data x and y is indexing the 3rd dimension (in height dimension), + finally results is the bilinear interpolation or nearest value of 4 nearest corner + points. The output tensor shape will be [N, C, H, W]. + + .. code-block:: text + + Step 1: + Get (x, y) grid coordinates and scale to [0, H-1/W-1]. + + .. code-block:: text + + grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1) + grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1) + + Step 2: + Indices input data X with grid (x, y) in each [H, W] area, and bilinear + interpolate point value by 4 nearest points or nearest interpolate point value + by nearest point. + + wn ------- y_n ------- en + | | | + | d_n | + | | | + x_w --d_w-- grid--d_e-- x_e + | | | + | d_s | + | | | + ws ------- y_s ------- wn + + For bilinear interpolation: + + x_w = floor(x) // west side x coord + x_e = x_w + 1 // east side x coord + y_n = floor(y) // north side y coord + y_s = y_s + 1 // south side y coord + + d_w = grid_x - x_w // distance to west side + d_e = x_e - grid_x // distance to east side + d_n = grid_y - y_n // distance to north side + d_s = y_s - grid_y // distance to south side + + wn = X[:, :, y_n, x_w] // north-west point value + en = X[:, :, y_n, x_e] // north-east point value + ws = X[:, :, y_s, x_w] // south-east point value + es = X[:, :, y_s, x_w] // north-east point value + + output = wn * d_e * d_s + en * d_w * d_s + + ws * d_e * d_n + es * d_w * d_n + + Args: + x(Tensor): The input tensor, which is a 4-d tensor with shape + [N, C, H, W], N is the batch size, C is the channel + number, H and W is the feature height and width. + The data type is float32 or float64. + grid(Tensor): Input grid tensor of shape [N, grid_H, grid_W, 2]. The + data type is float32 or float64. + mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'. + Default: 'bilinear'. + padding_mode(str, optional) The padding method used when source index + is out of input images. It can be 'zeros', 'reflect' and 'border'. + Default: zeros. + align_corners(bool, optional): If `align_corners` is true, it will projects + -1 and 1 to the centers of the corner pixels. Otherwise, it will + projects -1 and 1 to the image edges. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: Tensor, The shape of output is [N, C, grid_H, grid_W] in which `grid_H` is the height of grid + and `grid_W` is the width of grid. The data type is same as input tensor. + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + # shape=[1, 1, 3, 3] + x = np.array([[[[-0.6, 0.8, -0.5], + [-0.5, 0.2, 1.2], + [ 1.4, 0.3, -0.2]]]]).astype("float64") + + # grid shape = [1, 3, 4, 2] + grid = np.array( + [[[[ 0.2, 0.3], + [-0.4, -0.3], + [-0.9, 0.3], + [-0.9, -0.6]], + [[ 0.4, 0.1], + [ 0.9, -0.8], + [ 0.4, 0.5], + [ 0.5, -0.2]], + [[ 0.1, -0.8], + [-0.3, -1. ], + [ 0.7, 0.4], + [ 0.2, 0.8]]]]).astype("float64") + + paddle.disable_static() + x = paddle.to_tensor(x) + grid = paddle.to_tensor(grid) + y_t = F.grid_sample( + x, + grid, + mode='bilinear', + padding_mode='border', + align_corners=True) + print(y_t.numpy()) + + # output shape = [1, 1, 3, 4] + # [[[[ 0.34 0.016 0.086 -0.448] + # [ 0.55 -0.076 0.35 0.59 ] + # [ 0.596 0.38 0.52 0.24 ]]]] + """ + helper = LayerHelper("grid_sample", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'grid_sampler') + check_variable_and_dtype(grid, 'grid', ['float32', 'float64'], + 'grid_sampler') + if not isinstance(x, Variable): + raise ValueError("The x should be a Variable") + if not isinstance(grid, Variable): + raise ValueError("The grid should be a Variable") + _modes = ['bilinear', 'nearest'] + _padding_modes = ['zeros', 'reflect', 'border'] + if mode not in _modes: + raise ValueError( + "The mode of grid sample function should be in {}, but got: {}". + format(_modes, mode)) + if padding_mode not in _padding_modes: + raise ValueError( + "The padding mode of grid sample function should be in {}, but got: {}". + format(_padding_modes, padding_mode)) + + if not isinstance(align_corners, bool): + raise ValueError("The align corners should be bool, but got: {}".format( + align_corners)) + + cudnn_version = get_cudnn_version() + use_cudnn = False + if (cudnn_version is not None + ) and align_corners and mode == 'bilinear' and padding_mode == 'zeros': + use_cudnn = True + ipts = {'X': x, 'Grid': grid} + attrs = { + 'mode': mode, + 'padding_mode': padding_mode, + 'align_corners': align_corners, + 'use_cudnn': use_cudnn + } + + if in_dygraph_mode(): + attrs = ('mode', mode, 'padding_mode', padding_mode, 'align_corners', + align_corners, 'use_cudnn', use_cudnn) + out = getattr(core.ops, 'grid_sampler')(x, grid, *attrs) + else: + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='grid_sampler', + inputs=ipts, + attrs=attrs, + outputs={'Output': out}) + return out