diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 3bbe7c2b8cd60be93cbe71cb1cdfe1b85aa7e461..bb0146dd0abcb5d3bc9f638871a32b0483b57def 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) diff --git a/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc b/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed71594ba5781590f3291d56c4ba1a4443003bd5 --- /dev/null +++ b/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedSpatialTransformerDescriptor = + platform::ScopedSpatialTransformerDescriptor; + +template +class CUDNNAffineGridOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto* theta = ctx.Input("Theta"); + auto* output = ctx.Output("Output"); + const T* theta_data = theta->data(); + + int n = theta->dims()[0]; + auto size_attr = ctx.Attr>("output_shape"); + Tensor h_sizes; + int* h_size_data; + if (size_attr.size() == 0) { + auto* output_shape = ctx.Input("OutputShape"); + framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); + h_size_data = h_sizes.data(); + } else { + h_size_data = h_sizes.mutable_data({4}, platform::CPUPlace()); + h_size_data[0] = n; + h_size_data[1] = size_attr[1]; + h_size_data[2] = size_attr[2]; + h_size_data[3] = size_attr[3]; + } + + T* output_data = output->mutable_data( + {n, h_size_data[2], h_size_data[3], 2}, ctx.GetPlace()); + ScopedSpatialTransformerDescriptor st_desc; + cudnnSpatialTransformerDescriptor_t cudnn_st_desc = + st_desc.descriptor(4, h_size_data); + + PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorForward( + handle, cudnn_st_desc, theta_data, output_data)); + } +}; + +template +class CUDNNAffineGridGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto theta_grad = ctx.Output(framework::GradVarName("Theta")); + + int n = output_grad->dims()[0]; + auto size_attr = ctx.Attr>("output_shape"); + Tensor h_sizes; + int* h_size_data; + if (size_attr.size() == 0) { + auto* output_shape = ctx.Input("OutputShape"); + framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); + h_size_data = h_sizes.data(); + } else { + h_size_data = h_sizes.mutable_data({4}, platform::CPUPlace()); + h_size_data[0] = n; + h_size_data[1] = size_attr[1]; + h_size_data[2] = size_attr[2]; + h_size_data[3] = size_attr[3]; + } + + ScopedSpatialTransformerDescriptor st_desc; + cudnnSpatialTransformerDescriptor_t cudnn_st_desc = + st_desc.descriptor(4, h_size_data); + + const T* output_grad_data = output_grad->data(); + T* theta_grad_data = theta_grad->mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorBackward( + handle, cudnn_st_desc, output_grad_data, theta_grad_data)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNAffineGridOpKernel, + paddle::operators::CUDNNAffineGridOpKernel); +REGISTER_OP_KERNEL(affine_grid_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNAffineGridGradOpKernel, + paddle::operators::CUDNNAffineGridGradOpKernel); diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ea28265a245c9cd1a35a79324a33f7cf208a159 --- /dev/null +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -0,0 +1,233 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/affine_grid_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +struct Linspace { + framework::Tensor operator()(T start, T end, int count, + const framework::ExecutionContext& ctx) { + Tensor numbers; + T* number_data = numbers.mutable_data({count}, platform::CPUPlace()); + T slice = (end - start) / (T)(count - 1); + for (int i = 0; i < count; ++i) { + number_data[i] = start + (T)i * slice; + } + return numbers; + } +}; + +class AffineGridOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Theta"), + "Input(Theta) of AffineGridOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of AffineGridOp should not be null."); + auto theta_dims = ctx->GetInputDim("Theta"); + PADDLE_ENFORCE(theta_dims.size() == 3, + "AffineGrid's Input(Theta) should be 3-D tensor."); + + auto output_shape = ctx->Attrs().Get>("output_shape"); + if (output_shape.size() == 0) { + PADDLE_ENFORCE(ctx->HasInput("OutputShape"), + "Input(OutputShape) of AffineGridOp should not be null if " + "attr(output_shape) is not configured."); + auto output_shape_dims = ctx->GetInputDim("OutputShape"); + PADDLE_ENFORCE(output_shape_dims.size() == 1, + "AffineGrid's Input(OutputShape) should be 1-D tensor."); + } else { + PADDLE_ENFORCE(output_shape.size() == 4, + "The size of attr(output_shape) should be 4."); + } + + PADDLE_ENFORCE(theta_dims[1] == 2, "Input(theta) dims[1] should be 2."); + PADDLE_ENFORCE(theta_dims[2] == 3, "Input(theta) dims[2] should be 3."); + // N * H * W * 2 + ctx->SetOutputDim("Output", + framework::make_ddim({theta_dims[0], -1, -1, 2})); + ctx->ShareLoD("Theta", "Output"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_CUDA + if (platform::CanCUDNNBeUsed(ctx)) { + library = framework::LibraryType::kCUDNN; + } +#endif + auto data_type = framework::ToDataType(ctx.Input("Theta")->type()); + return framework::OpKernelType(data_type, ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library); + } +}; + +class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Theta", + "(Tensor) A batch of affine transform parameters with shape [N, 2, 3]. " + "It is used to transform coordinate (x_0, y_0) to coordinate (x_1, " + "y_1)."); + AddInput("OutputShape", + "(Tensor) The shape of target image with format [N, C, H, W].") + .AsDispensable(); + AddOutput("Output", "(Tensor) Output Tensor with shape [N, H, W, 2]."); + AddAttr( + "use_cudnn", + "(bool, default false) Only used in cudnn kernel, need install cudnn") + .SetDefault(true); + AddAttr>( + "output_shape", + "The target output image shape with format [N, C, H, W].") + .SetDefault(std::vector()); + + AddComment(R"DOC( + It generates a grid of (x,y) coordinates using the parameters of the + affine transformation that correspond to a set of points where the input + feature map should be sampled to produce the transformed output feature map. + + Given: + Theta = [[[x_11, x_12, x_13] + [x_14, x_15, x_16]] + [[x_21, x_22, x_23] + [x_24, x_25, x_26]]] + + OutputShape = [2, 3, 5, 5] + + Step 1: + + Generate relative coordinates according to OutputShape. + The values of relative coordinates are in the interval between -1 and 1. + The shape of the relative coordinates is [2, H, W] as below: + + C = [[[-1. -1. -1. -1. -1. ] + [-0.5 -0.5 -0.5 -0.5 -0.5] + [ 0. 0. 0. 0. 0. ] + [ 0.5 0.5 0.5 0.5 0.5] + [ 1. 1. 1. 1. 1. ]] + [[-1. -0.5 0. 0.5 1. ] + [-1. -0.5 0. 0.5 1. ] + [-1. -0.5 0. 0.5 1. ] + [-1. -0.5 0. 0.5 1. ] + [-1. -0.5 0. 0.5 1. ]]] + C[0] is the coordinates in height axis and C[1] is the coordinates in width axis. + + Step2: + Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get: + C_ = [[-1. -1. 1. ] + [-0.5 -1. 1. ] + [ 0. -1. 1. ] + [ 0.5 -1. 1. ] + [ 1. -1. 1. ] + [-1. -0.5 1. ] + [-0.5 -0.5 1. ] + [ 0. -0.5 1. ] + [ 0.5 -0.5 1. ] + [ 1. -0.5 1. ] + [-1. 0. 1. ] + [-0.5 0. 1. ] + [ 0. 0. 1. ] + [ 0.5 0. 1. ] + [ 1. 0. 1. ] + [-1. 0.5 1. ] + [-0.5 0.5 1. ] + [ 0. 0.5 1. ] + [ 0.5 0.5 1. ] + [ 1. 0.5 1. ] + [-1. 1. 1. ] + [-0.5 1. 1. ] + [ 0. 1. 1. ] + [ 0.5 1. 1. ] + [ 1. 1. 1. ]] + Step3: + Compute output by equation $$Output[i] = C_ * Theta[i]^T$$ + )DOC"); + } +}; + +class AffineGridOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + auto theta_dims = ctx->GetInputDim("Theta"); + if (ctx->HasOutput(framework::GradVarName("Theta"))) { + ctx->SetOutputDim(framework::GradVarName("Theta"), theta_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_CUDA + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; + } +#endif + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Theta")->type()), + ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); + } +}; + +class AffineGridGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("affine_grid_grad"); + op->SetInput("Theta", Input("Theta")); + op->SetInput("OutputShape", Input("OutputShape")); + op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("Theta"), InputGrad("Theta")); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(affine_grid, ops::AffineGridOp, ops::AffineGridOpMaker, + ops::AffineGridGradMaker); +REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad); + +REGISTER_OP_CPU_KERNEL( + affine_grid, + ops::AffineGridOpKernel, + ops::AffineGridOpKernel); +REGISTER_OP_CPU_KERNEL( + affine_grid_grad, + ops::AffineGridGradOpKernel, + ops::AffineGridGradOpKernel); diff --git a/paddle/fluid/operators/affine_grid_op.h b/paddle/fluid/operators/affine_grid_op.h new file mode 100644 index 0000000000000000000000000000000000000000..07e26c292c3bafc4d98bd392a9e1e21a9eb383a8 --- /dev/null +++ b/paddle/fluid/operators/affine_grid_op.h @@ -0,0 +1,190 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenTensor = framework::EigenTensor; + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using Array3 = Eigen::DSizes; +using Array4 = Eigen::DSizes; + +/** + *Return a tensor with evenly spaced numbers over a specified interval. + */ +template +struct Linspace { + framework::Tensor operator()(T start, T end, int count, + const framework::ExecutionContext& ctx); +}; + +template +class AffineGridOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto* theta = ctx.Input("Theta"); + int n = theta->dims()[0]; + + auto size_attr = ctx.Attr>("output_shape"); + int h = 0; + int w = 0; + if (size_attr.size() == 0) { + auto* output_shape = ctx.Input("OutputShape"); + Tensor h_sizes; + framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); + const int* h_size_data = h_sizes.data(); + h = h_size_data[2]; + w = h_size_data[3]; + } else { + h = size_attr[2]; + w = size_attr[3]; + } + + auto* output = ctx.Output("Output"); + output->mutable_data({n, h, w, 2}, ctx.GetPlace()); + + math::SetConstant()( + ctx.template device_context(), output, + static_cast(0)); + + Linspace linspace; + // Get indexes of height with shape [height, width, 1] + auto h_idx = linspace((T)-1, (T)1, h, ctx); + auto h_idx_t = EigenTensor::From(h_idx); + // Get indexes of width with shape [height, width, 1] + auto w_idx = linspace((T)-1, (T)1, w, ctx); + auto w_idx_t = EigenTensor::From(w_idx); + // Get constant ones tensor with shape [height, width, 1] + Tensor ones; + ones.mutable_data({h, w, 1}, ctx.GetPlace()); + auto ones_t = EigenTensor::From(ones).setConstant((T)1); + // Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and + // ones + Tensor grid; + grid.mutable_data({n, h, w, 3}, ctx.GetPlace()); + auto grid_t = EigenTensor::From(grid); + + grid_t.device(place) = w_idx_t.reshape(Array2(1, w)) + .broadcast(Array2(h, 1)) + .reshape(Array3(h, w, 1)) + .concatenate(h_idx_t.reshape(Array2(1, h)) + .broadcast(Array2(w, 1)) + .shuffle(Array2(1, 0)) + .reshape(Array3(h, w, 1)), + 2) + .eval() + .concatenate(ones_t, 2) + .reshape(Array4(1, h, w, 3)) + .broadcast(Array4(n, 1, 1, 1)); + + // output = grid * theta.T + // TODO(wanghaoshuang): Refine batched matrix multiply + auto blas = math::GetBlas(ctx); + for (int i = 0; i < n; ++i) { + Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3}); + Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3}); + Tensor sliced_out = output->Slice(i, i + 1).Resize({h * w, 2}); + blas.MatMul(sliced_grid, false, sliced_theta, true, T(1), &sliced_out, + T(0)); + } + } +}; + +template +class AffineGridGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto theta_grad = ctx.Output(framework::GradVarName("Theta")); + + int n = output_grad->dims()[0]; + auto size_attr = ctx.Attr>("output_shape"); + int h = 0; + int w = 0; + if (size_attr.size() == 0) { + auto* output_shape = ctx.Input("OutputShape"); + Tensor h_sizes; + framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); + const int* h_size_data = h_sizes.data(); + h = h_size_data[2]; + w = h_size_data[3]; + } else { + h = size_attr[2]; + w = size_attr[3]; + } + + theta_grad->mutable_data({n, 2, 3}, ctx.GetPlace()); + + math::SetConstant()( + ctx.template device_context(), theta_grad, + static_cast(0)); + + Linspace linspace; + + // Get indexes of height with shape [height, width, 1] + auto h_idx = linspace((T)-1, (T)1, h, ctx); + auto h_idx_t = EigenTensor::From(h_idx); + // Get indexes of width with shape [height, width, 1] + auto w_idx = linspace((T)-1, (T)1, w, ctx); + auto w_idx_t = EigenTensor::From(w_idx); + // Get constant ones tensor with shape [height, width, 1] + Tensor ones; + ones.mutable_data({h, w, 1}, ctx.GetPlace()); + auto ones_t = EigenTensor::From(ones).setConstant((T)1); + // Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and + // ones + Tensor grid; + grid.mutable_data({n, h, w, 3}, ctx.GetPlace()); + auto grid_t = EigenTensor::From(grid); + grid_t.device(place) = w_idx_t.reshape(Array2(1, w)) + .broadcast(Array2(h, 1)) + .reshape(Array3(h, w, 1)) + .concatenate(h_idx_t.reshape(Array2(1, h)) + .broadcast(Array2(w, 1)) + .shuffle(Array2(1, 0)) + .reshape(Array3(h, w, 1)), + 2) + .eval() + .concatenate(ones_t, 2) + .reshape(Array4(1, h, w, 3)) + .broadcast(Array4(n, 1, 1, 1)); + // output = grid * theta.T + // TODO(wanghaoshuang): Refine batched matrix multiply + auto blas = math::GetBlas(ctx); + for (int i = 0; i < n; ++i) { + Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3}); + Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize({h * w, 2}); + Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3}); + blas.MatMul(sliced_out_grad, true, sliced_grid, false, T(1), + &sliced_theta_grad, T(0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index bb8b14bb9fa41942c3aa653ca224c0842fbf9a00..1ad66f0525e4857533d5cf5ef3ee041aec297794 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -341,6 +341,28 @@ class ScopedPoolingDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor); }; +class ScopedSpatialTransformerDescriptor { + public: + ScopedSpatialTransformerDescriptor() { + PADDLE_ENFORCE(dynload::cudnnCreateSpatialTransformerDescriptor(&desc_)); + } + ~ScopedSpatialTransformerDescriptor() { + PADDLE_ENFORCE(dynload::cudnnDestroySpatialTransformerDescriptor(desc_)); + } + + template + inline cudnnSpatialTransformerDescriptor_t descriptor(const int nbDims, + const int dimA[]) { + PADDLE_ENFORCE(dynload::cudnnSetSpatialTransformerNdDescriptor( + desc_, CUDNN_SAMPLER_BILINEAR, CudnnDataType::type, nbDims, dimA)); + return desc_; + } + + private: + cudnnSpatialTransformerDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedSpatialTransformerDescriptor); +}; + inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index e6353f67ef118072a2d8e49111e8ecc486589998..d3d754b6f58d25a9dfacafaf55d50b353a71ee6d 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -65,44 +65,51 @@ extern void EnforceCUDNNLoaded(const char* fn_name); * include all needed cudnn functions in HPPL * different cudnn version has different interfaces **/ -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor); \ - __macro(cudnnSetTensor4dDescriptorEx); \ - __macro(cudnnSetTensorNdDescriptor); \ - __macro(cudnnGetTensorNdDescriptor); \ - __macro(cudnnGetConvolutionNdForwardOutputDim); \ - __macro(cudnnGetConvolutionForwardAlgorithm); \ - __macro(cudnnCreateTensorDescriptor); \ - __macro(cudnnDestroyTensorDescriptor); \ - __macro(cudnnCreateFilterDescriptor); \ - __macro(cudnnSetFilter4dDescriptor); \ - __macro(cudnnSetFilterNdDescriptor); \ - __macro(cudnnGetFilterNdDescriptor); \ - __macro(cudnnSetPooling2dDescriptor); \ - __macro(cudnnSetPoolingNdDescriptor); \ - __macro(cudnnGetPoolingNdDescriptor); \ - __macro(cudnnDestroyFilterDescriptor); \ - __macro(cudnnCreateConvolutionDescriptor); \ - __macro(cudnnCreatePoolingDescriptor); \ - __macro(cudnnDestroyPoolingDescriptor); \ - __macro(cudnnSetConvolution2dDescriptor); \ - __macro(cudnnDestroyConvolutionDescriptor); \ - __macro(cudnnSetConvolutionNdDescriptor); \ - __macro(cudnnGetConvolutionNdDescriptor); \ - __macro(cudnnDeriveBNTensorDescriptor); \ - __macro(cudnnCreate); \ - __macro(cudnnDestroy); \ - __macro(cudnnSetStream); \ - __macro(cudnnActivationForward); \ - __macro(cudnnConvolutionForward); \ - __macro(cudnnConvolutionBackwardBias); \ - __macro(cudnnGetConvolutionForwardWorkspaceSize); \ - __macro(cudnnTransformTensor); \ - __macro(cudnnPoolingForward); \ - __macro(cudnnPoolingBackward); \ - __macro(cudnnSoftmaxBackward); \ - __macro(cudnnSoftmaxForward); \ - __macro(cudnnGetVersion); \ +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor); \ + __macro(cudnnSetTensor4dDescriptorEx); \ + __macro(cudnnSetTensorNdDescriptor); \ + __macro(cudnnGetTensorNdDescriptor); \ + __macro(cudnnGetConvolutionNdForwardOutputDim); \ + __macro(cudnnGetConvolutionForwardAlgorithm); \ + __macro(cudnnCreateTensorDescriptor); \ + __macro(cudnnDestroyTensorDescriptor); \ + __macro(cudnnCreateFilterDescriptor); \ + __macro(cudnnSetFilter4dDescriptor); \ + __macro(cudnnSetFilterNdDescriptor); \ + __macro(cudnnGetFilterNdDescriptor); \ + __macro(cudnnSetPooling2dDescriptor); \ + __macro(cudnnSetPoolingNdDescriptor); \ + __macro(cudnnGetPoolingNdDescriptor); \ + __macro(cudnnDestroyFilterDescriptor); \ + __macro(cudnnCreateConvolutionDescriptor); \ + __macro(cudnnCreatePoolingDescriptor); \ + __macro(cudnnDestroyPoolingDescriptor); \ + __macro(cudnnSetConvolution2dDescriptor); \ + __macro(cudnnDestroyConvolutionDescriptor); \ + __macro(cudnnSetConvolutionNdDescriptor); \ + __macro(cudnnGetConvolutionNdDescriptor); \ + __macro(cudnnDeriveBNTensorDescriptor); \ + __macro(cudnnCreateSpatialTransformerDescriptor); \ + __macro(cudnnSetSpatialTransformerNdDescriptor); \ + __macro(cudnnDestroySpatialTransformerDescriptor); \ + __macro(cudnnSpatialTfGridGeneratorForward); \ + __macro(cudnnSpatialTfGridGeneratorBackward); \ + __macro(cudnnSpatialTfSamplerForward); \ + __macro(cudnnSpatialTfSamplerBackward); \ + __macro(cudnnCreate); \ + __macro(cudnnDestroy); \ + __macro(cudnnSetStream); \ + __macro(cudnnActivationForward); \ + __macro(cudnnConvolutionForward); \ + __macro(cudnnConvolutionBackwardBias); \ + __macro(cudnnGetConvolutionForwardWorkspaceSize); \ + __macro(cudnnTransformTensor); \ + __macro(cudnnPoolingForward); \ + __macro(cudnnPoolingBackward); \ + __macro(cudnnSoftmaxBackward); \ + __macro(cudnnSoftmaxForward); \ + __macro(cudnnGetVersion); \ __macro(cudnnGetErrorString); CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b60a2438014de36c001b263cdae095237a8e389a..cdfa26dfe910ba65d2a6424ab083c57d8422ee69 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -154,6 +154,7 @@ __all__ = [ 'mul', 'sigmoid_cross_entropy_with_logits', 'maxout', + 'affine_grid', 'sequence_reverse', 'affine_channel', 'hash', @@ -6140,6 +6141,124 @@ def crop(x, shape=None, offsets=None, name=None): return out +def affine_grid(theta, out_shape, name=None): + """ + It generates a grid of (x,y) coordinates using the parameters of + the affine transformation that correspond to a set of points where + the input feature map should be sampled to produce the transformed + output feature map. + + .. code-block:: text + + * Case 1: + + Given: + + theta = [[[x_11, x_12, x_13] + [x_14, x_15, x_16]] + [[x_21, x_22, x_23] + [x_24, x_25, x_26]]] + + out_shape = [2, 3, 5, 5] + + Step 1: + + Generate normalized coordinates according to out_shape. + The values of the normalized coordinates are in the interval between -1 and 1. + The shape of the normalized coordinates is [2, H, W] as below: + + C = [[[-1. -1. -1. -1. -1. ] + [-0.5 -0.5 -0.5 -0.5 -0.5] + [ 0. 0. 0. 0. 0. ] + [ 0.5 0.5 0.5 0.5 0.5] + [ 1. 1. 1. 1. 1. ]] + [[-1. -0.5 0. 0.5 1. ] + [-1. -0.5 0. 0.5 1. ] + [-1. -0.5 0. 0.5 1. ] + [-1. -0.5 0. 0.5 1. ] + [-1. -0.5 0. 0.5 1. ]]] + C[0] is the coordinates in height axis and C[1] is the coordinates in width axis. + + Step2: + + Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get: + C_ = [[-1. -1. 1. ] + [-0.5 -1. 1. ] + [ 0. -1. 1. ] + [ 0.5 -1. 1. ] + [ 1. -1. 1. ] + [-1. -0.5 1. ] + [-0.5 -0.5 1. ] + [ 0. -0.5 1. ] + [ 0.5 -0.5 1. ] + [ 1. -0.5 1. ] + [-1. 0. 1. ] + [-0.5 0. 1. ] + [ 0. 0. 1. ] + [ 0.5 0. 1. ] + [ 1. 0. 1. ] + [-1. 0.5 1. ] + [-0.5 0.5 1. ] + [ 0. 0.5 1. ] + [ 0.5 0.5 1. ] + [ 1. 0.5 1. ] + [-1. 1. 1. ] + [-0.5 1. 1. ] + [ 0. 1. 1. ] + [ 0.5 1. 1. ] + [ 1. 1. 1. ]] + Step3: + Compute output by equation $$Output[i] = C_ * Theta[i]^T$$ + + Args: + theta (Variable): A batch of affine transform parameters with shape [N, 2, 3]. + out_shape (Variable | list | tuple): The shape of target output with format [N, C, H, W]. + out_shape can be a Variable or a list or tuple. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The output with shape [N, H, W, 2]. + + Raises: + ValueError: If the type of arguments is not supported. + + Examples: + + .. code-block:: python + theta = fluid.layers.data(name="x", shape=[2, 3], dtype="float32") + out_shape = fluid.layers.data(name="y", shape=[-1], dtype="float32") + data = fluid.layers.affine_grid(theta, out_shape) + + # or + data = fluid.layers.affine_grid(theta, [5, 3, 28, 28]) + + """ + helper = LayerHelper('affine_grid') + + if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ + isinstance(out_shape, Variable)): + raise ValueError("The out_shape should be a list, tuple or Variable.") + + if not isinstance(theta, Variable): + raise ValueError("The theta should be a Variable.") + + out = helper.create_variable_for_type_inference(theta.dtype) + ipts = {'Theta': theta} + attrs = {} + if isinstance(out_shape, Variable): + ipts['OutputShape'] = out_shape + else: + attrs['output_shape'] = out_shape + + helper.append_op( + type='affine_grid', + inputs=ipts, + outputs={'Output': out}, + attrs=None if len(attrs) == 0 else attrs) + return out + + def rank_loss(label, left, right, name=None): """ **Rank loss layer for RankNet** diff --git a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py new file mode 100644 index 0000000000000000000000000000000000000000..576d00940c4c7a5e30af5550e14b674a73e7df11 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py @@ -0,0 +1,79 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def AffineGrid(theta, size): + n = size[0] + w = size[3] + h = size[2] + h_idx = np.repeat( + np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis] + w_idx = np.repeat( + np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis] + grid = np.concatenate( + [w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3 + grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3 + + ret = np.zeros([n, h * w, 2]) + theta = theta.transpose([0, 2, 1]) + for i in range(len(theta)): + ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i]) + +# print ret.reshape([h * w, 2]).astype("float32") + return ret.reshape([n, h, w, 2]).astype("float32") + + +class TestAffineGridOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = "affine_grid" + theta = np.random.randint(1, 3, self.theta_shape).astype("float32") + theta = np.ones(self.theta_shape).astype("float32") + self.inputs = {'Theta': theta} + self.attrs = {"use_cudnn": True} + if self.dynamic_shape: + self.inputs['OutputShape'] = self.output_shape + else: + self.attrs['output_shape'] = self.output_shape + self.outputs = {'Output': AffineGrid(theta, self.output_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['Theta'], + 'Output', + no_grad_set=['OutputShape'], + max_relative_error=0.006) + + def initTestCase(self): + self.theta_shape = (3, 2, 3) + self.output_shape = np.array([3, 2, 5, 7]).astype("int32") + self.dynamic_shape = False + + +class TestAffineGridOpCase1(TestAffineGridOp): + def initTestCase(self): + self.theta_shape = (3, 2, 3) + self.output_shape = np.array([3, 2, 5, 7]).astype("int32") + self.dynamic_shape = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 50de468dba803d0a2a0c129ad04aac8a3822cdbc..8081813b71d6ecedd874046a95283587a15a004c 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -865,6 +865,22 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_affine_grid(self): + program = Program() + with program_guard(program): + data = layers.data(name='data', shape=[2, 3, 3], dtype="float32") + out, ids = layers.argsort(input=data, axis=1) + + theta = layers.data(name="theta", shape=[2, 3], dtype="float32") + out_shape = layers.data( + name="out_shape", shape=[-1], dtype="float32") + data_0 = layers.affine_grid(theta, out_shape) + data_1 = layers.affine_grid(theta, [5, 3, 28, 28]) + + self.assertIsNotNone(data_0) + self.assertIsNotNone(data_1) + print(str(program)) + if __name__ == '__main__': unittest.main()