未验证 提交 8fd20b5b 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move grid sample op kernel into phi (#40585)

* add grid sample phi kernel

* add grid sample phi kernel and remove original kernel

* replace mutable_data by alloc
上级 ad81f22c
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
......@@ -229,15 +229,6 @@ REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker,
ops::GridSampleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad);
REGISTER_OP_CPU_KERNEL(
grid_sampler,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
grid_sampler_grad,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(grid_sampler)
.AddCheckpoint(
R"ROC(
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
static __forceinline__ __device__ bool in_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename T>
static __forceinline__ __device__ void atomic_add(T* data, int h, int w, int sH,
int sW, int H, int W,
T delta) {
if (in_bounds(h, w, H, W)) {
platform::CudaAtomicAdd(data + h * sH + w * sW, delta);
}
}
template <typename T>
static __forceinline__ __device__ T _unnormalize(T coord, int size,
bool align_corners) {
if (align_corners) {
return ((coord + 1.f) / 2) * (size - 1);
} else {
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T clip_indexes(T in, int max_value) {
return min(static_cast<T>(max_value), max(in, static_cast<T>(0)));
}
template <typename T>
static __forceinline__ __device__ T reflect_indexes(T in, int twice_low,
int twice_high) {
if (twice_low == twice_high) {
return static_cast<T>(0);
}
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = fabs(in - min);
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T compute_positions(T coord, int size,
PaddingMode padding_mode,
bool align_corners) {
coord = _unnormalize<T>(coord, size, align_corners);
if (padding_mode == PaddingMode::border) {
coord = clip_indexes(coord, size - 1);
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = reflect_indexes(coord, 0, 2 * (size - 1));
} else {
coord = reflect_indexes(coord, -1, 2 * size - 1);
}
coord = clip_indexes(coord, size - 1);
}
return coord;
}
template <typename T>
static __forceinline__ __device__ T _unnormalize_with_mask(T coord, int size,
bool align_corners,
T* grad_in) {
if (align_corners) {
*grad_in = static_cast<T>(size - 1) / 2;
return ((coord + 1.f) / 2) * (size - 1);
} else {
*grad_in = static_cast<T>(size) / 2;
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T clip_indexes_with_mask(T in, int clip_limit,
T* grad_in) {
if (in <= static_cast<T>(0)) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
} else {
T max = static_cast<T>(clip_limit - 1);
if (in >= max) {
*grad_in = static_cast<T>(0);
return max;
} else {
*grad_in = static_cast<T>(1);
return in;
}
}
}
template <typename T>
static __forceinline__ __device__ T
reflect_indexes_with_mask(T in, int twice_low, int twice_high, T* grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
}
int grad_in_mult_;
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<T>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<T>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<T>(-grad_in_mult_);
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T
compute_positions_with_mask(T coord, int size, PaddingMode padding_mode,
bool align_corners, T* grad_in) {
T grad_clip, grad_refl;
coord = _unnormalize_with_mask<T>(coord, size, align_corners, grad_in);
if (padding_mode == PaddingMode::border) {
coord = clip_indexes_with_mask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = reflect_indexes_with_mask(coord, 0, 2 * (size - 1), &grad_refl);
} else {
coord = reflect_indexes_with_mask(coord, -1, 2 * size - 1, &grad_refl);
}
coord = clip_indexes_with_mask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
return coord;
}
template <typename T>
__global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c,
int out_h, int out_w, int in_h,
int in_w, const T* input, const T* grid,
T* output, const Mode mode,
const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;
int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int out_sN = out_c * out_h * out_w;
int out_sC = out_h * out_w;
int out_sH = out_w;
int out_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
ix = compute_positions(ix, in_w, padding_mode, align_corners);
iy = compute_positions(iy, in_h, padding_mode, align_corners);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
*out_ptr_NCHW = static_cast<T>(0);
if (in_bounds(iy_nw, ix_nw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
}
if (in_bounds(iy_ne, ix_ne, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
}
if (in_bounds(iy_sw, ix_sw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
}
if (in_bounds(iy_se, ix_se, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
}
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
if (in_bounds(iy_nearest, ix_nearest, in_h, in_w)) {
*out_ptr_NCHW =
input[inp_offset_NC + iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCHW = static_cast<T>(0);
}
}
}
}
}
template <typename T>
class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.cuda_device_context();
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode_s = ctx.Attr<std::string>("padding_mode");
auto mode_s = ctx.Attr<std::string>("mode");
PaddingMode padding_mode;
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
}
if (mode_s == "nearest") {
mode = Mode::nearest;
} else {
mode = Mode::bilinear;
}
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h
<< "; out_w: " << out_w;
auto* output = ctx.Output<Tensor>("Output");
auto* output_data = output->mutable_data<T>(ctx.GetPlace());
VLOG(3) << "out dims: " << output->dims()[0] << "; " << output->dims()[1]
<< "; " << output->dims()[2] << "; " << output->dims()[3];
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, count);
grid_sample_cuda_kernel<
T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count, n, c, out_h, out_w, in_h, in_w, input->data<T>(),
grid->data<T>(), output_data, mode, padding_mode, align_corners);
}
};
template <typename T>
__global__ void grid_sampler_cuda_backward_kernel(
const int nthreads, const T* grad_output, const T* input, const T* grid,
int n, int out_c, int out_h, int out_w, int in_h, int in_w, T* grad_input,
T* grad_grid, const Mode mode, const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;
int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int gOut_sN = out_c * out_h * out_w;
int gOut_sC = out_h * out_w;
int gOut_sH = out_w;
int gOut_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
T gix_mult, giy_mult;
ix = compute_positions_with_mask(ix, in_w, padding_mode, align_corners,
&gix_mult);
iy = compute_positions_with_mask(iy, in_h, padding_mode, align_corners,
&giy_mult);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
T gix = static_cast<T>(0), giy = static_cast<T>(0);
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
int inp_offset_NC = n * inp_sN;
for (int c = 0; c < out_c; ++c, inp_offset_NC += inp_sC,
gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) {
T gOut = grad_output[gOut_offset];
atomic_add(gInp_ptr_NC, iy_nw, ix_nw, inp_sH, inp_sW, in_h, in_w,
nw * gOut);
atomic_add(gInp_ptr_NC, iy_ne, ix_ne, inp_sH, inp_sW, in_h, in_w,
ne * gOut);
atomic_add(gInp_ptr_NC, iy_sw, ix_sw, inp_sH, inp_sW, in_h, in_w,
sw * gOut);
atomic_add(gInp_ptr_NC, iy_se, ix_se, inp_sH, inp_sW, in_h, in_w,
se * gOut);
if (in_bounds(iy_nw, ix_nw, in_h, in_w)) {
T nw_val = input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW];
gix -= nw_val * (iy_se - iy) * gOut;
giy -= nw_val * (ix_se - ix) * gOut;
}
if (in_bounds(iy_ne, ix_ne, in_h, in_w)) {
T ne_val = input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW];
gix += ne_val * (iy_sw - iy) * gOut;
giy -= ne_val * (ix - ix_sw) * gOut;
}
if (in_bounds(iy_sw, ix_sw, in_h, in_w)) {
T sw_val = input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW];
gix -= sw_val * (iy - iy_ne) * gOut;
giy += sw_val * (ix_ne - ix) * gOut;
}
if (in_bounds(iy_se, ix_se, in_h, in_w)) {
T se_val = input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW];
gix += se_val * (iy - iy_nw) * gOut;
giy += se_val * (ix - ix_nw) * gOut;
}
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
for (int c = 0; c < out_c;
++c, gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) {
atomic_add(gInp_ptr_NC, iy_nearest, ix_nearest, inp_sH, inp_sW, in_h,
in_w, grad_output[gOut_offset]);
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0);
}
}
}
}
template <typename T>
class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.cuda_device_context();
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode_s = ctx.Attr<std::string>("padding_mode");
auto mode_s = ctx.Attr<std::string>("mode");
PaddingMode padding_mode;
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
}
if (mode_s == "nearest") {
mode = Mode::nearest;
} else {
mode = Mode::bilinear;
}
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
input_grad, static_cast<T>(0));
T* grid_grad_data = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
}
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, count);
grid_sampler_cuda_backward_kernel<
T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
padding_mode, align_corners);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(grid_sampler, ops::GridSampleOpCUDAKernel<float>,
ops::GridSampleOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(grid_sampler_grad,
ops::GridSampleGradOpCUDAKernel<float>,
ops::GridSampleGradOpCUDAKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <utility>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
enum class Mode {
bilinear,
nearest,
};
enum class PaddingMode { zeros, border, reflect };
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array3 = Eigen::DSizes<int64_t, 3>;
using Array4 = Eigen::DSizes<int64_t, 4>;
template <typename T>
static inline bool isInBound(T x, T y, T x_max, T y_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max) {
return false;
}
return true;
}
template <typename T>
static inline void unnormalize(const platform::CPUDeviceContext& ctx,
Tensor* grid_slice,
const int max_val, // height-1 or width-1
bool align_corners) {
auto& place = *ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
if (!align_corners) {
auto factor = static_cast<T>((max_val + 1) * 0.5);
grid_slice_t.device(place) =
(grid_slice_t + static_cast<T>(1)) * factor - static_cast<T>(0.5);
} else {
auto factor = static_cast<T>(max_val * 0.5);
grid_slice_t.device(place) = (grid_slice_t + static_cast<T>(1)) * factor;
}
}
template <typename T>
static inline void clip(const platform::CPUDeviceContext& ctx,
Tensor* grid_slice,
const int max_val, // height-1 or width-1
bool align_corners, std::string padding_mode) {
auto& place = *ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
if (padding_mode == "border") {
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto grid_abs = grid_slice_t.abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
grid_slice_t.device(place) = extra.cwiseMin(double_range - extra);
if (max_val == 0) {
grid_slice_t.device(place) = grid_slice_t.constant(static_cast<T>(0));
}
} else {
auto double_range = static_cast<T>((max_val + 1) * 2);
auto grid_abs = (grid_slice_t + static_cast<T>(0.5)).abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
grid_slice_t.device(place) =
extra.cwiseMin(double_range - extra) - static_cast<T>(0.5);
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
}
}
}
template <typename T>
static inline void clipWithMask(const platform::CPUDeviceContext& ctx,
const int max_val, // height-1 or width-1
bool align_corners, std::string padding_mode,
Tensor* grid_slice, Tensor* grid_scale) {
auto& place = *ctx.eigen_device();
grid_scale->mutable_data<T>(grid_slice->dims(), ctx.GetPlace());
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
auto factor = static_cast<T>(max_val * 0.5);
if (!align_corners) {
factor = static_cast<T>((max_val + 1) * 0.5);
}
auto grid_scale_t = EigenTensor<T, 3>::From(*grid_scale).setConstant(factor);
if (padding_mode == "border") {
// auto bounded_lo = grid_slice_t.cwiseMax(static_cast<T>(0));
auto res = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
auto in_bound = (res == grid_slice_t);
grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>();
grid_slice_t.device(place) = res;
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto is_neg = (grid_slice_t < static_cast<T>(0));
auto grid_abs = grid_slice_t.abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
auto one_more_flip = (extra > (double_range - extra));
grid_scale_t.device(place) =
grid_scale_t * ((is_neg == one_more_flip).template cast<T>() -
(is_neg != one_more_flip).template cast<T>());
grid_slice_t.device(place) = extra.cwiseMin(double_range - extra);
if (max_val == 0) {
grid_slice_t.device(place) = grid_slice_t.constant(static_cast<T>(0));
}
} else {
auto double_range = static_cast<T>((max_val + 1) * 2);
auto grid_abs = (grid_slice_t + static_cast<T>(0.5)).abs();
auto is_neg = ((grid_slice_t + static_cast<T>(0.5)) < static_cast<T>(0));
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
auto one_more_flip = (extra > (double_range - extra));
auto reflected =
extra.cwiseMin(double_range - extra) - static_cast<T>(0.5);
auto clipped = reflected.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
auto in_bound = (clipped == reflected).template cast<T>();
grid_scale_t.device(place) =
grid_scale_t * ((is_neg == one_more_flip).template cast<T>() -
(is_neg != one_more_flip).template cast<T>()) *
in_bound;
grid_slice_t.device(place) = clipped;
}
}
}
template <typename T>
static void calcGridLocations(const platform::CPUDeviceContext& ctx,
const Tensor& grid, const int in_h,
const int in_w, bool align_corners,
std::string padding_mode, Tensor* grid_x,
Tensor* grid_y) {
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
T* grid_x_data = grid_x->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
T* grid_y_data = grid_y->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_x_data[i] = grid_data[2 * i];
grid_y_data[i] = grid_data[(2 * i) + 1];
}
unnormalize<T>(ctx, grid_x, in_w - 1, align_corners);
unnormalize<T>(ctx, grid_y, in_h - 1, align_corners);
clip<T>(ctx, grid_x, in_w - 1, align_corners, padding_mode);
clip<T>(ctx, grid_y, in_h - 1, align_corners, padding_mode);
}
template <typename T>
static void calcGridLocationsWithGrad(const platform::CPUDeviceContext& ctx,
const Tensor& grid, const int in_h,
const int in_w, bool align_corners,
std::string padding_mode, Tensor* grid_x,
Tensor* grid_y, Tensor* grid_x_scale,
Tensor* grid_y_scale) {
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
T* grid_x_data = grid_x->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
T* grid_y_data = grid_y->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_x_data[i] = grid_data[2 * i];
grid_y_data[i] = grid_data[(2 * i) + 1];
}
unnormalize<T>(ctx, grid_x, in_w - 1, align_corners);
unnormalize<T>(ctx, grid_y, in_h - 1, align_corners);
clipWithMask<T>(ctx, in_w - 1, align_corners, padding_mode, grid_x,
grid_x_scale);
clipWithMask<T>(ctx, in_h - 1, align_corners, padding_mode, grid_y,
grid_y_scale);
}
template <typename T>
static void getGridPointValue(const Tensor& input, Tensor* output,
const Tensor& x, const Tensor& y) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_h = input.dims()[2];
const int in_w = input.dims()[3];
const int out_h = x.dims()[1];
const int out_w = x.dims()[2];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto output_t = EigenTensor<T, 4>::From(*output).setConstant((T)0);
auto input_t = EigenTensor<T, 4>::From(input);
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1),
(T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) =
input_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l))));
}
}
}
}
}
}
template <typename T>
static void allNeigbors(const platform::CPUDeviceContext& ctx,
const Tensor& input, Tensor* grid_x, Tensor* grid_y,
Tensor* x_w, Tensor* x_e, Tensor* y_n,
Tensor* y_s, // positions
Tensor* d_w, Tensor* d_e, Tensor* d_n,
Tensor* d_s, // distance
Tensor* v_wn, Tensor* v_en, Tensor* v_ws,
Tensor* v_es) { // values
auto& place = *ctx.eigen_device();
const int c = input.dims()[1];
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
// calculate coords of 4 corner points
x_w->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
x_e->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
y_n->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
y_s->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto x_w_t = EigenTensor<T, 3>::From(*x_w);
auto x_e_t = EigenTensor<T, 3>::From(*x_e);
auto y_n_t = EigenTensor<T, 3>::From(*y_n);
auto y_s_t = EigenTensor<T, 3>::From(*y_s);
auto grid_x_t = EigenTensor<T, 3>::From(*grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(*grid_y);
x_w_t.device(place) = grid_x_t.floor();
x_e_t.device(place) = x_w_t + static_cast<T>(1);
y_n_t.device(place) = grid_y_t.floor();
y_s_t.device(place) = y_n_t + static_cast<T>(1);
// calculate distances to 4 sides
d_w->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
d_e->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
d_n->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
d_s->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto d_w_t = EigenTensor<T, 3>::From(*d_w);
auto d_e_t = EigenTensor<T, 3>::From(*d_e);
auto d_n_t = EigenTensor<T, 3>::From(*d_n);
auto d_s_t = EigenTensor<T, 3>::From(*d_s);
d_w_t.device(place) = grid_x_t - x_w_t;
d_e_t.device(place) = x_e_t - grid_x_t;
d_n_t.device(place) = grid_y_t - y_n_t;
d_s_t.device(place) = y_s_t - grid_y_t;
// calc 4 corner points value
v_wn->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
v_en->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
v_ws->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
v_es->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
getGridPointValue<T>(input, v_wn, *x_w, *y_n);
getGridPointValue<T>(input, v_en, *x_e, *y_n);
getGridPointValue<T>(input, v_ws, *x_w, *y_s);
getGridPointValue<T>(input, v_es, *x_e, *y_s);
}
template <typename T>
static void bilinearInter(const platform::CPUDeviceContext& ctx,
const Tensor& input, Tensor* grid_x, Tensor* grid_y,
Tensor* out) {
auto& place = *ctx.eigen_device();
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
const int c = input.dims()[1];
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
Tensor v_wn, v_en, v_ws, v_es;
allNeigbors<T>(ctx, input, grid_x, grid_y, &x_w, &x_e, &y_n, &y_s, &d_w, &d_e,
&d_n, &d_s, &v_wn, &v_en, &v_ws, &v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto d_w_scaled_t =
d_w_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_e_scaled_t =
d_e_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_n_scaled_t =
d_n_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_s_scaled_t =
d_s_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto output_t = EigenTensor<T, 4>::From(*out);
// bilinear interpolaetion by 4 corner points
output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t +
v_en_t * d_w_scaled_t * d_s_scaled_t +
v_ws_t * d_e_scaled_t * d_n_scaled_t +
v_es_t * d_w_scaled_t * d_n_scaled_t;
}
template <typename T>
static void nearestInter(const platform::CPUDeviceContext& ctx,
const Tensor& input, Tensor* grid_x, Tensor* grid_y,
Tensor* out) {
auto& place = *ctx.eigen_device();
auto grid_x_t = EigenTensor<T, 3>::From(*grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(*grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
getGridPointValue<T>(input, out, *grid_x, *grid_y);
}
template <typename T>
static void gatherOutputGradToInputGrad(const Tensor& output_grad,
Tensor* input_grad, const Tensor& x,
const Tensor& y, const Tensor& d1,
const Tensor& d2) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int out_h = output_grad.dims()[2];
const int out_w = output_grad.dims()[3];
const int in_h = input_grad->dims()[2];
const int in_w = input_grad->dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto d1_t = EigenTensor<T, 3>::From(d1);
auto d2_t = EigenTensor<T, 3>::From(d2);
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1),
(T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
}
}
}
}
}
}
template <typename T>
static void gatherOutputGradToInputGrad(const Tensor& output_grad,
Tensor* input_grad, const Tensor& x,
const Tensor& y) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int out_h = output_grad.dims()[2];
const int out_w = output_grad.dims()[3];
const int in_h = input_grad->dims()[2];
const int in_w = input_grad->dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1),
(T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l);
}
}
}
}
}
}
template <typename T>
static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
const Tensor& input, const Tensor& output_grad,
Tensor* grid_x, Tensor* grid_y,
Tensor* grid_x_scale, Tensor* grid_y_scale,
Tensor* input_grad, Tensor* grid_grad) {
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
const int c = input.dims()[1];
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
Tensor v_wn, v_en, v_ws, v_es;
allNeigbors<T>(ctx, input,
grid_x, // grid_x
grid_y, // grid_y
&x_w, &x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s, &v_wn, &v_en,
&v_ws, &v_es);
// gather output grad value to input grad by corner point coords and weight
gatherOutputGradToInputGrad<T>(output_grad, input_grad, x_w, y_n, d_e, d_s);
gatherOutputGradToInputGrad<T>(output_grad, input_grad, x_w, y_s, d_e, d_n);
gatherOutputGradToInputGrad<T>(output_grad, input_grad, x_e, y_n, d_w, d_s);
gatherOutputGradToInputGrad<T>(output_grad, input_grad, x_e, y_s, d_w, d_n);
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
if (grid_grad != nullptr) {
Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto grid_grad_x_t =
EigenTensor<T, 3>::From(grid_grad_x).setConstant(static_cast<T>(0.0));
auto grid_grad_y_t =
EigenTensor<T, 3>::From(grid_grad_y).setConstant(static_cast<T>(0.0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
}
}
}
// const T x_max = static_cast<T>(in_w - 1);
// const T y_max = static_cast<T>(in_h - 1);
auto grid_x_scale_t = EigenTensor<T, 3>::From(*grid_x_scale);
auto grid_y_scale_t = EigenTensor<T, 3>::From(*grid_y_scale);
grid_grad_x_t = grid_grad_x_t * grid_x_scale_t;
grid_grad_y_t = grid_grad_y_t * grid_y_scale_t;
// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
}
}
template <typename DeviceContext, typename T>
class GridSampleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode = ctx.Attr<std::string>("padding_mode");
auto mode = ctx.Attr<std::string>("mode");
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), output,
static_cast<T>(0));
Tensor grid_x, grid_y;
calcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, in_h,
in_w, align_corners, padding_mode, &grid_x, &grid_y);
if (mode == "bilinear") {
bilinearInter<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *input,
&grid_x, &grid_y, output);
} else if (mode == "nearest") {
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
getGridPointValue<T>(*input, output, grid_x, grid_y);
}
}
};
template <typename DeviceContext, typename T>
class GridSampleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode = ctx.Attr<std::string>("padding_mode");
auto mode = ctx.Attr<std::string>("mode");
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0));
Tensor* grid_grad = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0));
}
Tensor grid_x, grid_y;
Tensor grid_x_scale, grid_y_scale;
calcGridLocationsWithGrad<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, in_h,
in_w, align_corners, padding_mode, &grid_x, &grid_y, &grid_x_scale,
&grid_y_scale);
if (mode == "bilinear") {
gatherBilinearGrad<T>(ctx.template device_context<DeviceContext>(),
*input, *output_grad, &grid_x, &grid_y,
&grid_x_scale, &grid_y_scale, input_grad,
grid_grad);
} else {
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
gatherOutputGradToInputGrad<T>(*output_grad, input_grad, grid_x, grid_y);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/grid_sample_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/grid_sample_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
static inline void ClipWithMask(const CPUContext& ctx,
const int max_val, // height-1 or width-1
bool align_corners,
std::string padding_mode,
DenseTensor* grid_slice,
DenseTensor* grid_scale) {
auto& place = *ctx.eigen_device();
grid_scale->Resize(grid_slice->dims());
ctx.Alloc<T>(grid_scale);
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
auto factor = static_cast<T>(max_val * 0.5);
if (!align_corners) {
factor = static_cast<T>((max_val + 1) * 0.5);
}
auto grid_scale_t = EigenTensor<T, 3>::From(*grid_scale).setConstant(factor);
if (padding_mode == "border") {
// auto bounded_lo = grid_slice_t.cwiseMax(static_cast<T>(0));
auto res = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
auto in_bound = (res == grid_slice_t);
grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>();
grid_slice_t.device(place) = res;
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto is_neg = (grid_slice_t < static_cast<T>(0));
auto grid_abs = grid_slice_t.abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
auto one_more_flip = (extra > (double_range - extra));
grid_scale_t.device(place) =
grid_scale_t * ((is_neg == one_more_flip).template cast<T>() -
(is_neg != one_more_flip).template cast<T>());
grid_slice_t.device(place) = extra.cwiseMin(double_range - extra);
if (max_val == 0) {
grid_slice_t.device(place) = grid_slice_t.constant(static_cast<T>(0));
}
} else {
auto double_range = static_cast<T>((max_val + 1) * 2);
auto grid_abs = (grid_slice_t + static_cast<T>(0.5)).abs();
auto is_neg = ((grid_slice_t + static_cast<T>(0.5)) < static_cast<T>(0));
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
auto one_more_flip = (extra > (double_range - extra));
auto reflected =
extra.cwiseMin(double_range - extra) - static_cast<T>(0.5);
auto clipped = reflected.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
auto in_bound = (clipped == reflected).template cast<T>();
grid_scale_t.device(place) =
grid_scale_t * ((is_neg == one_more_flip).template cast<T>() -
(is_neg != one_more_flip).template cast<T>()) *
in_bound;
grid_slice_t.device(place) = clipped;
}
}
}
template <typename T>
static void CalcGridLocationsWithGrad(const CPUContext& ctx,
const DenseTensor& grid,
const int in_h,
const int in_w,
bool align_corners,
std::string padding_mode,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* grid_x_scale,
DenseTensor* grid_y_scale) {
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
grid_x->Resize({n, out_h, out_w});
grid_y->Resize({n, out_h, out_w});
T* grid_x_data = ctx.Alloc<T>(grid_x);
T* grid_y_data = ctx.Alloc<T>(grid_y);
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_x_data[i] = grid_data[2 * i];
grid_y_data[i] = grid_data[(2 * i) + 1];
}
Unnormalize<T>(ctx, grid_x, in_w - 1, align_corners);
Unnormalize<T>(ctx, grid_y, in_h - 1, align_corners);
ClipWithMask<T>(
ctx, in_w - 1, align_corners, padding_mode, grid_x, grid_x_scale);
ClipWithMask<T>(
ctx, in_h - 1, align_corners, padding_mode, grid_y, grid_y_scale);
}
template <typename T>
static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& d1,
const DenseTensor& d2) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int out_h = output_grad.dims()[2];
const int out_w = output_grad.dims()[3];
const int in_h = input_grad->dims()[2];
const int in_w = input_grad->dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto d1_t = EigenTensor<T, 3>::From(d1);
auto d2_t = EigenTensor<T, 3>::From(d2);
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound(
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i,
j,
static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
}
}
}
}
}
}
template <typename T>
static void GatherBilinearGrad(const CPUContext& ctx,
const DenseTensor& input,
const DenseTensor& output_grad,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* grid_x_scale,
DenseTensor* grid_y_scale,
DenseTensor* input_grad,
DenseTensor* grid_grad) {
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
const int c = input.dims()[1];
DenseTensor x_w, x_e, y_n, y_s;
DenseTensor d_w, d_e, d_n, d_s;
DenseTensor v_wn, v_en, v_ws, v_es;
AllNeigbors<T>(ctx,
input,
grid_x, // grid_x
grid_y, // grid_y
&x_w,
&x_e,
&y_n,
&y_s,
&d_w,
&d_e,
&d_n,
&d_s,
&v_wn,
&v_en,
&v_ws,
&v_es);
// gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad<T>(output_grad, input_grad, x_w, y_n, d_e, d_s);
GatherOutputGradToInputGrad<T>(output_grad, input_grad, x_w, y_s, d_e, d_n);
GatherOutputGradToInputGrad<T>(output_grad, input_grad, x_e, y_n, d_w, d_s);
GatherOutputGradToInputGrad<T>(output_grad, input_grad, x_e, y_s, d_w, d_n);
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
if (grid_grad != nullptr) {
DenseTensor grid_grad_x, grid_grad_y;
grid_grad_x.Resize({n, out_h, out_w});
grid_grad_y.Resize({n, out_h, out_w});
ctx.Alloc<T>(&grid_grad_x);
ctx.Alloc<T>(&grid_grad_y);
auto grid_grad_x_t =
EigenTensor<T, 3>::From(grid_grad_x).setConstant(static_cast<T>(0.0));
auto grid_grad_y_t =
EigenTensor<T, 3>::From(grid_grad_y).setConstant(static_cast<T>(0.0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
}
}
}
// const T x_max = static_cast<T>(in_w - 1);
// const T y_max = static_cast<T>(in_h - 1);
auto grid_x_scale_t = EigenTensor<T, 3>::From(*grid_x_scale);
auto grid_y_scale_t = EigenTensor<T, 3>::From(*grid_y_scale);
grid_grad_x_t = grid_grad_x_t * grid_x_scale_t;
grid_grad_y_t = grid_grad_y_t * grid_y_scale_t;
// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
}
}
template <typename T>
static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const DenseTensor& x,
const DenseTensor& y) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int out_h = output_grad.dims()[2];
const int out_w = output_grad.dims()[3];
const int in_h = input_grad->dims()[2];
const int in_w = input_grad->dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound(
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i,
j,
static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l);
}
}
}
}
}
}
template <typename T, typename Context>
void GridSampleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& grid,
const DenseTensor& out_grid,
const std::string& mode,
const std::string& padding_mode,
bool align_corners,
DenseTensor* x_grad,
DenseTensor* grid_grad) {
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
const int c = x.dims()[1];
const int in_h = x.dims()[2];
const int in_w = x.dims()[3];
x_grad->Resize({n, c, in_h, in_w});
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T>()(dev_ctx, x_grad, static_cast<T>(0));
if (grid_grad != nullptr) {
grid_grad->Resize({n, out_h, out_w, 2});
dev_ctx.template Alloc<T>(grid_grad);
phi::funcs::SetConstant<Context, T>()(
dev_ctx, grid_grad, static_cast<T>(0));
}
DenseTensor grid_x, grid_y;
DenseTensor grid_x_scale, grid_y_scale;
CalcGridLocationsWithGrad<T>(dev_ctx,
grid,
in_h,
in_w,
align_corners,
padding_mode,
&grid_x,
&grid_y,
&grid_x_scale,
&grid_y_scale);
if (mode == "bilinear") {
GatherBilinearGrad<T>(dev_ctx,
x,
out_grid,
&grid_x,
&grid_y,
&grid_x_scale,
&grid_y_scale,
x_grad,
grid_grad);
} else {
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
GatherOutputGradToInputGrad<T>(out_grid, x_grad, grid_x, grid_y);
}
}
} // namespace phi
PD_REGISTER_KERNEL(grid_sample_grad,
CPU,
ALL_LAYOUT,
phi::GridSampleGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/grid_sample_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/grid_sample_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using Array4 = Eigen::DSizes<int64_t, 4>;
template <typename T>
static inline void Clip(const CPUContext& ctx,
DenseTensor* grid_slice,
const int max_val, // height-1 or width-1
bool align_corners,
std::string padding_mode) {
auto& place = *ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
if (padding_mode == "border") {
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto grid_abs = grid_slice_t.abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
grid_slice_t.device(place) = extra.cwiseMin(double_range - extra);
if (max_val == 0) {
grid_slice_t.device(place) = grid_slice_t.constant(static_cast<T>(0));
}
} else {
auto double_range = static_cast<T>((max_val + 1) * 2);
auto grid_abs = (grid_slice_t + static_cast<T>(0.5)).abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
grid_slice_t.device(place) =
extra.cwiseMin(double_range - extra) - static_cast<T>(0.5);
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
}
}
}
template <typename T>
static void CalcGridLocations(const CPUContext& ctx,
const DenseTensor& grid,
const int in_h,
const int in_w,
bool align_corners,
std::string padding_mode,
DenseTensor* grid_x,
DenseTensor* grid_y) {
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
grid_x->Resize({n, out_h, out_w});
grid_y->Resize({n, out_h, out_w});
T* grid_x_data = ctx.Alloc<T>(grid_x);
T* grid_y_data = ctx.Alloc<T>(grid_y);
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_x_data[i] = grid_data[2 * i];
grid_y_data[i] = grid_data[(2 * i) + 1];
}
Unnormalize<T>(ctx, grid_x, in_w - 1, align_corners);
Unnormalize<T>(ctx, grid_y, in_h - 1, align_corners);
Clip<T>(ctx, grid_x, in_w - 1, align_corners, padding_mode);
Clip<T>(ctx, grid_y, in_h - 1, align_corners, padding_mode);
}
template <typename T>
static void BilinearInter(const CPUContext& ctx,
const DenseTensor& input,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* out) {
auto& place = *ctx.eigen_device();
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
const int c = input.dims()[1];
DenseTensor x_w, x_e, y_n, y_s;
DenseTensor d_w, d_e, d_n, d_s;
DenseTensor v_wn, v_en, v_ws, v_es;
AllNeigbors<T>(ctx,
input,
grid_x,
grid_y,
&x_w,
&x_e,
&y_n,
&y_s,
&d_w,
&d_e,
&d_n,
&d_s,
&v_wn,
&v_en,
&v_ws,
&v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto d_w_scaled_t =
d_w_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_e_scaled_t =
d_e_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_n_scaled_t =
d_n_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_s_scaled_t =
d_s_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto output_t = EigenTensor<T, 4>::From(*out);
// bilinear interpolaetion by 4 corner points
output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t +
v_en_t * d_w_scaled_t * d_s_scaled_t +
v_ws_t * d_e_scaled_t * d_n_scaled_t +
v_es_t * d_w_scaled_t * d_n_scaled_t;
}
template <typename T, typename Context>
void GridSampleKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& grid,
const std::string& mode,
const std::string& padding_mode,
bool align_corners,
DenseTensor* out) {
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
const int c = x.dims()[1];
const int in_h = x.dims()[2];
const int in_w = x.dims()[3];
out->Resize(phi::make_ddim({n, c, out_h, out_w}));
dev_ctx.template Alloc<T>(out);
phi::funcs::SetConstant<Context, T>()(dev_ctx, out, static_cast<T>(0));
DenseTensor grid_x, grid_y;
CalcGridLocations<T>(
dev_ctx, grid, in_h, in_w, align_corners, padding_mode, &grid_x, &grid_y);
if (mode == "bilinear") {
BilinearInter<T>(dev_ctx, x, &grid_x, &grid_y, out);
} else if (mode == "nearest") {
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
GetGridPointValue<T>(x, out, grid_x, grid_y);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
grid_sample, CPU, ALL_LAYOUT, phi::GridSampleKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T>
void Unnormalize(const CPUContext& ctx,
DenseTensor* grid_slice,
const int max_val, // height-1 or width-1
bool align_corners) {
auto& place = *ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
if (!align_corners) {
auto factor = static_cast<T>((max_val + 1) * 0.5);
grid_slice_t.device(place) =
(grid_slice_t + static_cast<T>(1)) * factor - static_cast<T>(0.5);
} else {
auto factor = static_cast<T>(max_val * 0.5);
grid_slice_t.device(place) = (grid_slice_t + static_cast<T>(1)) * factor;
}
}
template <typename T>
inline bool IsInBound(T x, T y, T x_max, T y_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max) {
return false;
}
return true;
}
template <typename T>
void GetGridPointValue(const DenseTensor& input,
DenseTensor* output,
const DenseTensor& x,
const DenseTensor& y) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_h = input.dims()[2];
const int in_w = input.dims()[3];
const int out_h = x.dims()[1];
const int out_w = x.dims()[2];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto output_t = EigenTensor<T, 4>::From(*output).setConstant((T)0);
auto input_t = EigenTensor<T, 4>::From(input);
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound(
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) =
input_t(i,
j,
static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l))));
}
}
}
}
}
}
template <typename T>
void AllNeigbors(const CPUContext& ctx,
const DenseTensor& input,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* x_w,
DenseTensor* x_e,
DenseTensor* y_n,
DenseTensor* y_s, // positions
DenseTensor* d_w,
DenseTensor* d_e,
DenseTensor* d_n,
DenseTensor* d_s, // distance
DenseTensor* v_wn,
DenseTensor* v_en,
DenseTensor* v_ws,
DenseTensor* v_es) { // values
auto& place = *ctx.eigen_device();
const int c = input.dims()[1];
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
// calculate coords of 4 corner points
x_w->Resize({n, out_h, out_w});
x_e->Resize({n, out_h, out_w});
y_n->Resize({n, out_h, out_w});
y_s->Resize({n, out_h, out_w});
ctx.Alloc<T>(x_w);
ctx.Alloc<T>(x_e);
ctx.Alloc<T>(y_n);
ctx.Alloc<T>(y_s);
auto x_w_t = EigenTensor<T, 3>::From(*x_w);
auto x_e_t = EigenTensor<T, 3>::From(*x_e);
auto y_n_t = EigenTensor<T, 3>::From(*y_n);
auto y_s_t = EigenTensor<T, 3>::From(*y_s);
auto grid_x_t = EigenTensor<T, 3>::From(*grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(*grid_y);
x_w_t.device(place) = grid_x_t.floor();
x_e_t.device(place) = x_w_t + static_cast<T>(1);
y_n_t.device(place) = grid_y_t.floor();
y_s_t.device(place) = y_n_t + static_cast<T>(1);
// calculate distances to 4 sides
d_w->Resize({n, out_h, out_w});
d_e->Resize({n, out_h, out_w});
d_n->Resize({n, out_h, out_w});
d_s->Resize({n, out_h, out_w});
ctx.Alloc<T>(d_w);
ctx.Alloc<T>(d_e);
ctx.Alloc<T>(d_n);
ctx.Alloc<T>(d_s);
auto d_w_t = EigenTensor<T, 3>::From(*d_w);
auto d_e_t = EigenTensor<T, 3>::From(*d_e);
auto d_n_t = EigenTensor<T, 3>::From(*d_n);
auto d_s_t = EigenTensor<T, 3>::From(*d_s);
d_w_t.device(place) = grid_x_t - x_w_t;
d_e_t.device(place) = x_e_t - grid_x_t;
d_n_t.device(place) = grid_y_t - y_n_t;
d_s_t.device(place) = y_s_t - grid_y_t;
// calc 4 corner points value
v_wn->Resize({n, c, out_h, out_w});
v_en->Resize({n, c, out_h, out_w});
v_ws->Resize({n, c, out_h, out_w});
v_es->Resize({n, c, out_h, out_w});
ctx.Alloc<T>(v_wn);
ctx.Alloc<T>(v_en);
ctx.Alloc<T>(v_ws);
ctx.Alloc<T>(v_es);
GetGridPointValue<T>(input, v_wn, *x_w, *y_n);
GetGridPointValue<T>(input, v_en, *x_e, *y_n);
GetGridPointValue<T>(input, v_ws, *x_w, *y_s);
GetGridPointValue<T>(input, v_es, *x_e, *y_s);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/grid_sample_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/grid_sample_utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace phi {
template <typename T>
static __forceinline__ __device__ void AtomicAdd(
T* data, int h, int w, int sH, int sW, int H, int W, T delta) {
if (InBounds(h, w, H, W)) {
paddle::platform::CudaAtomicAdd(data + h * sH + w * sW, delta);
}
}
template <typename T>
static __forceinline__ __device__ T
UnnormalizeWithMask(T coord, int size, bool align_corners, T* grad_in) {
if (align_corners) {
*grad_in = static_cast<T>(size - 1) / 2;
return ((coord + 1.f) / 2) * (size - 1);
} else {
*grad_in = static_cast<T>(size) / 2;
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T ClipIndexesWithMask(T in,
int clip_limit,
T* grad_in) {
if (in <= static_cast<T>(0)) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
} else {
T max = static_cast<T>(clip_limit - 1);
if (in >= max) {
*grad_in = static_cast<T>(0);
return max;
} else {
*grad_in = static_cast<T>(1);
return in;
}
}
}
template <typename T>
static __forceinline__ __device__ T
ReflectIndexesWithMask(T in, int twice_low, int twice_high, T* grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
}
int grad_in_mult_;
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<T>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<T>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<T>(-grad_in_mult_);
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T
ComputePositionsWithMask(T coord,
int size,
PaddingMode padding_mode,
bool align_corners,
T* grad_in) {
T grad_clip, grad_refl;
coord = UnnormalizeWithMask<T>(coord, size, align_corners, grad_in);
if (padding_mode == PaddingMode::border) {
coord = ClipIndexesWithMask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = ReflectIndexesWithMask(coord, 0, 2 * (size - 1), &grad_refl);
} else {
coord = ReflectIndexesWithMask(coord, -1, 2 * size - 1, &grad_refl);
}
coord = ClipIndexesWithMask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
return coord;
}
template <typename T>
__global__ void GridSamplerCudaBackwardKernel(const int nthreads,
const T* grad_output,
const T* input,
const T* grid,
int n,
int out_c,
int out_h,
int out_w,
int in_h,
int in_w,
T* grad_input,
T* grad_grid,
const Mode mode,
const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;
int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int gOut_sN = out_c * out_h * out_w;
int gOut_sC = out_h * out_w;
int gOut_sH = out_w;
int gOut_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
T gix_mult, giy_mult;
ix = ComputePositionsWithMask(
ix, in_w, padding_mode, align_corners, &gix_mult);
iy = ComputePositionsWithMask(
iy, in_h, padding_mode, align_corners, &giy_mult);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
T gix = static_cast<T>(0), giy = static_cast<T>(0);
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
int inp_offset_NC = n * inp_sN;
for (int c = 0; c < out_c; ++c,
inp_offset_NC += inp_sC,
gInp_ptr_NC += inp_sC,
gOut_offset += gOut_sC) {
T gOut = grad_output[gOut_offset];
AtomicAdd(
gInp_ptr_NC, iy_nw, ix_nw, inp_sH, inp_sW, in_h, in_w, nw * gOut);
AtomicAdd(
gInp_ptr_NC, iy_ne, ix_ne, inp_sH, inp_sW, in_h, in_w, ne * gOut);
AtomicAdd(
gInp_ptr_NC, iy_sw, ix_sw, inp_sH, inp_sW, in_h, in_w, sw * gOut);
AtomicAdd(
gInp_ptr_NC, iy_se, ix_se, inp_sH, inp_sW, in_h, in_w, se * gOut);
if (InBounds(iy_nw, ix_nw, in_h, in_w)) {
T nw_val = input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW];
gix -= nw_val * (iy_se - iy) * gOut;
giy -= nw_val * (ix_se - ix) * gOut;
}
if (InBounds(iy_ne, ix_ne, in_h, in_w)) {
T ne_val = input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW];
gix += ne_val * (iy_sw - iy) * gOut;
giy -= ne_val * (ix - ix_sw) * gOut;
}
if (InBounds(iy_sw, ix_sw, in_h, in_w)) {
T sw_val = input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW];
gix -= sw_val * (iy - iy_ne) * gOut;
giy += sw_val * (ix_ne - ix) * gOut;
}
if (InBounds(iy_se, ix_se, in_h, in_w)) {
T se_val = input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW];
gix += se_val * (iy - iy_nw) * gOut;
giy += se_val * (ix - ix_nw) * gOut;
}
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
for (int c = 0; c < out_c;
++c, gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) {
AtomicAdd(gInp_ptr_NC,
iy_nearest,
ix_nearest,
inp_sH,
inp_sW,
in_h,
in_w,
grad_output[gOut_offset]);
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0);
}
}
}
}
template <typename T, typename Context>
void GridSampleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& grid,
const DenseTensor& out_grad,
const std::string& mode,
const std::string& padding_mode,
bool align_corners,
DenseTensor* x_grad,
DenseTensor* grid_grad) {
PaddingMode enum_padding_mode;
Mode enum_mode;
if (padding_mode == "border") {
enum_padding_mode = PaddingMode::border;
} else if (padding_mode == "reflection") {
enum_padding_mode = PaddingMode::reflect;
} else {
enum_padding_mode = PaddingMode::zeros;
}
if (mode == "nearest") {
enum_mode = Mode::nearest;
} else {
enum_mode = Mode::bilinear;
}
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
const int c = x.dims()[1];
const int in_h = x.dims()[2];
const int in_w = x.dims()[3];
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T>()(dev_ctx, x_grad, static_cast<T>(0));
T* grid_grad_data = nullptr;
if (grid_grad != nullptr) {
grid_grad_data = dev_ctx.template Alloc<T>(grid_grad);
}
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count);
GridSamplerCudaBackwardKernel<
T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count,
out_grad.data<T>(),
x.data<T>(),
grid.data<T>(),
n,
c,
out_h,
out_w,
in_h,
in_w,
x_grad->data<T>(),
grid_grad_data,
enum_mode,
enum_padding_mode,
align_corners);
}
} // namespace phi
PD_REGISTER_KERNEL(grid_sample_grad,
GPU,
ALL_LAYOUT,
phi::GridSampleGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/grid_sample_kernel.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/grid_sample_utils.h"
namespace phi {
template <typename T>
static __forceinline__ __device__ T Unnormalize(T coord,
int size,
bool align_corners) {
if (align_corners) {
return ((coord + 1.f) / 2) * (size - 1);
} else {
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T ClipIndexes(T in, int max_value) {
return min(static_cast<T>(max_value), max(in, static_cast<T>(0)));
}
template <typename T>
static __forceinline__ __device__ T ReflectIndexes(T in,
int twice_low,
int twice_high) {
if (twice_low == twice_high) {
return static_cast<T>(0);
}
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = fabs(in - min);
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T ComputePositions(T coord,
int size,
PaddingMode padding_mode,
bool align_corners) {
coord = Unnormalize<T>(coord, size, align_corners);
if (padding_mode == PaddingMode::border) {
coord = ClipIndexes(coord, size - 1);
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = ReflectIndexes(coord, 0, 2 * (size - 1));
} else {
coord = ReflectIndexes(coord, -1, 2 * size - 1);
}
coord = ClipIndexes(coord, size - 1);
}
return coord;
}
template <typename T>
__global__ void GridSampleCudaKernel(const int nthreads,
int n,
int out_c,
int out_h,
int out_w,
int in_h,
int in_w,
const T* input,
const T* grid,
T* output,
const Mode mode,
const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;
int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int out_sN = out_c * out_h * out_w;
int out_sC = out_h * out_w;
int out_sH = out_w;
int out_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
ix = ComputePositions(ix, in_w, padding_mode, align_corners);
iy = ComputePositions(iy, in_h, padding_mode, align_corners);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
*out_ptr_NCHW = static_cast<T>(0);
if (InBounds(iy_nw, ix_nw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
}
if (InBounds(iy_ne, ix_ne, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
}
if (InBounds(iy_sw, ix_sw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
}
if (InBounds(iy_se, ix_se, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
}
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
if (InBounds(iy_nearest, ix_nearest, in_h, in_w)) {
*out_ptr_NCHW =
input[inp_offset_NC + iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCHW = static_cast<T>(0);
}
}
}
}
}
template <typename T, typename Context>
void GridSampleKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& grid,
const std::string& mode,
const std::string& padding_mode,
bool align_corners,
DenseTensor* out) {
PaddingMode enum_padding_mode;
Mode enum_mode;
if (padding_mode == "border") {
enum_padding_mode = PaddingMode::border;
} else if (padding_mode == "reflection") {
enum_padding_mode = PaddingMode::reflect;
} else {
enum_padding_mode = PaddingMode::zeros;
}
if (mode == "nearest") {
enum_mode = Mode::nearest;
} else {
enum_mode = Mode::bilinear;
}
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
const int c = x.dims()[1];
const int in_h = x.dims()[2];
const int in_w = x.dims()[3];
VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h
<< "; out_w: " << out_w;
auto* output_data = dev_ctx.template Alloc<T>(out);
VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; "
<< out->dims()[2] << "; " << out->dims()[3];
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count);
GridSampleCudaKernel<
T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count,
n,
c,
out_h,
out_w,
in_h,
in_w,
x.data<T>(),
grid.data<T>(),
output_data,
enum_mode,
enum_padding_mode,
align_corners);
}
} // namespace phi
PD_REGISTER_KERNEL(
grid_sample, GPU, ALL_LAYOUT, phi::GridSampleKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
enum class Mode {
bilinear,
nearest,
};
enum class PaddingMode { zeros, border, reflect };
static __forceinline__ __device__ bool InBounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GridSampleGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &grid,
const DenseTensor &out_grid,
const std::string &mode,
const std::string &padding_mode,
bool align_corners,
DenseTensor *x_grad,
DenseTensor *grid_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GridSampleKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &grid,
const std::string &mode,
const std::string &padding_mode,
bool align_corners,
DenseTensor *out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GridSamplerOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("grid_sample",
{"X", "Grid"},
{"mode", "padding_mode", "align_corners"},
{"Output"});
}
KernelSignature GridSamplerGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("grid_sample_grad",
{"X", "Grid", GradVarName("Output")},
{"mode", "padding_mode", "align_corners"},
{GradVarName("X"), GradVarName("Grid")});
}
} // namespace phi
// use Python API name as kernel name
PD_REGISTER_BASE_KERNEL_NAME(grid_sampler, grid_sample);
PD_REGISTER_BASE_KERNEL_NAME(grid_sampler_grad, grid_sample_grad);
PD_REGISTER_ARG_MAPPING_FN(grid_sampler, phi::GridSamplerOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(grid_sampler_grad,
phi::GridSamplerGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册