From 7812522c506b1e49ed2db288e06b15577e8da30d Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Thu, 11 Aug 2022 20:30:22 +0800 Subject: [PATCH] make affine_grid_op support 5d input_dim on cpu and gpu (#45012) * make affine_grid_op support 5d_input on cpu and gpu --- paddle/fluid/operators/affine_grid_op.cc | 88 ++++++--- paddle/phi/infermeta/unary.cc | 57 ++++-- .../kernels/cpu/affine_grid_grad_kernel.cc | 69 ++++++- paddle/phi/kernels/cpu/affine_grid_kernel.cc | 64 +++++- paddle/phi/kernels/funcs/affine_grid_utils.h | 95 ++++++++- .../kernels/gpu/affine_grid_grad_kernel.cu | 182 +++++++++++++++--- paddle/phi/kernels/gpu/affine_grid_kernel.cu | 146 ++++++++++++-- .../tests/unittests/test_affine_grid_op.py | 89 ++++++++- python/paddle/nn/functional/vision.py | 29 ++- 9 files changed, 698 insertions(+), 121 deletions(-) diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 871d7350e5..6eecb5e6b3 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -71,34 +71,63 @@ class AffineGridOp : public framework::OperatorWithKernel { output_shape_dims.size(), output_shape_dims)); } else { - PADDLE_ENFORCE_EQ( - output_shape.size(), - 4, - platform::errors::InvalidArgument( - "The size of attribute 'output_shape' in AffineGridOp should be " - "4. But received output_shape's size=[%d].", - output_shape.size())); + PADDLE_ENFORCE_GE(output_shape.size(), + 4, + platform::errors::InvalidArgument( + "The size of attribute 'output_shape' in " + "AffineGridOp should be >= " + "4. But received output_shape's size=[%d].", + output_shape.size())); + PADDLE_ENFORCE_LE(output_shape.size(), + 5, + platform::errors::InvalidArgument( + "The size of attribute 'output_shape' in " + "AffineGridOp should be <= " + "5. But received output_shape's size=[%d].", + output_shape.size())); } - PADDLE_ENFORCE_EQ( - theta_dims[1], - 2, - platform::errors::InvalidArgument( - "The second dimesion of input 'theta' in AffineGridOp should be 2. " - "But received second dimesion=[%d], dimesions=[%s]", - theta_dims[1], - theta_dims)); - PADDLE_ENFORCE_EQ( - theta_dims[2], - 3, - platform::errors::InvalidArgument( - "The third dimesion of input 'theta' in AffineGridOp should be 3. " - "But received third dimesion=[%d], dimesions=[%s]", - theta_dims[2], - theta_dims)); + PADDLE_ENFORCE_GE(theta_dims[1], + 2, + platform::errors::InvalidArgument( + "The second dimesion of input 'theta' in " + "AffineGridOp should be >= 2. " + "But received second dimesion=[%d], dimesions=[%s]", + theta_dims[1], + theta_dims)); + PADDLE_ENFORCE_LE(theta_dims[1], + 3, + platform::errors::InvalidArgument( + "The second dimesion of input 'theta' in " + "AffineGridOp should be <= 3. " + "But received second dimesion=[%d], dimesions=[%s]", + theta_dims[1], + theta_dims)); + PADDLE_ENFORCE_GE(theta_dims[2], + 3, + platform::errors::InvalidArgument( + "The third dimesion of input 'theta' in AffineGridOp " + "should be >= 3. " + "But received third dimesion=[%d], dimesions=[%s]", + theta_dims[2], + theta_dims)); + PADDLE_ENFORCE_LE(theta_dims[2], + 4, + platform::errors::InvalidArgument( + "The third dimesion of input 'theta' in AffineGridOp " + "should be <= 4. " + "But received third dimesion=[%d], dimesions=[%s]", + theta_dims[2], + theta_dims)); - // N * H * W * 2 - ctx->SetOutputDim("Output", phi::make_ddim({theta_dims[0], -1, -1, 2})); + if (output_shape.size() == 4) { + // N * H * W * 2 + ctx->SetOutputDim("Output", phi::make_ddim({theta_dims[0], -1, -1, 2})); + } else { + // N * D * H * W * 3 + ctx->SetOutputDim("Output", + phi::make_ddim({theta_dims[0], -1, -1, -1, 3})); + } ctx->ShareLoD("Theta", "Output"); } @@ -215,8 +244,13 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { if (ctx->HasOutput(framework::GradVarName("Theta"))) { auto output_dims = ctx->GetInputDim(framework::GradVarName("Output")); - ctx->SetOutputDim(framework::GradVarName("Theta"), - {output_dims[0], 2, 3}); + if (output_dims.size() == 4) { + ctx->SetOutputDim(framework::GradVarName("Theta"), + {output_dims[0], 2, 3}); + } else { + ctx->SetOutputDim(framework::GradVarName("Theta"), + {output_dims[0], 3, 4}); + } } } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 7da162cd0b..1382cb2e66 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -58,33 +58,64 @@ void AffineGridInferMeta(const MetaTensor& input, theta_dims.size(), theta_dims)); - PADDLE_ENFORCE_EQ( + PADDLE_ENFORCE_GE( outputShape.GetData().size(), 4, phi::errors::InvalidArgument( - "The size of attribute 'output_shape' in AffineGridOp should be " + "The size of attribute 'output_shape' in AffineGridOp should be >= " "4. But received output_shape's size=[%d].", outputShape.GetData().size())); - PADDLE_ENFORCE_EQ( - theta_dims[1], - 2, + PADDLE_ENFORCE_LE( + outputShape.GetData().size(), + 5, phi::errors::InvalidArgument( - "The second dimesion of input 'theta' in AffineGridOp should be 2. " - "But received second dimesion=[%d], dimesions=[%s]", - theta_dims[1], - theta_dims)); - PADDLE_ENFORCE_EQ( + "The size of attribute 'output_shape' in AffineGridOp should be <= " + "5. But received output_shape's size=[%d].", + outputShape.GetData().size())); + + PADDLE_ENFORCE_GE(theta_dims[1], + 2, + phi::errors::InvalidArgument( + "The second dimesion of input 'theta' in AffineGridOp " + "should be >= 2. " + "But received second dimesion=[%d], dimesions=[%s]", + theta_dims[1], + theta_dims)); + + PADDLE_ENFORCE_LE(theta_dims[1], + 3, + phi::errors::InvalidArgument( + "The second dimesion of input 'theta' in AffineGridOp " + "should be <= 3. " + "But received second dimesion=[%d], dimesions=[%s]", + theta_dims[1], + theta_dims)); + + PADDLE_ENFORCE_GE( theta_dims[2], 3, phi::errors::InvalidArgument( - "The third dimesion of input 'theta' in AffineGridOp should be 3. " + "The third dimesion of input 'theta' in AffineGridOp should be >= 3. " "But received third dimesion=[%d], dimesions=[%s]", theta_dims[2], theta_dims)); - // N * H * W * 2 - output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, 2})); + PADDLE_ENFORCE_LE( + theta_dims[2], + 4, + phi::errors::InvalidArgument( + "The third dimesion of input 'theta' in AffineGridOp should be <= 4. " + "But received third dimesion=[%d], dimesions=[%s]", + theta_dims[2], + theta_dims)); + if (outputShape.GetData().size() == 4) { + // N * H * W * 2 + output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, 2})); + } else { + // N * D * H * W * 3 + output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, -1, 3})); + } output->set_dtype(input.dtype()); output->share_lod(input); } diff --git a/paddle/phi/kernels/cpu/affine_grid_grad_kernel.cc b/paddle/phi/kernels/cpu/affine_grid_grad_kernel.cc index 7bea3ff476..5112586c1b 100644 --- a/paddle/phi/kernels/cpu/affine_grid_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/affine_grid_grad_kernel.cc @@ -43,11 +43,11 @@ struct Linspace { }; template -void AffineGridGradKernel(const Context& dev_ctx, - const DenseTensor& output_grad, - const IntArray& outputShape, - bool align_corners, - DenseTensor* input_grad) { +void AffineGridGrad4DKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad) { auto& theta_grad = input_grad; int n = output_grad.dims()[0]; auto& size_attr = outputShape.GetData(); @@ -59,7 +59,7 @@ void AffineGridGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(theta_grad); phi::funcs::SetConstant()(dev_ctx, theta_grad, static_cast(0)); DenseTensor grid; - GetIdxMap(n, h, w, align_corners, &grid, dev_ctx); + GetIdxMap4D(n, h, w, align_corners, &grid, dev_ctx); // output = grid * theta.T // TODO(wanghaoshuang): Refine batched matrix multiply auto blas = phi::funcs::GetBlas(dev_ctx); @@ -79,6 +79,63 @@ void AffineGridGradKernel(const Context& dev_ctx, } } +template +void AffineGridGrad5DKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad) { + auto& theta_grad = input_grad; + int n = output_grad.dims()[0]; + auto& size_attr = outputShape.GetData(); + int d = 0; + int h = 0; + int w = 0; + d = size_attr[2]; + h = size_attr[3]; + w = size_attr[4]; + theta_grad->Resize(phi::make_ddim({n, 3, 4})); + dev_ctx.template Alloc(theta_grad); + phi::funcs::SetConstant()(dev_ctx, theta_grad, static_cast(0)); + DenseTensor grid; + GetIdxMap5D(n, d, h, w, align_corners, &grid, dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); + for (int i = 0; i < n; ++i) { + DenseTensor sliced_grid = grid.Slice(i, i + 1).Resize( + {static_cast(d) * static_cast(h) * + static_cast(w), + 4}); + DenseTensor sliced_out_grad = output_grad.Slice(i, i + 1).Resize( + {static_cast(d) * static_cast(h) * + static_cast(w), + 3}); + DenseTensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({3, 4}); + blas.MatMul(sliced_out_grad, + true, + sliced_grid, + false, + T(1), + &sliced_theta_grad, + T(0)); + } +} + +template +void AffineGridGradKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad) { + auto& size_attr = outputShape.GetData(); + if (size_attr.size() == 4) { + AffineGridGrad4DKernel( + dev_ctx, output_grad, outputShape, align_corners, input_grad); + } else { + AffineGridGrad5DKernel( + dev_ctx, output_grad, outputShape, align_corners, input_grad); + } +} + } // namespace phi PD_REGISTER_KERNEL(affine_grid_grad, diff --git a/paddle/phi/kernels/cpu/affine_grid_kernel.cc b/paddle/phi/kernels/cpu/affine_grid_kernel.cc index 712c2a1927..250df84b08 100644 --- a/paddle/phi/kernels/cpu/affine_grid_kernel.cc +++ b/paddle/phi/kernels/cpu/affine_grid_kernel.cc @@ -43,11 +43,11 @@ struct Linspace { }; template -void AffineGridKernel(const Context& dev_ctx, - const DenseTensor& input, - const IntArray& outputShape, - bool align_corners, - DenseTensor* output) { +void AffineGrid4DKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { auto* theta = &input; int n = theta->dims()[0]; auto& size_attr = outputShape.GetData(); @@ -59,7 +59,7 @@ void AffineGridKernel(const Context& dev_ctx, dev_ctx.template Alloc(output); phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); DenseTensor grid; - GetIdxMap(n, h, w, align_corners, &grid, dev_ctx); + GetIdxMap4D(n, h, w, align_corners, &grid, dev_ctx); // output = grid * theta.T // TODO(wanghaoshuang): Refine batched matrix multiply auto blas = phi::funcs::GetBlas(dev_ctx); @@ -74,6 +74,58 @@ void AffineGridKernel(const Context& dev_ctx, } } +template +void AffineGrid5DKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + auto* theta = &input; + int n = theta->dims()[0]; + auto& size_attr = outputShape.GetData(); + int d = 0; + int h = 0; + int w = 0; + d = size_attr[2]; + h = size_attr[3]; + w = size_attr[4]; + output->Resize(phi::make_ddim({n, d, h, w, 3})); + dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + DenseTensor grid; + GetIdxMap5D(n, d, h, w, align_corners, &grid, dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); + for (int i = 0; i < n; ++i) { + DenseTensor sliced_grid = grid.Slice(i, i + 1).Resize( + {static_cast(d) * static_cast(h) * + static_cast(w), + 4}); + DenseTensor sliced_theta = theta->Slice(i, i + 1).Resize({3, 4}); + DenseTensor sliced_out = output->Slice(i, i + 1).Resize( + {static_cast(d) * static_cast(h) * + static_cast(w), + 3}); + blas.MatMul( + sliced_grid, false, sliced_theta, true, T(1), &sliced_out, T(0)); + } +} + +template +void AffineGridKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + auto& size_attr = outputShape.GetData(); + if (size_attr.size() == 4) { + AffineGrid4DKernel( + dev_ctx, input, outputShape, align_corners, output); + } else { + AffineGrid5DKernel( + dev_ctx, input, outputShape, align_corners, output); + } +} + } // namespace phi PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/funcs/affine_grid_utils.h b/paddle/phi/kernels/funcs/affine_grid_utils.h index 601b7f1ba6..1e6701d0c7 100644 --- a/paddle/phi/kernels/funcs/affine_grid_utils.h +++ b/paddle/phi/kernels/funcs/affine_grid_utils.h @@ -25,6 +25,7 @@ using Array1 = Eigen::DSizes; using Array2 = Eigen::DSizes; using Array3 = Eigen::DSizes; using Array4 = Eigen::DSizes; +using Array5 = Eigen::DSizes; template struct Linspace { @@ -37,12 +38,12 @@ struct Linspace { }; template -inline void GetIdxMap(int n, - int h, - int w, - bool align_corners, - DenseTensor* grid, - const Context& dev_ctx) { +inline void GetIdxMap4D(int n, + int h, + int w, + bool align_corners, + DenseTensor* grid, + const Context& dev_ctx) { auto& place = *dev_ctx.eigen_device(); grid->Resize(phi::make_ddim({n, h, w, 3})); dev_ctx.template Alloc(grid); @@ -99,4 +100,86 @@ inline void GetIdxMap(int n, .broadcast(Array4(n, 1, 1, 1)); } +template +inline void GetIdxMap5D(int n, + int d, + int h, + int w, + bool align_corners, + DenseTensor* grid, + const Context& dev_ctx) { + auto& place = *dev_ctx.eigen_device(); + grid->Resize(phi::make_ddim({n, d, h, w, 4})); + dev_ctx.template Alloc(grid); + auto grid_t = EigenTensor::From(*grid); + // Get indexes of height with shape [depth, height, width, 1] + DenseTensor d_idx; + Linspace linspace; + linspace((T)-1, (T)1, d, align_corners, &d_idx, dev_ctx); + auto d_idx_t = EigenTensor::From(d_idx); + // Get indexes of height with shape [depth, height, width, 1] + DenseTensor h_idx; + linspace((T)-1, (T)1, h, align_corners, &h_idx, dev_ctx); + auto h_idx_t = EigenTensor::From(h_idx); + // Get indexes of width with shape [depth, height, width, 1] + DenseTensor w_idx; + linspace((T)-1, (T)1, w, align_corners, &w_idx, dev_ctx); + auto w_idx_t = EigenTensor::From(w_idx); + // Get constant ones tensor with shape [depth, height, width, 1] + DenseTensor ones; + ones.Resize(phi::make_ddim({d, h, w, 1})); + dev_ctx.template Alloc(&ones); + + phi::funcs::SetConstant()(dev_ctx, &ones, static_cast(1)); + auto ones_t = EigenTensor::From(ones); + // Get grid tensor with shape [n, d, h, w, 4] by concatenating d_idx, h_idx, + // w_idx and ones + DenseTensor w_idx_map; + w_idx_map.Resize(phi::make_ddim({d, h, w, 1})); + dev_ctx.template Alloc(&w_idx_map); + auto w_idx_map_t = EigenTensor::From(w_idx_map); + + DenseTensor h_idx_map; + h_idx_map.Resize(phi::make_ddim({d, h, w, 1})); + dev_ctx.template Alloc(&h_idx_map); + auto h_idx_map_t = EigenTensor::From(h_idx_map); + + DenseTensor d_idx_map; + d_idx_map.Resize(phi::make_ddim({d, h, w, 1})); + dev_ctx.template Alloc(&d_idx_map); + auto d_idx_map_t = EigenTensor::From(d_idx_map); + + DenseTensor w_h_idx_map; + w_h_idx_map.Resize(phi::make_ddim({d, h, w, 2})); + dev_ctx.template Alloc(&w_h_idx_map); + auto w_h_idx_map_t = EigenTensor::From(w_h_idx_map); + + DenseTensor w_h_d_idx_map; + w_h_d_idx_map.Resize(phi::make_ddim({d, h, w, 3})); + dev_ctx.template Alloc(&w_h_d_idx_map); + auto w_h_d_idx_map_t = EigenTensor::From(w_h_d_idx_map); + + DenseTensor w_h_d_one_idx_map; + w_h_d_one_idx_map.Resize(phi::make_ddim({d, h, w, 4})); + dev_ctx.template Alloc(&w_h_d_one_idx_map); + auto w_h_d_one_idx_map_t = EigenTensor::From(w_h_d_one_idx_map); + + w_idx_map_t.device(place) = w_idx_t.reshape(Array3(1, 1, w)) + .broadcast(Array3(d, h, 1)) + .reshape(Array4(d, h, w, 1)); + h_idx_map_t.device(place) = h_idx_t.reshape(Array3(1, h, 1)) + .broadcast(Array3(d, 1, w)) + .reshape(Array4(d, h, w, 1)); + d_idx_map_t.device(place) = d_idx_t.reshape(Array3(d, 1, 1)) + .broadcast(Array3(1, h, w)) + .reshape(Array4(d, h, w, 1)); + + w_h_idx_map_t.device(place) = w_idx_map_t.concatenate(h_idx_map_t, 3); + w_h_d_idx_map_t.device(place) = w_h_idx_map_t.concatenate(d_idx_map_t, 3); + + w_h_d_one_idx_map_t.device(place) = w_h_d_idx_map_t.concatenate(ones_t, 3); + grid_t.device(place) = w_h_d_one_idx_map_t.reshape(Array5(1, d, h, w, 4)) + .broadcast(Array5(n, 1, 1, 1, 1)); +} + } // namespace phi diff --git a/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu b/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu index b2cb0f2ad7..5cfa8cf306 100644 --- a/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu @@ -56,16 +56,16 @@ struct Linspace { }; template -__global__ void affine_grid_grad_kernel(const int count, - int n, - int out_h, - int out_w, - T h_start, - T w_start, - T h_step, - T w_step, - const T* out_grad, // N, H, W, 2 - T* theta_grad) { // N, 2, 3 +__global__ void affine_grid_grad_kernel_4d(const int count, + int n, + int out_h, + int out_w, + T h_start, + T w_start, + T h_step, + T w_step, + const T* out_grad, // N, H, W, 2 + T* theta_grad) { // N, 2, 3 CUDA_KERNEL_LOOP(index, count) { int w = index % out_w; int h = (index / out_w) % out_h; @@ -90,12 +90,66 @@ __global__ void affine_grid_grad_kernel(const int count, } } +template +__global__ void affine_grid_grad_kernel_5d(const int count, + int n, + int out_d, + int out_h, + int out_w, + T d_start, + T h_start, + T w_start, + T d_step, + T h_step, + T w_step, + const T* out_grad, // N, D, H, W, 3 + T* theta_grad) { // N, 3, 4 + CUDA_KERNEL_LOOP(index, count) { + int w = index % out_w; + int h = (index / out_w) % out_h; + int d = (index / (out_w * out_h)) % out_d; + int n = index / (out_w * out_h * out_d); + + T d_coor = d_step * static_cast(d) + static_cast(d_start); + T h_coor = h_step * static_cast(h) + static_cast(h_start); + T w_coor = w_step * static_cast(w) + static_cast(w_start); + + int theta_offset = n * 12; // 3 * 4; + T out_grad_x = out_grad[index * 3]; + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset, + out_grad_x * w_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 1, + out_grad_x * h_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 2, + out_grad_x * d_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_x); + + T out_grad_y = out_grad[index * 3 + 1]; + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 4, + out_grad_y * w_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 5, + out_grad_y * h_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 6, + out_grad_y * d_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 7, out_grad_y); + + T out_grad_z = out_grad[index * 3 + 2]; + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 8, + out_grad_z * w_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 9, + out_grad_z * h_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 10, + out_grad_z * d_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 11, out_grad_z); + } +} + template -void AffineGridGradCUDAKernel(const Context& dev_ctx, - const DenseTensor& output_grad, - const IntArray& outputShape, - bool align_corners, - DenseTensor* input_grad) { +void AffineGridGrad4DCUDAKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad) { auto& theta_grad = input_grad; int n = output_grad.dims()[0]; auto& size_attr = outputShape.GetData(); @@ -129,16 +183,94 @@ void AffineGridGradCUDAKernel(const Context& dev_ctx, int block = 512; int grid = (count + block - 1) / block; auto cu_stream = dev_ctx.stream(); - affine_grid_grad_kernel<<>>(count, - n, - h, - w, - h_start, - w_start, - h_step, - w_step, - output_grad.data(), - theta_grad_data); + affine_grid_grad_kernel_4d<<>>( + count, + n, + h, + w, + h_start, + w_start, + h_step, + w_step, + output_grad.data(), + theta_grad_data); +} + +template +void AffineGridGrad5DCUDAKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad) { + // VLOG(0) << "in affine grid backward 5D"; + auto& theta_grad = input_grad; + int n = output_grad.dims()[0]; + auto& size_attr = outputShape.GetData(); + int d = 0; + int h = 0; + int w = 0; + d = size_attr[2]; + h = size_attr[3]; + w = size_attr[4]; + theta_grad->Resize(phi::make_ddim({n, 3, 4})); + T* theta_grad_data = dev_ctx.template Alloc(theta_grad); + phi::funcs::SetConstant()( + dev_ctx, theta_grad, static_cast(0)); + + T d_step; + T h_step; + T w_step; + T d_start = -1; + T h_start = -1; + T w_start = -1; + if (align_corners) { + d_step = static_cast(2) / static_cast(d - 1); + h_step = static_cast(2) / static_cast(h - 1); + w_step = static_cast(2) / static_cast(w - 1); + } else { + d_step = static_cast(2) / static_cast(d); + h_step = static_cast(2) / static_cast(h); + w_step = static_cast(2) / static_cast(w); + + d_start *= static_cast(d - 1) / static_cast(d); + h_start *= static_cast(h - 1) / static_cast(h); + w_start *= static_cast(w - 1) / static_cast(w); + } + const int count = n * d * h * w; + int block = 512; + int grid = (count + block - 1) / block; + auto cu_stream = dev_ctx.stream(); + affine_grid_grad_kernel_5d<<>>( + count, + n, + d, + h, + w, + d_start, + h_start, + w_start, + d_step, + h_step, + w_step, + output_grad.data(), + theta_grad_data); +} + +template +void AffineGridGradCUDAKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + auto* theta = &input; + auto theta_size = theta->dims().size(); + if (theta_size == 4) { + AffineGridGrad4DCUDAKernel( + dev_ctx, input, outputShape, align_corners, output); + } else { + AffineGridGrad5DCUDAKernel( + dev_ctx, input, outputShape, align_corners, output); + } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/affine_grid_kernel.cu b/paddle/phi/kernels/gpu/affine_grid_kernel.cu index 4e5c326be7..0f42960502 100644 --- a/paddle/phi/kernels/gpu/affine_grid_kernel.cu +++ b/paddle/phi/kernels/gpu/affine_grid_kernel.cu @@ -56,16 +56,16 @@ struct Linspace { }; template -__global__ void affine_grid_kernel(const int count, - int n, - int out_h, - int out_w, - T h_start, - T w_start, - T h_step, - T w_step, - const T* theta, // N, 2, 3 - T* output) { +__global__ void affine_grid_kernel_4d(const int count, + int n, + int out_h, + int out_w, + T h_start, + T w_start, + T h_step, + T w_step, + const T* theta, // N, 2, 3 + T* output) { CUDA_KERNEL_LOOP(index, count) { int w = index % out_w; int h = (index / out_w) % out_h; @@ -85,12 +85,51 @@ __global__ void affine_grid_kernel(const int count, } } +template +__global__ void affine_grid_kernel_5d(const int count, + int n, + int out_d, + int out_h, + int out_w, + T d_start, + T h_start, + T w_start, + T d_step, + T h_step, + T w_step, + const T* theta, // N, 3, 4 + T* output) { + CUDA_KERNEL_LOOP(index, count) { + int w = index % out_w; + int h = (index / out_w) % out_h; + int d = (index / (out_w * out_h)) % out_d; + int n = index / (out_w * out_h * out_d); + + T d_coor = d_step * static_cast(d) + static_cast(d_start); + T h_coor = h_step * static_cast(h) + static_cast(h_start); + T w_coor = w_step * static_cast(w) + static_cast(w_start); + + int theta_offset = n * 12; // 3 * 4 + // affine from (h_coor, w_coor) to (x, y) + output[index * 3] = + theta[theta_offset] * w_coor + theta[theta_offset + 1] * h_coor + + theta[theta_offset + 2] * d_coor + theta[theta_offset + 3]; + output[index * 3 + 1] = + theta[theta_offset + 4] * w_coor + theta[theta_offset + 5] * h_coor + + theta[theta_offset + 6] * d_coor + theta[theta_offset + 7]; + output[index * 3 + 2] = + theta[theta_offset + 8] * w_coor + theta[theta_offset + 9] * h_coor + + theta[theta_offset + 10] * d_coor + theta[theta_offset + 11]; + } +} + template -void AffineGridCUDAKernel(const Context& dev_ctx, - const DenseTensor& input, - const IntArray& outputShape, - bool align_corners, - DenseTensor* output) { +void AffineGrid4DCUDAKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + // VLOG(0) << "in affine grid 4d forward"; auto* theta = &input; int n = theta->dims()[0]; auto& size_attr = outputShape.GetData(); @@ -120,7 +159,7 @@ void AffineGridCUDAKernel(const Context& dev_ctx, int block = 512; int grid = (count + block - 1) / block; auto cu_stream = dev_ctx.stream(); - affine_grid_kernel<<>>( + affine_grid_kernel_4d<<>>( count, n, h, @@ -133,6 +172,81 @@ void AffineGridCUDAKernel(const Context& dev_ctx, out_data); } +template +void AffineGrid5DCUDAKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + auto* theta = &input; + int n = theta->dims()[0]; + auto& size_attr = outputShape.GetData(); + int d = 0; + int h = 0; + int w = 0; + d = size_attr[2]; + h = size_attr[3]; + w = size_attr[4]; + output->Resize(phi::make_ddim({n, d, h, w, 3})); + T* out_data = dev_ctx.template Alloc(output); + + T d_step; + T h_step; + T w_step; + T d_start = -1; + T h_start = -1; + T w_start = -1; + if (align_corners) { + d_step = static_cast(2) / static_cast(d - 1); + h_step = static_cast(2) / static_cast(h - 1); + w_step = static_cast(2) / static_cast(w - 1); + } else { + d_step = static_cast(2) / static_cast(d); + h_step = static_cast(2) / static_cast(h); + w_step = static_cast(2) / static_cast(w); + + d_start *= static_cast(d - 1) / static_cast(d); + h_start *= static_cast(h - 1) / static_cast(h); + w_start *= static_cast(w - 1) / static_cast(w); + } + + const int count = n * d * h * w; + int block = 512; + int grid = (count + block - 1) / block; + auto cu_stream = dev_ctx.stream(); + affine_grid_kernel_5d<<>>( + count, + n, + d, + h, + w, + d_start, + h_start, + w_start, + d_step, + h_step, + w_step, + theta->data(), // N, 3, 4 + out_data); +} + +template +void AffineGridCUDAKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + auto* theta = &input; + int theta_h = theta->dims()[1]; + if (theta_h == 2) { + AffineGrid4DCUDAKernel( + dev_ctx, input, outputShape, align_corners, output); + } else { + AffineGrid5DCUDAKernel( + dev_ctx, input, outputShape, align_corners, output); + } +} + } // namespace phi PD_REGISTER_KERNEL( diff --git a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py index 287c9edae2..f8c0dedc18 100644 --- a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py +++ b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py @@ -18,7 +18,7 @@ from op_test import OpTest import paddle -def AffineGrid(theta, size, align_corners): +def AffineGrid4D(theta, size, align_corners): n = size[0] w = size[3] h = size[2] @@ -38,10 +38,40 @@ def AffineGrid(theta, size, align_corners): theta = theta.transpose([0, 2, 1]) for i in range(len(theta)): ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i]) + return ret.reshape([n, h, w, 2]).astype("float32") -# print ret.reshape([h * w, 2]).astype("float32") - return ret.reshape([n, h, w, 2]).astype("float32") +def AffineGrid5D(theta, size, align_corners): + n = size[0] + d = size[2] + h = size[3] + w = size[4] + d_factor = h_factor = w_factor = 1 + if not align_corners: + d_factor = (d - 1) / float(d) + h_factor = (h - 1) / float(h) + w_factor = (w - 1) / float(w) + d_idx = np.repeat(np.repeat( + np.linspace(-1, 1, d)[:, np.newaxis, np.newaxis], h, axis=1), + w, + axis=2)[:, :, :, np.newaxis] * d_factor + h_idx = np.repeat(np.repeat( + np.linspace(-1, 1, h)[np.newaxis, :, np.newaxis], w, axis=2), + d, + axis=0)[:, :, :, np.newaxis] * h_factor + w_idx = np.repeat(np.repeat( + np.linspace(-1, 1, w)[np.newaxis, np.newaxis, :], h, axis=1), + d, + axis=0)[:, :, :, np.newaxis] * w_factor + grid = np.concatenate( + [w_idx, h_idx, d_idx, np.ones([d, h, w, 1])], axis=3) # d * h * w * 4 + grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * d * h * w * 4 + + ret = np.zeros([n, d * h * w, 3]) + theta = theta.transpose([0, 2, 1]) + for i in range(len(theta)): + ret[i] = np.dot(grid[i].reshape([d * h * w, 4]), theta[i]) + return ret.reshape([n, d, h, w, 3]).astype("float32") class TestAffineGridOp(OpTest): @@ -60,9 +90,16 @@ class TestAffineGridOp(OpTest): self.inputs['OutputShape'] = self.output_shape else: self.attrs['output_shape'] = self.output_shape - self.outputs = { - 'Output': AffineGrid(theta, self.output_shape, self.align_corners) - } + if (self.theta_shape[1] == 2 and self.theta_shape[2] == 3): + self.outputs = { + 'Output': AffineGrid4D(theta, self.output_shape, + self.align_corners) + } + else: + self.outputs = { + 'Output': AffineGrid5D(theta, self.output_shape, + self.align_corners) + } def test_check_output(self): self.check_output(check_eager=True) @@ -123,6 +160,46 @@ class TestAffineGridOpCase4(TestAffineGridOp): self.align_corners = False +class TestAffineGridOp5DCase1(TestAffineGridOp): + + def initTestCase(self): + self.theta_shape = (20, 3, 4) + self.output_shape = np.array([20, 1, 2, 5, 7]).astype("int32") + self.dynamic_shape = True + self.use_cudnn = False + self.align_corners = False + + +class TestAffineGridOp5DCase2(TestAffineGridOp): + + def initTestCase(self): + self.theta_shape = (20, 3, 4) + self.output_shape = np.array([20, 1, 2, 5, 7]).astype("int32") + self.dynamic_shape = True + self.use_cudnn = False + self.align_corners = True + + +class TestAffineGridOp5DCase3(TestAffineGridOp): + + def initTestCase(self): + self.theta_shape = (20, 3, 4) + self.output_shape = np.array([20, 1, 2, 5, 7]).astype("int32") + self.dynamic_shape = True + self.use_cudnn = False + self.align_corners = False + + +class TestAffineGridOp5DCase4(TestAffineGridOp): + + def initTestCase(self): + self.theta_shape = (25, 3, 4) + self.output_shape = np.array([25, 1, 2, 5, 6]).astype("int32") + self.dynamic_shape = False + self.use_cudnn = False + self.align_corners = False + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 7515ee66dc..293b055d8e 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -29,25 +29,22 @@ __all__ = [] def affine_grid(theta, out_shape, align_corners=True, name=None): """ - It generates a grid of (x,y) coordinates using the parameters of + It generates a grid of (x,y) or (x,y,z) coordinates using the parameters of the affine transformation that correspond to a set of points where the input feature map should be sampled to produce the transformed output feature map. Args: - theta (Tensor) - A tensor with shape [N, 2, 3]. It contains a batch of affine transform parameters. + theta (Tensor) - A tensor with shape [N, 2, 3] or [N, 3, 4]. It contains a batch of affine transform parameters. The data type can be float32 or float64. - out_shape (Tensor | list | tuple): The shape of target output with format [batch_size, channel, height, width]. - ``out_shape`` can be a Tensor or a list or tuple. The data - type must be int32. - align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True. - name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + out_shape (Tensor | list | tuple): Type can be a 1-D Tensor, list, or tuple. It is used to represent the shape of the output in an affine transformation, in the format ``[N, C, H, W]`` or ``[N, C, D, H, W]``. + When the format is ``[N, C, H, W]``, it represents the batch size, number of channels, height and width. When the format is ``[N, C, D, H, W]``, it represents the batch size, number of channels, depth, height and width. + The data type must be int32. + align_corners(bool, optional): if True, aligns the centers of the 4 (4D) or 8 (5D) corner pixels of the input and output tensors, and preserves the value of the corner pixels. Default: True + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`. - - Raises: - ValueError: If the type of arguments is not supported. + Tensor, A Tensor with shape [batch_size, H, W, 2] or [batch, D, H, W, 3] while ('D')'H', 'W' are the (depth)height, width of feature map in affine transformation. The data type is the same as `theta`. Examples: @@ -55,13 +52,11 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): import paddle import paddle.nn.functional as F - import numpy as np # theta shape = [1, 2, 3] - theta = np.array([[[-0.7, -0.4, 0.3], - [ 0.6, 0.5, 1.5]]]).astype("float32") - theta_t = paddle.to_tensor(theta) + theta = paddle.to_tensor([[[-0.7, -0.4, 0.3], + [ 0.6, 0.5, 1.5]]], dtype="float32") y_t = F.affine_grid( - theta_t, + theta, [1, 2, 3, 3], align_corners=False) print(y_t) @@ -86,6 +81,8 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): use_cudnn = True else: use_cudnn = False + if theta.shape[1] == 3: + use_cudnn = False if is_compiled_with_rocm(): use_cudnn = False # ROCM platform do not have MIOPEN kernel for affine_grid -- GitLab