未验证 提交 7812522c 编写于 作者: C carryyu 提交者: GitHub

make affine_grid_op support 5d input_dim on cpu and gpu (#45012)

* make affine_grid_op support 5d_input on cpu and gpu
上级 f4bc69ec
......@@ -71,34 +71,63 @@ class AffineGridOp : public framework::OperatorWithKernel {
output_shape_dims.size(),
output_shape_dims));
} else {
PADDLE_ENFORCE_EQ(
output_shape.size(),
4,
platform::errors::InvalidArgument(
"The size of attribute 'output_shape' in AffineGridOp should be "
"4. But received output_shape's size=[%d].",
output_shape.size()));
PADDLE_ENFORCE_GE(output_shape.size(),
4,
platform::errors::InvalidArgument(
"The size of attribute 'output_shape' in "
"AffineGridOp should be >= "
"4. But received output_shape's size=[%d].",
output_shape.size()));
PADDLE_ENFORCE_LE(output_shape.size(),
5,
platform::errors::InvalidArgument(
"The size of attribute 'output_shape' in "
"AffineGridOp should be <= "
"5. But received output_shape's size=[%d].",
output_shape.size()));
}
PADDLE_ENFORCE_EQ(
theta_dims[1],
2,
platform::errors::InvalidArgument(
"The second dimesion of input 'theta' in AffineGridOp should be 2. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_EQ(
theta_dims[2],
3,
platform::errors::InvalidArgument(
"The third dimesion of input 'theta' in AffineGridOp should be 3. "
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
PADDLE_ENFORCE_GE(theta_dims[1],
2,
platform::errors::InvalidArgument(
"The second dimesion of input 'theta' in "
"AffineGridOp should be >= 2. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_LE(theta_dims[1],
3,
platform::errors::InvalidArgument(
"The second dimesion of input 'theta' in "
"AffineGridOp should be <= 3. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_GE(theta_dims[2],
3,
platform::errors::InvalidArgument(
"The third dimesion of input 'theta' in AffineGridOp "
"should be >= 3. "
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
PADDLE_ENFORCE_LE(theta_dims[2],
4,
platform::errors::InvalidArgument(
"The third dimesion of input 'theta' in AffineGridOp "
"should be <= 4. "
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
// N * H * W * 2
ctx->SetOutputDim("Output", phi::make_ddim({theta_dims[0], -1, -1, 2}));
if (output_shape.size() == 4) {
// N * H * W * 2
ctx->SetOutputDim("Output", phi::make_ddim({theta_dims[0], -1, -1, 2}));
} else {
// N * D * H * W * 3
ctx->SetOutputDim("Output",
phi::make_ddim({theta_dims[0], -1, -1, -1, 3}));
}
ctx->ShareLoD("Theta", "Output");
}
......@@ -215,8 +244,13 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput(framework::GradVarName("Theta"))) {
auto output_dims = ctx->GetInputDim(framework::GradVarName("Output"));
ctx->SetOutputDim(framework::GradVarName("Theta"),
{output_dims[0], 2, 3});
if (output_dims.size() == 4) {
ctx->SetOutputDim(framework::GradVarName("Theta"),
{output_dims[0], 2, 3});
} else {
ctx->SetOutputDim(framework::GradVarName("Theta"),
{output_dims[0], 3, 4});
}
}
}
......
......@@ -58,33 +58,64 @@ void AffineGridInferMeta(const MetaTensor& input,
theta_dims.size(),
theta_dims));
PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_GE(
outputShape.GetData().size(),
4,
phi::errors::InvalidArgument(
"The size of attribute 'output_shape' in AffineGridOp should be "
"The size of attribute 'output_shape' in AffineGridOp should be >= "
"4. But received output_shape's size=[%d].",
outputShape.GetData().size()));
PADDLE_ENFORCE_EQ(
theta_dims[1],
2,
PADDLE_ENFORCE_LE(
outputShape.GetData().size(),
5,
phi::errors::InvalidArgument(
"The second dimesion of input 'theta' in AffineGridOp should be 2. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_EQ(
"The size of attribute 'output_shape' in AffineGridOp should be <= "
"5. But received output_shape's size=[%d].",
outputShape.GetData().size()));
PADDLE_ENFORCE_GE(theta_dims[1],
2,
phi::errors::InvalidArgument(
"The second dimesion of input 'theta' in AffineGridOp "
"should be >= 2. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_LE(theta_dims[1],
3,
phi::errors::InvalidArgument(
"The second dimesion of input 'theta' in AffineGridOp "
"should be <= 3. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_GE(
theta_dims[2],
3,
phi::errors::InvalidArgument(
"The third dimesion of input 'theta' in AffineGridOp should be 3. "
"The third dimesion of input 'theta' in AffineGridOp should be >= 3. "
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
// N * H * W * 2
output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, 2}));
PADDLE_ENFORCE_LE(
theta_dims[2],
4,
phi::errors::InvalidArgument(
"The third dimesion of input 'theta' in AffineGridOp should be <= 4. "
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
if (outputShape.GetData().size() == 4) {
// N * H * W * 2
output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, 2}));
} else {
// N * D * H * W * 3
output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, -1, 3}));
}
output->set_dtype(input.dtype());
output->share_lod(input);
}
......
......@@ -43,11 +43,11 @@ struct Linspace<phi::CPUContext, T> {
};
template <typename T, typename Context>
void AffineGridGradKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
void AffineGridGrad4DKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
auto& theta_grad = input_grad;
int n = output_grad.dims()[0];
auto& size_attr = outputShape.GetData();
......@@ -59,7 +59,7 @@ void AffineGridGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(theta_grad);
phi::funcs::SetConstant<Context, T>()(dev_ctx, theta_grad, static_cast<T>(0));
DenseTensor grid;
GetIdxMap<Context, T>(n, h, w, align_corners, &grid, dev_ctx);
GetIdxMap4D<Context, T>(n, h, w, align_corners, &grid, dev_ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
......@@ -79,6 +79,63 @@ void AffineGridGradKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void AffineGridGrad5DKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
auto& theta_grad = input_grad;
int n = output_grad.dims()[0];
auto& size_attr = outputShape.GetData();
int d = 0;
int h = 0;
int w = 0;
d = size_attr[2];
h = size_attr[3];
w = size_attr[4];
theta_grad->Resize(phi::make_ddim({n, 3, 4}));
dev_ctx.template Alloc<T>(theta_grad);
phi::funcs::SetConstant<Context, T>()(dev_ctx, theta_grad, static_cast<T>(0));
DenseTensor grid;
GetIdxMap5D<Context, T>(n, d, h, w, align_corners, &grid, dev_ctx);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
for (int i = 0; i < n; ++i) {
DenseTensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(d) * static_cast<int64_t>(h) *
static_cast<int64_t>(w),
4});
DenseTensor sliced_out_grad = output_grad.Slice(i, i + 1).Resize(
{static_cast<int64_t>(d) * static_cast<int64_t>(h) *
static_cast<int64_t>(w),
3});
DenseTensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({3, 4});
blas.MatMul(sliced_out_grad,
true,
sliced_grid,
false,
T(1),
&sliced_theta_grad,
T(0));
}
}
template <typename T, typename Context>
void AffineGridGradKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
auto& size_attr = outputShape.GetData();
if (size_attr.size() == 4) {
AffineGridGrad4DKernel<T, Context>(
dev_ctx, output_grad, outputShape, align_corners, input_grad);
} else {
AffineGridGrad5DKernel<T, Context>(
dev_ctx, output_grad, outputShape, align_corners, input_grad);
}
}
} // namespace phi
PD_REGISTER_KERNEL(affine_grid_grad,
......
......@@ -43,11 +43,11 @@ struct Linspace<phi::CPUContext, T> {
};
template <typename T, typename Context>
void AffineGridKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
void AffineGrid4DKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
auto* theta = &input;
int n = theta->dims()[0];
auto& size_attr = outputShape.GetData();
......@@ -59,7 +59,7 @@ void AffineGridKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(output);
phi::funcs::SetConstant<Context, T>()(dev_ctx, output, static_cast<T>(0));
DenseTensor grid;
GetIdxMap<Context, T>(n, h, w, align_corners, &grid, dev_ctx);
GetIdxMap4D<Context, T>(n, h, w, align_corners, &grid, dev_ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
......@@ -74,6 +74,58 @@ void AffineGridKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void AffineGrid5DKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
auto* theta = &input;
int n = theta->dims()[0];
auto& size_attr = outputShape.GetData();
int d = 0;
int h = 0;
int w = 0;
d = size_attr[2];
h = size_attr[3];
w = size_attr[4];
output->Resize(phi::make_ddim({n, d, h, w, 3}));
dev_ctx.template Alloc<T>(output);
phi::funcs::SetConstant<Context, T>()(dev_ctx, output, static_cast<T>(0));
DenseTensor grid;
GetIdxMap5D<Context, T>(n, d, h, w, align_corners, &grid, dev_ctx);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
for (int i = 0; i < n; ++i) {
DenseTensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(d) * static_cast<int64_t>(h) *
static_cast<int64_t>(w),
4});
DenseTensor sliced_theta = theta->Slice(i, i + 1).Resize({3, 4});
DenseTensor sliced_out = output->Slice(i, i + 1).Resize(
{static_cast<int64_t>(d) * static_cast<int64_t>(h) *
static_cast<int64_t>(w),
3});
blas.MatMul(
sliced_grid, false, sliced_theta, true, T(1), &sliced_out, T(0));
}
}
template <typename T, typename Context>
void AffineGridKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
auto& size_attr = outputShape.GetData();
if (size_attr.size() == 4) {
AffineGrid4DKernel<T, Context>(
dev_ctx, input, outputShape, align_corners, output);
} else {
AffineGrid5DKernel<T, Context>(
dev_ctx, input, outputShape, align_corners, output);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
......
......@@ -25,6 +25,7 @@ using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>;
using Array3 = Eigen::DSizes<int64_t, 3>;
using Array4 = Eigen::DSizes<int64_t, 4>;
using Array5 = Eigen::DSizes<int64_t, 5>;
template <typename Context, typename T>
struct Linspace {
......@@ -37,12 +38,12 @@ struct Linspace {
};
template <typename Context, typename T>
inline void GetIdxMap(int n,
int h,
int w,
bool align_corners,
DenseTensor* grid,
const Context& dev_ctx) {
inline void GetIdxMap4D(int n,
int h,
int w,
bool align_corners,
DenseTensor* grid,
const Context& dev_ctx) {
auto& place = *dev_ctx.eigen_device();
grid->Resize(phi::make_ddim({n, h, w, 3}));
dev_ctx.template Alloc<T>(grid);
......@@ -99,4 +100,86 @@ inline void GetIdxMap(int n,
.broadcast(Array4(n, 1, 1, 1));
}
template <typename Context, typename T>
inline void GetIdxMap5D(int n,
int d,
int h,
int w,
bool align_corners,
DenseTensor* grid,
const Context& dev_ctx) {
auto& place = *dev_ctx.eigen_device();
grid->Resize(phi::make_ddim({n, d, h, w, 4}));
dev_ctx.template Alloc<T>(grid);
auto grid_t = EigenTensor<T, 5>::From(*grid);
// Get indexes of height with shape [depth, height, width, 1]
DenseTensor d_idx;
Linspace<Context, T> linspace;
linspace((T)-1, (T)1, d, align_corners, &d_idx, dev_ctx);
auto d_idx_t = EigenTensor<T, 1>::From(d_idx);
// Get indexes of height with shape [depth, height, width, 1]
DenseTensor h_idx;
linspace((T)-1, (T)1, h, align_corners, &h_idx, dev_ctx);
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
// Get indexes of width with shape [depth, height, width, 1]
DenseTensor w_idx;
linspace((T)-1, (T)1, w, align_corners, &w_idx, dev_ctx);
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
// Get constant ones tensor with shape [depth, height, width, 1]
DenseTensor ones;
ones.Resize(phi::make_ddim({d, h, w, 1}));
dev_ctx.template Alloc<T>(&ones);
phi::funcs::SetConstant<Context, T>()(dev_ctx, &ones, static_cast<T>(1));
auto ones_t = EigenTensor<T, 4>::From(ones);
// Get grid tensor with shape [n, d, h, w, 4] by concatenating d_idx, h_idx,
// w_idx and ones
DenseTensor w_idx_map;
w_idx_map.Resize(phi::make_ddim({d, h, w, 1}));
dev_ctx.template Alloc<T>(&w_idx_map);
auto w_idx_map_t = EigenTensor<T, 4>::From(w_idx_map);
DenseTensor h_idx_map;
h_idx_map.Resize(phi::make_ddim({d, h, w, 1}));
dev_ctx.template Alloc<T>(&h_idx_map);
auto h_idx_map_t = EigenTensor<T, 4>::From(h_idx_map);
DenseTensor d_idx_map;
d_idx_map.Resize(phi::make_ddim({d, h, w, 1}));
dev_ctx.template Alloc<T>(&d_idx_map);
auto d_idx_map_t = EigenTensor<T, 4>::From(d_idx_map);
DenseTensor w_h_idx_map;
w_h_idx_map.Resize(phi::make_ddim({d, h, w, 2}));
dev_ctx.template Alloc<T>(&w_h_idx_map);
auto w_h_idx_map_t = EigenTensor<T, 4>::From(w_h_idx_map);
DenseTensor w_h_d_idx_map;
w_h_d_idx_map.Resize(phi::make_ddim({d, h, w, 3}));
dev_ctx.template Alloc<T>(&w_h_d_idx_map);
auto w_h_d_idx_map_t = EigenTensor<T, 4>::From(w_h_d_idx_map);
DenseTensor w_h_d_one_idx_map;
w_h_d_one_idx_map.Resize(phi::make_ddim({d, h, w, 4}));
dev_ctx.template Alloc<T>(&w_h_d_one_idx_map);
auto w_h_d_one_idx_map_t = EigenTensor<T, 4>::From(w_h_d_one_idx_map);
w_idx_map_t.device(place) = w_idx_t.reshape(Array3(1, 1, w))
.broadcast(Array3(d, h, 1))
.reshape(Array4(d, h, w, 1));
h_idx_map_t.device(place) = h_idx_t.reshape(Array3(1, h, 1))
.broadcast(Array3(d, 1, w))
.reshape(Array4(d, h, w, 1));
d_idx_map_t.device(place) = d_idx_t.reshape(Array3(d, 1, 1))
.broadcast(Array3(1, h, w))
.reshape(Array4(d, h, w, 1));
w_h_idx_map_t.device(place) = w_idx_map_t.concatenate(h_idx_map_t, 3);
w_h_d_idx_map_t.device(place) = w_h_idx_map_t.concatenate(d_idx_map_t, 3);
w_h_d_one_idx_map_t.device(place) = w_h_d_idx_map_t.concatenate(ones_t, 3);
grid_t.device(place) = w_h_d_one_idx_map_t.reshape(Array5(1, d, h, w, 4))
.broadcast(Array5(n, 1, 1, 1, 1));
}
} // namespace phi
......@@ -56,16 +56,16 @@ struct Linspace<phi::GPUContext, T> {
};
template <typename T>
__global__ void affine_grid_grad_kernel(const int count,
int n,
int out_h,
int out_w,
T h_start,
T w_start,
T h_step,
T w_step,
const T* out_grad, // N, H, W, 2
T* theta_grad) { // N, 2, 3
__global__ void affine_grid_grad_kernel_4d(const int count,
int n,
int out_h,
int out_w,
T h_start,
T w_start,
T h_step,
T w_step,
const T* out_grad, // N, H, W, 2
T* theta_grad) { // N, 2, 3
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
......@@ -90,12 +90,66 @@ __global__ void affine_grid_grad_kernel(const int count,
}
}
template <typename T>
__global__ void affine_grid_grad_kernel_5d(const int count,
int n,
int out_d,
int out_h,
int out_w,
T d_start,
T h_start,
T w_start,
T d_step,
T h_step,
T w_step,
const T* out_grad, // N, D, H, W, 3
T* theta_grad) { // N, 3, 4
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int d = (index / (out_w * out_h)) % out_d;
int n = index / (out_w * out_h * out_d);
T d_coor = d_step * static_cast<T>(d) + static_cast<T>(d_start);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
int theta_offset = n * 12; // 3 * 4;
T out_grad_x = out_grad[index * 3];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset,
out_grad_x * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 1,
out_grad_x * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 2,
out_grad_x * d_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_x);
T out_grad_y = out_grad[index * 3 + 1];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 4,
out_grad_y * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 5,
out_grad_y * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 6,
out_grad_y * d_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 7, out_grad_y);
T out_grad_z = out_grad[index * 3 + 2];
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 8,
out_grad_z * w_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 9,
out_grad_z * h_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 10,
out_grad_z * d_coor);
paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 11, out_grad_z);
}
}
template <typename T, typename Context>
void AffineGridGradCUDAKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
void AffineGridGrad4DCUDAKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
auto& theta_grad = input_grad;
int n = output_grad.dims()[0];
auto& size_attr = outputShape.GetData();
......@@ -129,16 +183,94 @@ void AffineGridGradCUDAKernel(const Context& dev_ctx,
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = dev_ctx.stream();
affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>(count,
n,
h,
w,
h_start,
w_start,
h_step,
w_step,
output_grad.data<T>(),
theta_grad_data);
affine_grid_grad_kernel_4d<<<grid, block, 0, cu_stream>>>(
count,
n,
h,
w,
h_start,
w_start,
h_step,
w_step,
output_grad.data<T>(),
theta_grad_data);
}
template <typename T, typename Context>
void AffineGridGrad5DCUDAKernel(const Context& dev_ctx,
const DenseTensor& output_grad,
const IntArray& outputShape,
bool align_corners,
DenseTensor* input_grad) {
// VLOG(0) << "in affine grid backward 5D";
auto& theta_grad = input_grad;
int n = output_grad.dims()[0];
auto& size_attr = outputShape.GetData();
int d = 0;
int h = 0;
int w = 0;
d = size_attr[2];
h = size_attr[3];
w = size_attr[4];
theta_grad->Resize(phi::make_ddim({n, 3, 4}));
T* theta_grad_data = dev_ctx.template Alloc<T>(theta_grad);
phi::funcs::SetConstant<phi::GPUContext, T>()(
dev_ctx, theta_grad, static_cast<T>(0));
T d_step;
T h_step;
T w_step;
T d_start = -1;
T h_start = -1;
T w_start = -1;
if (align_corners) {
d_step = static_cast<T>(2) / static_cast<T>(d - 1);
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
d_step = static_cast<T>(2) / static_cast<T>(d);
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);
d_start *= static_cast<T>(d - 1) / static_cast<T>(d);
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * d * h * w;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = dev_ctx.stream();
affine_grid_grad_kernel_5d<<<grid, block, 0, cu_stream>>>(
count,
n,
d,
h,
w,
d_start,
h_start,
w_start,
d_step,
h_step,
w_step,
output_grad.data<T>(),
theta_grad_data);
}
template <typename T, typename Context>
void AffineGridGradCUDAKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
auto* theta = &input;
auto theta_size = theta->dims().size();
if (theta_size == 4) {
AffineGridGrad4DCUDAKernel<T, Context>(
dev_ctx, input, outputShape, align_corners, output);
} else {
AffineGridGrad5DCUDAKernel<T, Context>(
dev_ctx, input, outputShape, align_corners, output);
}
}
} // namespace phi
......
......@@ -56,16 +56,16 @@ struct Linspace<phi::GPUContext, T> {
};
template <typename T>
__global__ void affine_grid_kernel(const int count,
int n,
int out_h,
int out_w,
T h_start,
T w_start,
T h_step,
T w_step,
const T* theta, // N, 2, 3
T* output) {
__global__ void affine_grid_kernel_4d(const int count,
int n,
int out_h,
int out_w,
T h_start,
T w_start,
T h_step,
T w_step,
const T* theta, // N, 2, 3
T* output) {
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
......@@ -85,12 +85,51 @@ __global__ void affine_grid_kernel(const int count,
}
}
template <typename T>
__global__ void affine_grid_kernel_5d(const int count,
int n,
int out_d,
int out_h,
int out_w,
T d_start,
T h_start,
T w_start,
T d_step,
T h_step,
T w_step,
const T* theta, // N, 3, 4
T* output) {
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int d = (index / (out_w * out_h)) % out_d;
int n = index / (out_w * out_h * out_d);
T d_coor = d_step * static_cast<T>(d) + static_cast<T>(d_start);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
int theta_offset = n * 12; // 3 * 4
// affine from (h_coor, w_coor) to (x, y)
output[index * 3] =
theta[theta_offset] * w_coor + theta[theta_offset + 1] * h_coor +
theta[theta_offset + 2] * d_coor + theta[theta_offset + 3];
output[index * 3 + 1] =
theta[theta_offset + 4] * w_coor + theta[theta_offset + 5] * h_coor +
theta[theta_offset + 6] * d_coor + theta[theta_offset + 7];
output[index * 3 + 2] =
theta[theta_offset + 8] * w_coor + theta[theta_offset + 9] * h_coor +
theta[theta_offset + 10] * d_coor + theta[theta_offset + 11];
}
}
template <typename T, typename Context>
void AffineGridCUDAKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
void AffineGrid4DCUDAKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
// VLOG(0) << "in affine grid 4d forward";
auto* theta = &input;
int n = theta->dims()[0];
auto& size_attr = outputShape.GetData();
......@@ -120,7 +159,7 @@ void AffineGridCUDAKernel(const Context& dev_ctx,
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = dev_ctx.stream();
affine_grid_kernel<<<grid, block, 0, cu_stream>>>(
affine_grid_kernel_4d<<<grid, block, 0, cu_stream>>>(
count,
n,
h,
......@@ -133,6 +172,81 @@ void AffineGridCUDAKernel(const Context& dev_ctx,
out_data);
}
template <typename T, typename Context>
void AffineGrid5DCUDAKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
auto* theta = &input;
int n = theta->dims()[0];
auto& size_attr = outputShape.GetData();
int d = 0;
int h = 0;
int w = 0;
d = size_attr[2];
h = size_attr[3];
w = size_attr[4];
output->Resize(phi::make_ddim({n, d, h, w, 3}));
T* out_data = dev_ctx.template Alloc<T>(output);
T d_step;
T h_step;
T w_step;
T d_start = -1;
T h_start = -1;
T w_start = -1;
if (align_corners) {
d_step = static_cast<T>(2) / static_cast<T>(d - 1);
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
d_step = static_cast<T>(2) / static_cast<T>(d);
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);
d_start *= static_cast<T>(d - 1) / static_cast<T>(d);
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * d * h * w;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = dev_ctx.stream();
affine_grid_kernel_5d<<<grid, block, 0, cu_stream>>>(
count,
n,
d,
h,
w,
d_start,
h_start,
w_start,
d_step,
h_step,
w_step,
theta->data<T>(), // N, 3, 4
out_data);
}
template <typename T, typename Context>
void AffineGridCUDAKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& outputShape,
bool align_corners,
DenseTensor* output) {
auto* theta = &input;
int theta_h = theta->dims()[1];
if (theta_h == 2) {
AffineGrid4DCUDAKernel<T, Context>(
dev_ctx, input, outputShape, align_corners, output);
} else {
AffineGrid5DCUDAKernel<T, Context>(
dev_ctx, input, outputShape, align_corners, output);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
......
......@@ -18,7 +18,7 @@ from op_test import OpTest
import paddle
def AffineGrid(theta, size, align_corners):
def AffineGrid4D(theta, size, align_corners):
n = size[0]
w = size[3]
h = size[2]
......@@ -38,10 +38,40 @@ def AffineGrid(theta, size, align_corners):
theta = theta.transpose([0, 2, 1])
for i in range(len(theta)):
ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])
return ret.reshape([n, h, w, 2]).astype("float32")
# print ret.reshape([h * w, 2]).astype("float32")
return ret.reshape([n, h, w, 2]).astype("float32")
def AffineGrid5D(theta, size, align_corners):
n = size[0]
d = size[2]
h = size[3]
w = size[4]
d_factor = h_factor = w_factor = 1
if not align_corners:
d_factor = (d - 1) / float(d)
h_factor = (h - 1) / float(h)
w_factor = (w - 1) / float(w)
d_idx = np.repeat(np.repeat(
np.linspace(-1, 1, d)[:, np.newaxis, np.newaxis], h, axis=1),
w,
axis=2)[:, :, :, np.newaxis] * d_factor
h_idx = np.repeat(np.repeat(
np.linspace(-1, 1, h)[np.newaxis, :, np.newaxis], w, axis=2),
d,
axis=0)[:, :, :, np.newaxis] * h_factor
w_idx = np.repeat(np.repeat(
np.linspace(-1, 1, w)[np.newaxis, np.newaxis, :], h, axis=1),
d,
axis=0)[:, :, :, np.newaxis] * w_factor
grid = np.concatenate(
[w_idx, h_idx, d_idx, np.ones([d, h, w, 1])], axis=3) # d * h * w * 4
grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * d * h * w * 4
ret = np.zeros([n, d * h * w, 3])
theta = theta.transpose([0, 2, 1])
for i in range(len(theta)):
ret[i] = np.dot(grid[i].reshape([d * h * w, 4]), theta[i])
return ret.reshape([n, d, h, w, 3]).astype("float32")
class TestAffineGridOp(OpTest):
......@@ -60,9 +90,16 @@ class TestAffineGridOp(OpTest):
self.inputs['OutputShape'] = self.output_shape
else:
self.attrs['output_shape'] = self.output_shape
self.outputs = {
'Output': AffineGrid(theta, self.output_shape, self.align_corners)
}
if (self.theta_shape[1] == 2 and self.theta_shape[2] == 3):
self.outputs = {
'Output': AffineGrid4D(theta, self.output_shape,
self.align_corners)
}
else:
self.outputs = {
'Output': AffineGrid5D(theta, self.output_shape,
self.align_corners)
}
def test_check_output(self):
self.check_output(check_eager=True)
......@@ -123,6 +160,46 @@ class TestAffineGridOpCase4(TestAffineGridOp):
self.align_corners = False
class TestAffineGridOp5DCase1(TestAffineGridOp):
def initTestCase(self):
self.theta_shape = (20, 3, 4)
self.output_shape = np.array([20, 1, 2, 5, 7]).astype("int32")
self.dynamic_shape = True
self.use_cudnn = False
self.align_corners = False
class TestAffineGridOp5DCase2(TestAffineGridOp):
def initTestCase(self):
self.theta_shape = (20, 3, 4)
self.output_shape = np.array([20, 1, 2, 5, 7]).astype("int32")
self.dynamic_shape = True
self.use_cudnn = False
self.align_corners = True
class TestAffineGridOp5DCase3(TestAffineGridOp):
def initTestCase(self):
self.theta_shape = (20, 3, 4)
self.output_shape = np.array([20, 1, 2, 5, 7]).astype("int32")
self.dynamic_shape = True
self.use_cudnn = False
self.align_corners = False
class TestAffineGridOp5DCase4(TestAffineGridOp):
def initTestCase(self):
self.theta_shape = (25, 3, 4)
self.output_shape = np.array([25, 1, 2, 5, 6]).astype("int32")
self.dynamic_shape = False
self.use_cudnn = False
self.align_corners = False
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -29,25 +29,22 @@ __all__ = []
def affine_grid(theta, out_shape, align_corners=True, name=None):
"""
It generates a grid of (x,y) coordinates using the parameters of
It generates a grid of (x,y) or (x,y,z) coordinates using the parameters of
the affine transformation that correspond to a set of points where
the input feature map should be sampled to produce the transformed
output feature map.
Args:
theta (Tensor) - A tensor with shape [N, 2, 3]. It contains a batch of affine transform parameters.
theta (Tensor) - A tensor with shape [N, 2, 3] or [N, 3, 4]. It contains a batch of affine transform parameters.
The data type can be float32 or float64.
out_shape (Tensor | list | tuple): The shape of target output with format [batch_size, channel, height, width].
``out_shape`` can be a Tensor or a list or tuple. The data
type must be int32.
align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True.
name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
out_shape (Tensor | list | tuple): Type can be a 1-D Tensor, list, or tuple. It is used to represent the shape of the output in an affine transformation, in the format ``[N, C, H, W]`` or ``[N, C, D, H, W]``.
When the format is ``[N, C, H, W]``, it represents the batch size, number of channels, height and width. When the format is ``[N, C, D, H, W]``, it represents the batch size, number of channels, depth, height and width.
The data type must be int32.
align_corners(bool, optional): if True, aligns the centers of the 4 (4D) or 8 (5D) corner pixels of the input and output tensors, and preserves the value of the corner pixels. Default: True
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`.
Raises:
ValueError: If the type of arguments is not supported.
Tensor, A Tensor with shape [batch_size, H, W, 2] or [batch, D, H, W, 3] while ('D')'H', 'W' are the (depth)height, width of feature map in affine transformation. The data type is the same as `theta`.
Examples:
......@@ -55,13 +52,11 @@ def affine_grid(theta, out_shape, align_corners=True, name=None):
import paddle
import paddle.nn.functional as F
import numpy as np
# theta shape = [1, 2, 3]
theta = np.array([[[-0.7, -0.4, 0.3],
[ 0.6, 0.5, 1.5]]]).astype("float32")
theta_t = paddle.to_tensor(theta)
theta = paddle.to_tensor([[[-0.7, -0.4, 0.3],
[ 0.6, 0.5, 1.5]]], dtype="float32")
y_t = F.affine_grid(
theta_t,
theta,
[1, 2, 3, 3],
align_corners=False)
print(y_t)
......@@ -86,6 +81,8 @@ def affine_grid(theta, out_shape, align_corners=True, name=None):
use_cudnn = True
else:
use_cudnn = False
if theta.shape[1] == 3:
use_cudnn = False
if is_compiled_with_rocm():
use_cudnn = False # ROCM platform do not have MIOPEN kernel for affine_grid
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册