未验证 提交 64f3e3ed 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #14069 from heavengate/grid_sampler

Grid sampler operator for spatial transformer network.
......@@ -178,6 +178,7 @@ paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], var
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using ScopedSpatialTransformerDescriptor =
platform::ScopedSpatialTransformerDescriptor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output = ctx.Output<Tensor>("Output");
int n = input->dims()[0];
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
const int size[4] = {n, c, h, w};
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
T* output_data = output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, size);
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(output->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward(
handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc,
input_data, grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc,
output_data));
}
};
template <typename T>
class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
auto output_grad_dims = output_grad->dims();
const int n = output_grad_dims[0];
const int c = output_grad_dims[1];
const int h = output_grad_dims[2];
const int w = output_grad_dims[3];
const int size[4] = {n, c, h, w};
ScopedSpatialTransformerDescriptor st_dest;
cudnnSpatialTransformerDescriptor_t cudnn_st_dest =
st_dest.descriptor<T>(4, size);
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
T* grid_grad_data =
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor input_grad_desc;
ScopedTensorDescriptor output_grad_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_input_grad_desc =
input_grad_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input_grad->dims()));
cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(output_grad->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward(
handle, cudnn_st_dest, CudnnDataType<T>::kOne(), cudnn_input_desc,
input_data, CudnnDataType<T>::kZero(), cudnn_input_grad_desc,
input_grad_data, CudnnDataType<T>::kOne(), cudnn_output_grad_desc,
output_grad_data, grid_data, CudnnDataType<T>::kZero(),
grid_grad_data));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNGridSampleOpKernel<float>,
paddle::operators::CUDNNGridSampleOpKernel<double>);
REGISTER_OP_KERNEL(grid_sampler_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNGridSampleGradOpKernel<float>,
paddle::operators::CUDNNGridSampleGradOpKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class GridSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GridSampleOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grid"),
"Input(Grid) of GridSampleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of GridSampleOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
PADDLE_ENFORCE(x_dims.size() == 4,
"Input(X) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims.size() == 4,
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
"Input(X) and Input(Grid) dims[0] should be equal.");
PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
PADDLE_ENFORCE_EQ(
grid_dims[2], x_dims[3],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
ctx->SetOutputDim("Output", x_dims);
ctx->ShareLoD("X", "Output");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input data of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W]");
AddInput(
"Grid",
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
"of x and y coordinates with shape [N, H, W] in last dimention");
AddOutput("Output", "(Tensor) Output tensor with shape [N, C, H, W]");
AddAttr<bool>(
"use_cudnn",
"(bool, default true) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddComment(R"DOC(
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
)DOC");
}
};
class GridSampleOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
}
if (ctx->HasOutput(framework::GradVarName("Grid"))) {
ctx->SetOutputDim(framework::GradVarName("Grid"), grid_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
class GridSampleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("grid_sampler_grad");
op->SetInput("X", Input("X"));
op->SetInput("Grid", Input("Grid"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Grid"), InputGrad("Grid"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker,
ops::GridSampleGradMaker);
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad);
REGISTER_OP_CPU_KERNEL(
grid_sampler,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
grid_sampler_grad,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array3 = Eigen::DSizes<int64_t, 3>;
using Array4 = Eigen::DSizes<int64_t, 4>;
template <typename T>
static inline bool isInBound(T x, T y, T x_max, T y_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max) {
return false;
}
return true;
}
template <typename T>
static void CalcGridLocations(const platform::CPUDeviceContext& ctx,
const Tensor& grid, Tensor* x_w, Tensor* x_e,
Tensor* y_n, Tensor* y_s, Tensor* d_w,
Tensor* d_e, Tensor* d_n, Tensor* d_s) {
auto& place = *ctx.eigen_device();
const int n = grid.dims()[0];
const int h = grid.dims()[1];
const int w = grid.dims()[2];
const T x_max = static_cast<T>(w - 1);
const T y_max = static_cast<T>(h - 1);
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
Tensor grid_x, grid_y;
T* grid_x_data = grid_x.mutable_data<T>({n, h, w}, ctx.GetPlace());
T* grid_y_data = grid_y.mutable_data<T>({n, h, w}, ctx.GetPlace());
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * h * w; i++) {
grid_x_data[i] = grid_data[2 * i];
grid_y_data[i] = grid_data[(2 * i) + 1];
}
Tensor ones;
ones.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant(1.0);
// scale grid to [0, h-1/w-1]
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t.device(place) = 0.5 * ((grid_x_t + ones_t) * x_max);
grid_y_t.device(place) = 0.5 * ((grid_y_t + ones_t) * y_max);
// calculate coords of 4 corner points
x_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
x_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
y_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
y_s->mutable_data<T>({n, h, w}, ctx.GetPlace());
auto x_w_t = EigenTensor<T, 3>::From(*x_w);
auto x_e_t = EigenTensor<T, 3>::From(*x_e);
auto y_n_t = EigenTensor<T, 3>::From(*y_n);
auto y_s_t = EigenTensor<T, 3>::From(*y_s);
x_w_t.device(place) = grid_x_t.floor();
x_e_t.device(place) = x_w_t + ones_t;
y_n_t.device(place) = grid_y_t.floor();
y_s_t.device(place) = y_n_t + ones_t;
// calculate distances to 4 sides
d_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_s->mutable_data<T>({n, h, w}, ctx.GetPlace());
auto d_w_t = EigenTensor<T, 3>::From(*d_w);
auto d_e_t = EigenTensor<T, 3>::From(*d_e);
auto d_n_t = EigenTensor<T, 3>::From(*d_n);
auto d_s_t = EigenTensor<T, 3>::From(*d_s);
d_w_t.device(place) = grid_x_t - x_w_t;
d_e_t.device(place) = x_e_t - grid_x_t;
d_n_t.device(place) = grid_y_t - y_n_t;
d_s_t.device(place) = y_s_t - grid_y_t;
}
template <typename T>
static void GetGridPointValue(const Tensor& input, Tensor* output,
const Tensor& x, const Tensor& y) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int h = input.dims()[2];
const int w = input.dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto output_t = EigenTensor<T, 4>::From(*output).setConstant((T)0);
auto input_t = EigenTensor<T, 4>::From(input);
for (int i = 0; i < n; i++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) =
input_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l))));
}
}
}
}
}
}
template <typename T>
static void GatherOutputGradToInputGrad(const Tensor& output_grad,
Tensor* input_grad, const Tensor& x,
const Tensor& y, const Tensor& d1,
const Tensor& d2) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int h = output_grad.dims()[2];
const int w = output_grad.dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto d1_t = EigenTensor<T, 3>::From(d1);
auto d2_t = EigenTensor<T, 3>::From(d2);
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
}
}
}
}
}
}
template <typename DeviceContext, typename T>
class GridSampleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
// calc locations and distances of 4 corner points
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, &x_w,
&x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s);
auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), output,
static_cast<T>(0));
// calc 4 corner points value
Tensor v_wn, v_en, v_ws, v_es;
v_wn.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_en.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_ws.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_es.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
GetGridPointValue<T>(*input, &v_wn, x_w, y_n);
GetGridPointValue<T>(*input, &v_en, x_e, y_n);
GetGridPointValue<T>(*input, &v_ws, x_w, y_s);
GetGridPointValue<T>(*input, &v_es, x_e, y_s);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto d_w_scaled_t =
d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_e_scaled_t =
d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_n_scaled_t =
d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_s_scaled_t =
d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto output_t = EigenTensor<T, 4>::From(*output);
// bilinear interpolaetion by 4 corner points
output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t +
v_en_t * d_w_scaled_t * d_s_scaled_t +
v_ws_t * d_e_scaled_t * d_n_scaled_t +
v_es_t * d_w_scaled_t * d_n_scaled_t;
}
};
template <typename DeviceContext, typename T>
class GridSampleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
const int n = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0));
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, &x_w,
&x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s);
// gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_n, d_e,
d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_s, d_e,
d_n);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_n, d_w,
d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_s, d_w,
d_n);
// calc 4 corner points value
Tensor v_wn, v_en, v_ws, v_es;
v_wn.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_en.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_ws.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_es.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
GetGridPointValue<T>(*input, &v_wn, x_w, y_n);
GetGridPointValue<T>(*input, &v_en, x_e, y_n);
GetGridPointValue<T>(*input, &v_ws, x_w, y_s);
GetGridPointValue<T>(*input, &v_es, x_e, y_s);
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto output_grad_t = EigenTensor<T, 4>::From(*output_grad);
Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, h, w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto grid_grad_x_t = EigenTensor<T, 3>::From(grid_grad_x).setConstant(0.0);
auto grid_grad_y_t = EigenTensor<T, 3>::From(grid_grad_y).setConstant(0.0);
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
}
}
}
const T x_max = static_cast<T>(w - 1);
const T y_max = static_cast<T>(h - 1);
grid_grad_x_t = grid_grad_x_t * (x_max / (T)2);
grid_grad_y_t = grid_grad_y_t * (y_max / (T)2);
// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * h * w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
}
};
} // namespace operators
} // namespace paddle
......@@ -158,6 +158,7 @@ __all__ = [
'sequence_reverse',
'affine_channel',
'hash',
'grid_sampler',
'log_loss',
'add_position_encoding',
]
......@@ -7801,6 +7802,87 @@ def hash(input, hash_size, num_hash=1, name=None):
return out
@templatedoc()
def grid_sampler(x, grid, name=None):
"""
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args:
x(Variable): Input data of shape [N, C, H, W].
grid(Variable): Input grid tensor of shape [N, H, W, 2].
name (str, default None): The name of this layer.
Returns:
out(Variable): Output of shape [N, C, H, W] data samples input X
using bilnear interpolation based on input grid.
Exmples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[3, 10, 32, 32], dtype='float32')
theta = fluid.layers.data(name='theta', shape=[3, 2, 3], dtype='float32')
grid = fluid.layers.affine_grid(input=theta, size=[3, 10, 32, 32]})
out = fluid.layers.grid_sampler(x=x, grid=grid)
"""
helper = LayerHelper("grid_sampler", **locals())
if not isinstance(x, Variable):
return ValueError("The x should be a Variable")
if not isinstance(grid, Variable):
return ValueError("The grid should be a Variable")
out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x, 'Grid': grid}
helper.append_op(type='grid_sampler', inputs=ipts, outputs={'Output': out})
return out
def log_loss(input, label, epsilon=1e-4, name=None):
"""
**Negative Log Loss Layer**
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
def AffineGrid(theta, size):
n = size[0]
h = size[2]
w = size[3]
h_idx = np.repeat(
np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis]
w_idx = np.repeat(
np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis]
grid = np.concatenate(
[w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3
grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3
ret = np.zeros([n, h * w, 2])
theta = theta.transpose([0, 2, 1])
for i in range(len(theta)):
ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])
return ret.reshape([n, h, w, 2]).astype("float32")
def getGridPointValue(data, x, y):
data_shape = data.shape
N = data_shape[0]
H = data_shape[2]
W = data_shape[3]
out = np.zeros(data_shape, dtype='float')
for i in range(N):
for j in range(H):
for k in range(W):
if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[
i, j, k] > W - 1:
out[i, :, j, k] = 0
else:
out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]
return out
def GridSampler(data, grid):
dims = data.shape
N = dims[0]
C = dims[1]
H = dims[2]
W = dims[3]
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
y_max = H - 1
x_max = W - 1
x = 0.5 * ((x.astype('float32') + 1.0) * x_max)
y = 0.5 * ((y.astype('float32') + 1.0) * y_max)
x0 = np.floor(x).astype('int32')
x1 = x0 + 1
y0 = np.floor(y).astype('int32')
y1 = y0 + 1
wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1))
wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1))
wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1))
wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1))
va = getGridPointValue(data, x0, y0)
vb = getGridPointValue(data, x0, y1)
vc = getGridPointValue(data, x1, y0)
vd = getGridPointValue(data, x1, y1)
out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float32')
return out
class TestGridSamplerOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'grid_sampler'
x = np.random.randint(0, 255, self.x_shape).astype('float32')
theta = np.zeros(self.theta_shape).astype('float32')
for i in range(self.theta_shape[0]):
for j in range(2):
for k in range(3):
theta[i, j, k] = np.random.rand(1)[0]
grid = AffineGrid(theta, self.x_shape)
self.inputs = {'X': x, 'Grid': grid}
self.attrs = {'use_cudnn': True}
self.outputs = {'Output': GridSampler(x, grid)}
def test_check_output(self):
self.check_output(atol=1e-3)
def test_check_grad_normal(self):
self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61)
def initTestCase(self):
self.x_shape = (2, 5, 7, 3)
self.grid_shape = (2, 7, 3, 2)
self.theta_shape = (2, 2, 3)
if __name__ == "__main__":
unittest.main()
......@@ -865,6 +865,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_grid_sampler(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 5, 7], dtype='float32')
grid = layers.data(name='grid', shape=[5, 7, 2], dtype='float32')
out = layers.grid_sampler(x, grid)
self.assertIsNotNone(out)
print(str(program))
def test_affine_grid(self):
program = Program()
with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册