From ff6329bd5f789893aea2721abb27d5650131aef9 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 29 Oct 2018 12:14:59 +0800 Subject: [PATCH] fix some inappropriate expressions in api doc for grid_sampler. test=develop --- .../operators/grid_sampler_cudnn_op.cu.cc | 172 ++++----- paddle/fluid/operators/grid_sampler_op.cc | 188 +++++----- paddle/fluid/operators/grid_sampler_op.h | 335 +++++++++--------- paddle/fluid/platform/cudnn_helper.h | 10 +- paddle/fluid/platform/dynload/cudnn.h | 90 ++--- python/paddle/fluid/layers/nn.py | 29 +- .../tests/unittests/test_grid_sampler_op.py | 16 +- .../fluid/tests/unittests/test_layers.py | 5 +- 8 files changed, 436 insertions(+), 409 deletions(-) diff --git a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc index 0e8ca01eba..7cde7ca462 100644 --- a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc +++ b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc @@ -22,107 +22,111 @@ using framework::Tensor; using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using DataLayout = platform::DataLayout; using ScopedSpatialTransformerDescriptor = - platform::ScopedSpatialTransformerDescriptor; + platform::ScopedSpatialTransformerDescriptor; template using CudnnDataType = platform::CudnnDataType; template class CUDNNGridSampleOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace"); - auto& dev_ctx = ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - auto* input = ctx.Input("X"); - auto* grid = ctx.Input("Grid"); - auto* output = ctx.Output("Output"); - - int n = input->dims()[0]; - int c = input->dims()[1]; - int h = input->dims()[2]; - int w = input->dims()[3]; - const int size[4] = {n, c, h, w}; - - const T* input_data = input->data(); - const T* grid_data = grid->data(); - T* output_data = output->mutable_data({n, c, h, w}, ctx.GetPlace()); - - ScopedSpatialTransformerDescriptor st_desc; - cudnnSpatialTransformerDescriptor_t cudnn_st_desc = + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace"); + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + auto* output = ctx.Output("Output"); + + int n = input->dims()[0]; + int c = input->dims()[1]; + int h = input->dims()[2]; + int w = input->dims()[3]; + const int size[4] = {n, c, h, w}; + + const T* input_data = input->data(); + const T* grid_data = grid->data(); + T* output_data = output->mutable_data({n, c, h, w}, ctx.GetPlace()); + + ScopedSpatialTransformerDescriptor st_desc; + cudnnSpatialTransformerDescriptor_t cudnn_st_desc = st_desc.descriptor(4, size); - ScopedTensorDescriptor input_desc; - ScopedTensorDescriptor output_desc; - cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - DataLayout::kNCHW, framework::vectorize2int(input->dims())); - cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( - DataLayout::kNCHW, framework::vectorize2int(output->dims())); - - CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward( - handle, cudnn_st_desc, CudnnDataType::kOne(), cudnn_input_desc, input_data, - grid_data, CudnnDataType::kZero(), cudnn_output_desc, output_data)); - } - + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(output->dims())); + + CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward( + handle, cudnn_st_desc, CudnnDataType::kOne(), cudnn_input_desc, + input_data, grid_data, CudnnDataType::kZero(), cudnn_output_desc, + output_data)); + } }; template class CUDNNGridSampleGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace"); - auto& dev_ctx = ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - auto* input = ctx.Input("X"); - auto* grid = ctx.Input("Grid"); - auto* output_grad = ctx.Input(framework::GradVarName("Output")); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); - - auto output_grad_dims = output_grad->dims(); - const int n = output_grad_dims[0]; - const int c = output_grad_dims[1]; - const int h = output_grad_dims[2]; - const int w = output_grad_dims[3]; - const int size[4] = {n, c, h, w}; - - ScopedSpatialTransformerDescriptor st_dest; - cudnnSpatialTransformerDescriptor_t cudnn_st_dest = + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace"); + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + auto* output_grad = ctx.Input(framework::GradVarName("Output")); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); + + auto output_grad_dims = output_grad->dims(); + const int n = output_grad_dims[0]; + const int c = output_grad_dims[1]; + const int h = output_grad_dims[2]; + const int w = output_grad_dims[3]; + const int size[4] = {n, c, h, w}; + + ScopedSpatialTransformerDescriptor st_dest; + cudnnSpatialTransformerDescriptor_t cudnn_st_dest = st_dest.descriptor(4, size); - const T* input_data = input->data(); - const T* grid_data = grid->data(); - const T* output_grad_data = output_grad->data(); - T* input_grad_data = input_grad->mutable_data(output_grad_dims, ctx.GetPlace()); - T* grid_grad_data = grid_grad->mutable_data({n, h, w, 2}, ctx.GetPlace()); - - ScopedTensorDescriptor input_desc; - ScopedTensorDescriptor input_grad_desc; - ScopedTensorDescriptor output_grad_desc; - cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - DataLayout::kNCHW, framework::vectorize2int(input->dims())); - cudnnTensorDescriptor_t cudnn_input_grad_desc = input_grad_desc.descriptor( - DataLayout::kNCHW, framework::vectorize2int(input_grad->dims())); - cudnnTensorDescriptor_t cudnn_output_grad_desc = output_grad_desc.descriptor( - DataLayout::kNCHW, framework::vectorize2int(output_grad->dims())); - - CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward( - handle, cudnn_st_dest, CudnnDataType::kOne(), - cudnn_input_desc, input_data, CudnnDataType::kZero(), - cudnn_input_grad_desc, input_grad_data, CudnnDataType::kOne(), - cudnn_output_grad_desc, output_grad_data, grid_data, - CudnnDataType::kZero(), grid_grad_data)); - } + const T* input_data = input->data(); + const T* grid_data = grid->data(); + const T* output_grad_data = output_grad->data(); + T* input_grad_data = + input_grad->mutable_data(output_grad_dims, ctx.GetPlace()); + T* grid_grad_data = + grid_grad->mutable_data({n, h, w, 2}, ctx.GetPlace()); + + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor input_grad_desc; + ScopedTensorDescriptor output_grad_desc; + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_input_grad_desc = + input_grad_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(input_grad->dims())); + cudnnTensorDescriptor_t cudnn_output_grad_desc = + output_grad_desc.descriptor( + DataLayout::kNCHW, framework::vectorize2int(output_grad->dims())); + + CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward( + handle, cudnn_st_dest, CudnnDataType::kOne(), cudnn_input_desc, + input_data, CudnnDataType::kZero(), cudnn_input_grad_desc, + input_grad_data, CudnnDataType::kOne(), cudnn_output_grad_desc, + output_grad_data, grid_data, CudnnDataType::kZero(), + grid_grad_data)); + } }; } // namespace operators } // namespace paddle namespace plat = paddle::platform; -REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace, - paddle::operators::CUDNNGridSampleOpKernel, - paddle::operators::CUDNNGridSampleOpKernel); +REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNGridSampleOpKernel, + paddle::operators::CUDNNGridSampleOpKernel); REGISTER_OP_KERNEL(grid_sampler_grad, CUDNN, plat::CUDAPlace, - paddle::operators::CUDNNGridSampleGradOpKernel, - paddle::operators::CUDNNGridSampleGradOpKernel); + paddle::operators::CUDNNGridSampleGradOpKernel, + paddle::operators::CUDNNGridSampleGradOpKernel); diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 599ff9a9c1..e76eb6893b 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -24,70 +24,76 @@ namespace operators { using Tensor = framework::Tensor; class GridSampleOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of GridSampleOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Grid"), - "Input(Grid) of GridSampleOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of GridSampleOp should not be null."); - - auto x_dims = ctx->GetInputDim("X"); - auto grid_dims = ctx->GetInputDim("Grid"); - PADDLE_ENFORCE(x_dims.size() == 4, "Input(X) of GridSampleOp should be 4-D Tensor."); - PADDLE_ENFORCE(grid_dims.size() == 4, "Input(Grid) of GridSampleOp should be 4-D Tensor."); - PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); - PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], "Input(X) and Input(Grid) dims[0] should be equal."); - PADDLE_ENFORCE_EQ(grid_dims[1], x_dims[2], "Input(X) dims[2] and Input(Grid) dims[1] should be equal."); - PADDLE_ENFORCE_EQ(grid_dims[2], x_dims[3], "Input(X) dims[3] and Input(Grid) dims[2] should be equal."); - - ctx->SetOutputDim("Output", x_dims); - ctx->ShareLoD("X", "Output"); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of GridSampleOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grid"), + "Input(Grid) of GridSampleOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of GridSampleOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto grid_dims = ctx->GetInputDim("Grid"); + PADDLE_ENFORCE(x_dims.size() == 4, + "Input(X) of GridSampleOp should be 4-D Tensor."); + PADDLE_ENFORCE(grid_dims.size() == 4, + "Input(Grid) of GridSampleOp should be 4-D Tensor."); + PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); + PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], + "Input(X) and Input(Grid) dims[0] should be equal."); + PADDLE_ENFORCE_EQ( + grid_dims[1], x_dims[2], + "Input(X) dims[2] and Input(Grid) dims[1] should be equal."); + PADDLE_ENFORCE_EQ( + grid_dims[2], x_dims[3], + "Input(X) dims[3] and Input(Grid) dims[2] should be equal."); + + ctx->SetOutputDim("Output", x_dims); + ctx->ShareLoD("X", "Output"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; } +#endif + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); + } }; class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "X", - "(Tensor) The input data of GridSampleOp, " - "This is a 4-D tensor with shape of [N, C, H, W]"); - AddInput( - "Grid", - "(Tensor) The input grid of GridSampleOp generated by AffineGridOp, " - "This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation " - "of x and y coordinates with shape [N, H, W] in last dimention"); - AddOutput( - "Output", - "(Tensor) Output tensor with shape [N, C, H, W]"); - AddAttr( - "use_cudnn", - "(bool, default true) Only used in cudnn kernel, need install cudnn") - .SetDefault(true); - - AddComment(R"DOC( - It sample input X by grid gennerate by AffineGridOp. The grid of shape - [N, H, W, 2] is the concatenation of (x, y) coordinates with shape - [N, H, W] each, with x indexing the 4th-D(W) of input feature map and y to - indexng the 3rd-D(H), finally results is the bilinear interpolation value - of 4 nearest corner points. + public: + void Make() override { + AddInput("X", + "(Tensor) The input data of GridSampleOp, " + "This is a 4-D tensor with shape of [N, C, H, W]"); + AddInput( + "Grid", + "(Tensor) The input grid of GridSampleOp generated by AffineGridOp, " + "This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation " + "of x and y coordinates with shape [N, H, W] in last dimention"); + AddOutput("Output", "(Tensor) Output tensor with shape [N, C, H, W]"); + AddAttr( + "use_cudnn", + "(bool, default true) Only used in cudnn kernel, need install cudnn") + .SetDefault(true); + + AddComment(R"DOC( + This operation samples input X by using bilinear interpolation based on + flow field grid, which is usually gennerated by affine_grid. The grid of + shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates + with shape [N, H, W] each, where grid_x is indexing the 4th dimension + (in width dimension) of input data x and grid_y is indexng the 3rd + dimention (in height dimension), finally results is the bilinear + interpolation value of 4 nearest corner points. Step 1: Get (x, y) grid coordinates and scale to [0, H-1/W-1]. @@ -127,11 +133,11 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { output = wn * d_e * d_s + en * d_w * d_s + ws * d_e * d_n + es * d_w * d_n )DOC"); - } + } }; class GridSampleOpGrad : public framework::OperatorWithKernel { - public: + public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { auto input_dims = ctx->GetInputDim("X"); @@ -144,43 +150,43 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { } } - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; - } -#endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; } +#endif + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); + } }; class GridSampleGradMaker : public framework::SingleGradOpDescMaker { - public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; - - protected: - std::unique_ptr Apply() const override { - auto* op = new framework::OpDesc(); - op->SetType("grid_sampler_grad"); - op->SetInput("X", Input("X")); - op->SetInput("Grid", Input("Grid")); - op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); - - op->SetAttrMap(Attrs()); - - op->SetOutput(framework::GradVarName("X"), InputGrad("X")); - op->SetOutput(framework::GradVarName("Grid"), InputGrad("Grid")); - return std::unique_ptr(op); - } + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("grid_sampler_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Grid", Input("Grid")); + op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Grid"), InputGrad("Grid")); + return std::unique_ptr(op); + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker, diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h index 1e8f36567f..0d5874fc0c 100644 --- a/paddle/fluid/operators/grid_sampler_op.h +++ b/paddle/fluid/operators/grid_sampler_op.h @@ -19,19 +19,17 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/hostdevice.h" - namespace paddle { namespace operators { using Tensor = framework::Tensor; template + typename IndexType = Eigen::DenseIndex> using EigenTensor = framework::EigenTensor; using Array3 = Eigen::DSizes; using Array4 = Eigen::DSizes; - template static inline bool isInBound(T x, T y, T x_max, T y_max) { if (x < 0 || x > x_max || y < 0 || y > y_max) { @@ -40,16 +38,17 @@ static inline bool isInBound(T x, T y, T x_max, T y_max) { return true; } -template -static void CalcGridLocations(const DeviceContext& ctx, const Tensor& grid, - Tensor* x_w, Tensor* x_e, Tensor* y_n, Tensor* y_s, - Tensor* d_w, Tensor* d_e, Tensor* d_n, Tensor* d_s) { +template +static void CalcGridLocations(const platform::CPUDeviceContext& ctx, + const Tensor& grid, Tensor* x_w, Tensor* x_e, + Tensor* y_n, Tensor* y_s, Tensor* d_w, + Tensor* d_e, Tensor* d_n, Tensor* d_s) { auto& place = *ctx.eigen_device(); const int n = grid.dims()[0]; const int h = grid.dims()[1]; const int w = grid.dims()[2]; - const T x_max = static_cast (w - 1); - const T y_max = static_cast (h - 1); + const T x_max = static_cast(w - 1); + const T y_max = static_cast(h - 1); // split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim Tensor grid_x, grid_y; @@ -102,7 +101,7 @@ static void CalcGridLocations(const DeviceContext& ctx, const Tensor& grid, template static void GetGridPointValue(const Tensor& input, Tensor* output, - const Tensor& x, const Tensor& y) { + const Tensor& x, const Tensor& y) { const int n = input.dims()[0]; const int c = input.dims()[1]; const int h = input.dims()[2]; @@ -117,7 +116,9 @@ static void GetGridPointValue(const Tensor& input, Tensor* output, for (int l = 0; l < w; l++) { if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) { for (int j = 0; j < c; j++) { - output_t(i, j, k, l) = input_t(i, j, (int)round(y_t(i, k, l)), (int)round(x_t(i, k, l))); + output_t(i, j, k, l) = + input_t(i, j, static_cast(round(y_t(i, k, l))), + static_cast(round(x_t(i, k, l)))); } } } @@ -126,9 +127,10 @@ static void GetGridPointValue(const Tensor& input, Tensor* output, } template -static void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input_grad, - const Tensor& x, const Tensor& y, - const Tensor& d1, const Tensor& d2) { +static void GatherOutputGradToInputGrad(const Tensor& output_grad, + Tensor* input_grad, const Tensor& x, + const Tensor& y, const Tensor& d1, + const Tensor& d2) { const int n = output_grad.dims()[0]; const int c = output_grad.dims()[1]; const int h = output_grad.dims()[2]; @@ -143,10 +145,11 @@ static void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input for (int i = 0; i < n; i++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - if(isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) { + if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) { for (int j = 0; j < c; j++) { - input_grad_t(i, j, (int) y_t(i, k, l), (int) x_t(i, k, l)) += - output_grad_t(i, j, k ,l) * d1_t(i, k, l) * d2_t(i, k, l); + input_grad_t(i, j, static_cast(round(y_t(i, k, l))), + static_cast(round(x_t(i, k, l)))) += + output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l); } } } @@ -154,162 +157,166 @@ static void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input } } - - template class GridSampleOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& place = *ctx.template device_context().eigen_device(); - auto* input = ctx.Input("X"); - auto* grid = ctx.Input("Grid"); - - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - - // calc locations and distances of 4 corner points - Tensor x_w, x_e, y_n, y_s; - Tensor d_w, d_e, d_n, d_s; - CalcGridLocations(ctx.template device_context(), - *grid, - &x_w, &x_e, &y_n, &y_s, - &d_w, &d_e, &d_n, &d_s); - - auto* output = ctx.Output("Output"); - output->mutable_data({n, c, h, w}, ctx.GetPlace()); - math::SetConstant()( - ctx.template device_context(), output, - static_cast(0)); - - // calc 4 corner points value - Tensor v_wn, v_en, v_ws, v_es; - v_wn.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_en.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_ws.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_es.mutable_data({n, c, h, w}, ctx.GetPlace()); - GetGridPointValue(*input, &v_wn, x_w, y_n); - GetGridPointValue(*input, &v_en, x_e, y_n); - GetGridPointValue(*input, &v_ws, x_w, y_s); - GetGridPointValue(*input, &v_es, x_e, y_s); - - auto d_w_t = EigenTensor::From(d_w); - auto d_e_t = EigenTensor::From(d_e); - auto d_n_t = EigenTensor::From(d_n); - auto d_s_t = EigenTensor::From(d_s); - auto d_w_scaled_t = d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); - auto d_e_scaled_t = d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); - auto d_n_scaled_t = d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); - auto d_s_scaled_t = d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); - auto v_wn_t = EigenTensor::From(v_wn); - auto v_en_t = EigenTensor::From(v_en); - auto v_ws_t = EigenTensor::From(v_ws); - auto v_es_t = EigenTensor::From(v_es); - auto output_t = EigenTensor::From(*output); - //bilinear interpolaetion by 4 corner points - output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t - + v_en_t * d_w_scaled_t * d_s_scaled_t - + v_ws_t * d_e_scaled_t * d_n_scaled_t - + v_es_t * d_w_scaled_t * d_n_scaled_t; - } - + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + // calc locations and distances of 4 corner points + Tensor x_w, x_e, y_n, y_s; + Tensor d_w, d_e, d_n, d_s; + CalcGridLocations( + ctx.template device_context(), *grid, &x_w, + &x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s); + + auto* output = ctx.Output("Output"); + output->mutable_data({n, c, h, w}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), output, + static_cast(0)); + + // calc 4 corner points value + Tensor v_wn, v_en, v_ws, v_es; + v_wn.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_en.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_ws.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_es.mutable_data({n, c, h, w}, ctx.GetPlace()); + GetGridPointValue(*input, &v_wn, x_w, y_n); + GetGridPointValue(*input, &v_en, x_e, y_n); + GetGridPointValue(*input, &v_ws, x_w, y_s); + GetGridPointValue(*input, &v_es, x_e, y_s); + + auto d_w_t = EigenTensor::From(d_w); + auto d_e_t = EigenTensor::From(d_e); + auto d_n_t = EigenTensor::From(d_n); + auto d_s_t = EigenTensor::From(d_s); + auto d_w_scaled_t = + d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); + auto d_e_scaled_t = + d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); + auto d_n_scaled_t = + d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); + auto d_s_scaled_t = + d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1)); + auto v_wn_t = EigenTensor::From(v_wn); + auto v_en_t = EigenTensor::From(v_en); + auto v_ws_t = EigenTensor::From(v_ws); + auto v_es_t = EigenTensor::From(v_es); + auto output_t = EigenTensor::From(*output); + // bilinear interpolaetion by 4 corner points + output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t + + v_en_t * d_w_scaled_t * d_s_scaled_t + + v_ws_t * d_e_scaled_t * d_n_scaled_t + + v_es_t * d_w_scaled_t * d_n_scaled_t; + } }; template class GridSampleGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* grid = ctx.Input("Grid"); - auto* output_grad = ctx.Input(framework::GradVarName("Output")); - - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - - auto* input_grad = ctx.Output(framework::GradVarName("X")); - input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - math::SetConstant()( - ctx.template device_context(), input_grad, - static_cast(0)); - auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); - grid_grad->mutable_data({n, h, w, 2}, ctx.GetPlace()); - math::SetConstant()( - ctx.template device_context(), grid_grad, - static_cast(0)); - - Tensor x_w, x_e, y_n, y_s; - Tensor d_w, d_e, d_n, d_s; - CalcGridLocations(ctx.template device_context(), - *grid, - &x_w, &x_e, &y_n, &y_s, - &d_w, &d_e, &d_n, &d_s); - - // gather output grad value to input grad by corner point coords and weight - GatherOutputGradToInputGrad(*output_grad, input_grad, x_w, y_n, d_e, d_s); - GatherOutputGradToInputGrad(*output_grad, input_grad, x_w, y_s, d_e, d_n); - GatherOutputGradToInputGrad(*output_grad, input_grad, x_e, y_n, d_w, d_s); - GatherOutputGradToInputGrad(*output_grad, input_grad, x_e, y_s, d_w, d_n); - - // calc 4 corner points value - Tensor v_wn, v_en, v_ws, v_es; - v_wn.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_en.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_ws.mutable_data({n, c, h, w}, ctx.GetPlace()); - v_es.mutable_data({n, c, h, w}, ctx.GetPlace()); - GetGridPointValue(*input, &v_wn, x_w, y_n); - GetGridPointValue(*input, &v_en, x_e, y_n); - GetGridPointValue(*input, &v_ws, x_w, y_s); - GetGridPointValue(*input, &v_es, x_e, y_s); - auto v_wn_t = EigenTensor::From(v_wn); - auto v_en_t = EigenTensor::From(v_en); - auto v_ws_t = EigenTensor::From(v_ws); - auto v_es_t = EigenTensor::From(v_es); - - auto d_w_t = EigenTensor::From(d_w); - auto d_e_t = EigenTensor::From(d_e); - auto d_n_t = EigenTensor::From(d_n); - auto d_s_t = EigenTensor::From(d_s); - - auto output_grad_t = EigenTensor::From(*output_grad); - - Tensor grid_grad_x, grid_grad_y; - grid_grad_x.mutable_data({n, h, w}, ctx.GetPlace()); - grid_grad_y.mutable_data({n, h, w}, ctx.GetPlace()); - auto grid_grad_x_t = EigenTensor::From(grid_grad_x).setConstant(0.0); - auto grid_grad_y_t = EigenTensor::From(grid_grad_y).setConstant(0.0); - for (int i = 0; i < n; i++) { - for(int j = 0; j < c; j++) { - for(int k = 0; k < h; k++) { - for(int l = 0; l < w; l++) { - grid_grad_x_t(i, k, l) += ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) - + (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) - * output_grad_t(i, j, k, l); - grid_grad_y_t(i, k, l) += ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) - + (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) - * output_grad_t(i, j, k, l); - } + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* grid = ctx.Input("Grid"); + auto* output_grad = ctx.Input(framework::GradVarName("Output")); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + auto* input_grad = ctx.Output(framework::GradVarName("X")); + input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), input_grad, + static_cast(0)); + auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); + grid_grad->mutable_data({n, h, w, 2}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), grid_grad, + static_cast(0)); + + Tensor x_w, x_e, y_n, y_s; + Tensor d_w, d_e, d_n, d_s; + CalcGridLocations( + ctx.template device_context(), *grid, &x_w, + &x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s); + + // gather output grad value to input grad by corner point coords and weight + GatherOutputGradToInputGrad(*output_grad, input_grad, x_w, y_n, d_e, + d_s); + GatherOutputGradToInputGrad(*output_grad, input_grad, x_w, y_s, d_e, + d_n); + GatherOutputGradToInputGrad(*output_grad, input_grad, x_e, y_n, d_w, + d_s); + GatherOutputGradToInputGrad(*output_grad, input_grad, x_e, y_s, d_w, + d_n); + + // calc 4 corner points value + Tensor v_wn, v_en, v_ws, v_es; + v_wn.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_en.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_ws.mutable_data({n, c, h, w}, ctx.GetPlace()); + v_es.mutable_data({n, c, h, w}, ctx.GetPlace()); + GetGridPointValue(*input, &v_wn, x_w, y_n); + GetGridPointValue(*input, &v_en, x_e, y_n); + GetGridPointValue(*input, &v_ws, x_w, y_s); + GetGridPointValue(*input, &v_es, x_e, y_s); + auto v_wn_t = EigenTensor::From(v_wn); + auto v_en_t = EigenTensor::From(v_en); + auto v_ws_t = EigenTensor::From(v_ws); + auto v_es_t = EigenTensor::From(v_es); + + auto d_w_t = EigenTensor::From(d_w); + auto d_e_t = EigenTensor::From(d_e); + auto d_n_t = EigenTensor::From(d_n); + auto d_s_t = EigenTensor::From(d_s); + + auto output_grad_t = EigenTensor::From(*output_grad); + + Tensor grid_grad_x, grid_grad_y; + grid_grad_x.mutable_data({n, h, w}, ctx.GetPlace()); + grid_grad_y.mutable_data({n, h, w}, ctx.GetPlace()); + auto grid_grad_x_t = EigenTensor::From(grid_grad_x).setConstant(0.0); + auto grid_grad_y_t = EigenTensor::From(grid_grad_y).setConstant(0.0); + for (int i = 0; i < n; i++) { + for (int j = 0; j < c; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + grid_grad_x_t(i, k, l) += + ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + + (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) * + output_grad_t(i, j, k, l); + grid_grad_y_t(i, k, l) += + ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + + (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) * + output_grad_t(i, j, k, l); } } } - const T x_max = static_cast(w - 1); - const T y_max = static_cast(h - 1); - grid_grad_x_t = grid_grad_x_t * (x_max / (T)2); - grid_grad_y_t = grid_grad_y_t * (y_max / (T)2); - - // gather grid_grad [x, y] in 3rd Dim - T* grid_grad_data = grid_grad->data(); - T* grid_grad_x_data = grid_grad_x.data(); - T* grid_grad_y_data = grid_grad_y.data(); - for (int i = 0; i < n * h * w; i++) { - grid_grad_data[2 * i] = grid_grad_x_data[i]; - grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; - } } - + const T x_max = static_cast(w - 1); + const T y_max = static_cast(h - 1); + grid_grad_x_t = grid_grad_x_t * (x_max / (T)2); + grid_grad_y_t = grid_grad_y_t * (y_max / (T)2); + + // gather grid_grad [x, y] in 3rd Dim + T* grid_grad_data = grid_grad->data(); + T* grid_grad_x_data = grid_grad_x.data(); + T* grid_grad_y_data = grid_grad_y.data(); + for (int i = 0; i < n * h * w; i++) { + grid_grad_data[2 * i] = grid_grad_x_data[i]; + grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; + } + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 140c8c3829..1ad66f0525 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -342,7 +342,7 @@ class ScopedPoolingDescriptor { }; class ScopedSpatialTransformerDescriptor { - public: + public: ScopedSpatialTransformerDescriptor() { PADDLE_ENFORCE(dynload::cudnnCreateSpatialTransformerDescriptor(&desc_)); } @@ -354,13 +354,13 @@ class ScopedSpatialTransformerDescriptor { inline cudnnSpatialTransformerDescriptor_t descriptor(const int nbDims, const int dimA[]) { PADDLE_ENFORCE(dynload::cudnnSetSpatialTransformerNdDescriptor( - desc_, CUDNN_SAMPLER_BILINEAR, CudnnDataType::type, nbDims, dimA)); + desc_, CUDNN_SAMPLER_BILINEAR, CudnnDataType::type, nbDims, dimA)); return desc_; } - private: - cudnnSpatialTransformerDescriptor_t desc_; - DISABLE_COPY_AND_ASSIGN(ScopedSpatialTransformerDescriptor); + private: + cudnnSpatialTransformerDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedSpatialTransformerDescriptor); }; inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 0a531ec118..d3d754b6f5 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -65,51 +65,51 @@ extern void EnforceCUDNNLoaded(const char* fn_name); * include all needed cudnn functions in HPPL * different cudnn version has different interfaces **/ -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor); \ - __macro(cudnnSetTensor4dDescriptorEx); \ - __macro(cudnnSetTensorNdDescriptor); \ - __macro(cudnnGetTensorNdDescriptor); \ - __macro(cudnnGetConvolutionNdForwardOutputDim); \ - __macro(cudnnGetConvolutionForwardAlgorithm); \ - __macro(cudnnCreateTensorDescriptor); \ - __macro(cudnnDestroyTensorDescriptor); \ - __macro(cudnnCreateFilterDescriptor); \ - __macro(cudnnSetFilter4dDescriptor); \ - __macro(cudnnSetFilterNdDescriptor); \ - __macro(cudnnGetFilterNdDescriptor); \ - __macro(cudnnSetPooling2dDescriptor); \ - __macro(cudnnSetPoolingNdDescriptor); \ - __macro(cudnnGetPoolingNdDescriptor); \ - __macro(cudnnDestroyFilterDescriptor); \ - __macro(cudnnCreateConvolutionDescriptor); \ - __macro(cudnnCreatePoolingDescriptor); \ - __macro(cudnnDestroyPoolingDescriptor); \ - __macro(cudnnSetConvolution2dDescriptor); \ - __macro(cudnnDestroyConvolutionDescriptor); \ - __macro(cudnnSetConvolutionNdDescriptor); \ - __macro(cudnnGetConvolutionNdDescriptor); \ - __macro(cudnnDeriveBNTensorDescriptor); \ - __macro(cudnnCreateSpatialTransformerDescriptor); \ - __macro(cudnnSetSpatialTransformerNdDescriptor); \ - __macro(cudnnDestroySpatialTransformerDescriptor);\ - __macro(cudnnSpatialTfGridGeneratorForward); \ - __macro(cudnnSpatialTfGridGeneratorBackward); \ - __macro(cudnnSpatialTfSamplerForward); \ - __macro(cudnnSpatialTfSamplerBackward); \ - __macro(cudnnCreate); \ - __macro(cudnnDestroy); \ - __macro(cudnnSetStream); \ - __macro(cudnnActivationForward); \ - __macro(cudnnConvolutionForward); \ - __macro(cudnnConvolutionBackwardBias); \ - __macro(cudnnGetConvolutionForwardWorkspaceSize); \ - __macro(cudnnTransformTensor); \ - __macro(cudnnPoolingForward); \ - __macro(cudnnPoolingBackward); \ - __macro(cudnnSoftmaxBackward); \ - __macro(cudnnSoftmaxForward); \ - __macro(cudnnGetVersion); \ +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor); \ + __macro(cudnnSetTensor4dDescriptorEx); \ + __macro(cudnnSetTensorNdDescriptor); \ + __macro(cudnnGetTensorNdDescriptor); \ + __macro(cudnnGetConvolutionNdForwardOutputDim); \ + __macro(cudnnGetConvolutionForwardAlgorithm); \ + __macro(cudnnCreateTensorDescriptor); \ + __macro(cudnnDestroyTensorDescriptor); \ + __macro(cudnnCreateFilterDescriptor); \ + __macro(cudnnSetFilter4dDescriptor); \ + __macro(cudnnSetFilterNdDescriptor); \ + __macro(cudnnGetFilterNdDescriptor); \ + __macro(cudnnSetPooling2dDescriptor); \ + __macro(cudnnSetPoolingNdDescriptor); \ + __macro(cudnnGetPoolingNdDescriptor); \ + __macro(cudnnDestroyFilterDescriptor); \ + __macro(cudnnCreateConvolutionDescriptor); \ + __macro(cudnnCreatePoolingDescriptor); \ + __macro(cudnnDestroyPoolingDescriptor); \ + __macro(cudnnSetConvolution2dDescriptor); \ + __macro(cudnnDestroyConvolutionDescriptor); \ + __macro(cudnnSetConvolutionNdDescriptor); \ + __macro(cudnnGetConvolutionNdDescriptor); \ + __macro(cudnnDeriveBNTensorDescriptor); \ + __macro(cudnnCreateSpatialTransformerDescriptor); \ + __macro(cudnnSetSpatialTransformerNdDescriptor); \ + __macro(cudnnDestroySpatialTransformerDescriptor); \ + __macro(cudnnSpatialTfGridGeneratorForward); \ + __macro(cudnnSpatialTfGridGeneratorBackward); \ + __macro(cudnnSpatialTfSamplerForward); \ + __macro(cudnnSpatialTfSamplerBackward); \ + __macro(cudnnCreate); \ + __macro(cudnnDestroy); \ + __macro(cudnnSetStream); \ + __macro(cudnnActivationForward); \ + __macro(cudnnConvolutionForward); \ + __macro(cudnnConvolutionBackwardBias); \ + __macro(cudnnGetConvolutionForwardWorkspaceSize); \ + __macro(cudnnTransformTensor); \ + __macro(cudnnPoolingForward); \ + __macro(cudnnPoolingBackward); \ + __macro(cudnnSoftmaxBackward); \ + __macro(cudnnSoftmaxForward); \ + __macro(cudnnGetVersion); \ __macro(cudnnGetErrorString); CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f4c2c2813f..a3ae9bdcf5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7586,11 +7586,13 @@ def hash(input, hash_size, num_hash=1, name=None): @templatedoc() def grid_sampler(x, grid, name=None): """ - It sample input X by grid gennerate by AffineGridOp. The grid of shape - [N, H, W, 2] is the concatenation of (x, y) coordinates with shape - [N, H, W] each, with x indexing the 4th-D(W) of input feature map and y to - indexng the 3rd-D(H), finally results is the bilinear interpolation value - of 4 nearest corner points. + This operation samples input X by using bilinear interpolation based on + flow field grid, which is usually gennerated by affine_grid. The grid of + shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates + with shape [N, H, W] each, where grid_x is indexing the 4th dimension + (in width dimension) of input data x and grid_y is indexng the 3rd + dimention (in height dimension), finally results is the bilinear + interpolation value of 4 nearest corner points. Step 1: Get (x, y) grid coordinates and scale to [0, H-1/W-1]. @@ -7636,7 +7638,16 @@ def grid_sampler(x, grid, name=None): name (str, default None): The name of this layer. Returns: - out(Variable): Output data indices by grid from x of shape [N, C, H, W]. + out(Variable): Output of shape [N, C, H, W] data samples input X + using bilnear interpolation based on input grid. + + Exmples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[3, 10, 32, 32], dtype='float32') + theta = fluid.layers.data(name='theta', shape=[3, 2, 3], dtype='float32') + grid = fluid.layers.affine_grid(input=theta, size=[3, 10, 32, 32]}) + out = fluid.layers.grid_sampler(x=x, grid=grid) """ helper = LayerHelper("grid_sampler", **locals()) @@ -7649,10 +7660,6 @@ def grid_sampler(x, grid, name=None): out = helper.create_tmp_variable(x.dtype) ipts = {'X': x, 'Grid': grid} - helper.apppend_op( - type='grid_sampler', - inputs=ipts, - outputs={'Output', out}) + helper.apppend_op(type='grid_sampler', inputs=ipts, outputs={'Output', out}) return out - diff --git a/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py b/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py index 5a0b2d41b2..c2529e0d70 100644 --- a/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py +++ b/python/paddle/fluid/tests/unittests/test_grid_sampler_op.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import unittest import numpy as np from op_test import OpTest @@ -23,11 +22,11 @@ def AffineGrid(theta, size): h = size[2] w = size[3] h_idx = np.repeat( - np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis] + np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis] w_idx = np.repeat( - np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis] + np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis] grid = np.concatenate( - [w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3 + [w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3 grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3 ret = np.zeros([n, h * w, 2]) @@ -37,6 +36,7 @@ def AffineGrid(theta, size): return ret.reshape([n, h, w, 2]).astype("float32") + def getGridPointValue(data, x, y): data_shape = data.shape N = data_shape[0] @@ -47,13 +47,15 @@ def getGridPointValue(data, x, y): for i in range(N): for j in range(H): for k in range(W): - if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[i, j, k] > W - 1: + if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[ + i, j, k] > W - 1: out[i, :, j, k] = 0 else: out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]] return out + def GridSampler(data, grid): dims = data.shape N = dims[0] @@ -71,7 +73,7 @@ def GridSampler(data, grid): x0 = np.floor(x).astype('int32') x1 = x0 + 1 - y0 = np.floor(y).astype('int32') + y0 = np.floor(y).astype('int32') y1 = y0 + 1 wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1)) @@ -87,6 +89,7 @@ def GridSampler(data, grid): out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float32') return out + class TestGridSamplerOp(OpTest): def setUp(self): self.initTestCase() @@ -115,5 +118,6 @@ class TestGridSamplerOp(OpTest): self.grid_shape = (2, 7, 3, 2) self.theta_shape = (2, 2, 3) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 17c94a1d47..c6493b2ecc 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -868,13 +868,12 @@ class TestBook(unittest.TestCase): def test_affine_grid_gen(self): program = Program() with program_guard(program): - x = layers.data(name='x', shape=[2, 5, 7, 3 ], dtype='float32') - grid = layers.data(name='grid', shape=[2, 5, 7, 2], dtype='float32' ) + x = layers.data(name='x', shape=[2, 5, 7, 3], dtype='float32') + grid = layers.data(name='grid', shape=[2, 5, 7, 2], dtype='float32') out = layers.grid_sampler(x, grid) self.assertIsNotNone(out) print(str(program)) - if __name__ == '__main__': unittest.main() -- GitLab