diff --git a/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc b/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc deleted file mode 100644 index 48832ac1d6dadf274e5389b953de1486669c0871..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_WITH_HIP -// HIP not support cudnnSpatialTfGridGeneratorForward - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using ScopedSpatialTransformerDescriptor = - platform::ScopedSpatialTransformerDescriptor; - -template -class CUDNNAffineGridOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::InvalidArgument( - "Only support for CUDAPlace.Please switch your context from " - "CPUPlace to CUDAPlace or update your cudnn.")); - auto& dev_ctx = ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - auto* theta = ctx.Input("Theta"); - auto* output = ctx.Output("Output"); - const T* theta_data = theta->data(); - - int n = theta->dims()[0]; - auto size_attr = ctx.Attr>("output_shape"); - Tensor h_sizes; - int* h_size_data; - if (size_attr.size() == 0) { - auto* output_shape = ctx.Input("OutputShape"); - framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); - h_size_data = h_sizes.data(); - } else { - h_size_data = h_sizes.mutable_data({4}, platform::CPUPlace()); - h_size_data[0] = n; - h_size_data[1] = size_attr[1]; - h_size_data[2] = size_attr[2]; - h_size_data[3] = size_attr[3]; - } - - T* output_data = output->mutable_data( - {n, h_size_data[2], h_size_data[3], 2}, ctx.GetPlace()); - ScopedSpatialTransformerDescriptor st_desc; - cudnnSpatialTransformerDescriptor_t cudnn_st_desc = - st_desc.descriptor(4, h_size_data); - - PADDLE_ENFORCE_EQ( - platform::dynload::cudnnSpatialTfGridGeneratorForward( - handle, cudnn_st_desc, theta_data, output_data), - 0, - platform::errors::Fatal("Some errors has occurred " - "during forward computation in cudnn.")); - } -}; - -template -class CUDNNAffineGridGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::InvalidArgument( - "Only " - "support for CUDAPlace. Please switch " - "your context from CPUPlace to " - "CUDAPlace or update your cudnn.")); - auto& dev_ctx = ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - auto output_grad = ctx.Input(framework::GradVarName("Output")); - auto theta_grad = ctx.Output(framework::GradVarName("Theta")); - - int n = output_grad->dims()[0]; - auto size_attr = ctx.Attr>("output_shape"); - Tensor h_sizes; - int* h_size_data; - if (size_attr.size() == 0) { - auto* output_shape = ctx.Input("OutputShape"); - framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); - h_size_data = h_sizes.data(); - } else { - h_size_data = h_sizes.mutable_data({4}, platform::CPUPlace()); - h_size_data[0] = n; - h_size_data[1] = size_attr[1]; - h_size_data[2] = size_attr[2]; - h_size_data[3] = size_attr[3]; - } - - ScopedSpatialTransformerDescriptor st_desc; - cudnnSpatialTransformerDescriptor_t cudnn_st_desc = - st_desc.descriptor(4, h_size_data); - - const T* output_grad_data = output_grad->data(); - T* theta_grad_data = theta_grad->mutable_data(ctx.GetPlace()); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnSpatialTfGridGeneratorBackward( - handle, cudnn_st_desc, output_grad_data, theta_grad_data)); - } -}; - -} // namespace operators -} // namespace paddle - -namespace plat = paddle::platform; -REGISTER_OP_KERNEL(affine_grid, - CUDNN, - plat::CUDAPlace, - paddle::operators::CUDNNAffineGridOpKernel, - paddle::operators::CUDNNAffineGridOpKernel); -REGISTER_OP_KERNEL(affine_grid_grad, - CUDNN, - plat::CUDAPlace, - paddle::operators::CUDNNAffineGridGradOpKernel, - paddle::operators::CUDNNAffineGridGradOpKernel); - -#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 1977a33fc197e2220927b7f805c60bb4af258767..871d7350e540dce93031aa4e33ee58df0dceec19 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -12,41 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/affine_grid_op.h" - #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -struct Linspace { - void operator()(T start, - T end, - int count, - bool align_corners, - framework::Tensor* numbers, - const framework::ExecutionContext& ctx) { - T* number_data = numbers->mutable_data({count}, platform::CPUPlace()); - T slice = (end - start) / (T)(count - 1); - if (!align_corners) { - slice = (end - start) / (T)count; - start *= (T)(count - 1) / (T)count; - } - for (int i = 0; i < count; ++i) { - number_data[i] = start + (T)i * slice; - } - } -}; - class AffineGridOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -280,14 +263,8 @@ REGISTER_OPERATOR(affine_grid, ops::AffineGridOpMaker, ops::AffineGridGradMaker, ops::AffineGridGradMaker); -REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad); -REGISTER_OP_CPU_KERNEL(affine_grid, - ops::AffineGridOpKernel, - ops::AffineGridOpKernel); -REGISTER_OP_CPU_KERNEL(affine_grid_grad, - ops::AffineGridGradOpKernel, - ops::AffineGridGradOpKernel); +REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad); REGISTER_OP_VERSION(affine_grid) .AddCheckpoint( diff --git a/paddle/fluid/operators/affine_grid_op.cu b/paddle/fluid/operators/affine_grid_op.cu deleted file mode 100644 index a5d4c6484a1f9b1bd01a99568a59edcdad40d651..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/affine_grid_op.cu +++ /dev/null @@ -1,241 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/affine_grid_op.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { - CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } -} - -template -struct Linspace { - void operator()(T start, - T end, - int count, - bool align_corners, - framework::Tensor* numbers, - const framework::ExecutionContext& ctx) { - T* number_data = numbers->mutable_data({count}, ctx.GetPlace()); - T slice = (end - start) / (T)(count - 1); - if (!align_corners) { - slice = (end - start) / (T)count; - start *= (T)(count - 1) / (T)count; - } - auto stream = ctx.cuda_device_context().stream(); - int block = 512; - int grid = (count + block - 1) / block; - LinspaceKernel - <<>>(start, slice, count, number_data); - } -}; - -template -__global__ void affine_grid_kernel(const int count, - int n, - int out_h, - int out_w, - T h_start, - T w_start, - T h_step, - T w_step, - const T* theta, // N, 2, 3 - T* output) { - CUDA_KERNEL_LOOP(index, count) { - int w = index % out_w; - int h = (index / out_w) % out_h; - int n = index / (out_w * out_h); - - T h_coor = h_step * static_cast(h) + static_cast(h_start); - T w_coor = w_step * static_cast(w) + static_cast(w_start); - - int theta_offset = n * 6; // 2 * 3; - // affine from (h_coor, w_coor) to (x, y) - output[index * 2] = theta[theta_offset] * w_coor + - theta[theta_offset + 1] * h_coor + - theta[theta_offset + 2]; - output[index * 2 + 1] = theta[theta_offset + 3] * w_coor + - theta[theta_offset + 4] * h_coor + - theta[theta_offset + 5]; - } -} - -template -__global__ void affine_grid_grad_kernel(const int count, - int n, - int out_h, - int out_w, - T h_start, - T w_start, - T h_step, - T w_step, - const T* out_grad, // N, H, W, 2 - T* theta_grad) { // N, 2, 3 - CUDA_KERNEL_LOOP(index, count) { - int w = index % out_w; - int h = (index / out_w) % out_h; - int n = index / (out_w * out_h); - T h_coor = h_step * static_cast(h) + static_cast(h_start); - T w_coor = w_step * static_cast(w) + static_cast(w_start); - - int theta_offset = n * 6; // 2 * 3; - T out_grad_x = out_grad[index * 2]; - platform::CudaAtomicAdd(theta_grad + theta_offset, out_grad_x * w_coor); - platform::CudaAtomicAdd(theta_grad + theta_offset + 1, out_grad_x * h_coor); - platform::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x); - - T out_grad_y = out_grad[index * 2 + 1]; - platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_y * w_coor); - platform::CudaAtomicAdd(theta_grad + theta_offset + 4, out_grad_y * h_coor); - platform::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y); - } -} - -template -class AffineGridOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* theta = ctx.Input("Theta"); - int n = theta->dims()[0]; - auto size_attr = ctx.Attr>("output_shape"); - auto align_corners = ctx.Attr("align_corners"); - int h = 0; - int w = 0; - if (size_attr.size() == 0) { - auto* output_shape = ctx.Input("OutputShape"); - Tensor h_sizes; - framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); - const int* h_size_data = h_sizes.data(); - h = h_size_data[2]; - w = h_size_data[3]; - } else { - h = size_attr[2]; - w = size_attr[3]; - } - auto* output = ctx.Output("Output"); - T* out_data = output->mutable_data({n, h, w, 2}, ctx.GetPlace()); - - T h_step; - T w_step; - T h_start = -1; - T w_start = -1; - if (align_corners) { - h_step = static_cast(2) / static_cast(h - 1); - w_step = static_cast(2) / static_cast(w - 1); - } else { - h_step = static_cast(2) / static_cast(h); - w_step = static_cast(2) / static_cast(w); - - h_start *= static_cast(h - 1) / static_cast(h); - w_start *= static_cast(w - 1) / static_cast(w); - } - - const int count = n * h * w; - int block = 512; - int grid = (count + block - 1) / block; - auto cu_stream = ctx.cuda_device_context().stream(); - affine_grid_kernel<<>>( - count, - n, - h, - w, - h_start, - w_start, - h_step, - w_step, - theta->data(), // N, 2, 3 - out_data); - } -}; - -template -class AffineGridGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto output_grad = ctx.Input(framework::GradVarName("Output")); - auto theta_grad = ctx.Output(framework::GradVarName("Theta")); - int n = output_grad->dims()[0]; - auto size_attr = ctx.Attr>("output_shape"); - auto align_corners = ctx.Attr("align_corners"); - int h = 0; - int w = 0; - if (size_attr.size() == 0) { - auto* output_shape = ctx.Input("OutputShape"); - Tensor h_sizes; - framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); - const int* h_size_data = h_sizes.data(); - h = h_size_data[2]; - w = h_size_data[3]; - } else { - h = size_attr[2]; - w = size_attr[3]; - } - T* theta_grad_data = theta_grad->mutable_data({n, 2, 3}, ctx.GetPlace()); - phi::funcs::SetConstant()( - ctx.cuda_device_context(), theta_grad, static_cast(0)); - - T h_step; - T w_step; - T h_start = -1; - T w_start = -1; - if (align_corners) { - h_step = static_cast(2) / static_cast(h - 1); - w_step = static_cast(2) / static_cast(w - 1); - } else { - h_step = static_cast(2) / static_cast(h); - w_step = static_cast(2) / static_cast(w); - - h_start *= static_cast(h - 1) / static_cast(h); - w_start *= static_cast(w - 1) / static_cast(w); - } - const int count = n * h * w; - VLOG(3) << "count: " << count << "; h_step: " << h_step - << "; w_step: " << w_step << "; h_start: " << h_start - << "; w_start: " << w_start; - int block = 512; - int grid = (count + block - 1) / block; - auto cu_stream = ctx.cuda_device_context().stream(); - affine_grid_grad_kernel<<>>( - count, - n, - h, - w, - h_start, - w_start, - h_step, - w_step, - output_grad->data(), - theta_grad_data); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(affine_grid, - ops::AffineGridOpCUDAKernel, - ops::AffineGridOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(affine_grid_grad, - ops::AffineGridGradOpCUDAKernel, - ops::AffineGridGradOpCUDAKernel); diff --git a/paddle/fluid/operators/affine_grid_op.h b/paddle/fluid/operators/affine_grid_op.h deleted file mode 100644 index a0cdadecbe456e239fd03077c8474419163643a8..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/affine_grid_op.h +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -using EigenTensor = framework::EigenTensor; - -using Array1 = Eigen::DSizes; -using Array2 = Eigen::DSizes; -using Array3 = Eigen::DSizes; -using Array4 = Eigen::DSizes; - -/** - *Return a tensor with evenly spaced numbers over a specified interval. - */ -template -struct Linspace { - void operator()(T start, - T end, - int count, - bool align_corners, - framework::Tensor* numbers, - const framework::ExecutionContext& ctx); -}; - -template -inline void GetIdxMap(int n, - int h, - int w, - bool align_corners, - Tensor* grid, - const framework::ExecutionContext& ctx) { - auto& place = *ctx.template device_context().eigen_device(); - grid->mutable_data({n, h, w, 3}, ctx.GetPlace()); - auto grid_t = EigenTensor::From(*grid); - // Get indexes of height with shape [height, width, 1] - Tensor h_idx; - Linspace linspace; - linspace((T)-1, (T)1, h, align_corners, &h_idx, ctx); - auto h_idx_t = EigenTensor::From(h_idx); - // Get indexes of width with shape [height, width, 1] - Tensor w_idx; - linspace((T)-1, (T)1, w, align_corners, &w_idx, ctx); - auto w_idx_t = EigenTensor::From(w_idx); - // Get constant ones tensor with shape [height, width, 1] - Tensor ones; - ones.mutable_data({h, w, 1}, ctx.GetPlace()); - - phi::funcs::SetConstant()( - ctx.template device_context(), &ones, static_cast(1)); - auto ones_t = EigenTensor::From(ones); - // Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and - // ones - Tensor w_idx_map; - w_idx_map.mutable_data({h, w, 1}, ctx.GetPlace()); - auto w_idx_map_t = EigenTensor::From(w_idx_map); - Tensor h_idx_map; - h_idx_map.mutable_data({h, w, 1}, ctx.GetPlace()); - auto h_idx_map_t = EigenTensor::From(h_idx_map); - Tensor w_h_idx_map; - w_h_idx_map.mutable_data({h, w, 2}, ctx.GetPlace()); - auto w_h_idx_map_t = EigenTensor::From(w_h_idx_map); - Tensor w_h_one_idx_map; - w_h_one_idx_map.mutable_data({h, w, 3}, ctx.GetPlace()); - auto w_h_one_idx_map_t = EigenTensor::From(w_h_one_idx_map); - w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w)) - .broadcast(Array2(h, 1)) - .reshape(Array3(h, w, 1)); - h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h)) - .broadcast(Array2(w, 1)) - .shuffle(Array2(1, 0)) - .reshape(Array3(h, w, 1)); - - w_h_idx_map_t.device(place) = w_idx_map_t.concatenate(h_idx_map_t, 2); - w_h_one_idx_map_t.device(place) = w_h_idx_map_t.concatenate(ones_t, 2); - grid_t.device(place) = w_h_one_idx_map_t.reshape(Array4(1, h, w, 3)) - .broadcast(Array4(n, 1, 1, 1)); -} - -template -class AffineGridOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* theta = ctx.Input("Theta"); - int n = theta->dims()[0]; - auto size_attr = ctx.Attr>("output_shape"); - auto align_corners = ctx.Attr("align_corners"); - int h = 0; - int w = 0; - if (size_attr.size() == 0) { - auto* output_shape = ctx.Input("OutputShape"); - Tensor h_sizes; - framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); - const int* h_size_data = h_sizes.data(); - h = h_size_data[2]; - w = h_size_data[3]; - } else { - h = size_attr[2]; - w = size_attr[3]; - } - auto* output = ctx.Output("Output"); - output->mutable_data({n, h, w, 2}, ctx.GetPlace()); - phi::funcs::SetConstant()( - ctx.template device_context(), - output, - static_cast(0)); - Tensor grid; - GetIdxMap(n, h, w, align_corners, &grid, ctx); - // output = grid * theta.T - // TODO(wanghaoshuang): Refine batched matrix multiply - auto blas = phi::funcs::GetBlas(ctx); - for (int i = 0; i < n; ++i) { - Tensor sliced_grid = grid.Slice(i, i + 1).Resize( - {static_cast(h) * static_cast(w), 3}); - Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3}); - Tensor sliced_out = output->Slice(i, i + 1).Resize( - {static_cast(h) * static_cast(w), 2}); - blas.MatMul( - sliced_grid, false, sliced_theta, true, T(1), &sliced_out, T(0)); - } - } -}; - -template -class AffineGridGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto output_grad = ctx.Input(framework::GradVarName("Output")); - auto theta_grad = ctx.Output(framework::GradVarName("Theta")); - int n = output_grad->dims()[0]; - auto size_attr = ctx.Attr>("output_shape"); - auto align_corners = ctx.Attr("align_corners"); - int h = 0; - int w = 0; - if (size_attr.size() == 0) { - auto* output_shape = ctx.Input("OutputShape"); - Tensor h_sizes; - framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); - const int* h_size_data = h_sizes.data(); - h = h_size_data[2]; - w = h_size_data[3]; - } else { - h = size_attr[2]; - w = size_attr[3]; - } - theta_grad->mutable_data({n, 2, 3}, ctx.GetPlace()); - phi::funcs::SetConstant()( - ctx.template device_context(), - theta_grad, - static_cast(0)); - Tensor grid; - GetIdxMap(n, h, w, align_corners, &grid, ctx); - // output = grid * theta.T - // TODO(wanghaoshuang): Refine batched matrix multiply - auto blas = phi::funcs::GetBlas(ctx); - for (int i = 0; i < n; ++i) { - Tensor sliced_grid = grid.Slice(i, i + 1).Resize( - {static_cast(h) * static_cast(w), 3}); - Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize( - {static_cast(h) * static_cast(w), 2}); - Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3}); - blas.MatMul(sliced_out_grad, - true, - sliced_grid, - false, - T(1), - &sliced_theta_grad, - T(0)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 60e3012f5715286bf93ebee90116a03e9f405e83..4857603c080fe60992943dbc245bfd95e30bc7d3 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -112,6 +112,19 @@ func : addmm backward : addmm_grad +- api : affine_grid + args : (Tensor input, IntArray outputShape, bool use_cudnn=true, bool align_corners=true) + output : Tensor + infer_meta : + func : AffineGridInferMeta + param : [input, outputShape, align_corners] + kernel : + func : affine_grid + param : [input, outputShape, align_corners] + data_type : input + use_gpudnn: use_cudnn + backward : affine_grid_grad + - api : all args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) output : Tensor(out) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 5182709967213b23220edfd968d1ccc409031ff7..000f88979f3fafb0063f4738381fb3083eebeff9 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -92,6 +92,18 @@ kernel : func : addmm_grad +- backward_api : affine_grid_grad + forward : affine_grid (Tensor input, IntArray outputShape, bool use_cudnn=true, bool align_corners=true) -> Tensor(output) + args : (Tensor output_grad, IntArray outputShape, bool use_cudnn=true, bool align_corners=true) + output : Tensor(input_grad) + infer_meta : + func : AffineGridGradInferMeta + param : [output_grad, outputShape, align_corners] + kernel : + func : affine_grid_grad + param : [output_grad, outputShape, align_corners] + use_gpudnn: use_cudnn + - backward_api : amax_grad forward: amax (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e375999bfbafd86513fd3c28519a6854c5db4e5b..5395b4e23dcd4b6582c26aba29a5d5a618ae8bc3 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -18,6 +18,16 @@ limitations under the License. */ namespace phi { +void AffineGridGradInferMeta(const MetaTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + MetaTensor* input_grad) { + if (input_grad) { + auto output_dims = output_grad.dims(); + input_grad->set_dims(phi::make_ddim({output_dims[0], 2, 3})); + } +} + void AngleGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, MetaTensor* x_grad) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 2d31860c17b49e764a178a95eabb9196427b407b..a0e79cfaf04363547ed83327353cd015daaff35e 100755 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -27,6 +27,11 @@ namespace phi { // // NOTE: The InferMeta Functions in this file are arranged in alphabetic order. +void AffineGridGradInferMeta(const MetaTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + MetaTensor* input_grad); + void AngleGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, MetaTensor* x_grad); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8add8e970c60fac48a8efdf79c154ccbb244be11..8cc7f75533ca44c8a1199220430e80281dceed6f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -44,6 +44,51 @@ static DDim CheckAndGetOutputDim(const DDim& dim_x) { } } // namespace detail +void AffineGridInferMeta(const MetaTensor& input, + const IntArray& outputShape, + bool align_corners, + MetaTensor* output) { + auto theta_dims = input.dims(); + PADDLE_ENFORCE_EQ( + theta_dims.size(), + 3, + phi::errors::InvalidArgument( + "The input Theta's dimensions size should be 3. But received " + "Theta's demensions size=[%d], Theta's dimensions=[%s].", + theta_dims.size(), + theta_dims)); + + PADDLE_ENFORCE_EQ( + outputShape.GetData().size(), + 4, + phi::errors::InvalidArgument( + "The size of attribute 'output_shape' in AffineGridOp should be " + "4. But received output_shape's size=[%d].", + outputShape.GetData().size())); + + PADDLE_ENFORCE_EQ( + theta_dims[1], + 2, + phi::errors::InvalidArgument( + "The second dimesion of input 'theta' in AffineGridOp should be 2. " + "But received second dimesion=[%d], dimesions=[%s]", + theta_dims[1], + theta_dims)); + PADDLE_ENFORCE_EQ( + theta_dims[2], + 3, + phi::errors::InvalidArgument( + "The third dimesion of input 'theta' in AffineGridOp should be 3. " + "But received third dimesion=[%d], dimesions=[%s]", + theta_dims[2], + theta_dims)); + + // N * H * W * 2 + output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, 2})); + output->set_dtype(input.dtype()); + output->share_lod(input); +} + void ArgMinMaxInferMeta(const MetaTensor& x, int64_t axis, bool keepdims, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a2753e46c89ce11f87950faa506dcc8d5ca1d6dc..ae88ecb40c3c037de1531dc9c48d60ecb7f87d2b 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -34,6 +34,11 @@ class MetaConfig; // // The InferMeta Functions in this file are arranged in alphabetic order. +void AffineGridInferMeta(const MetaTensor& input, + const IntArray& outputShape, + bool align_corners, + MetaTensor* output); + void ArgMinMaxInferMeta(const MetaTensor& x, int64_t axis, bool keepdims, diff --git a/paddle/phi/kernels/affine_grid_grad_kernel.h b/paddle/phi/kernels/affine_grid_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..061b763ed33f0ef422f8175f0daa7da6a3100777 --- /dev/null +++ b/paddle/phi/kernels/affine_grid_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/kernels/affine_grid_impl.h" + +namespace phi { + +template +void AffineGridGradKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad); +} // namespace phi diff --git a/paddle/phi/kernels/affine_grid_impl.h b/paddle/phi/kernels/affine_grid_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..14c9fa7b56f8c0fdead99b33337296c9c0808217 --- /dev/null +++ b/paddle/phi/kernels/affine_grid_impl.h @@ -0,0 +1,103 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using Array3 = Eigen::DSizes; +using Array4 = Eigen::DSizes; + +template +struct Linspace { + void operator()(T start, + T end, + int count, + bool align_corners, + DenseTensor* numbers, + const Context& dev_ctx); +}; + +template +inline void GetIdxMap(int n, + int h, + int w, + bool align_corners, + DenseTensor* grid, + const Context& dev_ctx) { + auto& place = *dev_ctx.eigen_device(); + grid->Resize(phi::make_ddim({n, h, w, 3})); + dev_ctx.template Alloc(grid); + auto grid_t = EigenTensor::From(*grid); + // Get indexes of height with shape [height, width, 1] + DenseTensor h_idx; + Linspace linspace; + linspace((T)-1, (T)1, h, align_corners, &h_idx, dev_ctx); + auto h_idx_t = EigenTensor::From(h_idx); + // Get indexes of width with shape [height, width, 1] + DenseTensor w_idx; + linspace((T)-1, (T)1, w, align_corners, &w_idx, dev_ctx); + auto w_idx_t = EigenTensor::From(w_idx); + // Get constant ones tensor with shape [height, width, 1] + DenseTensor ones; + ones.Resize(phi::make_ddim({h, w, 1})); + dev_ctx.template Alloc(&ones); + + phi::funcs::SetConstant()(dev_ctx, &ones, static_cast(1)); + auto ones_t = EigenTensor::From(ones); + // Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and + // ones + DenseTensor w_idx_map; + w_idx_map.Resize(phi::make_ddim({h, w, 1})); + dev_ctx.template Alloc(&w_idx_map); + auto w_idx_map_t = EigenTensor::From(w_idx_map); + + DenseTensor h_idx_map; + h_idx_map.Resize(phi::make_ddim({h, w, 1})); + dev_ctx.template Alloc(&h_idx_map); + auto h_idx_map_t = EigenTensor::From(h_idx_map); + + DenseTensor w_h_idx_map; + w_h_idx_map.Resize(phi::make_ddim({h, w, 2})); + dev_ctx.template Alloc(&w_h_idx_map); + auto w_h_idx_map_t = EigenTensor::From(w_h_idx_map); + + DenseTensor w_h_one_idx_map; + w_h_one_idx_map.Resize(phi::make_ddim({h, w, 3})); + dev_ctx.template Alloc(&w_h_one_idx_map); + auto w_h_one_idx_map_t = EigenTensor::From(w_h_one_idx_map); + + w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w)) + .broadcast(Array2(h, 1)) + .reshape(Array3(h, w, 1)); + h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h)) + .broadcast(Array2(w, 1)) + .shuffle(Array2(1, 0)) + .reshape(Array3(h, w, 1)); + + w_h_idx_map_t.device(place) = w_idx_map_t.concatenate(h_idx_map_t, 2); + w_h_one_idx_map_t.device(place) = w_h_idx_map_t.concatenate(ones_t, 2); + grid_t.device(place) = w_h_one_idx_map_t.reshape(Array4(1, h, w, 3)) + .broadcast(Array4(n, 1, 1, 1)); +} + +} // namespace phi diff --git a/paddle/phi/kernels/affine_grid_kernel.h b/paddle/phi/kernels/affine_grid_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..febd5be0a36617c64fdd2f13372bab338aa6b84c --- /dev/null +++ b/paddle/phi/kernels/affine_grid_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/kernels/affine_grid_impl.h" + +namespace phi { + +template +void AffineGridKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output); +} // namespace phi diff --git a/paddle/phi/kernels/cpu/affine_grid_grad_kernel.cc b/paddle/phi/kernels/cpu/affine_grid_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..778a0adc9ca934f237fc76ab548599b60ef79bd3 --- /dev/null +++ b/paddle/phi/kernels/cpu/affine_grid_grad_kernel.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/affine_grid_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +struct Linspace { + void operator()(T start, + T end, + int count, + bool align_corners, + DenseTensor* numbers, + const phi::CPUContext& dev_ctx) { + numbers->Resize(phi::make_ddim({count})); + T* number_data = dev_ctx.template Alloc(numbers); + T slice = (end - start) / (T)(count - 1); + if (!align_corners) { + slice = (end - start) / (T)count; + start *= (T)(count - 1) / (T)count; + } + for (int i = 0; i < count; ++i) { + number_data[i] = start + (T)i * slice; + } + } +}; + +template +void AffineGridGradKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad) { + auto& theta_grad = input_grad; + int n = output_grad.dims()[0]; + auto& size_attr = outputShape.GetData(); + int h = 0; + int w = 0; + h = size_attr[2]; + w = size_attr[3]; + theta_grad->Resize(phi::make_ddim({n, 2, 3})); + dev_ctx.template Alloc(theta_grad); + phi::funcs::SetConstant()(dev_ctx, theta_grad, static_cast(0)); + DenseTensor grid; + GetIdxMap(n, h, w, align_corners, &grid, dev_ctx); + // output = grid * theta.T + // TODO(wanghaoshuang): Refine batched matrix multiply + auto blas = phi::funcs::GetBlas(dev_ctx); + for (int i = 0; i < n; ++i) { + DenseTensor sliced_grid = grid.Slice(i, i + 1).Resize( + {static_cast(h) * static_cast(w), 3}); + DenseTensor sliced_out_grad = output_grad.Slice(i, i + 1).Resize( + {static_cast(h) * static_cast(w), 2}); + DenseTensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3}); + blas.MatMul(sliced_out_grad, + true, + sliced_grid, + false, + T(1), + &sliced_theta_grad, + T(0)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(affine_grid_grad, + CPU, + ALL_LAYOUT, + phi::AffineGridGradKernel, + float, + double){}; diff --git a/paddle/phi/kernels/cpu/affine_grid_kernel.cc b/paddle/phi/kernels/cpu/affine_grid_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6584a8eb263ea6eda970a68f81c85faacfdfff3e --- /dev/null +++ b/paddle/phi/kernels/cpu/affine_grid_kernel.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/affine_grid_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +struct Linspace { + void operator()(T start, + T end, + int count, + bool align_corners, + DenseTensor* numbers, + const phi::CPUContext& dev_ctx) { + numbers->Resize(phi::make_ddim({count})); + T* number_data = dev_ctx.template Alloc(numbers); + T slice = (end - start) / (T)(count - 1); + if (!align_corners) { + slice = (end - start) / (T)count; + start *= (T)(count - 1) / (T)count; + } + for (int i = 0; i < count; ++i) { + number_data[i] = start + (T)i * slice; + } + } +}; + +template +void AffineGridKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + auto* theta = &input; + int n = theta->dims()[0]; + auto& size_attr = outputShape.GetData(); + int h = 0; + int w = 0; + h = size_attr[2]; + w = size_attr[3]; + output->Resize(phi::make_ddim({n, h, w, 2})); + dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, static_cast(0)); + DenseTensor grid; + GetIdxMap(n, h, w, align_corners, &grid, dev_ctx); + // output = grid * theta.T + // TODO(wanghaoshuang): Refine batched matrix multiply + auto blas = phi::funcs::GetBlas(dev_ctx); + for (int i = 0; i < n; ++i) { + DenseTensor sliced_grid = grid.Slice(i, i + 1).Resize( + {static_cast(h) * static_cast(w), 3}); + DenseTensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3}); + DenseTensor sliced_out = output->Slice(i, i + 1).Resize( + {static_cast(h) * static_cast(w), 2}); + blas.MatMul( + sliced_grid, false, sliced_theta, true, T(1), &sliced_out, T(0)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + affine_grid, CPU, ALL_LAYOUT, phi::AffineGridKernel, float, double){}; diff --git a/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu b/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..baa2616267d537dd7e016274f67abb5a110ac7bb --- /dev/null +++ b/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu @@ -0,0 +1,149 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/affine_grid_grad_kernel.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { + CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } +} + +template +struct Linspace { + void operator()(T start, + T end, + int count, + bool align_corners, + DenseTensor* numbers, + const phi::GPUContext& dev_ctx) { + numbers->Resize(phi::make_ddim({count})); + T* number_data = dev_ctx.template Alloc(numbers); + T slice = (end - start) / (T)(count - 1); + if (!align_corners) { + slice = (end - start) / (T)count; + start *= (T)(count - 1) / (T)count; + } + auto stream = dev_ctx.stream(); + int block = 512; + int grid = (count + block - 1) / block; + LinspaceKernel + <<>>(start, slice, count, number_data); + } +}; + +template +__global__ void affine_grid_grad_kernel(const int count, + int n, + int out_h, + int out_w, + T h_start, + T w_start, + T h_step, + T w_step, + const T* out_grad, // N, H, W, 2 + T* theta_grad) { // N, 2, 3 + CUDA_KERNEL_LOOP(index, count) { + int w = index % out_w; + int h = (index / out_w) % out_h; + int n = index / (out_w * out_h); + T h_coor = h_step * static_cast(h) + static_cast(h_start); + T w_coor = w_step * static_cast(w) + static_cast(w_start); + + int theta_offset = n * 6; // 2 * 3; + T out_grad_x = out_grad[index * 2]; + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset, + out_grad_x * w_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 1, + out_grad_x * h_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x); + + T out_grad_y = out_grad[index * 2 + 1]; + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 3, + out_grad_y * w_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 4, + out_grad_y * h_coor); + paddle::platform::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y); + } +} + +template +void AffineGridGradCUDAKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad) { + auto& theta_grad = input_grad; + int n = output_grad.dims()[0]; + auto& size_attr = outputShape.GetData(); + int h = 0; + int w = 0; + h = size_attr[2]; + w = size_attr[3]; + theta_grad->Resize(phi::make_ddim({n, 2, 3})); + T* theta_grad_data = dev_ctx.template Alloc(theta_grad); + phi::funcs::SetConstant()( + dev_ctx, theta_grad, static_cast(0)); + + T h_step; + T w_step; + T h_start = -1; + T w_start = -1; + if (align_corners) { + h_step = static_cast(2) / static_cast(h - 1); + w_step = static_cast(2) / static_cast(w - 1); + } else { + h_step = static_cast(2) / static_cast(h); + w_step = static_cast(2) / static_cast(w); + + h_start *= static_cast(h - 1) / static_cast(h); + w_start *= static_cast(w - 1) / static_cast(w); + } + const int count = n * h * w; + VLOG(3) << "count: " << count << "; h_step: " << h_step + << "; w_step: " << w_step << "; h_start: " << h_start + << "; w_start: " << w_start; + int block = 512; + int grid = (count + block - 1) / block; + auto cu_stream = dev_ctx.stream(); + affine_grid_grad_kernel<<>>(count, + n, + h, + w, + h_start, + w_start, + h_step, + w_step, + output_grad.data(), + theta_grad_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL(affine_grid_grad, + GPU, + ALL_LAYOUT, + phi::AffineGridGradCUDAKernel, + float, + double){}; diff --git a/paddle/phi/kernels/gpu/affine_grid_kernel.cu b/paddle/phi/kernels/gpu/affine_grid_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..ad5072f4bacd1d40f7368d64e15cabe113df8d80 --- /dev/null +++ b/paddle/phi/kernels/gpu/affine_grid_kernel.cu @@ -0,0 +1,137 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/affine_grid_kernel.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { + CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } +} + +template +struct Linspace { + void operator()(T start, + T end, + int count, + bool align_corners, + DenseTensor* numbers, + const phi::GPUContext& dev_ctx) { + numbers->Resize(phi::make_ddim({count})); + T* number_data = dev_ctx.template Alloc(numbers); + T slice = (end - start) / (T)(count - 1); + if (!align_corners) { + slice = (end - start) / (T)count; + start *= (T)(count - 1) / (T)count; + } + auto stream = dev_ctx.stream(); + int block = 512; + int grid = (count + block - 1) / block; + LinspaceKernel + <<>>(start, slice, count, number_data); + } +}; + +template +__global__ void affine_grid_kernel(const int count, + int n, + int out_h, + int out_w, + T h_start, + T w_start, + T h_step, + T w_step, + const T* theta, // N, 2, 3 + T* output) { + CUDA_KERNEL_LOOP(index, count) { + int w = index % out_w; + int h = (index / out_w) % out_h; + int n = index / (out_w * out_h); + + T h_coor = h_step * static_cast(h) + static_cast(h_start); + T w_coor = w_step * static_cast(w) + static_cast(w_start); + + int theta_offset = n * 6; // 2 * 3; + // affine from (h_coor, w_coor) to (x, y) + output[index * 2] = theta[theta_offset] * w_coor + + theta[theta_offset + 1] * h_coor + + theta[theta_offset + 2]; + output[index * 2 + 1] = theta[theta_offset + 3] * w_coor + + theta[theta_offset + 4] * h_coor + + theta[theta_offset + 5]; + } +} + +template +void AffineGridCUDAKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + auto* theta = &input; + int n = theta->dims()[0]; + auto& size_attr = outputShape.GetData(); + int h = 0; + int w = 0; + h = size_attr[2]; + w = size_attr[3]; + output->Resize(phi::make_ddim({n, h, w, 2})); + T* out_data = dev_ctx.template Alloc(output); + + T h_step; + T w_step; + T h_start = -1; + T w_start = -1; + if (align_corners) { + h_step = static_cast(2) / static_cast(h - 1); + w_step = static_cast(2) / static_cast(w - 1); + } else { + h_step = static_cast(2) / static_cast(h); + w_step = static_cast(2) / static_cast(w); + + h_start *= static_cast(h - 1) / static_cast(h); + w_start *= static_cast(w - 1) / static_cast(w); + } + + const int count = n * h * w; + int block = 512; + int grid = (count + block - 1) / block; + auto cu_stream = dev_ctx.stream(); + affine_grid_kernel<<>>( + count, + n, + h, + w, + h_start, + w_start, + h_step, + w_step, + theta->data(), // N, 2, 3 + out_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + affine_grid, GPU, ALL_LAYOUT, phi::AffineGridCUDAKernel, float, double){}; diff --git a/paddle/phi/kernels/gpudnn/affine_grid_grad_kernel.cu b/paddle/phi/kernels/gpudnn/affine_grid_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..3090e97e359d952d4b1265ee695e06a9629eada0 --- /dev/null +++ b/paddle/phi/kernels/gpudnn/affine_grid_grad_kernel.cu @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP + +#include "paddle/phi/kernels/affine_grid_grad_kernel.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +using ScopedSpatialTransformerDescriptor = + paddle::platform::ScopedSpatialTransformerDescriptor; + +template +void AffineGridGradCudnnKernel(const Context& dev_ctx, + const DenseTensor& output_grad, + const IntArray& outputShape, + bool align_corners, + DenseTensor* input_grad) { + PADDLE_ENFORCE_EQ( + paddle::platform::is_gpu_place(dev_ctx.GetPlace()), + true, + phi::errors::InvalidArgument( + "Only support for CUDAPlace.Please switch your context from " + "CPUPlace to CUDAPlace or update your cudnn.")); + auto handle = dev_ctx.cudnn_handle(); + auto& theta_grad = input_grad; + + int n = output_grad.dims()[0]; + auto& size_attr = outputShape.GetData(); + int h_size_data[4] = {0}; + h_size_data[0] = n; + h_size_data[1] = size_attr[1]; + h_size_data[2] = size_attr[2]; + h_size_data[3] = size_attr[3]; + + ScopedSpatialTransformerDescriptor st_desc; + cudnnSpatialTransformerDescriptor_t cudnn_st_desc = + st_desc.descriptor(4, h_size_data); + + const T* output_grad_data = output_grad.data(); + T* theta_grad_data = dev_ctx.template Alloc(theta_grad); + + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cudnnSpatialTfGridGeneratorBackward( + handle, cudnn_st_desc, output_grad_data, theta_grad_data)); +} + +} // namespace phi + +PD_REGISTER_KERNEL(affine_grid_grad, // cuda_only + GPUDNN, + ALL_LAYOUT, + phi::AffineGridGradCudnnKernel, + float, + double){}; +#endif diff --git a/paddle/phi/kernels/gpudnn/affine_grid_kernel.cu b/paddle/phi/kernels/gpudnn/affine_grid_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..1cbd0c364672c651c96b60d0392998d0c92bb49d --- /dev/null +++ b/paddle/phi/kernels/gpudnn/affine_grid_kernel.cu @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP + +#include "paddle/phi/kernels/affine_grid_kernel.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +using ScopedSpatialTransformerDescriptor = + paddle::platform::ScopedSpatialTransformerDescriptor; + +template +void AffineGridCudnnKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& outputShape, + bool align_corners, + DenseTensor* output) { + PADDLE_ENFORCE_EQ( + paddle::platform::is_gpu_place(dev_ctx.GetPlace()), + true, + phi::errors::InvalidArgument( + "Only support for CUDAPlace.Please switch your context from " + "CPUPlace to CUDAPlace or update your cudnn.")); + auto handle = dev_ctx.cudnn_handle(); + auto* theta = &input; + const T* theta_data = theta->data(); + int n = theta->dims()[0]; + auto& size_attr = outputShape.GetData(); + int h_size_data[4] = {0}; + h_size_data[0] = n; + h_size_data[1] = size_attr[1]; + h_size_data[2] = size_attr[2]; + h_size_data[3] = size_attr[3]; + output->Resize(phi::make_ddim({n, h_size_data[2], h_size_data[3], 2})); + T* output_data = dev_ctx.template Alloc(output); + ScopedSpatialTransformerDescriptor st_desc; + cudnnSpatialTransformerDescriptor_t cudnn_st_desc = + st_desc.descriptor(4, h_size_data); + + PADDLE_ENFORCE_EQ( + paddle::platform::dynload::cudnnSpatialTfGridGeneratorForward( + handle, cudnn_st_desc, theta_data, output_data), + 0, + phi::errors::Fatal("Some errors has occurred " + "during forward computation in cudnn.")); +} + +} // namespace phi + +PD_REGISTER_KERNEL(affine_grid, // cuda_only + GPUDNN, + ALL_LAYOUT, + phi::AffineGridCudnnKernel, + float, + double){}; +#endif diff --git a/paddle/phi/ops/compat/affine_grid_sig.cc b/paddle/phi/ops/compat/affine_grid_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..2506b4b7557189b3134817593dc4844e8f32f066 --- /dev/null +++ b/paddle/phi/ops/compat/affine_grid_sig.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" +namespace phi { + +KernelSignature AffineGridOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("OutputShape")) { + return KernelSignature( + "affine_grid", {"Theta"}, {"OutputShape", "align_corners"}, {"Output"}); + } else { + return KernelSignature("affine_grid", + {"Theta"}, + {"output_shape", "align_corners"}, + {"Output"}); + } +} + +KernelSignature AffineGridGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.HasInput("OutputShape")) { + return KernelSignature("affine_grid_grad", + {"Output@GRAD"}, + {"OutputShape", "align_corners"}, + {"Theta@GRAD"}); + } else { + return KernelSignature("affine_grid_grad", + {"Output@GRAD"}, + {"output_shape", "align_corners"}, + {"Theta@GRAD"}); + } +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(affine_grid, phi::AffineGridOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(affine_grid_grad, + phi::AffineGridGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py index 9c5b2e9971e70fbbf18c1bea93bdac69d06b9c5f..287c9edae27540c9fe753bfdd57544cacbf7d24f 100644 --- a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py +++ b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py @@ -49,6 +49,7 @@ class TestAffineGridOp(OpTest): def setUp(self): self.initTestCase() self.op_type = "affine_grid" + self.python_api = paddle.nn.functional.vision.affine_grid theta = np.random.randint(1, 3, self.theta_shape).astype("float32") self.inputs = {'Theta': theta} self.attrs = { @@ -64,10 +65,13 @@ class TestAffineGridOp(OpTest): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['Theta'], 'Output', no_grad_set=['OutputShape']) + self.check_grad(['Theta'], + 'Output', + no_grad_set=['OutputShape'], + check_eager=True) def initTestCase(self): self.theta_shape = (17, 2, 3) diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index b1a962845c4c54078056a56cbf4064187cb162ec..0e06612bbb7169624ab1de869f50c53647a07d88 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -89,11 +89,12 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): if is_compiled_with_rocm(): use_cudnn = False # ROCM platform do not have MIOPEN kernel for affine_grid - if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ - isinstance(out_shape, Variable)): - raise ValueError("The out_shape should be a list, tuple or Tensor.") - - if in_dynamic_mode(): + if in_dygraph_mode(): + _out_shape = out_shape.numpy().tolist() if isinstance( + out_shape, Variable) else out_shape + return _C_ops.final_state_affine_grid(theta, _out_shape, use_cudnn, + align_corners) + elif in_dynamic_mode(): _out_shape = out_shape.numpy().tolist() if isinstance( out_shape, Variable) else out_shape return _C_ops.affine_grid(theta, "output_shape", _out_shape,