未验证 提交 d94b9686 编写于 作者: T Thomas Young 提交者: GitHub

[operator migration] Migrate affine grid op (#44663)

* save change

* save change by YSL

* save change by YSL

* change by YSL

* test pre commit

* Revert "test pre commit"

This reverts commit eee5e116331186cc544de871b4a5174a6431f17c.

* fix code style

* fix ctest

* temp save

* save change

* change by YSL

* final change by ysl

* fix ci

* fix code style

* delete unuse code

* change by ysl
上级 51ed2788
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cudnnSpatialTfGridGeneratorForward
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using ScopedSpatialTransformerDescriptor =
platform::ScopedSpatialTransformerDescriptor;
template <typename T>
class CUDNNAffineGridOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()),
true,
platform::errors::InvalidArgument(
"Only support for CUDAPlace.Please switch your context from "
"CPUPlace to CUDAPlace or update your cudnn."));
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
auto handle = dev_ctx.cudnn_handle();
auto* theta = ctx.Input<Tensor>("Theta");
auto* output = ctx.Output<Tensor>("Output");
const T* theta_data = theta->data<T>();
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
Tensor h_sizes;
int* h_size_data;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
h_size_data = h_sizes.data<int>();
} else {
h_size_data = h_sizes.mutable_data<int>({4}, platform::CPUPlace());
h_size_data[0] = n;
h_size_data[1] = size_attr[1];
h_size_data[2] = size_attr[2];
h_size_data[3] = size_attr[3];
}
T* output_data = output->mutable_data<T>(
{n, h_size_data[2], h_size_data[3], 2}, ctx.GetPlace());
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, h_size_data);
PADDLE_ENFORCE_EQ(
platform::dynload::cudnnSpatialTfGridGeneratorForward(
handle, cudnn_st_desc, theta_data, output_data),
0,
platform::errors::Fatal("Some errors has occurred "
"during forward computation in cudnn."));
}
};
template <typename T>
class CUDNNAffineGridGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()),
true,
platform::errors::InvalidArgument(
"Only "
"support for CUDAPlace. Please switch "
"your context from CPUPlace to "
"CUDAPlace or update your cudnn."));
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
auto handle = dev_ctx.cudnn_handle();
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
Tensor h_sizes;
int* h_size_data;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
h_size_data = h_sizes.data<int>();
} else {
h_size_data = h_sizes.mutable_data<int>({4}, platform::CPUPlace());
h_size_data[0] = n;
h_size_data[1] = size_attr[1];
h_size_data[2] = size_attr[2];
h_size_data[3] = size_attr[3];
}
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, h_size_data);
const T* output_grad_data = output_grad->data<T>();
T* theta_grad_data = theta_grad->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnSpatialTfGridGeneratorBackward(
handle, cudnn_st_desc, output_grad_data, theta_grad_data));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(affine_grid,
CUDNN,
plat::CUDAPlace,
paddle::operators::CUDNNAffineGridOpKernel<float>,
paddle::operators::CUDNNAffineGridOpKernel<double>);
REGISTER_OP_KERNEL(affine_grid_grad,
CUDNN,
plat::CUDAPlace,
paddle::operators::CUDNNAffineGridGradOpKernel<float>,
paddle::operators::CUDNNAffineGridGradOpKernel<double>);
#endif // not PADDLE_WITH_HIP
......@@ -12,41 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/affine_grid_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct Linspace<phi::CPUContext, T> {
void operator()(T start,
T end,
int count,
bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx) {
T* number_data = numbers->mutable_data<T>({count}, platform::CPUPlace());
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
for (int i = 0; i < count; ++i) {
number_data[i] = start + (T)i * slice;
}
}
};
class AffineGridOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -280,14 +263,8 @@ REGISTER_OPERATOR(affine_grid,
ops::AffineGridOpMaker,
ops::AffineGridGradMaker<paddle::framework::OpDesc>,
ops::AffineGridGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad);
REGISTER_OP_CPU_KERNEL(affine_grid,
ops::AffineGridOpKernel<phi::CPUContext, float>,
ops::AffineGridOpKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(affine_grid_grad,
ops::AffineGridGradOpKernel<phi::CPUContext, float>,
ops::AffineGridGradOpKernel<phi::CPUContext, double>);
REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad);
REGISTER_OP_VERSION(affine_grid)
.AddCheckpoint(
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/affine_grid_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
}
template <typename T>
struct Linspace<phi::GPUContext, T> {
void operator()(T start,
T end,
int count,
bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx) {
T* number_data = numbers->mutable_data<T>({count}, ctx.GetPlace());
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
auto stream = ctx.cuda_device_context().stream();
int block = 512;
int grid = (count + block - 1) / block;
LinspaceKernel<T>
<<<grid, block, 0, stream>>>(start, slice, count, number_data);
}
};
template <typename T>
__global__ void affine_grid_kernel(const int count,
int n,
int out_h,
int out_w,
T h_start,
T w_start,
T h_step,
T w_step,
const T* theta, // N, 2, 3
T* output) {
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int n = index / (out_w * out_h);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
int theta_offset = n * 6; // 2 * 3;
// affine from (h_coor, w_coor) to (x, y)
output[index * 2] = theta[theta_offset] * w_coor +
theta[theta_offset + 1] * h_coor +
theta[theta_offset + 2];
output[index * 2 + 1] = theta[theta_offset + 3] * w_coor +
theta[theta_offset + 4] * h_coor +
theta[theta_offset + 5];
}
}
template <typename T>
__global__ void affine_grid_grad_kernel(const int count,
int n,
int out_h,
int out_w,
T h_start,
T w_start,
T h_step,
T w_step,
const T* out_grad, // N, H, W, 2
T* theta_grad) { // N, 2, 3
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int n = index / (out_w * out_h);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
int theta_offset = n * 6; // 2 * 3;
T out_grad_x = out_grad[index * 2];
platform::CudaAtomicAdd(theta_grad + theta_offset, out_grad_x * w_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 1, out_grad_x * h_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x);
T out_grad_y = out_grad[index * 2 + 1];
platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_y * w_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 4, out_grad_y * h_coor);
platform::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y);
}
}
template <typename T>
class AffineGridOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* theta = ctx.Input<Tensor>("Theta");
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
auto* output = ctx.Output<Tensor>("Output");
T* out_data = output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
T h_step;
T w_step;
T h_start = -1;
T w_start = -1;
if (align_corners) {
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * h * w;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = ctx.cuda_device_context().stream();
affine_grid_kernel<<<grid, block, 0, cu_stream>>>(
count,
n,
h,
w,
h_start,
w_start,
h_step,
w_step,
theta->data<T>(), // N, 2, 3
out_data);
}
};
template <typename T>
class AffineGridGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
T* theta_grad_data = theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
phi::funcs::SetConstant<phi::GPUContext, T>()(
ctx.cuda_device_context(), theta_grad, static_cast<T>(0));
T h_step;
T w_step;
T h_start = -1;
T w_start = -1;
if (align_corners) {
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * h * w;
VLOG(3) << "count: " << count << "; h_step: " << h_step
<< "; w_step: " << w_step << "; h_start: " << h_start
<< "; w_start: " << w_start;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = ctx.cuda_device_context().stream();
affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>(
count,
n,
h,
w,
h_start,
w_start,
h_step,
w_step,
output_grad->data<T>(),
theta_grad_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(affine_grid,
ops::AffineGridOpCUDAKernel<float>,
ops::AffineGridOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(affine_grid_grad,
ops::AffineGridGradOpCUDAKernel<float>,
ops::AffineGridGradOpCUDAKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>;
using Array3 = Eigen::DSizes<int64_t, 3>;
using Array4 = Eigen::DSizes<int64_t, 4>;
/**
*Return a tensor with evenly spaced numbers over a specified interval.
*/
template <typename DeviceContext, typename T>
struct Linspace {
void operator()(T start,
T end,
int count,
bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
inline void GetIdxMap(int n,
int h,
int w,
bool align_corners,
Tensor* grid,
const framework::ExecutionContext& ctx) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
grid->mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
auto grid_t = EigenTensor<T, 4>::From(*grid);
// Get indexes of height with shape [height, width, 1]
Tensor h_idx;
Linspace<DeviceContext, T> linspace;
linspace((T)-1, (T)1, h, align_corners, &h_idx, ctx);
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
// Get indexes of width with shape [height, width, 1]
Tensor w_idx;
linspace((T)-1, (T)1, w, align_corners, &w_idx, ctx);
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
// Get constant ones tensor with shape [height, width, 1]
Tensor ones;
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &ones, static_cast<T>(1));
auto ones_t = EigenTensor<T, 3>::From(ones);
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
// ones
Tensor w_idx_map;
w_idx_map.mutable_data<T>({h, w, 1}, ctx.GetPlace());
auto w_idx_map_t = EigenTensor<T, 3>::From(w_idx_map);
Tensor h_idx_map;
h_idx_map.mutable_data<T>({h, w, 1}, ctx.GetPlace());
auto h_idx_map_t = EigenTensor<T, 3>::From(h_idx_map);
Tensor w_h_idx_map;
w_h_idx_map.mutable_data<T>({h, w, 2}, ctx.GetPlace());
auto w_h_idx_map_t = EigenTensor<T, 3>::From(w_h_idx_map);
Tensor w_h_one_idx_map;
w_h_one_idx_map.mutable_data<T>({h, w, 3}, ctx.GetPlace());
auto w_h_one_idx_map_t = EigenTensor<T, 3>::From(w_h_one_idx_map);
w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w))
.broadcast(Array2(h, 1))
.reshape(Array3(h, w, 1));
h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h))
.broadcast(Array2(w, 1))
.shuffle(Array2(1, 0))
.reshape(Array3(h, w, 1));
w_h_idx_map_t.device(place) = w_idx_map_t.concatenate(h_idx_map_t, 2);
w_h_one_idx_map_t.device(place) = w_h_idx_map_t.concatenate(ones_t, 2);
grid_t.device(place) = w_h_one_idx_map_t.reshape(Array4(1, h, w, 3))
.broadcast(Array4(n, 1, 1, 1));
}
template <typename DeviceContext, typename T>
class AffineGridOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* theta = ctx.Input<Tensor>("Theta");
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(),
output,
static_cast<T>(0));
Tensor grid;
GetIdxMap<DeviceContext, T>(n, h, w, align_corners, &grid, ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
for (int i = 0; i < n; ++i) {
Tensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3});
Tensor sliced_out = output->Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
blas.MatMul(
sliced_grid, false, sliced_theta, true, T(1), &sliced_out, T(0));
}
}
};
template <typename DeviceContext, typename T>
class AffineGridGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(),
theta_grad,
static_cast<T>(0));
Tensor grid;
GetIdxMap<DeviceContext, T>(n, h, w, align_corners, &grid, ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
for (int i = 0; i < n; ++i) {
Tensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3});
blas.MatMul(sliced_out_grad,
true,
sliced_grid,
false,
T(1),
&sliced_theta_grad,
T(0));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -112,6 +112,19 @@
func : addmm
backward : addmm_grad
- api : affine_grid
args : (Tensor input, IntArray outputShape, bool use_cudnn=true, bool align_corners=true)
output : Tensor
infer_meta :
func : AffineGridInferMeta
param : [input, outputShape, align_corners]
kernel :
func : affine_grid
param : [input, outputShape, align_corners]
data_type : input
use_gpudnn: use_cudnn
backward : affine_grid_grad
- api : all
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
output : Tensor(out)
......
......@@ -92,6 +92,18 @@
kernel :
func : addmm_grad
- backward_api : affine_grid_grad
forward : affine_grid (Tensor input, IntArray outputShape, bool use_cudnn=true, bool align_corners=true) -> Tensor(output)
args : (Tensor output_grad, IntArray outputShape, bool use_cudnn=true, bool align_corners=true)
output : Tensor(input_grad)
infer_meta :
func : AffineGridGradInferMeta
param : [output_grad, outputShape, align_corners]
kernel :
func : affine_grid_grad
param : [output_grad, outputShape, align_corners]
use_gpudnn: use_cudnn
- backward_api : amax_grad
forward: amax (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false)
......
......@@ -18,6 +18,16 @@ limitations under the License. */
namespace phi {
void AffineGridGradInferMeta(const MetaTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
MetaTensor* input_grad) {
if (input_grad) {
auto output_dims = output_grad.dims();
input_grad->set_dims(phi::make_ddim({output_dims[0], 2, 3}));
}
}
void AngleGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
MetaTensor* x_grad) {
......
......@@ -27,6 +27,11 @@ namespace phi {
//
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
void AffineGridGradInferMeta(const MetaTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
MetaTensor* input_grad);
void AngleGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
MetaTensor* x_grad);
......
......@@ -44,6 +44,51 @@ static DDim CheckAndGetOutputDim(const DDim& dim_x) {
}
} // namespace detail
void AffineGridInferMeta(const MetaTensor& input,
const IntArray& outputShape,
bool align_corners,
MetaTensor* output) {
auto theta_dims = input.dims();
PADDLE_ENFORCE_EQ(
theta_dims.size(),
3,
phi::errors::InvalidArgument(
"The input Theta's dimensions size should be 3. But received "
"Theta's demensions size=[%d], Theta's dimensions=[%s].",
theta_dims.size(),
theta_dims));
PADDLE_ENFORCE_EQ(
outputShape.GetData().size(),
4,
phi::errors::InvalidArgument(
"The size of attribute 'output_shape' in AffineGridOp should be "
"4. But received output_shape's size=[%d].",
outputShape.GetData().size()));
PADDLE_ENFORCE_EQ(
theta_dims[1],
2,
phi::errors::InvalidArgument(
"The second dimesion of input 'theta' in AffineGridOp should be 2. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_EQ(
theta_dims[2],
3,
phi::errors::InvalidArgument(
"The third dimesion of input 'theta' in AffineGridOp should be 3. "
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
// N * H * W * 2
output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, 2}));
output->set_dtype(input.dtype());
output->share_lod(input);
}
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
......
......@@ -34,6 +34,11 @@ class MetaConfig;
//
// The InferMeta Functions in this file are arranged in alphabetic order.
void AffineGridInferMeta(const MetaTensor& input,
const IntArray& outputShape,
bool align_corners,
MetaTensor* output);
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/kernels/affine_grid_impl.h"
namespace phi {
template <typename T, typename Context>
void AffineGridGradKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>;
using Array3 = Eigen::DSizes<int64_t, 3>;
using Array4 = Eigen::DSizes<int64_t, 4>;
template <typename Context, typename T>
struct Linspace {
void operator()(T start,
T end,
int count,
bool align_corners,
DenseTensor* numbers,
const Context& dev_ctx);
};
template <typename Context, typename T>
inline void GetIdxMap(int n,
int h,
int w,
bool align_corners,
DenseTensor* grid,
const Context& dev_ctx) {
auto& place = *dev_ctx.eigen_device();
grid->Resize(phi::make_ddim({n, h, w, 3}));
dev_ctx.template Alloc<T>(grid);
auto grid_t = EigenTensor<T, 4>::From(*grid);
// Get indexes of height with shape [height, width, 1]
DenseTensor h_idx;
Linspace<Context, T> linspace;
linspace((T)-1, (T)1, h, align_corners, &h_idx, dev_ctx);
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
// Get indexes of width with shape [height, width, 1]
DenseTensor w_idx;
linspace((T)-1, (T)1, w, align_corners, &w_idx, dev_ctx);
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
// Get constant ones tensor with shape [height, width, 1]
DenseTensor ones;
ones.Resize(phi::make_ddim({h, w, 1}));
dev_ctx.template Alloc<T>(&ones);
phi::funcs::SetConstant<Context, T>()(dev_ctx, &ones, static_cast<T>(1));
auto ones_t = EigenTensor<T, 3>::From(ones);
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
// ones
DenseTensor w_idx_map;
w_idx_map.Resize(phi::make_ddim({h, w, 1}));
dev_ctx.template Alloc<T>(&w_idx_map);
auto w_idx_map_t = EigenTensor<T, 3>::From(w_idx_map);
DenseTensor h_idx_map;
h_idx_map.Resize(phi::make_ddim({h, w, 1}));
dev_ctx.template Alloc<T>(&h_idx_map);
auto h_idx_map_t = EigenTensor<T, 3>::From(h_idx_map);
DenseTensor w_h_idx_map;
w_h_idx_map.Resize(phi::make_ddim({h, w, 2}));
dev_ctx.template Alloc<T>(&w_h_idx_map);
auto w_h_idx_map_t = EigenTensor<T, 3>::From(w_h_idx_map);
DenseTensor w_h_one_idx_map;
w_h_one_idx_map.Resize(phi::make_ddim({h, w, 3}));
dev_ctx.template Alloc<T>(&w_h_one_idx_map);
auto w_h_one_idx_map_t = EigenTensor<T, 3>::From(w_h_one_idx_map);
w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w))
.broadcast(Array2(h, 1))
.reshape(Array3(h, w, 1));
h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h))
.broadcast(Array2(w, 1))
.shuffle(Array2(1, 0))
.reshape(Array3(h, w, 1));
w_h_idx_map_t.device(place) = w_idx_map_t.concatenate(h_idx_map_t, 2);
w_h_one_idx_map_t.device(place) = w_h_idx_map_t.concatenate(ones_t, 2);
grid_t.device(place) = w_h_one_idx_map_t.reshape(Array4(1, h, w, 3))
.broadcast(Array4(n, 1, 1, 1));
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/kernels/affine_grid_impl.h"
namespace phi {
template <typename T, typename Context>
void AffineGridKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/affine_grid_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
struct Linspace<phi::CPUContext, T> {
void operator()(T start,
T end,
int count,
bool align_corners,
DenseTensor* numbers,
const phi::CPUContext& dev_ctx) {
numbers->Resize(phi::make_ddim({count}));
T* number_data = dev_ctx.template Alloc<T>(numbers);
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
for (int i = 0; i < count; ++i) {
number_data[i] = start + (T)i * slice;
}
}
};
template <typename T, typename Context>
void AffineGridGradKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
auto& theta_grad = input_grad;
int n = output_grad.dims()[0];
auto& size_attr = outputShape.GetData();
int h = 0;
int w = 0;
h = size_attr[2];
w = size_attr[3];
theta_grad->Resize(phi::make_ddim({n, 2, 3}));
dev_ctx.template Alloc<T>(theta_grad);
phi::funcs::SetConstant<Context, T>()(dev_ctx, theta_grad, static_cast<T>(0));
DenseTensor grid;
GetIdxMap<Context, T>(n, h, w, align_corners, &grid, dev_ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
for (int i = 0; i < n; ++i) {
DenseTensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
DenseTensor sliced_out_grad = output_grad.Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
DenseTensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3});
blas.MatMul(sliced_out_grad,
true,
sliced_grid,
false,
T(1),
&sliced_theta_grad,
T(0));
}
}
} // namespace phi
PD_REGISTER_KERNEL(affine_grid_grad,
CPU,
ALL_LAYOUT,
phi::AffineGridGradKernel,
float,
double){};
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/affine_grid_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
struct Linspace<phi::CPUContext, T> {
void operator()(T start,
T end,
int count,
bool align_corners,
DenseTensor* numbers,
const phi::CPUContext& dev_ctx) {
numbers->Resize(phi::make_ddim({count}));
T* number_data = dev_ctx.template Alloc<T>(numbers);
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
for (int i = 0; i < count; ++i) {
number_data[i] = start + (T)i * slice;
}
}
};
template <typename T, typename Context>
void AffineGridKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
auto* theta = &input;
int n = theta->dims()[0];
auto& size_attr = outputShape.GetData();
int h = 0;
int w = 0;
h = size_attr[2];
w = size_attr[3];
output->Resize(phi::make_ddim({n, h, w, 2}));
dev_ctx.template Alloc<T>(output);
phi::funcs::SetConstant<Context, T>()(dev_ctx, output, static_cast<T>(0));
DenseTensor grid;
GetIdxMap<Context, T>(n, h, w, align_corners, &grid, dev_ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
for (int i = 0; i < n; ++i) {
DenseTensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
DenseTensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3});
DenseTensor sliced_out = output->Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
blas.MatMul(
sliced_grid, false, sliced_theta, true, T(1), &sliced_out, T(0));
}
}
} // namespace phi
PD_REGISTER_KERNEL(
affine_grid, CPU, ALL_LAYOUT, phi::AffineGridKernel, float, double){};
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/affine_grid_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
}
template <typename T>
struct Linspace<phi::GPUContext, T> {
void operator()(T start,
T end,
int count,
bool align_corners,
DenseTensor* numbers,
const phi::GPUContext& dev_ctx) {
numbers->Resize(phi::make_ddim({count}));
T* number_data = dev_ctx.template Alloc<T>(numbers);
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
auto stream = dev_ctx.stream();
int block = 512;
int grid = (count + block - 1) / block;
LinspaceKernel<T>
<<<grid, block, 0, stream>>>(start, slice, count, number_data);
}
};
template <typename T>
__global__ void affine_grid_grad_kernel(const int count,
int n,
int out_h,
int out_w,
T h_start,
T w_start,
T h_step,
T w_step,
const T* out_grad, // N, H, W, 2
T* theta_grad) { // N, 2, 3
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int n = index / (out_w * out_h);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
int theta_offset = n * 6; // 2 * 3;
T out_grad_x = out_grad[index * 2];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset,
out_grad_x * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 1,
out_grad_x * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x);
T out_grad_y = out_grad[index * 2 + 1];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 3,
out_grad_y * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 4,
out_grad_y * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y);
}
}
template <typename T, typename Context>
void AffineGridGradCUDAKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
auto& theta_grad = input_grad;
int n = output_grad.dims()[0];
auto& size_attr = outputShape.GetData();
int h = 0;
int w = 0;
h = size_attr[2];
w = size_attr[3];
theta_grad->Resize(phi::make_ddim({n, 2, 3}));
T* theta_grad_data = dev_ctx.template Alloc<T>(theta_grad);
phi::funcs::SetConstant<phi::GPUContext, T>()(
dev_ctx, theta_grad, static_cast<T>(0));
T h_step;
T w_step;
T h_start = -1;
T w_start = -1;
if (align_corners) {
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * h * w;
VLOG(3) << "count: " << count << "; h_step: " << h_step
<< "; w_step: " << w_step << "; h_start: " << h_start
<< "; w_start: " << w_start;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = dev_ctx.stream();
affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>(count,
n,
h,
w,
h_start,
w_start,
h_step,
w_step,
output_grad.data<T>(),
theta_grad_data);
}
} // namespace phi
PD_REGISTER_KERNEL(affine_grid_grad,
GPU,
ALL_LAYOUT,
phi::AffineGridGradCUDAKernel,
float,
double){};
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/affine_grid_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
}
template <typename T>
struct Linspace<phi::GPUContext, T> {
void operator()(T start,
T end,
int count,
bool align_corners,
DenseTensor* numbers,
const phi::GPUContext& dev_ctx) {
numbers->Resize(phi::make_ddim({count}));
T* number_data = dev_ctx.template Alloc<T>(numbers);
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
auto stream = dev_ctx.stream();
int block = 512;
int grid = (count + block - 1) / block;
LinspaceKernel<T>
<<<grid, block, 0, stream>>>(start, slice, count, number_data);
}
};
template <typename T>
__global__ void affine_grid_kernel(const int count,
int n,
int out_h,
int out_w,
T h_start,
T w_start,
T h_step,
T w_step,
const T* theta, // N, 2, 3
T* output) {
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int n = index / (out_w * out_h);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
int theta_offset = n * 6; // 2 * 3;
// affine from (h_coor, w_coor) to (x, y)
output[index * 2] = theta[theta_offset] * w_coor +
theta[theta_offset + 1] * h_coor +
theta[theta_offset + 2];
output[index * 2 + 1] = theta[theta_offset + 3] * w_coor +
theta[theta_offset + 4] * h_coor +
theta[theta_offset + 5];
}
}
template <typename T, typename Context>
void AffineGridCUDAKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
auto* theta = &input;
int n = theta->dims()[0];
auto& size_attr = outputShape.GetData();
int h = 0;
int w = 0;
h = size_attr[2];
w = size_attr[3];
output->Resize(phi::make_ddim({n, h, w, 2}));
T* out_data = dev_ctx.template Alloc<T>(output);
T h_step;
T w_step;
T h_start = -1;
T w_start = -1;
if (align_corners) {
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * h * w;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = dev_ctx.stream();
affine_grid_kernel<<<grid, block, 0, cu_stream>>>(
count,
n,
h,
w,
h_start,
w_start,
h_step,
w_step,
theta->data<T>(), // N, 2, 3
out_data);
}
} // namespace phi
PD_REGISTER_KERNEL(
affine_grid, GPU, ALL_LAYOUT, phi::AffineGridCUDAKernel, float, double){};
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/affine_grid_grad_kernel.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using ScopedSpatialTransformerDescriptor =
paddle::platform::ScopedSpatialTransformerDescriptor;
template <typename T, typename Context>
void AffineGridGradCudnnKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(dev_ctx.GetPlace()),
true,
phi::errors::InvalidArgument(
"Only support for CUDAPlace.Please switch your context from "
"CPUPlace to CUDAPlace or update your cudnn."));
auto handle = dev_ctx.cudnn_handle();
auto& theta_grad = input_grad;
int n = output_grad.dims()[0];
auto& size_attr = outputShape.GetData();
int h_size_data[4] = {0};
h_size_data[0] = n;
h_size_data[1] = size_attr[1];
h_size_data[2] = size_attr[2];
h_size_data[3] = size_attr[3];
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, h_size_data);
const T* output_grad_data = output_grad.data<T>();
T* theta_grad_data = dev_ctx.template Alloc<T>(theta_grad);
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnSpatialTfGridGeneratorBackward(
handle, cudnn_st_desc, output_grad_data, theta_grad_data));
}
} // namespace phi
PD_REGISTER_KERNEL(affine_grid_grad, // cuda_only
GPUDNN,
ALL_LAYOUT,
phi::AffineGridGradCudnnKernel,
float,
double){};
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/affine_grid_kernel.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
using ScopedSpatialTransformerDescriptor =
paddle::platform::ScopedSpatialTransformerDescriptor;
template <typename T, typename Context>
void AffineGridCudnnKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(dev_ctx.GetPlace()),
true,
phi::errors::InvalidArgument(
"Only support for CUDAPlace.Please switch your context from "
"CPUPlace to CUDAPlace or update your cudnn."));
auto handle = dev_ctx.cudnn_handle();
auto* theta = &input;
const T* theta_data = theta->data<T>();
int n = theta->dims()[0];
auto& size_attr = outputShape.GetData();
int h_size_data[4] = {0};
h_size_data[0] = n;
h_size_data[1] = size_attr[1];
h_size_data[2] = size_attr[2];
h_size_data[3] = size_attr[3];
output->Resize(phi::make_ddim({n, h_size_data[2], h_size_data[3], 2}));
T* output_data = dev_ctx.template Alloc<T>(output);
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, h_size_data);
PADDLE_ENFORCE_EQ(
paddle::platform::dynload::cudnnSpatialTfGridGeneratorForward(
handle, cudnn_st_desc, theta_data, output_data),
0,
phi::errors::Fatal("Some errors has occurred "
"during forward computation in cudnn."));
}
} // namespace phi
PD_REGISTER_KERNEL(affine_grid, // cuda_only
GPUDNN,
ALL_LAYOUT,
phi::AffineGridCudnnKernel,
float,
double){};
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature AffineGridOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("OutputShape")) {
return KernelSignature(
"affine_grid", {"Theta"}, {"OutputShape", "align_corners"}, {"Output"});
} else {
return KernelSignature("affine_grid",
{"Theta"},
{"output_shape", "align_corners"},
{"Output"});
}
}
KernelSignature AffineGridGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("OutputShape")) {
return KernelSignature("affine_grid_grad",
{"Output@GRAD"},
{"OutputShape", "align_corners"},
{"Theta@GRAD"});
} else {
return KernelSignature("affine_grid_grad",
{"Output@GRAD"},
{"output_shape", "align_corners"},
{"Theta@GRAD"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(affine_grid, phi::AffineGridOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(affine_grid_grad,
phi::AffineGridGradOpArgumentMapping);
......@@ -49,6 +49,7 @@ class TestAffineGridOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "affine_grid"
self.python_api = paddle.nn.functional.vision.affine_grid
theta = np.random.randint(1, 3, self.theta_shape).astype("float32")
self.inputs = {'Theta': theta}
self.attrs = {
......@@ -64,10 +65,13 @@ class TestAffineGridOp(OpTest):
}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['Theta'], 'Output', no_grad_set=['OutputShape'])
self.check_grad(['Theta'],
'Output',
no_grad_set=['OutputShape'],
check_eager=True)
def initTestCase(self):
self.theta_shape = (17, 2, 3)
......
......@@ -89,11 +89,12 @@ def affine_grid(theta, out_shape, align_corners=True, name=None):
if is_compiled_with_rocm():
use_cudnn = False # ROCM platform do not have MIOPEN kernel for affine_grid
if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \
isinstance(out_shape, Variable)):
raise ValueError("The out_shape should be a list, tuple or Tensor.")
if in_dynamic_mode():
if in_dygraph_mode():
_out_shape = out_shape.numpy().tolist() if isinstance(
out_shape, Variable) else out_shape
return _C_ops.final_state_affine_grid(theta, _out_shape, use_cudnn,
align_corners)
elif in_dynamic_mode():
_out_shape = out_shape.numpy().tolist() if isinstance(
out_shape, Variable) else out_shape
return _C_ops.affine_grid(theta, "output_shape", _out_shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册