未验证 提交 0c319e0b 编写于 作者: W whs 提交者: GitHub

Add affine grid generator op (#12238)

* Add affine grid generator.

* fix ffine grid.

* Add unitest.

* Add CPU kernel and fix unitest.

* Fix CPU kernel.

* Refine code.
test=develop

* Fix python api.
test=develop

* Update python api.
test=develop

* Fix comment.
test=develop

* Rename affine_grid_generator to affine_grid and enhence unitest.
test=develop

* Fix unitest.
test=develop
上级 d325e668
......@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using ScopedSpatialTransformerDescriptor =
platform::ScopedSpatialTransformerDescriptor;
template <typename T>
class CUDNNAffineGridOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto* theta = ctx.Input<Tensor>("Theta");
auto* output = ctx.Output<Tensor>("Output");
const T* theta_data = theta->data<T>();
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
Tensor h_sizes;
int* h_size_data;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
h_size_data = h_sizes.data<int>();
} else {
h_size_data = h_sizes.mutable_data<int>({4}, platform::CPUPlace());
h_size_data[0] = n;
h_size_data[1] = size_attr[1];
h_size_data[2] = size_attr[2];
h_size_data[3] = size_attr[3];
}
T* output_data = output->mutable_data<T>(
{n, h_size_data[2], h_size_data[3], 2}, ctx.GetPlace());
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, h_size_data);
PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorForward(
handle, cudnn_st_desc, theta_data, output_data));
}
};
template <typename T>
class CUDNNAffineGridGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
Tensor h_sizes;
int* h_size_data;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
h_size_data = h_sizes.data<int>();
} else {
h_size_data = h_sizes.mutable_data<int>({4}, platform::CPUPlace());
h_size_data[0] = n;
h_size_data[1] = size_attr[1];
h_size_data[2] = size_attr[2];
h_size_data[3] = size_attr[3];
}
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, h_size_data);
const T* output_grad_data = output_grad->data<T>();
T* theta_grad_data = theta_grad->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorBackward(
handle, cudnn_st_desc, output_grad_data, theta_grad_data));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNAffineGridOpKernel<float>,
paddle::operators::CUDNNAffineGridOpKernel<double>);
REGISTER_OP_KERNEL(affine_grid_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNAffineGridGradOpKernel<float>,
paddle::operators::CUDNNAffineGridGradOpKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/affine_grid_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct Linspace<paddle::platform::CPUDeviceContext, T> {
framework::Tensor operator()(T start, T end, int count,
const framework::ExecutionContext& ctx) {
Tensor numbers;
T* number_data = numbers.mutable_data<T>({count}, platform::CPUPlace());
T slice = (end - start) / (T)(count - 1);
for (int i = 0; i < count; ++i) {
number_data[i] = start + (T)i * slice;
}
return numbers;
}
};
class AffineGridOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Theta"),
"Input(Theta) of AffineGridOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of AffineGridOp should not be null.");
auto theta_dims = ctx->GetInputDim("Theta");
PADDLE_ENFORCE(theta_dims.size() == 3,
"AffineGrid's Input(Theta) should be 3-D tensor.");
auto output_shape = ctx->Attrs().Get<std::vector<int>>("output_shape");
if (output_shape.size() == 0) {
PADDLE_ENFORCE(ctx->HasInput("OutputShape"),
"Input(OutputShape) of AffineGridOp should not be null if "
"attr(output_shape) is not configured.");
auto output_shape_dims = ctx->GetInputDim("OutputShape");
PADDLE_ENFORCE(output_shape_dims.size() == 1,
"AffineGrid's Input(OutputShape) should be 1-D tensor.");
} else {
PADDLE_ENFORCE(output_shape.size() == 4,
"The size of attr(output_shape) should be 4.");
}
PADDLE_ENFORCE(theta_dims[1] == 2, "Input(theta) dims[1] should be 2.");
PADDLE_ENFORCE(theta_dims[2] == 3, "Input(theta) dims[2] should be 3.");
// N * H * W * 2
ctx->SetOutputDim("Output",
framework::make_ddim({theta_dims[0], -1, -1, 2}));
ctx->ShareLoD("Theta", "Output");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type());
return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library);
}
};
class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Theta",
"(Tensor) A batch of affine transform parameters with shape [N, 2, 3]. "
"It is used to transform coordinate (x_0, y_0) to coordinate (x_1, "
"y_1).");
AddInput("OutputShape",
"(Tensor) The shape of target image with format [N, C, H, W].")
.AsDispensable();
AddOutput("Output", "(Tensor) Output Tensor with shape [N, H, W, 2].");
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddAttr<std::vector<int>>(
"output_shape",
"The target output image shape with format [N, C, H, W].")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
It generates a grid of (x,y) coordinates using the parameters of the
affine transformation that correspond to a set of points where the input
feature map should be sampled to produce the transformed output feature map.
Given:
Theta = [[[x_11, x_12, x_13]
[x_14, x_15, x_16]]
[[x_21, x_22, x_23]
[x_24, x_25, x_26]]]
OutputShape = [2, 3, 5, 5]
Step 1:
Generate relative coordinates according to OutputShape.
The values of relative coordinates are in the interval between -1 and 1.
The shape of the relative coordinates is [2, H, W] as below:
C = [[[-1. -1. -1. -1. -1. ]
[-0.5 -0.5 -0.5 -0.5 -0.5]
[ 0. 0. 0. 0. 0. ]
[ 0.5 0.5 0.5 0.5 0.5]
[ 1. 1. 1. 1. 1. ]]
[[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]]]
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
Step2:
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
C_ = [[-1. -1. 1. ]
[-0.5 -1. 1. ]
[ 0. -1. 1. ]
[ 0.5 -1. 1. ]
[ 1. -1. 1. ]
[-1. -0.5 1. ]
[-0.5 -0.5 1. ]
[ 0. -0.5 1. ]
[ 0.5 -0.5 1. ]
[ 1. -0.5 1. ]
[-1. 0. 1. ]
[-0.5 0. 1. ]
[ 0. 0. 1. ]
[ 0.5 0. 1. ]
[ 1. 0. 1. ]
[-1. 0.5 1. ]
[-0.5 0.5 1. ]
[ 0. 0.5 1. ]
[ 0.5 0.5 1. ]
[ 1. 0.5 1. ]
[-1. 1. 1. ]
[-0.5 1. 1. ]
[ 0. 1. 1. ]
[ 0.5 1. 1. ]
[ 1. 1. 1. ]]
Step3:
Compute output by equation $$Output[i] = C_ * Theta[i]^T$$
)DOC");
}
};
class AffineGridOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto theta_dims = ctx->GetInputDim("Theta");
if (ctx->HasOutput(framework::GradVarName("Theta"))) {
ctx->SetOutputDim(framework::GradVarName("Theta"), theta_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
}
};
class AffineGridGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("affine_grid_grad");
op->SetInput("Theta", Input("Theta"));
op->SetInput("OutputShape", Input("OutputShape"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("Theta"), InputGrad("Theta"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(affine_grid, ops::AffineGridOp, ops::AffineGridOpMaker,
ops::AffineGridGradMaker);
REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad);
REGISTER_OP_CPU_KERNEL(
affine_grid,
ops::AffineGridOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AffineGridOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
affine_grid_grad,
ops::AffineGridGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AffineGridGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>;
using Array3 = Eigen::DSizes<int64_t, 3>;
using Array4 = Eigen::DSizes<int64_t, 4>;
/**
*Return a tensor with evenly spaced numbers over a specified interval.
*/
template <typename DeviceContext, typename T>
struct Linspace {
framework::Tensor operator()(T start, T end, int count,
const framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
class AffineGridOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto* theta = ctx.Input<Tensor>("Theta");
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), output,
static_cast<T>(0));
Linspace<DeviceContext, T> linspace;
// Get indexes of height with shape [height, width, 1]
auto h_idx = linspace((T)-1, (T)1, h, ctx);
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
// Get indexes of width with shape [height, width, 1]
auto w_idx = linspace((T)-1, (T)1, w, ctx);
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
// Get constant ones tensor with shape [height, width, 1]
Tensor ones;
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
// ones
Tensor grid;
grid.mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
auto grid_t = EigenTensor<T, 4>::From(grid);
grid_t.device(place) = w_idx_t.reshape(Array2(1, w))
.broadcast(Array2(h, 1))
.reshape(Array3(h, w, 1))
.concatenate(h_idx_t.reshape(Array2(1, h))
.broadcast(Array2(w, 1))
.shuffle(Array2(1, 0))
.reshape(Array3(h, w, 1)),
2)
.eval()
.concatenate(ones_t, 2)
.reshape(Array4(1, h, w, 3))
.broadcast(Array4(n, 1, 1, 1));
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx);
for (int i = 0; i < n; ++i) {
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3});
Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3});
Tensor sliced_out = output->Slice(i, i + 1).Resize({h * w, 2});
blas.MatMul(sliced_grid, false, sliced_theta, true, T(1), &sliced_out,
T(0));
}
}
};
template <typename DeviceContext, typename T>
class AffineGridGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), theta_grad,
static_cast<T>(0));
Linspace<DeviceContext, T> linspace;
// Get indexes of height with shape [height, width, 1]
auto h_idx = linspace((T)-1, (T)1, h, ctx);
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
// Get indexes of width with shape [height, width, 1]
auto w_idx = linspace((T)-1, (T)1, w, ctx);
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
// Get constant ones tensor with shape [height, width, 1]
Tensor ones;
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
// ones
Tensor grid;
grid.mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
auto grid_t = EigenTensor<T, 4>::From(grid);
grid_t.device(place) = w_idx_t.reshape(Array2(1, w))
.broadcast(Array2(h, 1))
.reshape(Array3(h, w, 1))
.concatenate(h_idx_t.reshape(Array2(1, h))
.broadcast(Array2(w, 1))
.shuffle(Array2(1, 0))
.reshape(Array3(h, w, 1)),
2)
.eval()
.concatenate(ones_t, 2)
.reshape(Array4(1, h, w, 3))
.broadcast(Array4(n, 1, 1, 1));
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx);
for (int i = 0; i < n; ++i) {
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3});
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize({h * w, 2});
Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3});
blas.MatMul(sliced_out_grad, true, sliced_grid, false, T(1),
&sliced_theta_grad, T(0));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -341,6 +341,28 @@ class ScopedPoolingDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
};
class ScopedSpatialTransformerDescriptor {
public:
ScopedSpatialTransformerDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreateSpatialTransformerDescriptor(&desc_));
}
~ScopedSpatialTransformerDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroySpatialTransformerDescriptor(desc_));
}
template <typename T>
inline cudnnSpatialTransformerDescriptor_t descriptor(const int nbDims,
const int dimA[]) {
PADDLE_ENFORCE(dynload::cudnnSetSpatialTransformerNdDescriptor(
desc_, CUDNN_SAMPLER_BILINEAR, CudnnDataType<T>::type, nbDims, dimA));
return desc_;
}
private:
cudnnSpatialTransformerDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedSpatialTransformerDescriptor);
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
......
......@@ -90,6 +90,13 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
......
......@@ -154,6 +154,7 @@ __all__ = [
'mul',
'sigmoid_cross_entropy_with_logits',
'maxout',
'affine_grid',
'sequence_reverse',
'affine_channel',
'hash',
......@@ -6140,6 +6141,124 @@ def crop(x, shape=None, offsets=None, name=None):
return out
def affine_grid(theta, out_shape, name=None):
"""
It generates a grid of (x,y) coordinates using the parameters of
the affine transformation that correspond to a set of points where
the input feature map should be sampled to produce the transformed
output feature map.
.. code-block:: text
* Case 1:
Given:
theta = [[[x_11, x_12, x_13]
[x_14, x_15, x_16]]
[[x_21, x_22, x_23]
[x_24, x_25, x_26]]]
out_shape = [2, 3, 5, 5]
Step 1:
Generate normalized coordinates according to out_shape.
The values of the normalized coordinates are in the interval between -1 and 1.
The shape of the normalized coordinates is [2, H, W] as below:
C = [[[-1. -1. -1. -1. -1. ]
[-0.5 -0.5 -0.5 -0.5 -0.5]
[ 0. 0. 0. 0. 0. ]
[ 0.5 0.5 0.5 0.5 0.5]
[ 1. 1. 1. 1. 1. ]]
[[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]]]
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
Step2:
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
C_ = [[-1. -1. 1. ]
[-0.5 -1. 1. ]
[ 0. -1. 1. ]
[ 0.5 -1. 1. ]
[ 1. -1. 1. ]
[-1. -0.5 1. ]
[-0.5 -0.5 1. ]
[ 0. -0.5 1. ]
[ 0.5 -0.5 1. ]
[ 1. -0.5 1. ]
[-1. 0. 1. ]
[-0.5 0. 1. ]
[ 0. 0. 1. ]
[ 0.5 0. 1. ]
[ 1. 0. 1. ]
[-1. 0.5 1. ]
[-0.5 0.5 1. ]
[ 0. 0.5 1. ]
[ 0.5 0.5 1. ]
[ 1. 0.5 1. ]
[-1. 1. 1. ]
[-0.5 1. 1. ]
[ 0. 1. 1. ]
[ 0.5 1. 1. ]
[ 1. 1. 1. ]]
Step3:
Compute output by equation $$Output[i] = C_ * Theta[i]^T$$
Args:
theta (Variable): A batch of affine transform parameters with shape [N, 2, 3].
out_shape (Variable | list | tuple): The shape of target output with format [N, C, H, W].
out_shape can be a Variable or a list or tuple.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The output with shape [N, H, W, 2].
Raises:
ValueError: If the type of arguments is not supported.
Examples:
.. code-block:: python
theta = fluid.layers.data(name="x", shape=[2, 3], dtype="float32")
out_shape = fluid.layers.data(name="y", shape=[-1], dtype="float32")
data = fluid.layers.affine_grid(theta, out_shape)
# or
data = fluid.layers.affine_grid(theta, [5, 3, 28, 28])
"""
helper = LayerHelper('affine_grid')
if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \
isinstance(out_shape, Variable)):
raise ValueError("The out_shape should be a list, tuple or Variable.")
if not isinstance(theta, Variable):
raise ValueError("The theta should be a Variable.")
out = helper.create_variable_for_type_inference(theta.dtype)
ipts = {'Theta': theta}
attrs = {}
if isinstance(out_shape, Variable):
ipts['OutputShape'] = out_shape
else:
attrs['output_shape'] = out_shape
helper.append_op(
type='affine_grid',
inputs=ipts,
outputs={'Output': out},
attrs=None if len(attrs) == 0 else attrs)
return out
def rank_loss(label, left, right, name=None):
"""
**Rank loss layer for RankNet**
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
def AffineGrid(theta, size):
n = size[0]
w = size[3]
h = size[2]
h_idx = np.repeat(
np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis]
w_idx = np.repeat(
np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis]
grid = np.concatenate(
[w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3
grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3
ret = np.zeros([n, h * w, 2])
theta = theta.transpose([0, 2, 1])
for i in range(len(theta)):
ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])
# print ret.reshape([h * w, 2]).astype("float32")
return ret.reshape([n, h, w, 2]).astype("float32")
class TestAffineGridOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "affine_grid"
theta = np.random.randint(1, 3, self.theta_shape).astype("float32")
theta = np.ones(self.theta_shape).astype("float32")
self.inputs = {'Theta': theta}
self.attrs = {"use_cudnn": True}
if self.dynamic_shape:
self.inputs['OutputShape'] = self.output_shape
else:
self.attrs['output_shape'] = self.output_shape
self.outputs = {'Output': AffineGrid(theta, self.output_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['Theta'],
'Output',
no_grad_set=['OutputShape'],
max_relative_error=0.006)
def initTestCase(self):
self.theta_shape = (3, 2, 3)
self.output_shape = np.array([3, 2, 5, 7]).astype("int32")
self.dynamic_shape = False
class TestAffineGridOpCase1(TestAffineGridOp):
def initTestCase(self):
self.theta_shape = (3, 2, 3)
self.output_shape = np.array([3, 2, 5, 7]).astype("int32")
self.dynamic_shape = True
if __name__ == '__main__':
unittest.main()
......@@ -865,6 +865,22 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_affine_grid(self):
program = Program()
with program_guard(program):
data = layers.data(name='data', shape=[2, 3, 3], dtype="float32")
out, ids = layers.argsort(input=data, axis=1)
theta = layers.data(name="theta", shape=[2, 3], dtype="float32")
out_shape = layers.data(
name="out_shape", shape=[-1], dtype="float32")
data_0 = layers.affine_grid(theta, out_shape)
data_1 = layers.affine_grid(theta, [5, 3, 28, 28])
self.assertIsNotNone(data_0)
self.assertIsNotNone(data_1)
print(str(program))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册