diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 04aa6a3e10f6e3f55f9845d1b4b6bd6aa762c016..6ee9582dacde372886075bb7c5619c6bc1b99c98 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/grid_sampler_op.h" #include #include + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" @@ -229,15 +229,6 @@ REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker, ops::GridSampleGradMaker); REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad); -REGISTER_OP_CPU_KERNEL( - grid_sampler, - ops::GridSampleOpKernel, - ops::GridSampleOpKernel); -REGISTER_OP_CPU_KERNEL( - grid_sampler_grad, - ops::GridSampleGradOpKernel, - ops::GridSampleGradOpKernel); - REGISTER_OP_VERSION(grid_sampler) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/grid_sampler_op.cu b/paddle/fluid/operators/grid_sampler_op.cu deleted file mode 100644 index a227a8e312765b4311314ea884f2c32443924fbc..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/grid_sampler_op.cu +++ /dev/null @@ -1,492 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/grid_sampler_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -static __forceinline__ __device__ bool in_bounds(int h, int w, int H, int W) { - return h >= 0 && h < H && w >= 0 && w < W; -} - -template -static __forceinline__ __device__ void atomic_add(T* data, int h, int w, int sH, - int sW, int H, int W, - T delta) { - if (in_bounds(h, w, H, W)) { - platform::CudaAtomicAdd(data + h * sH + w * sW, delta); - } -} - -template -static __forceinline__ __device__ T _unnormalize(T coord, int size, - bool align_corners) { - if (align_corners) { - return ((coord + 1.f) / 2) * (size - 1); - } else { - return ((coord + 1.f) * size - 1) / 2; - } -} - -template -static __forceinline__ __device__ T clip_indexes(T in, int max_value) { - return min(static_cast(max_value), max(in, static_cast(0))); -} - -template -static __forceinline__ __device__ T reflect_indexes(T in, int twice_low, - int twice_high) { - if (twice_low == twice_high) { - return static_cast(0); - } - T min = static_cast(twice_low) / 2; - T span = static_cast(twice_high - twice_low) / 2; - in = fabs(in - min); - T extra = fmod(in, span); - int flips = static_cast(floor(in / span)); - if (flips % 2 == 0) { - return extra + min; - } else { - return span - extra + min; - } -} - -template -static __forceinline__ __device__ T compute_positions(T coord, int size, - PaddingMode padding_mode, - bool align_corners) { - coord = _unnormalize(coord, size, align_corners); - if (padding_mode == PaddingMode::border) { - coord = clip_indexes(coord, size - 1); - } else if (padding_mode == PaddingMode::reflect) { - if (align_corners) { - coord = reflect_indexes(coord, 0, 2 * (size - 1)); - } else { - coord = reflect_indexes(coord, -1, 2 * size - 1); - } - coord = clip_indexes(coord, size - 1); - } - return coord; -} - -template -static __forceinline__ __device__ T _unnormalize_with_mask(T coord, int size, - bool align_corners, - T* grad_in) { - if (align_corners) { - *grad_in = static_cast(size - 1) / 2; - return ((coord + 1.f) / 2) * (size - 1); - } else { - *grad_in = static_cast(size) / 2; - return ((coord + 1.f) * size - 1) / 2; - } -} - -template -static __forceinline__ __device__ T clip_indexes_with_mask(T in, int clip_limit, - T* grad_in) { - if (in <= static_cast(0)) { - *grad_in = static_cast(0); - return static_cast(0); - } else { - T max = static_cast(clip_limit - 1); - if (in >= max) { - *grad_in = static_cast(0); - return max; - } else { - *grad_in = static_cast(1); - return in; - } - } -} - -template -static __forceinline__ __device__ T -reflect_indexes_with_mask(T in, int twice_low, int twice_high, T* grad_in) { - if (twice_low == twice_high) { - *grad_in = static_cast(0); - return static_cast(0); - } - int grad_in_mult_; - T min = static_cast(twice_low) / 2; - T span = static_cast(twice_high - twice_low) / 2; - in = in - min; - if (in < static_cast(0)) { - grad_in_mult_ = -1; - in = -in; - } else { - grad_in_mult_ = 1; - } - T extra = fmod(in, span); - int flips = static_cast(floor(in / span)); - if (flips % 2 == 0) { - *grad_in = static_cast(grad_in_mult_); - return extra + min; - } else { - *grad_in = static_cast(-grad_in_mult_); - return span - extra + min; - } -} - -template -static __forceinline__ __device__ T -compute_positions_with_mask(T coord, int size, PaddingMode padding_mode, - bool align_corners, T* grad_in) { - T grad_clip, grad_refl; - coord = _unnormalize_with_mask(coord, size, align_corners, grad_in); - if (padding_mode == PaddingMode::border) { - coord = clip_indexes_with_mask(coord, size, &grad_clip); - *grad_in = (*grad_in) * grad_clip; - } else if (padding_mode == PaddingMode::reflect) { - if (align_corners) { - coord = reflect_indexes_with_mask(coord, 0, 2 * (size - 1), &grad_refl); - } else { - coord = reflect_indexes_with_mask(coord, -1, 2 * size - 1, &grad_refl); - } - coord = clip_indexes_with_mask(coord, size, &grad_clip); - *grad_in = (*grad_in) * grad_refl * grad_clip; - } - - return coord; -} - -template -__global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c, - int out_h, int out_w, int in_h, - int in_w, const T* input, const T* grid, - T* output, const Mode mode, - const PaddingMode padding_mode, - bool align_corners) { - int inp_sN = out_c * in_h * in_w; - - int inp_sC = in_h * in_w; - int inp_sH = in_w; - int inp_sW = 1; - int grid_sN = out_h * out_w * 2; - int grid_sH = out_w * 2; - int grid_sW = 2; - int grid_sCoor = 1; - int out_sN = out_c * out_h * out_w; - int out_sC = out_h * out_w; - int out_sH = out_w; - int out_sW = 1; - CUDA_KERNEL_LOOP(index, nthreads) { - const int w = index % out_w; - const int h = (index / out_w) % out_h; - const int n = index / (out_h * out_w); - const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; - - T ix = grid[grid_offset]; - T iy = grid[grid_offset + grid_sCoor]; - - ix = compute_positions(ix, in_w, padding_mode, align_corners); - iy = compute_positions(iy, in_h, padding_mode, align_corners); - if (mode == Mode::bilinear) { - int ix_nw = static_cast(floor(ix)); - int iy_nw = static_cast(floor(iy)); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; - - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); - - auto inp_offset_NC = n * inp_sN; - - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; - for (int c = 0; c < out_c; - ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { - *out_ptr_NCHW = static_cast(0); - if (in_bounds(iy_nw, ix_nw, in_h, in_w)) { - *out_ptr_NCHW += - input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw; - } - if (in_bounds(iy_ne, ix_ne, in_h, in_w)) { - *out_ptr_NCHW += - input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne; - } - if (in_bounds(iy_sw, ix_sw, in_h, in_w)) { - *out_ptr_NCHW += - input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw; - } - if (in_bounds(iy_se, ix_se, in_h, in_w)) { - *out_ptr_NCHW += - input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se; - } - } - } else if (mode == Mode::nearest) { - int ix_nearest = static_cast(std::nearbyint(ix)); - int iy_nearest = static_cast(std::nearbyint(iy)); - auto inp_offset_NC = n * inp_sN; - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; - for (int c = 0; c < out_c; - ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { - if (in_bounds(iy_nearest, ix_nearest, in_h, in_w)) { - *out_ptr_NCHW = - input[inp_offset_NC + iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCHW = static_cast(0); - } - } - } - } -} - -template -class GridSampleOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.cuda_device_context(); - auto align_corners = ctx.Attr("align_corners"); - auto padding_mode_s = ctx.Attr("padding_mode"); - auto mode_s = ctx.Attr("mode"); - PaddingMode padding_mode; - Mode mode; - if (padding_mode_s == "border") { - padding_mode = PaddingMode::border; - } else if (padding_mode_s == "reflection") { - padding_mode = PaddingMode::reflect; - } else { - padding_mode = PaddingMode::zeros; - } - - if (mode_s == "nearest") { - mode = Mode::nearest; - } else { - mode = Mode::bilinear; - } - - auto* input = ctx.Input("X"); - auto* grid = ctx.Input("Grid"); - const int n = grid->dims()[0]; - const int out_h = grid->dims()[1]; - const int out_w = grid->dims()[2]; - const int c = input->dims()[1]; - const int in_h = input->dims()[2]; - const int in_w = input->dims()[3]; - VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h - << "; out_w: " << out_w; - auto* output = ctx.Output("Output"); - auto* output_data = output->mutable_data(ctx.GetPlace()); - VLOG(3) << "out dims: " << output->dims()[0] << "; " << output->dims()[1] - << "; " << output->dims()[2] << "; " << output->dims()[3]; - int count = static_cast(n * out_h * out_w); - auto cu_stream = dev_ctx.stream(); - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(dev_ctx, count); - grid_sample_cuda_kernel< - T><<>>( - count, n, c, out_h, out_w, in_h, in_w, input->data(), - grid->data(), output_data, mode, padding_mode, align_corners); - } -}; - -template -__global__ void grid_sampler_cuda_backward_kernel( - const int nthreads, const T* grad_output, const T* input, const T* grid, - int n, int out_c, int out_h, int out_w, int in_h, int in_w, T* grad_input, - T* grad_grid, const Mode mode, const PaddingMode padding_mode, - bool align_corners) { - int inp_sN = out_c * in_h * in_w; - int inp_sC = in_h * in_w; - int inp_sH = in_w; - int inp_sW = 1; - int grid_sN = out_h * out_w * 2; - int grid_sH = out_w * 2; - int grid_sW = 2; - int grid_sCoor = 1; - - int gOut_sN = out_c * out_h * out_w; - int gOut_sC = out_h * out_w; - int gOut_sH = out_w; - int gOut_sW = 1; - - CUDA_KERNEL_LOOP(index, nthreads) { - const int w = index % out_w; - const int h = (index / out_w) % out_h; - const int n = index / (out_h * out_w); - const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; - - T ix = grid[grid_offset]; - T iy = grid[grid_offset + grid_sCoor]; - - T gix_mult, giy_mult; - ix = compute_positions_with_mask(ix, in_w, padding_mode, align_corners, - &gix_mult); - iy = compute_positions_with_mask(iy, in_h, padding_mode, align_corners, - &giy_mult); - - if (mode == Mode::bilinear) { - int ix_nw = static_cast(floor(ix)); - int iy_nw = static_cast(floor(iy)); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; - - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); - - T gix = static_cast(0), giy = static_cast(0); - int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW; - T* gInp_ptr_NC = grad_input + n * inp_sN; - int inp_offset_NC = n * inp_sN; - for (int c = 0; c < out_c; ++c, inp_offset_NC += inp_sC, - gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) { - T gOut = grad_output[gOut_offset]; - - atomic_add(gInp_ptr_NC, iy_nw, ix_nw, inp_sH, inp_sW, in_h, in_w, - nw * gOut); - atomic_add(gInp_ptr_NC, iy_ne, ix_ne, inp_sH, inp_sW, in_h, in_w, - ne * gOut); - atomic_add(gInp_ptr_NC, iy_sw, ix_sw, inp_sH, inp_sW, in_h, in_w, - sw * gOut); - atomic_add(gInp_ptr_NC, iy_se, ix_se, inp_sH, inp_sW, in_h, in_w, - se * gOut); - - if (in_bounds(iy_nw, ix_nw, in_h, in_w)) { - T nw_val = input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW]; - gix -= nw_val * (iy_se - iy) * gOut; - giy -= nw_val * (ix_se - ix) * gOut; - } - if (in_bounds(iy_ne, ix_ne, in_h, in_w)) { - T ne_val = input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW]; - gix += ne_val * (iy_sw - iy) * gOut; - giy -= ne_val * (ix - ix_sw) * gOut; - } - if (in_bounds(iy_sw, ix_sw, in_h, in_w)) { - T sw_val = input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW]; - gix -= sw_val * (iy - iy_ne) * gOut; - giy += sw_val * (ix_ne - ix) * gOut; - } - if (in_bounds(iy_se, ix_se, in_h, in_w)) { - T se_val = input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW]; - gix += se_val * (iy - iy_nw) * gOut; - giy += se_val * (ix - ix_nw) * gOut; - } - } - - if (grad_grid != nullptr) { - T* gGrid_ptr_NHW = grad_grid + index * grid_sW; - gGrid_ptr_NHW[0] = gix_mult * gix; - gGrid_ptr_NHW[1] = giy_mult * giy; - } - } else if (mode == Mode::nearest) { - int ix_nearest = static_cast(std::nearbyint(ix)); - int iy_nearest = static_cast(std::nearbyint(iy)); - - int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW; - T* gInp_ptr_NC = grad_input + n * inp_sN; - for (int c = 0; c < out_c; - ++c, gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) { - atomic_add(gInp_ptr_NC, iy_nearest, ix_nearest, inp_sH, inp_sW, in_h, - in_w, grad_output[gOut_offset]); - } - - if (grad_grid != nullptr) { - T* gGrid_ptr_NHW = grad_grid + index * grid_sW; - gGrid_ptr_NHW[0] = static_cast(0); - gGrid_ptr_NHW[1] = static_cast(0); - } - } - } -} - -template -class GridSampleGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.cuda_device_context(); - auto align_corners = ctx.Attr("align_corners"); - auto padding_mode_s = ctx.Attr("padding_mode"); - auto mode_s = ctx.Attr("mode"); - - PaddingMode padding_mode; - Mode mode; - if (padding_mode_s == "border") { - padding_mode = PaddingMode::border; - } else if (padding_mode_s == "reflection") { - padding_mode = PaddingMode::reflect; - } else { - padding_mode = PaddingMode::zeros; - } - - if (mode_s == "nearest") { - mode = Mode::nearest; - } else { - mode = Mode::bilinear; - } - - auto* input = ctx.Input("X"); - auto* grid = ctx.Input("Grid"); - auto* output_grad = ctx.Input(framework::GradVarName("Output")); - - const int n = grid->dims()[0]; - const int out_h = grid->dims()[1]; - const int out_w = grid->dims()[2]; - const int c = input->dims()[1]; - const int in_h = input->dims()[2]; - const int in_w = input->dims()[3]; - - auto* input_grad = ctx.Output(framework::GradVarName("X")); - input_grad->mutable_data(ctx.GetPlace()); - phi::funcs::SetConstant()( - ctx.template device_context(), - input_grad, static_cast(0)); - - T* grid_grad_data = nullptr; - if (ctx.HasOutput(framework::GradVarName("Grid"))) { - auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); - grid_grad_data = grid_grad->mutable_data(ctx.GetPlace()); - } - - int count = static_cast(n * out_h * out_w); - auto cu_stream = dev_ctx.stream(); - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(dev_ctx, count); - grid_sampler_cuda_backward_kernel< - T><<>>( - count, output_grad->data(), input->data(), grid->data(), n, c, - out_h, out_w, in_h, in_w, input_grad->data(), grid_grad_data, mode, - padding_mode, align_corners); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(grid_sampler, ops::GridSampleOpCUDAKernel, - ops::GridSampleOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(grid_sampler_grad, - ops::GridSampleGradOpCUDAKernel, - ops::GridSampleGradOpCUDAKernel); diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h deleted file mode 100644 index 93e96694270a458844bbcabf78f2559975902c2f..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/grid_sampler_op.h +++ /dev/null @@ -1,600 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/hostdevice.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -enum class Mode { - bilinear, - nearest, -}; - -enum class PaddingMode { zeros, border, reflect }; - -using Tensor = framework::Tensor; -template -using EigenTensor = framework::EigenTensor; - -using Array3 = Eigen::DSizes; -using Array4 = Eigen::DSizes; - -template -static inline bool isInBound(T x, T y, T x_max, T y_max) { - if (x < 0 || x > x_max || y < 0 || y > y_max) { - return false; - } - return true; -} - -template -static inline void unnormalize(const platform::CPUDeviceContext& ctx, - Tensor* grid_slice, - const int max_val, // height-1 or width-1 - bool align_corners) { - auto& place = *ctx.eigen_device(); - auto grid_slice_t = EigenTensor::From(*grid_slice); - - if (!align_corners) { - auto factor = static_cast((max_val + 1) * 0.5); - grid_slice_t.device(place) = - (grid_slice_t + static_cast(1)) * factor - static_cast(0.5); - } else { - auto factor = static_cast(max_val * 0.5); - grid_slice_t.device(place) = (grid_slice_t + static_cast(1)) * factor; - } -} - -template -static inline void clip(const platform::CPUDeviceContext& ctx, - Tensor* grid_slice, - const int max_val, // height-1 or width-1 - bool align_corners, std::string padding_mode) { - auto& place = *ctx.eigen_device(); - auto grid_slice_t = EigenTensor::From(*grid_slice); - if (padding_mode == "border") { - grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast(0)) - .cwiseMin(static_cast(max_val)); - } else if (padding_mode == "reflection") { - if (align_corners) { - auto double_range = static_cast(max_val * 2); - auto grid_abs = grid_slice_t.abs(); - auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; - grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); - if (max_val == 0) { - grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); - } - } else { - auto double_range = static_cast((max_val + 1) * 2); - auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); - auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; - grid_slice_t.device(place) = - extra.cwiseMin(double_range - extra) - static_cast(0.5); - grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast(0)) - .cwiseMin(static_cast(max_val)); - } - } -} - -template -static inline void clipWithMask(const platform::CPUDeviceContext& ctx, - const int max_val, // height-1 or width-1 - bool align_corners, std::string padding_mode, - Tensor* grid_slice, Tensor* grid_scale) { - auto& place = *ctx.eigen_device(); - grid_scale->mutable_data(grid_slice->dims(), ctx.GetPlace()); - - auto grid_slice_t = EigenTensor::From(*grid_slice); - auto factor = static_cast(max_val * 0.5); - if (!align_corners) { - factor = static_cast((max_val + 1) * 0.5); - } - auto grid_scale_t = EigenTensor::From(*grid_scale).setConstant(factor); - - if (padding_mode == "border") { - // auto bounded_lo = grid_slice_t.cwiseMax(static_cast(0)); - auto res = grid_slice_t.cwiseMax(static_cast(0)) - .cwiseMin(static_cast(max_val)); - - auto in_bound = (res == grid_slice_t); - grid_scale_t.device(place) = grid_scale_t * in_bound.template cast(); - grid_slice_t.device(place) = res; - } else if (padding_mode == "reflection") { - if (align_corners) { - auto double_range = static_cast(max_val * 2); - auto is_neg = (grid_slice_t < static_cast(0)); - auto grid_abs = grid_slice_t.abs(); - auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; - auto one_more_flip = (extra > (double_range - extra)); - grid_scale_t.device(place) = - grid_scale_t * ((is_neg == one_more_flip).template cast() - - (is_neg != one_more_flip).template cast()); - grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); - if (max_val == 0) { - grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); - } - } else { - auto double_range = static_cast((max_val + 1) * 2); - auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); - auto is_neg = ((grid_slice_t + static_cast(0.5)) < static_cast(0)); - auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; - auto one_more_flip = (extra > (double_range - extra)); - auto reflected = - extra.cwiseMin(double_range - extra) - static_cast(0.5); - auto clipped = reflected.cwiseMax(static_cast(0)) - .cwiseMin(static_cast(max_val)); - auto in_bound = (clipped == reflected).template cast(); - grid_scale_t.device(place) = - grid_scale_t * ((is_neg == one_more_flip).template cast() - - (is_neg != one_more_flip).template cast()) * - in_bound; - grid_slice_t.device(place) = clipped; - } - } -} - -template -static void calcGridLocations(const platform::CPUDeviceContext& ctx, - const Tensor& grid, const int in_h, - const int in_w, bool align_corners, - std::string padding_mode, Tensor* grid_x, - Tensor* grid_y) { - const int n = grid.dims()[0]; - const int out_h = grid.dims()[1]; - const int out_w = grid.dims()[2]; - - // split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim - T* grid_x_data = grid_x->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - T* grid_y_data = grid_y->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - const T* grid_data = grid.data(); - for (int i = 0; i < n * out_h * out_w; i++) { - grid_x_data[i] = grid_data[2 * i]; - grid_y_data[i] = grid_data[(2 * i) + 1]; - } - - unnormalize(ctx, grid_x, in_w - 1, align_corners); - unnormalize(ctx, grid_y, in_h - 1, align_corners); - - clip(ctx, grid_x, in_w - 1, align_corners, padding_mode); - clip(ctx, grid_y, in_h - 1, align_corners, padding_mode); -} - -template -static void calcGridLocationsWithGrad(const platform::CPUDeviceContext& ctx, - const Tensor& grid, const int in_h, - const int in_w, bool align_corners, - std::string padding_mode, Tensor* grid_x, - Tensor* grid_y, Tensor* grid_x_scale, - Tensor* grid_y_scale) { - const int n = grid.dims()[0]; - const int out_h = grid.dims()[1]; - const int out_w = grid.dims()[2]; - - // split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim - T* grid_x_data = grid_x->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - T* grid_y_data = grid_y->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - - const T* grid_data = grid.data(); - for (int i = 0; i < n * out_h * out_w; i++) { - grid_x_data[i] = grid_data[2 * i]; - grid_y_data[i] = grid_data[(2 * i) + 1]; - } - - unnormalize(ctx, grid_x, in_w - 1, align_corners); - unnormalize(ctx, grid_y, in_h - 1, align_corners); - - clipWithMask(ctx, in_w - 1, align_corners, padding_mode, grid_x, - grid_x_scale); - clipWithMask(ctx, in_h - 1, align_corners, padding_mode, grid_y, - grid_y_scale); -} - -template -static void getGridPointValue(const Tensor& input, Tensor* output, - const Tensor& x, const Tensor& y) { - const int n = input.dims()[0]; - const int c = input.dims()[1]; - const int in_h = input.dims()[2]; - const int in_w = input.dims()[3]; - const int out_h = x.dims()[1]; - const int out_w = x.dims()[2]; - auto x_t = EigenTensor::From(x); - auto y_t = EigenTensor::From(y); - auto output_t = EigenTensor::From(*output).setConstant((T)0); - auto input_t = EigenTensor::From(input); - - for (int i = 0; i < n; i++) { - for (int k = 0; k < out_h; k++) { - for (int l = 0; l < out_w; l++) { - if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), - (T)(in_h - 1))) { - for (int j = 0; j < c; j++) { - output_t(i, j, k, l) = - input_t(i, j, static_cast(round(y_t(i, k, l))), - static_cast(round(x_t(i, k, l)))); - } - } - } - } - } -} - -template -static void allNeigbors(const platform::CPUDeviceContext& ctx, - const Tensor& input, Tensor* grid_x, Tensor* grid_y, - Tensor* x_w, Tensor* x_e, Tensor* y_n, - Tensor* y_s, // positions - Tensor* d_w, Tensor* d_e, Tensor* d_n, - Tensor* d_s, // distance - Tensor* v_wn, Tensor* v_en, Tensor* v_ws, - Tensor* v_es) { // values - auto& place = *ctx.eigen_device(); - - const int c = input.dims()[1]; - const int n = grid_x->dims()[0]; - const int out_h = grid_x->dims()[1]; - const int out_w = grid_x->dims()[2]; - // calculate coords of 4 corner points - x_w->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - x_e->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - y_n->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - y_s->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - auto x_w_t = EigenTensor::From(*x_w); - auto x_e_t = EigenTensor::From(*x_e); - auto y_n_t = EigenTensor::From(*y_n); - auto y_s_t = EigenTensor::From(*y_s); - - auto grid_x_t = EigenTensor::From(*grid_x); - auto grid_y_t = EigenTensor::From(*grid_y); - - x_w_t.device(place) = grid_x_t.floor(); - x_e_t.device(place) = x_w_t + static_cast(1); - y_n_t.device(place) = grid_y_t.floor(); - y_s_t.device(place) = y_n_t + static_cast(1); - - // calculate distances to 4 sides - d_w->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - d_e->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - d_n->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - d_s->mutable_data({n, out_h, out_w}, ctx.GetPlace()); - auto d_w_t = EigenTensor::From(*d_w); - auto d_e_t = EigenTensor::From(*d_e); - auto d_n_t = EigenTensor::From(*d_n); - auto d_s_t = EigenTensor::From(*d_s); - d_w_t.device(place) = grid_x_t - x_w_t; - d_e_t.device(place) = x_e_t - grid_x_t; - d_n_t.device(place) = grid_y_t - y_n_t; - d_s_t.device(place) = y_s_t - grid_y_t; - - // calc 4 corner points value - v_wn->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - v_en->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - v_ws->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - v_es->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - getGridPointValue(input, v_wn, *x_w, *y_n); - getGridPointValue(input, v_en, *x_e, *y_n); - getGridPointValue(input, v_ws, *x_w, *y_s); - getGridPointValue(input, v_es, *x_e, *y_s); -} - -template -static void bilinearInter(const platform::CPUDeviceContext& ctx, - const Tensor& input, Tensor* grid_x, Tensor* grid_y, - Tensor* out) { - auto& place = *ctx.eigen_device(); - const int n = grid_x->dims()[0]; - const int out_h = grid_x->dims()[1]; - const int out_w = grid_x->dims()[2]; - const int c = input.dims()[1]; - - Tensor x_w, x_e, y_n, y_s; - Tensor d_w, d_e, d_n, d_s; - Tensor v_wn, v_en, v_ws, v_es; - - allNeigbors(ctx, input, grid_x, grid_y, &x_w, &x_e, &y_n, &y_s, &d_w, &d_e, - &d_n, &d_s, &v_wn, &v_en, &v_ws, &v_es); - - auto d_w_t = EigenTensor::From(d_w); - auto d_e_t = EigenTensor::From(d_e); - auto d_n_t = EigenTensor::From(d_n); - auto d_s_t = EigenTensor::From(d_s); - - auto d_w_scaled_t = - d_w_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); - auto d_e_scaled_t = - d_e_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); - auto d_n_scaled_t = - d_n_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); - auto d_s_scaled_t = - d_s_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); - auto v_wn_t = EigenTensor::From(v_wn); - auto v_en_t = EigenTensor::From(v_en); - auto v_ws_t = EigenTensor::From(v_ws); - auto v_es_t = EigenTensor::From(v_es); - auto output_t = EigenTensor::From(*out); - // bilinear interpolaetion by 4 corner points - output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t + - v_en_t * d_w_scaled_t * d_s_scaled_t + - v_ws_t * d_e_scaled_t * d_n_scaled_t + - v_es_t * d_w_scaled_t * d_n_scaled_t; -} - -template -static void nearestInter(const platform::CPUDeviceContext& ctx, - const Tensor& input, Tensor* grid_x, Tensor* grid_y, - Tensor* out) { - auto& place = *ctx.eigen_device(); - - auto grid_x_t = EigenTensor::From(*grid_x); - auto grid_y_t = EigenTensor::From(*grid_y); - grid_x_t = grid_x_t.round(); - grid_y_t = grid_y_t.round(); - getGridPointValue(input, out, *grid_x, *grid_y); -} - -template -static void gatherOutputGradToInputGrad(const Tensor& output_grad, - Tensor* input_grad, const Tensor& x, - const Tensor& y, const Tensor& d1, - const Tensor& d2) { - const int n = output_grad.dims()[0]; - const int c = output_grad.dims()[1]; - const int out_h = output_grad.dims()[2]; - const int out_w = output_grad.dims()[3]; - const int in_h = input_grad->dims()[2]; - const int in_w = input_grad->dims()[3]; - auto x_t = EigenTensor::From(x); - auto y_t = EigenTensor::From(y); - auto d1_t = EigenTensor::From(d1); - auto d2_t = EigenTensor::From(d2); - auto input_grad_t = EigenTensor::From(*input_grad); - auto output_grad_t = EigenTensor::From(output_grad); - - for (int i = 0; i < n; i++) { - for (int k = 0; k < out_h; k++) { - for (int l = 0; l < out_w; l++) { - if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), - (T)(in_h - 1))) { - for (int j = 0; j < c; j++) { - input_grad_t(i, j, static_cast(round(y_t(i, k, l))), - static_cast(round(x_t(i, k, l)))) += - output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l); - } - } - } - } - } -} - -template -static void gatherOutputGradToInputGrad(const Tensor& output_grad, - Tensor* input_grad, const Tensor& x, - const Tensor& y) { - const int n = output_grad.dims()[0]; - const int c = output_grad.dims()[1]; - const int out_h = output_grad.dims()[2]; - const int out_w = output_grad.dims()[3]; - const int in_h = input_grad->dims()[2]; - const int in_w = input_grad->dims()[3]; - auto x_t = EigenTensor::From(x); - auto y_t = EigenTensor::From(y); - auto input_grad_t = EigenTensor::From(*input_grad); - auto output_grad_t = EigenTensor::From(output_grad); - for (int i = 0; i < n; i++) { - for (int k = 0; k < out_h; k++) { - for (int l = 0; l < out_w; l++) { - if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), - (T)(in_h - 1))) { - for (int j = 0; j < c; j++) { - input_grad_t(i, j, static_cast(round(y_t(i, k, l))), - static_cast(round(x_t(i, k, l)))) += - output_grad_t(i, j, k, l); - } - } - } - } - } -} - -template -static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx, - const Tensor& input, const Tensor& output_grad, - Tensor* grid_x, Tensor* grid_y, - Tensor* grid_x_scale, Tensor* grid_y_scale, - Tensor* input_grad, Tensor* grid_grad) { - const int n = grid_x->dims()[0]; - const int out_h = grid_x->dims()[1]; - const int out_w = grid_x->dims()[2]; - const int c = input.dims()[1]; - - Tensor x_w, x_e, y_n, y_s; - Tensor d_w, d_e, d_n, d_s; - Tensor v_wn, v_en, v_ws, v_es; - - allNeigbors(ctx, input, - grid_x, // grid_x - grid_y, // grid_y - &x_w, &x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s, &v_wn, &v_en, - &v_ws, &v_es); - - // gather output grad value to input grad by corner point coords and weight - gatherOutputGradToInputGrad(output_grad, input_grad, x_w, y_n, d_e, d_s); - gatherOutputGradToInputGrad(output_grad, input_grad, x_w, y_s, d_e, d_n); - gatherOutputGradToInputGrad(output_grad, input_grad, x_e, y_n, d_w, d_s); - gatherOutputGradToInputGrad(output_grad, input_grad, x_e, y_s, d_w, d_n); - - auto v_wn_t = EigenTensor::From(v_wn); - auto v_en_t = EigenTensor::From(v_en); - auto v_ws_t = EigenTensor::From(v_ws); - auto v_es_t = EigenTensor::From(v_es); - - auto d_w_t = EigenTensor::From(d_w); - auto d_e_t = EigenTensor::From(d_e); - auto d_n_t = EigenTensor::From(d_n); - auto d_s_t = EigenTensor::From(d_s); - - auto output_grad_t = EigenTensor::From(output_grad); - - if (grid_grad != nullptr) { - Tensor grid_grad_x, grid_grad_y; - grid_grad_x.mutable_data({n, out_h, out_w}, ctx.GetPlace()); - grid_grad_y.mutable_data({n, out_h, out_w}, ctx.GetPlace()); - auto grid_grad_x_t = - EigenTensor::From(grid_grad_x).setConstant(static_cast(0.0)); - auto grid_grad_y_t = - EigenTensor::From(grid_grad_y).setConstant(static_cast(0.0)); - for (int i = 0; i < n; i++) { - for (int j = 0; j < c; j++) { - for (int k = 0; k < out_h; k++) { - for (int l = 0; l < out_w; l++) { - grid_grad_x_t(i, k, l) += - ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + - (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) * - output_grad_t(i, j, k, l); - grid_grad_y_t(i, k, l) += - ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + - (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) * - output_grad_t(i, j, k, l); - } - } - } - } - - // const T x_max = static_cast(in_w - 1); - // const T y_max = static_cast(in_h - 1); - - auto grid_x_scale_t = EigenTensor::From(*grid_x_scale); - auto grid_y_scale_t = EigenTensor::From(*grid_y_scale); - grid_grad_x_t = grid_grad_x_t * grid_x_scale_t; - grid_grad_y_t = grid_grad_y_t * grid_y_scale_t; - - // gather grid_grad [x, y] in 3rd Dim - T* grid_grad_data = grid_grad->data(); - T* grid_grad_x_data = grid_grad_x.data(); - T* grid_grad_y_data = grid_grad_y.data(); - for (int i = 0; i < n * out_h * out_w; i++) { - grid_grad_data[2 * i] = grid_grad_x_data[i]; - grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; - } - } -} - -template -class GridSampleOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto align_corners = ctx.Attr("align_corners"); - auto padding_mode = ctx.Attr("padding_mode"); - auto mode = ctx.Attr("mode"); - - auto* input = ctx.Input("X"); - auto* grid = ctx.Input("Grid"); - - const int n = grid->dims()[0]; - const int out_h = grid->dims()[1]; - const int out_w = grid->dims()[2]; - const int c = input->dims()[1]; - const int in_h = input->dims()[2]; - const int in_w = input->dims()[3]; - - auto* output = ctx.Output("Output"); - output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - phi::funcs::SetConstant()( - ctx.template device_context(), output, - static_cast(0)); - - Tensor grid_x, grid_y; - calcGridLocations( - ctx.template device_context(), *grid, in_h, - in_w, align_corners, padding_mode, &grid_x, &grid_y); - if (mode == "bilinear") { - bilinearInter( - ctx.template device_context(), *input, - &grid_x, &grid_y, output); - } else if (mode == "nearest") { - auto grid_x_t = EigenTensor::From(grid_x); - auto grid_y_t = EigenTensor::From(grid_y); - grid_x_t = grid_x_t.round(); - grid_y_t = grid_y_t.round(); - getGridPointValue(*input, output, grid_x, grid_y); - } - } -}; - -template -class GridSampleGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto align_corners = ctx.Attr("align_corners"); - auto padding_mode = ctx.Attr("padding_mode"); - auto mode = ctx.Attr("mode"); - - auto* input = ctx.Input("X"); - auto* grid = ctx.Input("Grid"); - auto* output_grad = ctx.Input(framework::GradVarName("Output")); - - const int n = grid->dims()[0]; - const int out_h = grid->dims()[1]; - const int out_w = grid->dims()[2]; - const int c = input->dims()[1]; - const int in_h = input->dims()[2]; - const int in_w = input->dims()[3]; - - auto* input_grad = ctx.Output(framework::GradVarName("X")); - input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); - phi::funcs::SetConstant()( - ctx.template device_context(), input_grad, - static_cast(0)); - - Tensor* grid_grad = nullptr; - if (ctx.HasOutput(framework::GradVarName("Grid"))) { - grid_grad = ctx.Output(framework::GradVarName("Grid")); - grid_grad->mutable_data({n, out_h, out_w, 2}, ctx.GetPlace()); - phi::funcs::SetConstant()( - ctx.template device_context(), grid_grad, - static_cast(0)); - } - - Tensor grid_x, grid_y; - Tensor grid_x_scale, grid_y_scale; - calcGridLocationsWithGrad( - ctx.template device_context(), *grid, in_h, - in_w, align_corners, padding_mode, &grid_x, &grid_y, &grid_x_scale, - &grid_y_scale); - if (mode == "bilinear") { - gatherBilinearGrad(ctx.template device_context(), - *input, *output_grad, &grid_x, &grid_y, - &grid_x_scale, &grid_y_scale, input_grad, - grid_grad); - } else { - auto grid_x_t = EigenTensor::From(grid_x); - auto grid_y_t = EigenTensor::From(grid_y); - grid_x_t = grid_x_t.round(); - grid_y_t = grid_y_t.round(); - gatherOutputGradToInputGrad(*output_grad, input_grad, grid_x, grid_y); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/grid_sample_grad_kernel.cc b/paddle/phi/kernels/cpu/grid_sample_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..923cb8424115e00f07274f959ffe34adaa9a0327 --- /dev/null +++ b/paddle/phi/kernels/cpu/grid_sample_grad_kernel.cc @@ -0,0 +1,357 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/grid_sample_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/grid_sample_utils.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +static inline void ClipWithMask(const CPUContext& ctx, + const int max_val, // height-1 or width-1 + bool align_corners, + std::string padding_mode, + DenseTensor* grid_slice, + DenseTensor* grid_scale) { + auto& place = *ctx.eigen_device(); + grid_scale->Resize(grid_slice->dims()); + ctx.Alloc(grid_scale); + + auto grid_slice_t = EigenTensor::From(*grid_slice); + auto factor = static_cast(max_val * 0.5); + if (!align_corners) { + factor = static_cast((max_val + 1) * 0.5); + } + auto grid_scale_t = EigenTensor::From(*grid_scale).setConstant(factor); + + if (padding_mode == "border") { + // auto bounded_lo = grid_slice_t.cwiseMax(static_cast(0)); + auto res = grid_slice_t.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(max_val)); + + auto in_bound = (res == grid_slice_t); + grid_scale_t.device(place) = grid_scale_t * in_bound.template cast(); + grid_slice_t.device(place) = res; + } else if (padding_mode == "reflection") { + if (align_corners) { + auto double_range = static_cast(max_val * 2); + auto is_neg = (grid_slice_t < static_cast(0)); + auto grid_abs = grid_slice_t.abs(); + auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; + auto one_more_flip = (extra > (double_range - extra)); + grid_scale_t.device(place) = + grid_scale_t * ((is_neg == one_more_flip).template cast() - + (is_neg != one_more_flip).template cast()); + grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + if (max_val == 0) { + grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); + } + } else { + auto double_range = static_cast((max_val + 1) * 2); + auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); + auto is_neg = ((grid_slice_t + static_cast(0.5)) < static_cast(0)); + auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; + auto one_more_flip = (extra > (double_range - extra)); + auto reflected = + extra.cwiseMin(double_range - extra) - static_cast(0.5); + auto clipped = reflected.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(max_val)); + auto in_bound = (clipped == reflected).template cast(); + grid_scale_t.device(place) = + grid_scale_t * ((is_neg == one_more_flip).template cast() - + (is_neg != one_more_flip).template cast()) * + in_bound; + grid_slice_t.device(place) = clipped; + } + } +} + +template +static void CalcGridLocationsWithGrad(const CPUContext& ctx, + const DenseTensor& grid, + const int in_h, + const int in_w, + bool align_corners, + std::string padding_mode, + DenseTensor* grid_x, + DenseTensor* grid_y, + DenseTensor* grid_x_scale, + DenseTensor* grid_y_scale) { + const int n = grid.dims()[0]; + const int out_h = grid.dims()[1]; + const int out_w = grid.dims()[2]; + + // split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim + grid_x->Resize({n, out_h, out_w}); + grid_y->Resize({n, out_h, out_w}); + T* grid_x_data = ctx.Alloc(grid_x); + T* grid_y_data = ctx.Alloc(grid_y); + + const T* grid_data = grid.data(); + for (int i = 0; i < n * out_h * out_w; i++) { + grid_x_data[i] = grid_data[2 * i]; + grid_y_data[i] = grid_data[(2 * i) + 1]; + } + + Unnormalize(ctx, grid_x, in_w - 1, align_corners); + Unnormalize(ctx, grid_y, in_h - 1, align_corners); + + ClipWithMask( + ctx, in_w - 1, align_corners, padding_mode, grid_x, grid_x_scale); + ClipWithMask( + ctx, in_h - 1, align_corners, padding_mode, grid_y, grid_y_scale); +} + +template +static void GatherOutputGradToInputGrad(const DenseTensor& output_grad, + DenseTensor* input_grad, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& d1, + const DenseTensor& d2) { + const int n = output_grad.dims()[0]; + const int c = output_grad.dims()[1]; + const int out_h = output_grad.dims()[2]; + const int out_w = output_grad.dims()[3]; + const int in_h = input_grad->dims()[2]; + const int in_w = input_grad->dims()[3]; + auto x_t = EigenTensor::From(x); + auto y_t = EigenTensor::From(y); + auto d1_t = EigenTensor::From(d1); + auto d2_t = EigenTensor::From(d2); + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + + for (int i = 0; i < n; i++) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + if (IsInBound( + x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) { + for (int j = 0; j < c; j++) { + input_grad_t(i, + j, + static_cast(round(y_t(i, k, l))), + static_cast(round(x_t(i, k, l)))) += + output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l); + } + } + } + } + } +} + +template +static void GatherBilinearGrad(const CPUContext& ctx, + const DenseTensor& input, + const DenseTensor& output_grad, + DenseTensor* grid_x, + DenseTensor* grid_y, + DenseTensor* grid_x_scale, + DenseTensor* grid_y_scale, + DenseTensor* input_grad, + DenseTensor* grid_grad) { + const int n = grid_x->dims()[0]; + const int out_h = grid_x->dims()[1]; + const int out_w = grid_x->dims()[2]; + const int c = input.dims()[1]; + + DenseTensor x_w, x_e, y_n, y_s; + DenseTensor d_w, d_e, d_n, d_s; + DenseTensor v_wn, v_en, v_ws, v_es; + + AllNeigbors(ctx, + input, + grid_x, // grid_x + grid_y, // grid_y + &x_w, + &x_e, + &y_n, + &y_s, + &d_w, + &d_e, + &d_n, + &d_s, + &v_wn, + &v_en, + &v_ws, + &v_es); + + // gather output grad value to input grad by corner point coords and weight + GatherOutputGradToInputGrad(output_grad, input_grad, x_w, y_n, d_e, d_s); + GatherOutputGradToInputGrad(output_grad, input_grad, x_w, y_s, d_e, d_n); + GatherOutputGradToInputGrad(output_grad, input_grad, x_e, y_n, d_w, d_s); + GatherOutputGradToInputGrad(output_grad, input_grad, x_e, y_s, d_w, d_n); + + auto v_wn_t = EigenTensor::From(v_wn); + auto v_en_t = EigenTensor::From(v_en); + auto v_ws_t = EigenTensor::From(v_ws); + auto v_es_t = EigenTensor::From(v_es); + + auto d_w_t = EigenTensor::From(d_w); + auto d_e_t = EigenTensor::From(d_e); + auto d_n_t = EigenTensor::From(d_n); + auto d_s_t = EigenTensor::From(d_s); + + auto output_grad_t = EigenTensor::From(output_grad); + + if (grid_grad != nullptr) { + DenseTensor grid_grad_x, grid_grad_y; + grid_grad_x.Resize({n, out_h, out_w}); + grid_grad_y.Resize({n, out_h, out_w}); + ctx.Alloc(&grid_grad_x); + ctx.Alloc(&grid_grad_y); + auto grid_grad_x_t = + EigenTensor::From(grid_grad_x).setConstant(static_cast(0.0)); + auto grid_grad_y_t = + EigenTensor::From(grid_grad_y).setConstant(static_cast(0.0)); + for (int i = 0; i < n; i++) { + for (int j = 0; j < c; j++) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + grid_grad_x_t(i, k, l) += + ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + + (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) * + output_grad_t(i, j, k, l); + grid_grad_y_t(i, k, l) += + ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + + (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) * + output_grad_t(i, j, k, l); + } + } + } + } + + // const T x_max = static_cast(in_w - 1); + // const T y_max = static_cast(in_h - 1); + + auto grid_x_scale_t = EigenTensor::From(*grid_x_scale); + auto grid_y_scale_t = EigenTensor::From(*grid_y_scale); + grid_grad_x_t = grid_grad_x_t * grid_x_scale_t; + grid_grad_y_t = grid_grad_y_t * grid_y_scale_t; + + // gather grid_grad [x, y] in 3rd Dim + T* grid_grad_data = grid_grad->data(); + T* grid_grad_x_data = grid_grad_x.data(); + T* grid_grad_y_data = grid_grad_y.data(); + for (int i = 0; i < n * out_h * out_w; i++) { + grid_grad_data[2 * i] = grid_grad_x_data[i]; + grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; + } + } +} + +template +static void GatherOutputGradToInputGrad(const DenseTensor& output_grad, + DenseTensor* input_grad, + const DenseTensor& x, + const DenseTensor& y) { + const int n = output_grad.dims()[0]; + const int c = output_grad.dims()[1]; + const int out_h = output_grad.dims()[2]; + const int out_w = output_grad.dims()[3]; + const int in_h = input_grad->dims()[2]; + const int in_w = input_grad->dims()[3]; + auto x_t = EigenTensor::From(x); + auto y_t = EigenTensor::From(y); + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + for (int i = 0; i < n; i++) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + if (IsInBound( + x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) { + for (int j = 0; j < c; j++) { + input_grad_t(i, + j, + static_cast(round(y_t(i, k, l))), + static_cast(round(x_t(i, k, l)))) += + output_grad_t(i, j, k, l); + } + } + } + } + } +} + +template +void GridSampleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& grid, + const DenseTensor& out_grid, + const std::string& mode, + const std::string& padding_mode, + bool align_corners, + DenseTensor* x_grad, + DenseTensor* grid_grad) { + const int n = grid.dims()[0]; + const int out_h = grid.dims()[1]; + const int out_w = grid.dims()[2]; + const int c = x.dims()[1]; + const int in_h = x.dims()[2]; + const int in_w = x.dims()[3]; + + x_grad->Resize({n, c, in_h, in_w}); + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant()(dev_ctx, x_grad, static_cast(0)); + + if (grid_grad != nullptr) { + grid_grad->Resize({n, out_h, out_w, 2}); + dev_ctx.template Alloc(grid_grad); + phi::funcs::SetConstant()( + dev_ctx, grid_grad, static_cast(0)); + } + + DenseTensor grid_x, grid_y; + DenseTensor grid_x_scale, grid_y_scale; + CalcGridLocationsWithGrad(dev_ctx, + grid, + in_h, + in_w, + align_corners, + padding_mode, + &grid_x, + &grid_y, + &grid_x_scale, + &grid_y_scale); + if (mode == "bilinear") { + GatherBilinearGrad(dev_ctx, + x, + out_grid, + &grid_x, + &grid_y, + &grid_x_scale, + &grid_y_scale, + x_grad, + grid_grad); + } else { + auto grid_x_t = EigenTensor::From(grid_x); + auto grid_y_t = EigenTensor::From(grid_y); + grid_x_t = grid_x_t.round(); + grid_y_t = grid_y_t.round(); + GatherOutputGradToInputGrad(out_grid, x_grad, grid_x, grid_y); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(grid_sample_grad, + CPU, + ALL_LAYOUT, + phi::GridSampleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/grid_sample_kernel.cc b/paddle/phi/kernels/cpu/grid_sample_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..92a528cdda96a191cf73115feb3cf3dd3656305d --- /dev/null +++ b/paddle/phi/kernels/cpu/grid_sample_kernel.cc @@ -0,0 +1,184 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/grid_sample_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/grid_sample_utils.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +using Array4 = Eigen::DSizes; + +template +static inline void Clip(const CPUContext& ctx, + DenseTensor* grid_slice, + const int max_val, // height-1 or width-1 + bool align_corners, + std::string padding_mode) { + auto& place = *ctx.eigen_device(); + auto grid_slice_t = EigenTensor::From(*grid_slice); + if (padding_mode == "border") { + grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(max_val)); + } else if (padding_mode == "reflection") { + if (align_corners) { + auto double_range = static_cast(max_val * 2); + auto grid_abs = grid_slice_t.abs(); + auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; + grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + if (max_val == 0) { + grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); + } + } else { + auto double_range = static_cast((max_val + 1) * 2); + auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); + auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; + grid_slice_t.device(place) = + extra.cwiseMin(double_range - extra) - static_cast(0.5); + grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(max_val)); + } + } +} + +template +static void CalcGridLocations(const CPUContext& ctx, + const DenseTensor& grid, + const int in_h, + const int in_w, + bool align_corners, + std::string padding_mode, + DenseTensor* grid_x, + DenseTensor* grid_y) { + const int n = grid.dims()[0]; + const int out_h = grid.dims()[1]; + const int out_w = grid.dims()[2]; + + // split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim + grid_x->Resize({n, out_h, out_w}); + grid_y->Resize({n, out_h, out_w}); + T* grid_x_data = ctx.Alloc(grid_x); + T* grid_y_data = ctx.Alloc(grid_y); + const T* grid_data = grid.data(); + for (int i = 0; i < n * out_h * out_w; i++) { + grid_x_data[i] = grid_data[2 * i]; + grid_y_data[i] = grid_data[(2 * i) + 1]; + } + + Unnormalize(ctx, grid_x, in_w - 1, align_corners); + Unnormalize(ctx, grid_y, in_h - 1, align_corners); + + Clip(ctx, grid_x, in_w - 1, align_corners, padding_mode); + Clip(ctx, grid_y, in_h - 1, align_corners, padding_mode); +} + +template +static void BilinearInter(const CPUContext& ctx, + const DenseTensor& input, + DenseTensor* grid_x, + DenseTensor* grid_y, + DenseTensor* out) { + auto& place = *ctx.eigen_device(); + const int n = grid_x->dims()[0]; + const int out_h = grid_x->dims()[1]; + const int out_w = grid_x->dims()[2]; + const int c = input.dims()[1]; + + DenseTensor x_w, x_e, y_n, y_s; + DenseTensor d_w, d_e, d_n, d_s; + DenseTensor v_wn, v_en, v_ws, v_es; + + AllNeigbors(ctx, + input, + grid_x, + grid_y, + &x_w, + &x_e, + &y_n, + &y_s, + &d_w, + &d_e, + &d_n, + &d_s, + &v_wn, + &v_en, + &v_ws, + &v_es); + + auto d_w_t = EigenTensor::From(d_w); + auto d_e_t = EigenTensor::From(d_e); + auto d_n_t = EigenTensor::From(d_n); + auto d_s_t = EigenTensor::From(d_s); + + auto d_w_scaled_t = + d_w_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); + auto d_e_scaled_t = + d_e_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); + auto d_n_scaled_t = + d_n_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); + auto d_s_scaled_t = + d_s_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1)); + auto v_wn_t = EigenTensor::From(v_wn); + auto v_en_t = EigenTensor::From(v_en); + auto v_ws_t = EigenTensor::From(v_ws); + auto v_es_t = EigenTensor::From(v_es); + auto output_t = EigenTensor::From(*out); + // bilinear interpolaetion by 4 corner points + output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t + + v_en_t * d_w_scaled_t * d_s_scaled_t + + v_ws_t * d_e_scaled_t * d_n_scaled_t + + v_es_t * d_w_scaled_t * d_n_scaled_t; +} + +template +void GridSampleKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& grid, + const std::string& mode, + const std::string& padding_mode, + bool align_corners, + DenseTensor* out) { + const int n = grid.dims()[0]; + const int out_h = grid.dims()[1]; + const int out_w = grid.dims()[2]; + const int c = x.dims()[1]; + const int in_h = x.dims()[2]; + const int in_w = x.dims()[3]; + + out->Resize(phi::make_ddim({n, c, out_h, out_w})); + dev_ctx.template Alloc(out); + phi::funcs::SetConstant()(dev_ctx, out, static_cast(0)); + + DenseTensor grid_x, grid_y; + CalcGridLocations( + dev_ctx, grid, in_h, in_w, align_corners, padding_mode, &grid_x, &grid_y); + + if (mode == "bilinear") { + BilinearInter(dev_ctx, x, &grid_x, &grid_y, out); + } else if (mode == "nearest") { + auto grid_x_t = EigenTensor::From(grid_x); + auto grid_y_t = EigenTensor::From(grid_y); + grid_x_t = grid_x_t.round(); + grid_y_t = grid_y_t.round(); + GetGridPointValue(x, out, grid_x, grid_y); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + grid_sample, CPU, ALL_LAYOUT, phi::GridSampleKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/grid_sample_utils.h b/paddle/phi/kernels/cpu/grid_sample_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..53a16446d7e8c65b3d2d63835e6a2b86c1f96795 --- /dev/null +++ b/paddle/phi/kernels/cpu/grid_sample_utils.h @@ -0,0 +1,160 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void Unnormalize(const CPUContext& ctx, + DenseTensor* grid_slice, + const int max_val, // height-1 or width-1 + bool align_corners) { + auto& place = *ctx.eigen_device(); + auto grid_slice_t = EigenTensor::From(*grid_slice); + + if (!align_corners) { + auto factor = static_cast((max_val + 1) * 0.5); + grid_slice_t.device(place) = + (grid_slice_t + static_cast(1)) * factor - static_cast(0.5); + } else { + auto factor = static_cast(max_val * 0.5); + grid_slice_t.device(place) = (grid_slice_t + static_cast(1)) * factor; + } +} + +template +inline bool IsInBound(T x, T y, T x_max, T y_max) { + if (x < 0 || x > x_max || y < 0 || y > y_max) { + return false; + } + return true; +} + +template +void GetGridPointValue(const DenseTensor& input, + DenseTensor* output, + const DenseTensor& x, + const DenseTensor& y) { + const int n = input.dims()[0]; + const int c = input.dims()[1]; + const int in_h = input.dims()[2]; + const int in_w = input.dims()[3]; + const int out_h = x.dims()[1]; + const int out_w = x.dims()[2]; + auto x_t = EigenTensor::From(x); + auto y_t = EigenTensor::From(y); + auto output_t = EigenTensor::From(*output).setConstant((T)0); + auto input_t = EigenTensor::From(input); + + for (int i = 0; i < n; i++) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + if (IsInBound( + x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) { + for (int j = 0; j < c; j++) { + output_t(i, j, k, l) = + input_t(i, + j, + static_cast(round(y_t(i, k, l))), + static_cast(round(x_t(i, k, l)))); + } + } + } + } + } +} + +template +void AllNeigbors(const CPUContext& ctx, + const DenseTensor& input, + DenseTensor* grid_x, + DenseTensor* grid_y, + DenseTensor* x_w, + DenseTensor* x_e, + DenseTensor* y_n, + DenseTensor* y_s, // positions + DenseTensor* d_w, + DenseTensor* d_e, + DenseTensor* d_n, + DenseTensor* d_s, // distance + DenseTensor* v_wn, + DenseTensor* v_en, + DenseTensor* v_ws, + DenseTensor* v_es) { // values + auto& place = *ctx.eigen_device(); + + const int c = input.dims()[1]; + const int n = grid_x->dims()[0]; + const int out_h = grid_x->dims()[1]; + const int out_w = grid_x->dims()[2]; + // calculate coords of 4 corner points + x_w->Resize({n, out_h, out_w}); + x_e->Resize({n, out_h, out_w}); + y_n->Resize({n, out_h, out_w}); + y_s->Resize({n, out_h, out_w}); + ctx.Alloc(x_w); + ctx.Alloc(x_e); + ctx.Alloc(y_n); + ctx.Alloc(y_s); + auto x_w_t = EigenTensor::From(*x_w); + auto x_e_t = EigenTensor::From(*x_e); + auto y_n_t = EigenTensor::From(*y_n); + auto y_s_t = EigenTensor::From(*y_s); + + auto grid_x_t = EigenTensor::From(*grid_x); + auto grid_y_t = EigenTensor::From(*grid_y); + + x_w_t.device(place) = grid_x_t.floor(); + x_e_t.device(place) = x_w_t + static_cast(1); + y_n_t.device(place) = grid_y_t.floor(); + y_s_t.device(place) = y_n_t + static_cast(1); + + // calculate distances to 4 sides + d_w->Resize({n, out_h, out_w}); + d_e->Resize({n, out_h, out_w}); + d_n->Resize({n, out_h, out_w}); + d_s->Resize({n, out_h, out_w}); + ctx.Alloc(d_w); + ctx.Alloc(d_e); + ctx.Alloc(d_n); + ctx.Alloc(d_s); + auto d_w_t = EigenTensor::From(*d_w); + auto d_e_t = EigenTensor::From(*d_e); + auto d_n_t = EigenTensor::From(*d_n); + auto d_s_t = EigenTensor::From(*d_s); + d_w_t.device(place) = grid_x_t - x_w_t; + d_e_t.device(place) = x_e_t - grid_x_t; + d_n_t.device(place) = grid_y_t - y_n_t; + d_s_t.device(place) = y_s_t - grid_y_t; + + // calc 4 corner points value + v_wn->Resize({n, c, out_h, out_w}); + v_en->Resize({n, c, out_h, out_w}); + v_ws->Resize({n, c, out_h, out_w}); + v_es->Resize({n, c, out_h, out_w}); + ctx.Alloc(v_wn); + ctx.Alloc(v_en); + ctx.Alloc(v_ws); + ctx.Alloc(v_es); + GetGridPointValue(input, v_wn, *x_w, *y_n); + GetGridPointValue(input, v_en, *x_e, *y_n); + GetGridPointValue(input, v_ws, *x_w, *y_s); + GetGridPointValue(input, v_es, *x_e, *y_s); +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu b/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..457a348be832b006d9f224e3032c369a7fe4bb62 --- /dev/null +++ b/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu @@ -0,0 +1,324 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/grid_sample_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpu/grid_sample_utils.h" + +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace phi { + +template +static __forceinline__ __device__ void AtomicAdd( + T* data, int h, int w, int sH, int sW, int H, int W, T delta) { + if (InBounds(h, w, H, W)) { + paddle::platform::CudaAtomicAdd(data + h * sH + w * sW, delta); + } +} + +template +static __forceinline__ __device__ T +UnnormalizeWithMask(T coord, int size, bool align_corners, T* grad_in) { + if (align_corners) { + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + *grad_in = static_cast(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +template +static __forceinline__ __device__ T ClipIndexesWithMask(T in, + int clip_limit, + T* grad_in) { + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + T max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +template +static __forceinline__ __device__ T +ReflectIndexesWithMask(T in, int twice_low, int twice_high, T* grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_; + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + T extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +template +static __forceinline__ __device__ T +ComputePositionsWithMask(T coord, + int size, + PaddingMode padding_mode, + bool align_corners, + T* grad_in) { + T grad_clip, grad_refl; + coord = UnnormalizeWithMask(coord, size, align_corners, grad_in); + if (padding_mode == PaddingMode::border) { + coord = ClipIndexesWithMask(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == PaddingMode::reflect) { + if (align_corners) { + coord = ReflectIndexesWithMask(coord, 0, 2 * (size - 1), &grad_refl); + } else { + coord = ReflectIndexesWithMask(coord, -1, 2 * size - 1, &grad_refl); + } + coord = ClipIndexesWithMask(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + + return coord; +} + +template +__global__ void GridSamplerCudaBackwardKernel(const int nthreads, + const T* grad_output, + const T* input, + const T* grid, + int n, + int out_c, + int out_h, + int out_w, + int in_h, + int in_w, + T* grad_input, + T* grad_grid, + const Mode mode, + const PaddingMode padding_mode, + bool align_corners) { + int inp_sN = out_c * in_h * in_w; + int inp_sC = in_h * in_w; + int inp_sH = in_w; + int inp_sW = 1; + int grid_sN = out_h * out_w * 2; + int grid_sH = out_w * 2; + int grid_sW = 2; + int grid_sCoor = 1; + + int gOut_sN = out_c * out_h * out_w; + int gOut_sC = out_h * out_w; + int gOut_sH = out_w; + int gOut_sW = 1; + + CUDA_KERNEL_LOOP(index, nthreads) { + const int w = index % out_w; + const int h = (index / out_w) % out_h; + const int n = index / (out_h * out_w); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + + T gix_mult, giy_mult; + ix = ComputePositionsWithMask( + ix, in_w, padding_mode, align_corners, &gix_mult); + iy = ComputePositionsWithMask( + iy, in_h, padding_mode, align_corners, &giy_mult); + + if (mode == Mode::bilinear) { + int ix_nw = static_cast(floor(ix)); + int iy_nw = static_cast(floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); + + T gix = static_cast(0), giy = static_cast(0); + int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW; + T* gInp_ptr_NC = grad_input + n * inp_sN; + int inp_offset_NC = n * inp_sN; + for (int c = 0; c < out_c; ++c, + inp_offset_NC += inp_sC, + gInp_ptr_NC += inp_sC, + gOut_offset += gOut_sC) { + T gOut = grad_output[gOut_offset]; + + AtomicAdd( + gInp_ptr_NC, iy_nw, ix_nw, inp_sH, inp_sW, in_h, in_w, nw * gOut); + AtomicAdd( + gInp_ptr_NC, iy_ne, ix_ne, inp_sH, inp_sW, in_h, in_w, ne * gOut); + AtomicAdd( + gInp_ptr_NC, iy_sw, ix_sw, inp_sH, inp_sW, in_h, in_w, sw * gOut); + AtomicAdd( + gInp_ptr_NC, iy_se, ix_se, inp_sH, inp_sW, in_h, in_w, se * gOut); + + if (InBounds(iy_nw, ix_nw, in_h, in_w)) { + T nw_val = input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW]; + gix -= nw_val * (iy_se - iy) * gOut; + giy -= nw_val * (ix_se - ix) * gOut; + } + if (InBounds(iy_ne, ix_ne, in_h, in_w)) { + T ne_val = input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW]; + gix += ne_val * (iy_sw - iy) * gOut; + giy -= ne_val * (ix - ix_sw) * gOut; + } + if (InBounds(iy_sw, ix_sw, in_h, in_w)) { + T sw_val = input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW]; + gix -= sw_val * (iy - iy_ne) * gOut; + giy += sw_val * (ix_ne - ix) * gOut; + } + if (InBounds(iy_se, ix_se, in_h, in_w)) { + T se_val = input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW]; + gix += se_val * (iy - iy_nw) * gOut; + giy += se_val * (ix - ix_nw) * gOut; + } + } + + if (grad_grid != nullptr) { + T* gGrid_ptr_NHW = grad_grid + index * grid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } + } else if (mode == Mode::nearest) { + int ix_nearest = static_cast(std::nearbyint(ix)); + int iy_nearest = static_cast(std::nearbyint(iy)); + + int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW; + T* gInp_ptr_NC = grad_input + n * inp_sN; + for (int c = 0; c < out_c; + ++c, gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) { + AtomicAdd(gInp_ptr_NC, + iy_nearest, + ix_nearest, + inp_sH, + inp_sW, + in_h, + in_w, + grad_output[gOut_offset]); + } + + if (grad_grid != nullptr) { + T* gGrid_ptr_NHW = grad_grid + index * grid_sW; + gGrid_ptr_NHW[0] = static_cast(0); + gGrid_ptr_NHW[1] = static_cast(0); + } + } + } +} + +template +void GridSampleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& grid, + const DenseTensor& out_grad, + const std::string& mode, + const std::string& padding_mode, + bool align_corners, + DenseTensor* x_grad, + DenseTensor* grid_grad) { + PaddingMode enum_padding_mode; + Mode enum_mode; + if (padding_mode == "border") { + enum_padding_mode = PaddingMode::border; + } else if (padding_mode == "reflection") { + enum_padding_mode = PaddingMode::reflect; + } else { + enum_padding_mode = PaddingMode::zeros; + } + + if (mode == "nearest") { + enum_mode = Mode::nearest; + } else { + enum_mode = Mode::bilinear; + } + + const int n = grid.dims()[0]; + const int out_h = grid.dims()[1]; + const int out_w = grid.dims()[2]; + const int c = x.dims()[1]; + const int in_h = x.dims()[2]; + const int in_w = x.dims()[3]; + + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant()(dev_ctx, x_grad, static_cast(0)); + + T* grid_grad_data = nullptr; + if (grid_grad != nullptr) { + grid_grad_data = dev_ctx.template Alloc(grid_grad); + } + + int count = static_cast(n * out_h * out_w); + auto cu_stream = dev_ctx.stream(); + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count); + GridSamplerCudaBackwardKernel< + T><<>>( + count, + out_grad.data(), + x.data(), + grid.data(), + n, + c, + out_h, + out_w, + in_h, + in_w, + x_grad->data(), + grid_grad_data, + enum_mode, + enum_padding_mode, + align_corners); +} + +} // namespace phi + +PD_REGISTER_KERNEL(grid_sample_grad, + GPU, + ALL_LAYOUT, + phi::GridSampleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/grid_sample_kernel.cu b/paddle/phi/kernels/gpu/grid_sample_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f611b46911c4f1555ad27a538d8918f11ae761cc --- /dev/null +++ b/paddle/phi/kernels/gpu/grid_sample_kernel.cu @@ -0,0 +1,233 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/grid_sample_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/grid_sample_utils.h" + +namespace phi { + +template +static __forceinline__ __device__ T Unnormalize(T coord, + int size, + bool align_corners) { + if (align_corners) { + return ((coord + 1.f) / 2) * (size - 1); + } else { + return ((coord + 1.f) * size - 1) / 2; + } +} + +template +static __forceinline__ __device__ T ClipIndexes(T in, int max_value) { + return min(static_cast(max_value), max(in, static_cast(0))); +} + +template +static __forceinline__ __device__ T ReflectIndexes(T in, + int twice_low, + int twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = fabs(in - min); + T extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +template +static __forceinline__ __device__ T ComputePositions(T coord, + int size, + PaddingMode padding_mode, + bool align_corners) { + coord = Unnormalize(coord, size, align_corners); + if (padding_mode == PaddingMode::border) { + coord = ClipIndexes(coord, size - 1); + } else if (padding_mode == PaddingMode::reflect) { + if (align_corners) { + coord = ReflectIndexes(coord, 0, 2 * (size - 1)); + } else { + coord = ReflectIndexes(coord, -1, 2 * size - 1); + } + coord = ClipIndexes(coord, size - 1); + } + return coord; +} + +template +__global__ void GridSampleCudaKernel(const int nthreads, + int n, + int out_c, + int out_h, + int out_w, + int in_h, + int in_w, + const T* input, + const T* grid, + T* output, + const Mode mode, + const PaddingMode padding_mode, + bool align_corners) { + int inp_sN = out_c * in_h * in_w; + + int inp_sC = in_h * in_w; + int inp_sH = in_w; + int inp_sW = 1; + int grid_sN = out_h * out_w * 2; + int grid_sH = out_w * 2; + int grid_sW = 2; + int grid_sCoor = 1; + int out_sN = out_c * out_h * out_w; + int out_sC = out_h * out_w; + int out_sH = out_w; + int out_sW = 1; + CUDA_KERNEL_LOOP(index, nthreads) { + const int w = index % out_w; + const int h = (index / out_w) % out_h; + const int n = index / (out_h * out_w); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + + ix = ComputePositions(ix, in_w, padding_mode, align_corners); + iy = ComputePositions(iy, in_h, padding_mode, align_corners); + if (mode == Mode::bilinear) { + int ix_nw = static_cast(floor(ix)); + int iy_nw = static_cast(floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); + + auto inp_offset_NC = n * inp_sN; + + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < out_c; + ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { + *out_ptr_NCHW = static_cast(0); + if (InBounds(iy_nw, ix_nw, in_h, in_w)) { + *out_ptr_NCHW += + input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (InBounds(iy_ne, ix_ne, in_h, in_w)) { + *out_ptr_NCHW += + input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (InBounds(iy_sw, ix_sw, in_h, in_w)) { + *out_ptr_NCHW += + input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (InBounds(iy_se, ix_se, in_h, in_w)) { + *out_ptr_NCHW += + input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se; + } + } + } else if (mode == Mode::nearest) { + int ix_nearest = static_cast(std::nearbyint(ix)); + int iy_nearest = static_cast(std::nearbyint(iy)); + auto inp_offset_NC = n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < out_c; + ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { + if (InBounds(iy_nearest, ix_nearest, in_h, in_w)) { + *out_ptr_NCHW = + input[inp_offset_NC + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCHW = static_cast(0); + } + } + } + } +} + +template +void GridSampleKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& grid, + const std::string& mode, + const std::string& padding_mode, + bool align_corners, + DenseTensor* out) { + PaddingMode enum_padding_mode; + Mode enum_mode; + if (padding_mode == "border") { + enum_padding_mode = PaddingMode::border; + } else if (padding_mode == "reflection") { + enum_padding_mode = PaddingMode::reflect; + } else { + enum_padding_mode = PaddingMode::zeros; + } + + if (mode == "nearest") { + enum_mode = Mode::nearest; + } else { + enum_mode = Mode::bilinear; + } + + const int n = grid.dims()[0]; + const int out_h = grid.dims()[1]; + const int out_w = grid.dims()[2]; + const int c = x.dims()[1]; + const int in_h = x.dims()[2]; + const int in_w = x.dims()[3]; + VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h + << "; out_w: " << out_w; + + auto* output_data = dev_ctx.template Alloc(out); + VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; " + << out->dims()[2] << "; " << out->dims()[3]; + + int count = static_cast(n * out_h * out_w); + auto cu_stream = dev_ctx.stream(); + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count); + GridSampleCudaKernel< + T><<>>( + count, + n, + c, + out_h, + out_w, + in_h, + in_w, + x.data(), + grid.data(), + output_data, + enum_mode, + enum_padding_mode, + align_corners); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + grid_sample, GPU, ALL_LAYOUT, phi::GridSampleKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/grid_sample_utils.h b/paddle/phi/kernels/gpu/grid_sample_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..098eb9defb54904c41f33326b54eabdda657360a --- /dev/null +++ b/paddle/phi/kernels/gpu/grid_sample_utils.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { + +enum class Mode { + bilinear, + nearest, +}; + +enum class PaddingMode { zeros, border, reflect }; + +static __forceinline__ __device__ bool InBounds(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +} // namespace phi diff --git a/paddle/phi/kernels/grid_sample_grad_kernel.h b/paddle/phi/kernels/grid_sample_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..50a8d5be260bd387476467e2cdddaeb59f943b9b --- /dev/null +++ b/paddle/phi/kernels/grid_sample_grad_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GridSampleGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &grid, + const DenseTensor &out_grid, + const std::string &mode, + const std::string &padding_mode, + bool align_corners, + DenseTensor *x_grad, + DenseTensor *grid_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/grid_sample_kernel.h b/paddle/phi/kernels/grid_sample_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2e1e9b508649b22de086a103537f4984b7f693e5 --- /dev/null +++ b/paddle/phi/kernels/grid_sample_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GridSampleKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &grid, + const std::string &mode, + const std::string &padding_mode, + bool align_corners, + DenseTensor *out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/grid_sampler_sig.cc b/paddle/phi/ops/compat/grid_sampler_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..b76a9770d4dede5ea604f69858201c2fb035070d --- /dev/null +++ b/paddle/phi/ops/compat/grid_sampler_sig.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GridSamplerOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("grid_sample", + {"X", "Grid"}, + {"mode", "padding_mode", "align_corners"}, + {"Output"}); +} + +KernelSignature GridSamplerGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("grid_sample_grad", + {"X", "Grid", GradVarName("Output")}, + {"mode", "padding_mode", "align_corners"}, + {GradVarName("X"), GradVarName("Grid")}); +} + +} // namespace phi + +// use Python API name as kernel name +PD_REGISTER_BASE_KERNEL_NAME(grid_sampler, grid_sample); +PD_REGISTER_BASE_KERNEL_NAME(grid_sampler_grad, grid_sample_grad); + +PD_REGISTER_ARG_MAPPING_FN(grid_sampler, phi::GridSamplerOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(grid_sampler_grad, + phi::GridSamplerGradOpArgumentMapping);