未验证 提交 79539cf1 编写于 作者: W whs 提交者: GitHub

【2.0 API】Add CUDA kernel and enhance options for grid_sample (#26576)

This PR enhance CPU kernel and add new CUDA kernel to make grid_sample support:

- align_corners: with bool type.
- padding mode: which can be in ['zeros', 'reflect', 'border']
- Interpolation mode: which ca be in ['bilinear', 'nearest']

The old CPU and CUDNN version only support align_corners=true, padding_mode='zeros' and interpolation_mode='bilinear'.

The behavior of the new version op in default mode is compatible with the old version.
上级 39fe0d35
......@@ -41,13 +41,14 @@ class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
int n = input->dims()[0];
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
const int size[4] = {n, c, h, w};
int out_h = grid->dims()[1];
int out_w = grid->dims()[2];
const int size[4] = {n, c, out_h, out_w};
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
T* output_data = output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
T* output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
......@@ -97,7 +98,7 @@ class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
const T* grid_data = grid->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
input_grad->mutable_data<T>(input->dims(), ctx.GetPlace());
T* grid_grad_data =
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
......@@ -58,21 +59,10 @@ class GridSampleOp : public framework::OperatorWithKernel {
"Input(X) and Input(Grid) dimension[0] should be equal, but "
"received X dimension[0](%d) != Grid dimension[0](%d)",
x_dims[0], grid_dims[0]));
PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2],
platform::errors::InvalidArgument(
"Input(X) dims[2] and Input(Grid) dims[1] should be equal, but "
"received X dimension[2](%d) != Grid dimension[1](%d)",
x_dims[2], grid_dims[1]));
PADDLE_ENFORCE_EQ(
grid_dims[2], x_dims[3],
platform::errors::InvalidArgument(
"Input(X) dims[3] and Input(Grid) dims[2] should be equal, but "
"received X dimension[3](%d) != Grid dimension[2](%d)",
x_dims[3], grid_dims[2]));
}
ctx->SetOutputDim("Output", x_dims);
ctx->SetOutputDim("Output",
{x_dims[0], x_dims[1], grid_dims[1], grid_dims[2]});
ctx->ShareLoD("X", "Output");
}
......@@ -108,15 +98,37 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default true) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddAttr<bool>(
"align_corners",
"(bool, default true) If align_corners is true, it will project"
"-1 and 1 to the centers of the corner pixels. Otherwise, it will "
"project"
"-1 and 1 to the image edges.")
.SetDefault(true);
AddAttr<std::string>(
"mode",
"(bool, default true) The interpolation method which can be 'bilinear'"
" or 'nearest'.")
.SetDefault("bilinear");
AddAttr<std::string>(
"padding_mode",
"(bool, default true) The padding method used when source"
"index is out of input images. It can be 'zeros', 'reflect' and "
"'border'.")
.SetDefault("zeros");
AddComment(R"DOC(
This operation samples input X by using bilinear interpolation based on
This operation samples input X by using bilinear or nearest interpolation based on
flow field grid, which is usually generated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexing the 3rd
dimension (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
interpolation value or nearest value of 4 nearest corner points.
For bilinear interpolation mode:
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
static __forceinline__ __device__ bool in_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename T>
static __forceinline__ __device__ void atomic_add(T* data, int h, int w, int sH,
int sW, int H, int W,
T delta) {
if (in_bounds(h, w, H, W)) {
atomicAdd(data + h * sH + w * sW, delta);
}
}
template <typename T>
static __forceinline__ __device__ T _unnormalize(T coord, int size,
bool align_corners) {
if (align_corners) {
return ((coord + 1.f) / 2) * (size - 1);
} else {
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T clip_indexes(T in, int max_value) {
return min(static_cast<T>(max_value), max(in, static_cast<T>(0)));
}
template <typename T>
static __forceinline__ __device__ T reflect_indexes(T in, int twice_low,
int twice_high) {
if (twice_low == twice_high) {
return static_cast<T>(0);
}
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = fabs(in - min);
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T compute_positions(T coord, int size,
PaddingMode padding_mode,
bool align_corners) {
coord = _unnormalize<T>(coord, size, align_corners);
if (padding_mode == PaddingMode::border) {
coord = clip_indexes(coord, size - 1);
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = reflect_indexes(coord, 0, 2 * (size - 1));
} else {
coord = reflect_indexes(coord, -1, 2 * size - 1);
}
coord = clip_indexes(coord, size - 1);
}
return coord;
}
template <typename T>
static __forceinline__ __device__ T _unnormalize_with_mask(T coord, int size,
bool align_corners,
T* grad_in) {
if (align_corners) {
*grad_in = static_cast<T>(size - 1) / 2;
return ((coord + 1.f) / 2) * (size - 1);
} else {
*grad_in = static_cast<T>(size) / 2;
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T clip_indexes_with_mask(T in, int clip_limit,
T* grad_in) {
if (in <= static_cast<T>(0)) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
} else {
T max = static_cast<T>(clip_limit - 1);
if (in >= max) {
*grad_in = static_cast<T>(0);
return max;
} else {
*grad_in = static_cast<T>(1);
return in;
}
}
}
template <typename T>
static __forceinline__ __device__ T
reflect_indexes_with_mask(T in, int twice_low, int twice_high, T* grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
}
int grad_in_mult_;
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<T>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<T>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<T>(-grad_in_mult_);
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T
compute_positions_with_mask(T coord, int size, PaddingMode padding_mode,
bool align_corners, T* grad_in) {
T grad_clip, grad_refl;
coord = _unnormalize_with_mask<T>(coord, size, align_corners, grad_in);
if (padding_mode == PaddingMode::border) {
coord = clip_indexes_with_mask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = reflect_indexes_with_mask(coord, 0, 2 * (size - 1), &grad_refl);
} else {
coord = reflect_indexes_with_mask(coord, -1, 2 * size - 1, &grad_refl);
}
coord = clip_indexes_with_mask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
return coord;
}
template <typename T>
__global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c,
int out_h, int out_w, int in_h,
int in_w, const T* input, const T* grid,
T* output, const Mode mode,
const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;
int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int out_sN = out_c * out_h * out_w;
int out_sC = out_h * out_w;
int out_sH = out_w;
int out_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
ix = compute_positions(ix, in_w, padding_mode, align_corners);
iy = compute_positions(iy, in_h, padding_mode, align_corners);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
*out_ptr_NCHW = static_cast<T>(0);
if (in_bounds(iy_nw, ix_nw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
}
if (in_bounds(iy_ne, ix_ne, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
}
if (in_bounds(iy_sw, ix_sw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
}
if (in_bounds(iy_se, ix_se, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
}
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(round(ix));
int iy_nearest = static_cast<int>(round(iy));
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
if (in_bounds(iy_nearest, ix_nearest, in_h, in_w)) {
*out_ptr_NCHW =
input[inp_offset_NC + iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCHW = static_cast<T>(0);
}
}
}
}
}
template <typename T>
class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.cuda_device_context();
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode_s = ctx.Attr<std::string>("padding_mode");
auto mode_s = ctx.Attr<std::string>("mode");
PaddingMode padding_mode;
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
}
if (mode_s == "nearest") {
mode = Mode::nearest;
} else {
mode = Mode::bilinear;
}
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h
<< "; out_w: " << out_w;
auto* output = ctx.Output<Tensor>("Output");
auto* output_data = output->mutable_data<T>(ctx.GetPlace());
VLOG(3) << "set constant";
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
dev_ctx, output, static_cast<T>(0));
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
int block = 512;
int grid_size = (count + block - 1) / block;
grid_sample_cuda_kernel<T><<<block, grid_size, 0, cu_stream>>>(
count, n, c, out_h, out_w, in_h, in_w, input->data<T>(),
grid->data<T>(), output_data, mode, padding_mode, align_corners);
}
};
template <typename T>
__global__ void grid_sampler_cuda_backward_kernel(
const int nthreads, const T* grad_output, const T* input, const T* grid,
int n, int out_c, int out_h, int out_w, int in_h, int in_w, T* grad_input,
T* grad_grid, const Mode mode, const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;
int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int gOut_sN = out_c * out_h * out_w;
int gOut_sC = out_h * out_w;
int gOut_sH = out_w;
int gOut_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
T gix_mult, giy_mult;
ix = compute_positions_with_mask(ix, in_w, padding_mode, align_corners,
&gix_mult);
iy = compute_positions_with_mask(iy, in_h, padding_mode, align_corners,
&giy_mult);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
T gix = static_cast<T>(0), giy = static_cast<T>(0);
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
int inp_offset_NC = n * inp_sN;
for (int c = 0; c < out_c; ++c, inp_offset_NC += inp_sC,
gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) {
T gOut = grad_output[gOut_offset];
atomic_add(gInp_ptr_NC, iy_nw, ix_nw, inp_sH, inp_sW, in_h, in_w,
nw * gOut);
atomic_add(gInp_ptr_NC, iy_ne, ix_ne, inp_sH, inp_sW, in_h, in_w,
ne * gOut);
atomic_add(gInp_ptr_NC, iy_sw, ix_sw, inp_sH, inp_sW, in_h, in_w,
sw * gOut);
atomic_add(gInp_ptr_NC, iy_se, ix_se, inp_sH, inp_sW, in_h, in_w,
se * gOut);
if (in_bounds(iy_nw, ix_nw, in_h, in_w)) {
T nw_val = input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW];
gix -= nw_val * (iy_se - iy) * gOut;
giy -= nw_val * (ix_se - ix) * gOut;
}
if (in_bounds(iy_ne, ix_ne, in_h, in_w)) {
T ne_val = input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW];
gix += ne_val * (iy_sw - iy) * gOut;
giy -= ne_val * (ix - ix_sw) * gOut;
}
if (in_bounds(iy_sw, ix_sw, in_h, in_w)) {
T sw_val = input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW];
gix -= sw_val * (iy - iy_ne) * gOut;
giy += sw_val * (ix_ne - ix) * gOut;
}
if (in_bounds(iy_se, ix_se, in_h, in_w)) {
T se_val = input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW];
gix += se_val * (iy - iy_nw) * gOut;
giy += se_val * (ix - ix_nw) * gOut;
}
}
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
for (int c = 0; c < out_c;
++c, gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) {
atomic_add(gInp_ptr_NC, iy_nearest, ix_nearest, inp_sH, inp_sW, in_h,
in_w, grad_output[gOut_offset]);
}
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0);
}
}
}
template <typename T>
class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.cuda_device_context();
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode_s = ctx.Attr<std::string>("padding_mode");
auto mode_s = ctx.Attr<std::string>("mode");
PaddingMode padding_mode;
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
}
if (mode_s == "nearest") {
mode = Mode::nearest;
} else {
mode = Mode::bilinear;
}
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
input_grad, static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
grid_grad, static_cast<T>(0));
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
int block = 512;
int grid_size = (count + block - 1) / block;
grid_sampler_cuda_backward_kernel<T><<<block, grid_size, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad->data<T>(),
mode, padding_mode, align_corners);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(grid_sampler, ops::GridSampleOpCUDAKernel<float>,
ops::GridSampleOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(grid_sampler_grad,
ops::GridSampleGradOpCUDAKernel<float>,
ops::GridSampleGradOpCUDAKernel<double>);
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <utility>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
......@@ -22,6 +25,13 @@ limitations under the License. */
namespace paddle {
namespace operators {
enum class Mode {
bilinear,
nearest,
};
enum class PaddingMode { zeros, border, reflect };
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
......@@ -39,64 +49,229 @@ static inline bool isInBound(T x, T y, T x_max, T y_max) {
}
template <typename T>
static void CalcGridLocations(const platform::CPUDeviceContext& ctx,
const Tensor& grid, Tensor* x_w, Tensor* x_e,
Tensor* y_n, Tensor* y_s, Tensor* d_w,
Tensor* d_e, Tensor* d_n, Tensor* d_s) {
static inline void unnormalize(const platform::CPUDeviceContext& ctx,
Tensor* grid_slice,
const int max_val, // height-1 or width-1
bool align_corners) {
auto& place = *ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
if (!align_corners) {
auto factor = static_cast<T>((max_val + 1) * 0.5);
grid_slice_t.device(place) =
(grid_slice_t + static_cast<T>(1)) * factor - static_cast<T>(0.5);
} else {
auto factor = static_cast<T>(max_val * 0.5);
grid_slice_t.device(place) = (grid_slice_t + static_cast<T>(1)) * factor;
}
}
template <typename T>
static inline void clip(const platform::CPUDeviceContext& ctx,
Tensor* grid_slice,
const int max_val, // height-1 or width-1
bool align_corners, std::string padding_mode) {
auto& place = *ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
if (padding_mode == "border") {
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
} else if (padding_mode == "reflect") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto grid_abs = grid_slice_t.abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
grid_slice_t.device(place) = extra.cwiseMin(double_range - extra);
} else {
auto double_range = static_cast<T>((max_val + 1) * 2);
auto grid_abs = (grid_slice_t + static_cast<T>(0.5)).abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
grid_slice_t.device(place) =
extra.cwiseMin(double_range - extra) - static_cast<T>(0.5);
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
}
}
}
template <typename T>
static inline void clipWithMask(const platform::CPUDeviceContext& ctx,
const int max_val, // height-1 or width-1
bool align_corners, std::string padding_mode,
Tensor* grid_slice, Tensor* grid_scale) {
auto& place = *ctx.eigen_device();
grid_scale->mutable_data<T>(grid_slice->dims(), ctx.GetPlace());
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
auto factor = static_cast<T>(max_val * 0.5);
if (!align_corners) {
factor = static_cast<T>((max_val + 1) * 0.5);
}
auto grid_scale_t = EigenTensor<T, 3>::From(*grid_scale).setConstant(factor);
if (padding_mode == "border") {
// auto bounded_lo = grid_slice_t.cwiseMax(static_cast<T>(0));
auto res = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
auto in_bound = (res == grid_slice_t);
grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>();
grid_slice_t.device(place) = res;
} else if (padding_mode == "reflect") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto is_neg = (grid_slice_t < static_cast<T>(0));
auto grid_abs = grid_slice_t.abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
auto one_more_flip = (extra > (double_range - extra));
grid_scale_t.device(place) =
grid_scale_t * ((is_neg == one_more_flip).template cast<T>() -
(is_neg != one_more_flip).template cast<T>());
grid_slice_t.device(place) = extra.cwiseMin(double_range - extra);
} else {
auto double_range = static_cast<T>((max_val + 1) * 2);
auto grid_abs = (grid_slice_t + static_cast<T>(0.5)).abs();
auto is_neg = ((grid_slice_t + static_cast<T>(0.5)) < static_cast<T>(0));
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
auto one_more_flip = (extra > (double_range - extra));
auto reflected =
extra.cwiseMin(double_range - extra) - static_cast<T>(0.5);
auto clipped = reflected.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
auto in_bound = (clipped == reflected).template cast<T>();
grid_scale_t.device(place) =
grid_scale_t * ((is_neg == one_more_flip).template cast<T>() -
(is_neg != one_more_flip).template cast<T>()) *
in_bound;
grid_slice_t.device(place) = clipped;
}
}
}
template <typename T>
static void calcGridLocations(const platform::CPUDeviceContext& ctx,
const Tensor& grid, const int in_h,
const int in_w, bool align_corners,
std::string padding_mode, Tensor* grid_x,
Tensor* grid_y) {
const int n = grid.dims()[0];
const int h = grid.dims()[1];
const int w = grid.dims()[2];
const T x_max = static_cast<T>(w - 1);
const T y_max = static_cast<T>(h - 1);
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
Tensor grid_x, grid_y;
T* grid_x_data = grid_x.mutable_data<T>({n, h, w}, ctx.GetPlace());
T* grid_y_data = grid_y.mutable_data<T>({n, h, w}, ctx.GetPlace());
T* grid_x_data = grid_x->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
T* grid_y_data = grid_y->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * h * w; i++) {
for (int i = 0; i < n * out_h * out_w; i++) {
grid_x_data[i] = grid_data[2 * i];
grid_y_data[i] = grid_data[(2 * i) + 1];
}
Tensor ones;
ones.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant(1.0);
Tensor half_xmax;
Tensor half_ymax;
half_xmax.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto half_xmax_t =
EigenTensor<T, 3>::From(half_xmax).setConstant(0.5 * x_max);
half_ymax.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto half_ymax_t =
EigenTensor<T, 3>::From(half_ymax).setConstant(0.5 * y_max);
// scale grid to [0, h-1/w-1]
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t.device(place) = (grid_x_t + ones_t) * half_xmax_t;
grid_y_t.device(place) = (grid_y_t + ones_t) * half_ymax_t;
unnormalize<T>(ctx, grid_x, in_w - 1, align_corners);
unnormalize<T>(ctx, grid_y, in_h - 1, align_corners);
clip<T>(ctx, grid_x, in_w - 1, align_corners, padding_mode);
clip<T>(ctx, grid_y, in_h - 1, align_corners, padding_mode);
}
template <typename T>
static void calcGridLocationsWithGrad(const platform::CPUDeviceContext& ctx,
const Tensor& grid, const int in_h,
const int in_w, bool align_corners,
std::string padding_mode, Tensor* grid_x,
Tensor* grid_y, Tensor* grid_x_scale,
Tensor* grid_y_scale) {
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
T* grid_x_data = grid_x->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
T* grid_y_data = grid_y->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_x_data[i] = grid_data[2 * i];
grid_y_data[i] = grid_data[(2 * i) + 1];
}
unnormalize<T>(ctx, grid_x, in_w - 1, align_corners);
unnormalize<T>(ctx, grid_y, in_h - 1, align_corners);
clipWithMask<T>(ctx, in_w - 1, align_corners, padding_mode, grid_x,
grid_x_scale);
clipWithMask<T>(ctx, in_h - 1, align_corners, padding_mode, grid_y,
grid_y_scale);
}
template <typename T>
static void getGridPointValue(const Tensor& input, Tensor* output,
const Tensor& x, const Tensor& y) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_h = input.dims()[2];
const int in_w = input.dims()[3];
const int out_h = x.dims()[1];
const int out_w = x.dims()[2];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto output_t = EigenTensor<T, 4>::From(*output).setConstant((T)0);
auto input_t = EigenTensor<T, 4>::From(input);
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1),
(T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) =
input_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l))));
}
}
}
}
}
}
template <typename T>
static void allNeigbors(const platform::CPUDeviceContext& ctx,
const Tensor& input, Tensor* grid_x, Tensor* grid_y,
Tensor* x_w, Tensor* x_e, Tensor* y_n,
Tensor* y_s, // positions
Tensor* d_w, Tensor* d_e, Tensor* d_n,
Tensor* d_s, // distance
Tensor* v_wn, Tensor* v_en, Tensor* v_ws,
Tensor* v_es) { // values
auto& place = *ctx.eigen_device();
const int c = input.dims()[1];
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
// calculate coords of 4 corner points
x_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
x_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
y_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
y_s->mutable_data<T>({n, h, w}, ctx.GetPlace());
x_w->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
x_e->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
y_n->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
y_s->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto x_w_t = EigenTensor<T, 3>::From(*x_w);
auto x_e_t = EigenTensor<T, 3>::From(*x_e);
auto y_n_t = EigenTensor<T, 3>::From(*y_n);
auto y_s_t = EigenTensor<T, 3>::From(*y_s);
auto grid_x_t = EigenTensor<T, 3>::From(*grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(*grid_y);
x_w_t.device(place) = grid_x_t.floor();
x_e_t.device(place) = x_w_t + ones_t;
x_e_t.device(place) = x_w_t + static_cast<T>(1);
y_n_t.device(place) = grid_y_t.floor();
y_s_t.device(place) = y_n_t + ones_t;
y_s_t.device(place) = y_n_t + static_cast<T>(1);
// calculate distances to 4 sides
d_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_s->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_w->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
d_e->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
d_n->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
d_s->mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto d_w_t = EigenTensor<T, 3>::From(*d_w);
auto d_e_t = EigenTensor<T, 3>::From(*d_e);
auto d_n_t = EigenTensor<T, 3>::From(*d_n);
......@@ -105,28 +280,100 @@ static void CalcGridLocations(const platform::CPUDeviceContext& ctx,
d_e_t.device(place) = x_e_t - grid_x_t;
d_n_t.device(place) = grid_y_t - y_n_t;
d_s_t.device(place) = y_s_t - grid_y_t;
// calc 4 corner points value
v_wn->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
v_en->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
v_ws->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
v_es->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
getGridPointValue<T>(input, v_wn, *x_w, *y_n);
getGridPointValue<T>(input, v_en, *x_e, *y_n);
getGridPointValue<T>(input, v_ws, *x_w, *y_s);
getGridPointValue<T>(input, v_es, *x_e, *y_s);
}
template <typename T>
static void GetGridPointValue(const Tensor& input, Tensor* output,
const Tensor& x, const Tensor& y) {
const int n = input.dims()[0];
static void bilinearInter(const platform::CPUDeviceContext& ctx,
const Tensor& input, Tensor* grid_x, Tensor* grid_y,
Tensor* out) {
auto& place = *ctx.eigen_device();
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
const int c = input.dims()[1];
const int h = input.dims()[2];
const int w = input.dims()[3];
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
Tensor v_wn, v_en, v_ws, v_es;
allNeigbors<T>(ctx, input, grid_x, grid_y, &x_w, &x_e, &y_n, &y_s, &d_w, &d_e,
&d_n, &d_s, &v_wn, &v_en, &v_ws, &v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto d_w_scaled_t =
d_w_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_e_scaled_t =
d_e_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_n_scaled_t =
d_n_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto d_s_scaled_t =
d_s_t.reshape(Array4(n, 1, out_h, out_w)).broadcast(Array4(1, c, 1, 1));
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto output_t = EigenTensor<T, 4>::From(*out);
// bilinear interpolaetion by 4 corner points
output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t +
v_en_t * d_w_scaled_t * d_s_scaled_t +
v_ws_t * d_e_scaled_t * d_n_scaled_t +
v_es_t * d_w_scaled_t * d_n_scaled_t;
}
template <typename T>
static void nearestInter(const platform::CPUDeviceContext& ctx,
const Tensor& input, Tensor* grid_x, Tensor* grid_y,
Tensor* out) {
auto& place = *ctx.eigen_device();
auto grid_x_t = EigenTensor<T, 3>::From(*grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(*grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
getGridPointValue<T>(input, out, *grid_x, *grid_y);
}
template <typename T>
static void gatherOutputGradToInputGrad(const Tensor& output_grad,
Tensor* input_grad, const Tensor& x,
const Tensor& y, const Tensor& d1,
const Tensor& d2) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int out_h = output_grad.dims()[2];
const int out_w = output_grad.dims()[3];
const int in_h = input_grad->dims()[2];
const int in_w = input_grad->dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto output_t = EigenTensor<T, 4>::From(*output).setConstant((T)0);
auto input_t = EigenTensor<T, 4>::From(input);
auto d1_t = EigenTensor<T, 3>::From(d1);
auto d2_t = EigenTensor<T, 3>::From(d2);
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1),
(T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) =
input_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l))));
input_grad_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
}
}
}
......@@ -135,29 +382,28 @@ static void GetGridPointValue(const Tensor& input, Tensor* output,
}
template <typename T>
static void GatherOutputGradToInputGrad(const Tensor& output_grad,
static void gatherOutputGradToInputGrad(const Tensor& output_grad,
Tensor* input_grad, const Tensor& x,
const Tensor& y, const Tensor& d1,
const Tensor& d2) {
const Tensor& y) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int h = output_grad.dims()[2];
const int w = output_grad.dims()[3];
const int out_h = output_grad.dims()[2];
const int out_w = output_grad.dims()[3];
const int in_h = input_grad->dims()[2];
const int in_w = input_grad->dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto d1_t = EigenTensor<T, 3>::From(d1);
auto d2_t = EigenTensor<T, 3>::From(d2);
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1),
(T)(in_h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
output_grad_t(i, j, k, l);
}
}
}
......@@ -165,65 +411,126 @@ static void GatherOutputGradToInputGrad(const Tensor& output_grad,
}
}
template <typename T>
static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
const Tensor& input, const Tensor& output_grad,
Tensor* grid_x, Tensor* grid_y,
Tensor* grid_x_scale, Tensor* grid_y_scale,
Tensor* input_grad, Tensor* grid_grad) {
const int n = grid_x->dims()[0];
const int out_h = grid_x->dims()[1];
const int out_w = grid_x->dims()[2];
const int c = input.dims()[1];
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
Tensor v_wn, v_en, v_ws, v_es;
allNeigbors<T>(ctx, input,
grid_x, // grid_x
grid_y, // grid_y
&x_w, &x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s, &v_wn, &v_en,
&v_ws, &v_es);
// gather output grad value to input grad by corner point coords and weight
gatherOutputGradToInputGrad<T>(output_grad, input_grad, x_w, y_n, d_e, d_s);
gatherOutputGradToInputGrad<T>(output_grad, input_grad, x_w, y_s, d_e, d_n);
gatherOutputGradToInputGrad<T>(output_grad, input_grad, x_e, y_n, d_w, d_s);
gatherOutputGradToInputGrad<T>(output_grad, input_grad, x_e, y_s, d_w, d_n);
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto grid_grad_x_t =
EigenTensor<T, 3>::From(grid_grad_x).setConstant(static_cast<T>(0.0));
auto grid_grad_y_t =
EigenTensor<T, 3>::From(grid_grad_y).setConstant(static_cast<T>(0.0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
}
}
}
// const T x_max = static_cast<T>(in_w - 1);
// const T y_max = static_cast<T>(in_h - 1);
auto grid_x_scale_t = EigenTensor<T, 3>::From(*grid_x_scale);
auto grid_y_scale_t = EigenTensor<T, 3>::From(*grid_y_scale);
grid_grad_x_t = grid_grad_x_t * grid_x_scale_t;
grid_grad_y_t = grid_grad_y_t * grid_y_scale_t;
// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
}
template <typename DeviceContext, typename T>
class GridSampleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode = ctx.Attr<std::string>("padding_mode");
auto mode = ctx.Attr<std::string>("mode");
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
const int n = input->dims()[0];
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
// calc locations and distances of 4 corner points
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, &x_w,
&x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s);
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), output,
static_cast<T>(0));
// calc 4 corner points value
Tensor v_wn, v_en, v_ws, v_es;
v_wn.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_en.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_ws.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_es.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
GetGridPointValue<T>(*input, &v_wn, x_w, y_n);
GetGridPointValue<T>(*input, &v_en, x_e, y_n);
GetGridPointValue<T>(*input, &v_ws, x_w, y_s);
GetGridPointValue<T>(*input, &v_es, x_e, y_s);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto d_w_scaled_t =
d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_e_scaled_t =
d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_n_scaled_t =
d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_s_scaled_t =
d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto output_t = EigenTensor<T, 4>::From(*output);
// bilinear interpolaetion by 4 corner points
output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t +
v_en_t * d_w_scaled_t * d_s_scaled_t +
v_ws_t * d_e_scaled_t * d_n_scaled_t +
v_es_t * d_w_scaled_t * d_n_scaled_t;
Tensor grid_x, grid_y;
calcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, in_h,
in_w, align_corners, padding_mode, &grid_x, &grid_y);
if (mode == "bilinear") {
bilinearInter<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *input,
&grid_x, &grid_y, output);
} else if (mode == "nearest") {
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
getGridPointValue<T>(*input, output, grid_x, grid_y);
}
}
};
......@@ -231,97 +538,48 @@ template <typename DeviceContext, typename T>
class GridSampleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode = ctx.Attr<std::string>("padding_mode");
auto mode = ctx.Attr<std::string>("mode");
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
const int n = input->dims()[0];
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0));
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, &x_w,
&x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s);
// gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_n, d_e,
d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_s, d_e,
d_n);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_n, d_w,
d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_s, d_w,
d_n);
// calc 4 corner points value
Tensor v_wn, v_en, v_ws, v_es;
v_wn.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_en.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_ws.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_es.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
GetGridPointValue<T>(*input, &v_wn, x_w, y_n);
GetGridPointValue<T>(*input, &v_en, x_e, y_n);
GetGridPointValue<T>(*input, &v_ws, x_w, y_s);
GetGridPointValue<T>(*input, &v_es, x_e, y_s);
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto output_grad_t = EigenTensor<T, 4>::From(*output_grad);
Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, h, w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto grid_grad_x_t = EigenTensor<T, 3>::From(grid_grad_x).setConstant(0.0);
auto grid_grad_y_t = EigenTensor<T, 3>::From(grid_grad_y).setConstant(0.0);
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
}
}
}
const T x_max = static_cast<T>(w - 1);
const T y_max = static_cast<T>(h - 1);
grid_grad_x_t = grid_grad_x_t * (x_max / (T)2);
grid_grad_y_t = grid_grad_y_t * (y_max / (T)2);
// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * h * w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
Tensor grid_x, grid_y;
Tensor grid_x_scale, grid_y_scale;
calcGridLocationsWithGrad<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, in_h,
in_w, align_corners, padding_mode, &grid_x, &grid_y, &grid_x_scale,
&grid_y_scale);
if (mode == "bilinear") {
gatherBilinearGrad<T>(ctx.template device_context<DeviceContext>(),
*input, *output_grad, &grid_x, &grid_y,
&grid_x_scale, &grid_y_scale, input_grad,
grid_grad);
} else {
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
gatherOutputGradToInputGrad<T>(*output_grad, input_grad, grid_x, grid_y);
}
}
};
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import fluid, nn
import paddle.fluid.dygraph as dg
import paddle.nn.functional as F
import unittest
class GridSampleTestCase(unittest.TestCase):
def __init__(self,
methodName='runTest',
x_shape=[2, 2, 3, 3],
grid_shape=[2, 3, 3, 2],
mode="bilinear",
padding_mode="zeros",
align_corners=False):
super(GridSampleTestCase, self).__init__(methodName)
self.padding_mode = padding_mode
self.x_shape = x_shape
self.grid_shape = grid_shape
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
self.dtype = "float64"
def setUp(self):
self.x = np.random.randn(*(self.x_shape)).astype(self.dtype)
self.grid = np.random.uniform(-1, 1, self.grid_shape).astype(self.dtype)
def static_functional(self, place):
main = fluid.Program()
start = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, start):
x = fluid.data("x", self.x_shape, dtype=self.dtype)
grid = fluid.data("grid", self.grid_shape, dtype=self.dtype)
y_var = F.grid_sample(
x,
grid,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners)
feed_dict = {"x": self.x, "grid": self.grid}
exe = fluid.Executor(place)
exe.run(start)
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
return y_np
def dynamic_functional(self):
x_t = paddle.to_tensor(self.x)
grid_t = paddle.to_tensor(self.grid)
y_t = F.grid_sample(
x_t,
grid_t,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners)
y_np = y_t.numpy()
return y_np
def _test_equivalence(self, place):
result1 = self.static_functional(place)
with dg.guard(place):
result2 = self.dynamic_functional()
np.testing.assert_array_almost_equal(result1, result2)
def runTest(self):
place = fluid.CPUPlace()
self._test_equivalence(place)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
self._test_equivalence(place)
class GridSampleErrorTestCase(GridSampleTestCase):
def runTest(self):
place = fluid.CPUPlace()
with self.assertRaises(ValueError):
self.static_functional(place)
def add_cases(suite):
suite.addTest(GridSampleTestCase(methodName='runTest'))
suite.addTest(
GridSampleTestCase(
methodName='runTest',
mode='bilinear',
padding_mode='reflect',
align_corners=True))
suite.addTest(
GridSampleTestCase(
methodName='runTest',
mode='bilinear',
padding_mode='zeros',
align_corners=True))
def add_error_cases(suite):
suite.addTest(
GridSampleErrorTestCase(
methodName='runTest', padding_mode="VALID"))
suite.addTest(
GridSampleErrorTestCase(
methodName='runTest', align_corners="VALID"))
suite.addTest(GridSampleErrorTestCase(methodName='runTest', mode="VALID"))
def load_tests(loader, standard_tests, pattern):
suite = unittest.TestSuite()
add_cases(suite)
add_error_cases(suite)
return suite
if __name__ == '__main__':
unittest.main()
......@@ -17,17 +17,17 @@ import numpy as np
from op_test import OpTest
def AffineGrid(theta, size):
n = size[0]
h = size[2]
w = size[3]
def AffineGrid(theta, grid_shape):
n = grid_shape[0]
h = grid_shape[1]
w = grid_shape[2]
h_idx = np.repeat(
np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis]
w_idx = np.repeat(
np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis]
grid = np.concatenate(
[w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3
grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3
grid = np.repeat(grid[np.newaxis, :], n, axis=0) # n * h * w *3
ret = np.zeros([n, h * w, 2])
theta = theta.transpose([0, 2, 1])
......@@ -40,15 +40,19 @@ def AffineGrid(theta, size):
def getGridPointValue(data, x, y):
data_shape = data.shape
N = data_shape[0]
H = data_shape[2]
W = data_shape[3]
out = np.zeros(data_shape, dtype='float64')
C = data_shape[1]
in_H = data_shape[2]
in_W = data_shape[3]
out_H = x.shape[1]
out_W = x.shape[2]
#out = np.zeros(data_shape, dtype='float64')
out = np.zeros([N, C, out_H, out_W], dtype='float64')
for i in range(N):
for j in range(H):
for k in range(W):
if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[
i, j, k] > W - 1:
for j in range(out_H):
for k in range(out_W):
if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[
i, j, k] < 0 or x[i, j, k] > in_W - 1:
out[i, :, j, k] = 0
else:
out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]
......@@ -56,44 +60,89 @@ def getGridPointValue(data, x, y):
return out
def GridSampler(data, grid):
dims = data.shape
N = dims[0]
C = dims[1]
H = dims[2]
W = dims[3]
def clip(x, min_n, max_n):
return np.maximum(np.minimum(x, max_n), min_n)
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
y_max = H - 1
x_max = W - 1
x = 0.5 * ((x.astype('float64') + 1.0) * x_max)
y = 0.5 * ((y.astype('float64') + 1.0) * y_max)
def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode):
if align_corners:
grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) * max_val)
else:
grid_slice = 0.5 * (
(grid_slice.astype('float64') + 1.0) * (max_val + 1)) - 0.5
if padding_mode == "border":
grid_slice = clip(grid_slice, 0, max_val)
elif padding_mode == "reflect":
double_range = 2 * max_val if align_corners else (max_val + 1) * 2
grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice +
0.5)
extra = grid_abs - np.floor(grid_abs / double_range) * double_range
grid_slice = np.minimum(extra, double_range - extra)
grid_slice = grid_slice if align_corners else clip(grid_slice - 0.5, 0,
max_val)
return grid_slice
x0 = np.floor(x).astype('int32')
x1 = x0 + 1
y0 = np.floor(y).astype('int32')
y1 = y0 + 1
wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1))
wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1))
wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1))
wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1))
def GridSampler(data,
grid,
align_corners=True,
mode="bilinear",
padding_mode="zeros"):
dims = data.shape
N = dims[0]
in_C = dims[1]
in_H = dims[2]
in_W = dims[3]
va = getGridPointValue(data, x0, y0)
vb = getGridPointValue(data, x0, y1)
vc = getGridPointValue(data, x1, y0)
vd = getGridPointValue(data, x1, y1)
out_H = grid.shape[1]
out_W = grid.shape[2]
out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float64')
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
y_max = in_H - 1
x_max = in_W - 1
x = unnormalizeAndClip(x, x_max, align_corners, padding_mode)
y = unnormalizeAndClip(y, y_max, align_corners, padding_mode)
if mode == "bilinear":
x0 = np.floor(x).astype('int32')
x1 = x0 + 1
y0 = np.floor(y).astype('int32')
y1 = y0 + 1
wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)),
(1, in_C, 1, 1))
wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)),
(1, in_C, 1, 1))
wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)),
(1, in_C, 1, 1))
wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)),
(1, in_C, 1, 1))
va = getGridPointValue(data, x0, y0)
vb = getGridPointValue(data, x0, y1)
vc = getGridPointValue(data, x1, y0)
vd = getGridPointValue(data, x1, y1)
out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float64')
elif mode == "nearest":
x = np.round(x).astype('int32')
y = np.round(y).astype('int32')
out = getGridPointValue(data, x, y)
return out
class TestGridSamplerOp(OpTest):
def setUp(self):
self.initTestCase()
self.use_cudnn = False
self.numeric_grad_delta = 0.0001
self.op_type = 'grid_sampler'
self.align_corners = True
self.padding_mode = "zeros"
self.mode = "bilinear"
self.initTestCase()
x = np.random.randint(0, 255, self.x_shape).astype('float64')
theta = np.zeros(self.theta_shape).astype('float64')
......@@ -101,22 +150,90 @@ class TestGridSamplerOp(OpTest):
for j in range(2):
for k in range(3):
theta[i, j, k] = np.random.rand(1)[0]
grid = AffineGrid(theta, self.x_shape)
grid = AffineGrid(theta, self.grid_shape)
self.inputs = {'X': x, 'Grid': grid}
self.attrs = {'use_cudnn': True}
self.outputs = {'Output': GridSampler(x, grid)}
self.attrs = {
'use_cudnn': self.use_cudnn,
"align_corners": self.align_corners,
"padding_mode": self.padding_mode,
"mode": self.mode
}
# print("X: {}".format(x))
self.outputs = {
'Output': GridSampler(x, grid, self.align_corners, self.mode,
self.padding_mode)
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61)
self.check_grad(
['X', 'Grid'],
'Output',
max_relative_error=0.01,
numeric_grad_delta=self.numeric_grad_delta)
def initTestCase(self):
self.x_shape = (2, 3, 8, 8)
self.grid_shape = (2, 7, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = True
self.padding_mode = "zeros"
self.mode = "bilinear"
self.use_cudnn = True
class Case1(TestGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6)
self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = False
self.padding_mode = "zeros"
self.mode = "bilinear"
class Case1(TestGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6)
self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = False
self.padding_mode = "border"
self.mode = "bilinear"
class Case2(TestGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6)
self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = False
self.padding_mode = "reflect"
self.mode = "bilinear"
class Case3(TestGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6)
self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = True
self.padding_mode = "reflect"
self.mode = "bilinear"
class Case4(TestGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 5, 7, 3)
self.grid_shape = (2, 7, 3, 2)
self.x_shape = (2, 3, 5, 6)
self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = False
self.padding_mode = "reflect"
self.mode = "nearest"
self.numeric_grad_delta = 0.0001
if __name__ == "__main__":
......
......@@ -192,7 +192,7 @@ from .vision import fsp_matrix #DEFINE_ALIAS
from .vision import generate_mask_labels #DEFINE_ALIAS
from .vision import generate_proposal_labels #DEFINE_ALIAS
from .vision import generate_proposals #DEFINE_ALIAS
from .vision import grid_sampler #DEFINE_ALIAS
from .vision import grid_sample #DEFINE_ALIAS
from .vision import image_resize #DEFINE_ALIAS
from .vision import image_resize_short #DEFINE_ALIAS
# from .vision import multi_box_head #DEFINE_ALIAS
......
......@@ -28,7 +28,6 @@ from ...fluid.layers import distribute_fpn_proposals #DEFINE_ALIAS
from ...fluid.layers import generate_mask_labels #DEFINE_ALIAS
from ...fluid.layers import generate_proposal_labels #DEFINE_ALIAS
from ...fluid.layers import generate_proposals #DEFINE_ALIAS
from ...fluid.layers import grid_sampler #DEFINE_ALIAS
from ...fluid.layers import image_resize #DEFINE_ALIAS
from ...fluid.layers import prior_box #DEFINE_ALIAS
from ...fluid.layers import prroi_pool #DEFINE_ALIAS
......@@ -68,7 +67,7 @@ __all__ = [
'generate_mask_labels',
'generate_proposal_labels',
'generate_proposals',
'grid_sampler',
'grid_sample',
'image_resize',
'image_resize_short',
# 'multi_box_head',
......@@ -89,3 +88,187 @@ __all__ = [
'yolo_box',
'yolov3_loss'
]
from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid import core, dygraph_utils
from ...fluid.framework import Variable, in_dygraph_mode
from ...device import get_cudnn_version
import numpy as np
def grid_sample(x,
grid,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
name=None):
"""
This operation samples input X by using bilinear interpolation or
nearest interpolation based on flow field grid, which is usually
generated by :code:`affine_grid` . The grid of shape [N, H, W, 2]
is the concatenation of (x, y) coordinates with shape [N, H, W] each,
where x is indexing the 4th dimension (in width dimension) of input
data x and y is indexing the 3rd dimension (in height dimension),
finally results is the bilinear interpolation or nearest value of 4 nearest corner
points. The output tensor shape will be [N, C, H, W].
.. code-block:: text
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
.. code-block:: text
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points or nearest interpolate point value
by nearest point.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
For bilinear interpolation:
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args:
x(Tensor): The input tensor, which is a 4-d tensor with shape
[N, C, H, W], N is the batch size, C is the channel
number, H and W is the feature height and width.
The data type is float32 or float64.
grid(Tensor): Input grid tensor of shape [N, grid_H, grid_W, 2]. The
data type is float32 or float64.
mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'.
Default: 'bilinear'.
padding_mode(str, optional) The padding method used when source index
is out of input images. It can be 'zeros', 'reflect' and 'border'.
Default: zeros.
align_corners(bool, optional): If `align_corners` is true, it will projects
-1 and 1 to the centers of the corner pixels. Otherwise, it will
projects -1 and 1 to the image edges.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns: Tensor, The shape of output is [N, C, grid_H, grid_W] in which `grid_H` is the height of grid
and `grid_W` is the width of grid. The data type is same as input tensor.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
# shape=[1, 1, 3, 3]
x = np.array([[[[-0.6, 0.8, -0.5],
[-0.5, 0.2, 1.2],
[ 1.4, 0.3, -0.2]]]]).astype("float64")
# grid shape = [1, 3, 4, 2]
grid = np.array(
[[[[ 0.2, 0.3],
[-0.4, -0.3],
[-0.9, 0.3],
[-0.9, -0.6]],
[[ 0.4, 0.1],
[ 0.9, -0.8],
[ 0.4, 0.5],
[ 0.5, -0.2]],
[[ 0.1, -0.8],
[-0.3, -1. ],
[ 0.7, 0.4],
[ 0.2, 0.8]]]]).astype("float64")
paddle.disable_static()
x = paddle.to_tensor(x)
grid = paddle.to_tensor(grid)
y_t = F.grid_sample(
x,
grid,
mode='bilinear',
padding_mode='border',
align_corners=True)
print(y_t.numpy())
# output shape = [1, 1, 3, 4]
# [[[[ 0.34 0.016 0.086 -0.448]
# [ 0.55 -0.076 0.35 0.59 ]
# [ 0.596 0.38 0.52 0.24 ]]]]
"""
helper = LayerHelper("grid_sample", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'grid_sampler')
check_variable_and_dtype(grid, 'grid', ['float32', 'float64'],
'grid_sampler')
if not isinstance(x, Variable):
raise ValueError("The x should be a Variable")
if not isinstance(grid, Variable):
raise ValueError("The grid should be a Variable")
_modes = ['bilinear', 'nearest']
_padding_modes = ['zeros', 'reflect', 'border']
if mode not in _modes:
raise ValueError(
"The mode of grid sample function should be in {}, but got: {}".
format(_modes, mode))
if padding_mode not in _padding_modes:
raise ValueError(
"The padding mode of grid sample function should be in {}, but got: {}".
format(_padding_modes, padding_mode))
if not isinstance(align_corners, bool):
raise ValueError("The align corners should be bool, but got: {}".format(
align_corners))
cudnn_version = get_cudnn_version()
use_cudnn = False
if (cudnn_version is not None
) and align_corners and mode == 'bilinear' and padding_mode == 'zeros':
use_cudnn = True
ipts = {'X': x, 'Grid': grid}
attrs = {
'mode': mode,
'padding_mode': padding_mode,
'align_corners': align_corners,
'use_cudnn': use_cudnn
}
if in_dygraph_mode():
attrs = ('mode', mode, 'padding_mode', padding_mode, 'align_corners',
align_corners, 'use_cudnn', use_cudnn)
out = getattr(core.ops, 'grid_sampler')(x, grid, *attrs)
else:
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='grid_sampler',
inputs=ipts,
attrs=attrs,
outputs={'Output': out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册