From 0bb0e0c10ff05553c85b17a12d3b4ef430323202 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 19 Oct 2018 22:55:03 +0800 Subject: [PATCH] add Grid Sampler Operator for STN. --- paddle/fluid/API.spec | 1 + .../operators/grid_sampler_cudnn_op.cu.cc | 125 +++++++ paddle/fluid/operators/grid_sampler_op.cc | 147 +++++++++ paddle/fluid/operators/grid_sampler_op.h | 311 ++++++++++++++++++ paddle/fluid/platform/cudnn_helper.h | 22 ++ paddle/fluid/platform/dynload/cudnn.h | 7 + python/paddle/fluid/layers/nn.py | 36 ++ .../tests/unittests/test_grid_sampler_op.py | 121 +++++++ .../fluid/tests/unittests/test_layers.py | 10 + 9 files changed, 780 insertions(+) create mode 100644 paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc create mode 100644 paddle/fluid/operators/grid_sampler_op.cc create mode 100644 paddle/fluid/operators/grid_sampler_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_grid_sampler_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 2b8b82e74..fec54e985 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -175,6 +175,7 @@ paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dim paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) diff --git a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc new file mode 100644 index 000000000..3da8af332 --- /dev/null +++ b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc @@ -0,0 +1,125 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using DataLayout = platform::DataLayout; +using ScopedSpatialTransformerDescriptor = + platform::ScopedSpatialTransformerDescriptor; +template +using CudnnDataType = platform::CudnnDataType; + +template +class CUDNNGridSampleOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace"); + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + auto* output = ctx.Output("Output"); + + int n = input->dims()[0]; + int c = input->dims()[1]; + int h = input->dims()[2]; + int w = input->dims()[3]; + const int size[4] = {n, c, h, w}; + + const T* input_data = input->data(); + const T* grid_data = grid->data(); + T* output_data = output->mutable_data({n, c, h, w}, ctx.GetPlace()); + + ScopedSpatialTransformerDescriptor st_desc; + cudnnSpatialTransformerDescriptor_t cudnn_st_desc = + st_desc.descriptor(4, size); + + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(output->dims())); + + CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward( + handle, cudnn_st_desc, CudnnDataType::kOne(), cudnn_input_desc, input_data, + grid_data, CudnnDataType::kZero(), cudnn_output_desc, output_data)); + } + +}; + +template +class CUDNNGridSampleGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace"); + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + auto* output_grad = ctx.Input(framework::GradVarName("Output")); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); + + auto output_grad_dims = output_grad->dims(); + const int n = output_grad_dims[0]; + const int c = output_grad_dims[1]; + const int h = output_grad_dims[2]; + const int w = output_grad_dims[3]; + const int size[4] = {n, c, h, w}; + + ScopedSpatialTransformerDescriptor st_dest; + cudnnSpatialTransformerDescriptor_t cudnn_st_dest = + st_dest.descriptor(4, size); + + const T* input_data = input->data(); + const T* grid_data = grid->data(); + const T* output_grad_data = output_grad->data(); + T* input_grad_data = input_grad->mutable_data(output_grad_dims, ctx.GetPlace()); + T* grid_grad_data = grid_grad->mutable_data({n, h, w, 2}, ctx.GetPlace()); + + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor input_grad_desc; + ScopedTensorDescriptor output_grad_desc; + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_input_grad_desc = input_grad_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(input_grad->dims())); + cudnnTensorDescriptor_t cudnn_output_grad_desc = output_grad_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(output_grad->dims())); + + CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward( + handle, cudnn_st_dest, CudnnDataType::kOne(), + cudnn_input_desc, input_data, CudnnDataType::kZero(), + cudnn_input_grad_desc, input_grad_data, CudnnDataType::kOne(), + cudnn_output_grad_desc, output_grad_data, grid_data, + CudnnDataType::kZero(), grid_grad_data)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNGridSampleOpKernel, + paddle::operators::CUDNNGridSampleOpKernel); +REGISTER_OP_KERNEL(grid_sampler_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNGridSampleGradOpKernel, + paddle::operators::CUDNNGridSampleGradOpKernel); diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc new file mode 100644 index 000000000..3f28ed5df --- /dev/null +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/grid_sampler_op.h" +#include "paddle/fluid/framework/op_registry.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class GridSampleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of GridSampleOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grid"), + "Input(Grid) of GridSampleOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of GridSampleOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto grid_dims = ctx->GetInputDim("Grid"); + PADDLE_ENFORCE(x_dims.size() == 4, "Input(X) of GridSampleOp should be 4-D Tensor."); + PADDLE_ENFORCE(grid_dims.size() == 4, "Input(Grid) of GridSampleOp should be 4-D Tensor."); + PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); + PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], "Input(X) and Input(Grid) dims[0] should be equal."); + PADDLE_ENFORCE_EQ(grid_dims[1], x_dims[2], "Input(X) dims[2] and Input(Grid) dims[1] should be equal."); + PADDLE_ENFORCE_EQ(grid_dims[2], x_dims[3], "Input(X) dims[3] and Input(Grid) dims[2] should be equal."); + + ctx->SetOutputDim("Output", x_dims); + ctx->ShareLoD("X", "Output"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_CUDA + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; + } +#endif + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); + } +}; + +class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensor) The input tensor of GridSampleOp, " + "This is a 4-D tensor with shape of [N, C, H, W]"); + AddInput( + "Grid", + "(Tensor) The output of AffineGridOp, " + "This is a 4-D tensor with shape of [N, H, W, 2]"); + AddOutput( + "Output", + "(Tensor) Output tensor with shape [N, C, H, W]"); + AddAttr( + "use_cudnn", + "(bool, default false) Only used in cudnn kernel, need install cudnn") + .SetDefault(true); + + AddComment(R"DOC( + It sample input X by grid gennerate by AffineGridOp. + )DOC"); + } +}; + +class GridSampleOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + //TO DO + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_CUDA + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; + } +#endif + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); + } +}; + +class GridSampleGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("grid_sampler_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Grid", Input("Grid")); + op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Grid"), InputGrad("Grid")); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker, + ops::GridSampleGradMaker); +REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad); + +REGISTER_OP_CPU_KERNEL( + grid_sampler, + ops::GridSampleOpKernel, + ops::GridSampleOpKernel); +REGISTER_OP_CPU_KERNEL( + grid_sampler_grad, + ops::GridSampleGradOpKernel, + ops::GridSampleGradOpKernel); diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h new file mode 100644 index 000000000..7f42fa66c --- /dev/null +++ b/paddle/fluid/operators/grid_sampler_op.h @@ -0,0 +1,311 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/gather.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/hostdevice.h" + + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenTensor = framework::EigenTensor; + +using Array3 = Eigen::DSizes; +using Array4 = Eigen::DSizes; + + +template +inline bool isInBound(T x, T y, T x_max, T y_max) { + if (x < 0 || x > x_max || y < 0 || y > y_max) { + return false; + } + return true; +} + +template +void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& grid, + Tensor* x_w, Tensor* x_e, Tensor* y_n, Tensor* y_s, + Tensor* d_w, Tensor* d_e, Tensor* d_n, Tensor* d_s) { + auto& place = *ctx.template device_context().eigen_device(); + const int n = grid.dims()[0]; + const int h = grid.dims()[1]; + const int w = grid.dims()[2]; + const T x_max = static_cast (w - 1); + const T y_max = static_cast (h - 1); + + // split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim + Tensor grid_x, grid_y; + T* grid_x_data = grid_x.mutable_data({n, h, w}, ctx.GetPlace()); + T* grid_y_data = grid_y.mutable_data({n, h, w}, ctx.GetPlace()); + const T* grid_data = grid.data(); + for (int i = 0; i < n * h * w; i++) { + grid_x_data[i] = grid_data[2 * i]; + grid_y_data[i] = grid_data[(2 * i) + 1]; + } + + Tensor ones; + ones.mutable_data({n, h, w}, ctx.GetPlace()); + auto ones_t = EigenTensor::From(ones).setConstant(1.0); + + // scale grid to [0, h-1/w-1] + auto grid_x_t = EigenTensor::From(grid_x); + auto grid_y_t = EigenTensor::From(grid_y); + grid_x_t.device(place) = 0.5 * ((grid_x_t + ones_t) * x_max); + grid_y_t.device(place) = 0.5 * ((grid_y_t + ones_t) * y_max); + + x_w->mutable_data({n, h, w}, ctx.GetPlace()); + x_e->mutable_data({n, h, w}, ctx.GetPlace()); + y_n->mutable_data({n, h, w}, ctx.GetPlace()); + y_s->mutable_data({n, h, w}, ctx.GetPlace()); + auto x_w_t = EigenTensor::From(*x_w); + auto x_e_t = EigenTensor::From(*x_e); + auto y_n_t = EigenTensor::From(*y_n); + auto y_s_t = EigenTensor::From(*y_s); + x_w_t.device(place) = grid_x_t.floor(); + x_e_t.device(place) = x_w_t + ones_t; + y_n_t.device(place) = grid_y_t.floor(); + y_s_t.device(place) = y_n_t + ones_t; + + d_w->mutable_data({n, h, w}, ctx.GetPlace()); + d_e->mutable_data({n, h, w}, ctx.GetPlace()); + d_n->mutable_data({n, h, w}, ctx.GetPlace()); + d_s->mutable_data({n, h, w}, ctx.GetPlace()); + auto d_w_t = EigenTensor::From(*d_w); + auto d_e_t = EigenTensor::From(*d_e); + auto d_n_t = EigenTensor::From(*d_n); + auto d_s_t = EigenTensor::From(*d_s); + d_w_t.device(place) = grid_x_t - x_w_t; + d_e_t.device(place) = x_e_t - grid_x_t; + d_n_t.device(place) = grid_y_t - y_n_t; + d_s_t.device(place) = y_s_t - grid_y_t; +} + +template +void GetGridPointValue(const Tensor& input, Tensor* output, + const Tensor& x, const Tensor& y) { + const int n = input.dims()[0]; + const int c = input.dims()[1]; + const int h = input.dims()[2]; + const int w = input.dims()[3]; + auto x_t = EigenTensor::From(x); + auto y_t = EigenTensor::From(y); + auto output_t = EigenTensor::From(*output).setConstant((T)0); + auto input_t = EigenTensor::From(input); + + for (int i = 0; i < n; i++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) { + for (int j = 0; j < c; j++) { + output_t(i, j, k, l) = input_t(i, j, (int)round(y_t(i, k, l)), (int)round(x_t(i, k, l))); + } + } + } + } + } +} + +template +void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input_grad, + const Tensor& x, const Tensor& y, + const Tensor& d1, const Tensor& d2) { + const int n = output_grad.dims()[0]; + const int c = output_grad.dims()[1]; + const int h = output_grad.dims()[2]; + const int w = output_grad.dims()[3]; + auto x_t = EigenTensor::From(x); + auto y_t = EigenTensor::From(y); + auto d1_t = EigenTensor::From(d1); + auto d2_t = EigenTensor::From(d2); + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + + for (int i = 0; i < n; i++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + if(isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) { + for (int j = 0; j < c; j++) { + input_grad_t(i, j, (int) y_t(i, k, l), (int) x_t(i, k, l)) += + output_grad_t(i, j, k ,l) * d1_t(i, k, l) * d2_t(i, k, l); + } + } + } + } + } +} + + + +template +class GridSampleOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + // calc locations and distances of 4 corner points + Tensor x_w, x_e, y_n, y_s; + Tensor d_w, d_e, d_n, d_s; + CalcGridLocations(ctx, *grid, + &x_w, &x_e, &y_n, &y_s, + &d_w, &d_e, &d_n, &d_s); + + auto* output = ctx.Output("Output"); + output->mutable_data({n, c, h, w}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), output, + static_cast(0)); + + // calc 4 corner points value + Tensor v_wn, v_en, v_ws, v_es; + v_wn.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_en.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_ws.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_es.mutable_data({n, c, h, w}, ctx.GetPlace()); + GetGridPointValue(*input, &v_wn, x_w, y_n); + GetGridPointValue(*input, &v_en, x_e, y_n); + GetGridPointValue(*input, &v_ws, x_w, y_s); + GetGridPointValue(*input, &v_es, x_e, y_s); + + auto d_w_t = EigenTensor::From(d_w); + auto d_e_t = EigenTensor::From(d_e); + auto d_n_t = EigenTensor::From(d_n); + auto d_s_t = EigenTensor::From(d_s); + auto d_w_scaled_t = d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); + auto d_e_scaled_t = d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); + auto d_n_scaled_t = d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); + auto d_s_scaled_t = d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); + auto v_wn_t = EigenTensor::From(v_wn); + auto v_en_t = EigenTensor::From(v_en); + auto v_ws_t = EigenTensor::From(v_ws); + auto v_es_t = EigenTensor::From(v_es); + auto output_t = EigenTensor::From(*output); + //bilinear interpolaetion by 4 corner points + output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t + + v_en_t * d_w_scaled_t * d_s_scaled_t + + v_ws_t * d_e_scaled_t * d_n_scaled_t + + v_es_t * d_w_scaled_t * d_n_scaled_t; + } + +}; + +template +class GridSampleGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + auto* output_grad = ctx.Input(framework::GradVarName("Output")); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + auto* input_grad = ctx.Output(framework::GradVarName("X")); + input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), input_grad, + static_cast(0)); + auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); + grid_grad->mutable_data({n, h, w, 2}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), grid_grad, + static_cast(0)); + + Tensor x_w, x_e, y_n, y_s; + Tensor d_w, d_e, d_n, d_s; + CalcGridLocations(ctx, *grid, + &x_w, &x_e, &y_n, &y_s, + &d_w, &d_e, &d_n, &d_s); + + // gather output grad value to input grad by corner point coords and weight + GatherOutputGradToInputGrad(*output_grad, input_grad, x_w, y_n, d_e, d_s); + GatherOutputGradToInputGrad(*output_grad, input_grad, x_w, y_s, d_e, d_n); + GatherOutputGradToInputGrad(*output_grad, input_grad, x_e, y_n, d_w, d_s); + GatherOutputGradToInputGrad(*output_grad, input_grad, x_e, y_s, d_w, d_n); + + // calc 4 corner points value + Tensor v_wn, v_en, v_ws, v_es; + v_wn.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_en.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_ws.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_es.mutable_data({n, c, h, w}, ctx.GetPlace()); + GetGridPointValue(*input, &v_wn, x_w, y_n); + GetGridPointValue(*input, &v_en, x_e, y_n); + GetGridPointValue(*input, &v_ws, x_w, y_s); + GetGridPointValue(*input, &v_es, x_e, y_s); + auto v_wn_t = EigenTensor::From(v_wn); + auto v_en_t = EigenTensor::From(v_en); + auto v_ws_t = EigenTensor::From(v_ws); + auto v_es_t = EigenTensor::From(v_es); + + auto d_w_t = EigenTensor::From(d_w); + auto d_e_t = EigenTensor::From(d_e); + auto d_n_t = EigenTensor::From(d_n); + auto d_s_t = EigenTensor::From(d_s); + + auto output_grad_t = EigenTensor::From(*output_grad); + + Tensor grid_grad_x, grid_grad_y; + grid_grad_x.mutable_data({n, h, w}, ctx.GetPlace()); + grid_grad_y.mutable_data({n, h, w}, ctx.GetPlace()); + auto grid_grad_x_t = EigenTensor::From(grid_grad_x).setConstant(0.0); + auto grid_grad_y_t = EigenTensor::From(grid_grad_y).setConstant(0.0); + for (int i = 0; i < n; i++) { + for(int j = 0; j < c; j++) { + for(int k = 0; k < h; k++) { + for(int l = 0; l < w; l++) { + grid_grad_x_t(i, k, l) += ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + + (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) + * output_grad_t(i, j, k, l); + grid_grad_y_t(i, k, l) += ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + + (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) + * output_grad_t(i, j, k, l); + } + } + } + } + const T x_max = static_cast(w - 1); + const T y_max = static_cast(h - 1); + grid_grad_x_t = grid_grad_x_t * (x_max / (T)2); + grid_grad_y_t = grid_grad_y_t * (y_max / (T)2); + + // gather grid_grad [x, y] in 3rd Dim + T* grid_grad_data = grid_grad->data(); + T* grid_grad_x_data = grid_grad_x.data(); + T* grid_grad_y_data = grid_grad_y.data(); + for (int i = 0; i < n * h * w; i++) { + grid_grad_data[2 * i] = grid_grad_x_data[i]; + grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; + } + } + +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index bb8b14bb9..140c8c382 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -341,6 +341,28 @@ class ScopedPoolingDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor); }; +class ScopedSpatialTransformerDescriptor { + public: + ScopedSpatialTransformerDescriptor() { + PADDLE_ENFORCE(dynload::cudnnCreateSpatialTransformerDescriptor(&desc_)); + } + ~ScopedSpatialTransformerDescriptor() { + PADDLE_ENFORCE(dynload::cudnnDestroySpatialTransformerDescriptor(desc_)); + } + + template + inline cudnnSpatialTransformerDescriptor_t descriptor(const int nbDims, + const int dimA[]) { + PADDLE_ENFORCE(dynload::cudnnSetSpatialTransformerNdDescriptor( + desc_, CUDNN_SAMPLER_BILINEAR, CudnnDataType::type, nbDims, dimA)); + return desc_; + } + + private: + cudnnSpatialTransformerDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedSpatialTransformerDescriptor); +}; + inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index e6353f67e..0a531ec11 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -90,6 +90,13 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnSetConvolutionNdDescriptor); \ __macro(cudnnGetConvolutionNdDescriptor); \ __macro(cudnnDeriveBNTensorDescriptor); \ + __macro(cudnnCreateSpatialTransformerDescriptor); \ + __macro(cudnnSetSpatialTransformerNdDescriptor); \ + __macro(cudnnDestroySpatialTransformerDescriptor);\ + __macro(cudnnSpatialTfGridGeneratorForward); \ + __macro(cudnnSpatialTfGridGeneratorBackward); \ + __macro(cudnnSpatialTfSamplerForward); \ + __macro(cudnnSpatialTfSamplerBackward); \ __macro(cudnnCreate); \ __macro(cudnnDestroy); \ __macro(cudnnSetStream); \ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4bfa89d9f..6770f7421 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -157,6 +157,7 @@ __all__ = [ 'sequence_reverse', 'affine_channel', 'hash', + 'grid_sampler', ] @@ -7580,3 +7581,38 @@ def hash(input, hash_size, num_hash=1, name=None): attrs={'num_hash': num_hash, 'mod_by': hash_size}) return out + + +@templatedoc() +def grid_sampler(x, grid): + """ + It sample data from input x by the given grid, insert data of each + point by bilinear interp. + + Args: + x(Variable): Input data of shape [N, H, W, C] + grid(Variable): Input grid tensor of shape [N, H, W, 2] + + Returns: + out(Variable): Output data indices by grid from x of shape [N, H, W, C] + """ + helper = LayerHelper("grid_sampler", **locals()) + + if not isinstance(x, Variable): + return ValueError("The x should be a Variable") + + if not isinstance(grid, Variable): + return ValueError("The grid should be a Variable") + + out = helper.create_tmp_variable(x.dtype) + ipts = {'X': x, 'Grid': grid} + attrs = {} + + helper.apppend_op( + type='grid_sampler', + inputs=ipts, + outputs={'Output', out}, + attrs = None if len(attrs) == 0 else attrs) + + return 0 + diff --git a/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py b/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py new file mode 100644 index 000000000..958573c08 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py @@ -0,0 +1,121 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +import numpy as np +from op_test import OpTest + + +def AffineGrid(theta, size): + n = size[0] + h = size[2] + w = size[3] + h_idx = np.repeat( + np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis] + w_idx = np.repeat( + np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis] + grid = np.concatenate( + [w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3 + grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3 + + ret = np.zeros([n, h * w, 2]) + theta = theta.transpose([0, 2, 1]) + for i in range(len(theta)): + ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i]) + + # print ret.reshape([n, h * w, 2]).astype("float32") + return ret.reshape([n, h, w, 2]).astype("float32") + +def getGridPointValue(data, x, y): + data_shape = data.shape + N = data_shape[0] + H = data_shape[2] + W = data_shape[3] + + out = np.zeros(data_shape, dtype='float') + for i in range(N): + for j in range(H): + for k in range(W): + if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[i, j, k] > W - 1: + out[i, :, j, k] = 0 + else: + out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]] + + return out + +def GridSampler(data, grid): + dims = data.shape + N = dims[0] + C = dims[1] + H = dims[2] + W = dims[3] + + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + y_max = H - 1 + x_max = W - 1 + + x = 0.5 * ((x.astype('float32') + 1.0) * x_max) + y = 0.5 * ((y.astype('float32') + 1.0) * y_max) + + x0 = np.floor(x).astype('int32') + x1 = x0 + 1 + y0 = np.floor(y).astype('int32') + y1 = y0 + 1 + + wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1)) + wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1)) + wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1)) + wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1)) + + va = getGridPointValue(data, x0, y0) + vb = getGridPointValue(data, x0, y1) + vc = getGridPointValue(data, x1, y0) + vd = getGridPointValue(data, x1, y1) + + out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float32') + return out + +class TestGridSamplerOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'grid_sampler' + x = np.random.randint(0, 255, self.x_shape).astype('float32') + + theta = np.zeros(self.theta_shape).astype('float32') + for i in range(self.theta_shape[0]): + for j in range(2): + for k in range(3): + theta[i, j, k] = np.random.rand(1)[0] + grid = AffineGrid(theta, self.x_shape) + + self.inputs = {'X': x, 'Grid': grid} + self.attrs = {'use_cudnn': True} + self.outputs = {'Output': GridSampler(x, grid)} + # print self.outputs + + def test_check_output(self): + self.check_output(atol=1e-3) + + def test_check_grad_normal(self): + self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.6) + + def initTestCase(self): + self.x_shape = (2, 5, 7, 3) + self.grid_shape = (2, 7, 3, 2) + self.theta_shape = (2, 2, 3) + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 50de468db..17c94a1d4 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -865,6 +865,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_affine_grid_gen(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[2, 5, 7, 3 ], dtype='float32') + grid = layers.data(name='grid', shape=[2, 5, 7, 2], dtype='float32' ) + out = layers.grid_sampler(x, grid) + self.assertIsNotNone(out) + print(str(program)) + + if __name__ == '__main__': unittest.main() -- GitLab