From 7e3752bbd389b2f58d336454d0f95aa7b6c4fa92 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 23 Mar 2022 14:50:32 +0800 Subject: [PATCH] [Phi] Move deformable_conv and deformable_conv_v1 to phi (#40794) * move deformable_conv_grad to phi * move infershape of deformable_conv to phi * adjust some code format * move deformable_conv_v1 to phi --- paddle/fluid/operators/deformable_conv_func.h | 149 ---- paddle/fluid/operators/deformable_conv_op.cc | 168 +---- paddle/fluid/operators/deformable_conv_op.cu | 643 ------------------ paddle/fluid/operators/deformable_conv_op.h | 509 -------------- .../fluid/operators/deformable_conv_v1_op.cc | 141 +--- .../fluid/operators/deformable_conv_v1_op.cu | 604 ---------------- .../fluid/operators/deformable_conv_v1_op.h | 556 --------------- paddle/fluid/pybind/imperative.cc | 1 + paddle/phi/infermeta/multiary.cc | 209 ++++++ paddle/phi/infermeta/multiary.h | 13 + paddle/phi/kernels/CMakeLists.txt | 4 +- .../cpu/deformable_conv_grad_kernel.cc | 333 +++++++++ .../phi/kernels/cpu/deformable_conv_kernel.cc | 120 ---- .../phi/kernels/deformable_conv_grad_kernel.h | 39 ++ paddle/phi/kernels/deformable_conv_kernel.h | 3 +- paddle/phi/kernels/funcs/CMakeLists.txt | 1 + .../kernels/funcs/deformable_conv_functor.cc | 172 +++++ .../kernels/funcs/deformable_conv_functor.cu | 185 +++++ .../kernels/funcs/deformable_conv_functor.h | 74 ++ .../gpu/deformable_conv_grad_kernel.cu | 366 ++++++++++ .../phi/kernels/gpu/deformable_conv_kernel.cu | 134 ---- .../impl/deformable_conv_grad_kernel_impl.h | 364 ++++++++++ .../impl/deformable_conv_kernel_impl.h | 90 +-- paddle/phi/ops/compat/deformable_conv_sig.cc | 28 + 24 files changed, 1830 insertions(+), 3076 deletions(-) delete mode 100644 paddle/fluid/operators/deformable_conv_func.h delete mode 100644 paddle/fluid/operators/deformable_conv_op.cu delete mode 100644 paddle/fluid/operators/deformable_conv_op.h delete mode 100644 paddle/fluid/operators/deformable_conv_v1_op.cu delete mode 100644 paddle/fluid/operators/deformable_conv_v1_op.h create mode 100644 paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc create mode 100644 paddle/phi/kernels/deformable_conv_grad_kernel.h create mode 100644 paddle/phi/kernels/funcs/deformable_conv_functor.cc create mode 100644 paddle/phi/kernels/funcs/deformable_conv_functor.cu create mode 100644 paddle/phi/kernels/funcs/deformable_conv_functor.h create mode 100644 paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu create mode 100644 paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h diff --git a/paddle/fluid/operators/deformable_conv_func.h b/paddle/fluid/operators/deformable_conv_func.h deleted file mode 100644 index b0fdf31e1ce..00000000000 --- a/paddle/fluid/operators/deformable_conv_func.h +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Part of the following code in this file refs to -// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu -// -// Copyright (c) 2017 Microsoft -// Licensed under The Apache-2.0 License [see LICENSE for details] -// \file deformable_psroi_pooling.cu -// \brief -// \author Yi Li, Guodong Zhang, Jifeng Dai - -#pragma once -#include "paddle/phi/core/hostdevice.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -template -HOSTDEVICE T DmcnGetGradientWeight(T argmax_h, T argmax_w, const int h, - const int w, const int height, - const int width) { - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || - argmax_w >= width) { - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - T weight = 0; - - weight = (h == argmax_h_low && w == argmax_w_low) - ? (h + 1 - argmax_h) * (w + 1 - argmax_w) - : weight; - weight = (h == argmax_h_low && w == argmax_w_high) - ? (h + 1 - argmax_h) * (argmax_w + 1 - w) - : weight; - weight = (h == argmax_h_high && w == argmax_w_low) - ? (argmax_h + 1 - h) * (w + 1 - argmax_w) - : weight; - weight = (h == argmax_h_high && w == argmax_w_high) - ? (argmax_h + 1 - h) * (argmax_w + 1 - w) - : weight; - - return weight; -} - -template -HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height, - const int width, const T* im_data, - const int data_width, const int bp_dir) { - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || - argmax_w >= width) { - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - T weight = 0; - - if (bp_dir == 0) { - weight += (argmax_h_low >= 0 && argmax_w_low >= 0) - ? -1 * (argmax_w_low + 1 - argmax_w) * - im_data[argmax_h_low * data_width + argmax_w_low] - : 0; - - weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1) - ? -1 * (argmax_w - argmax_w_low) * - im_data[argmax_h_low * data_width + argmax_w_high] - : 0; - - weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0) - ? (argmax_w_low + 1 - argmax_w) * - im_data[argmax_h_high * data_width + argmax_w_low] - : 0; - weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - ? (argmax_w - argmax_w_low) * - im_data[argmax_h_high * data_width + argmax_w_high] - : 0; - } else if (bp_dir == 1) { - weight += (argmax_h_low >= 0 && argmax_w_low >= 0) - ? -1 * (argmax_h_low + 1 - argmax_h) * - im_data[argmax_h_low * data_width + argmax_w_low] - : 0; - weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1) - ? (argmax_h_low + 1 - argmax_h) * - im_data[argmax_h_low * data_width + argmax_w_high] - : 0; - weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0) - ? -1 * (argmax_h - argmax_h_low) * - im_data[argmax_h_high * data_width + argmax_w_low] - : 0; - weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - ? (argmax_h - argmax_h_low) * - im_data[argmax_h_high * data_width + argmax_w_high] - : 0; - } - - return weight; -} - -template -HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, const int data_width, - const int height, const int width, T h, T w) { - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - T lh = h - h_low; - T lw = w - w_low; - T hh = 1 - lh; - T hw = 1 - lw; - - T v1 = - (h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0; - T v2 = (h_low >= 0 && w_high <= width - 1) - ? bottom_data[h_low * data_width + w_high] - : 0; - T v3 = (h_high <= height - 1 && w_low >= 0) - ? bottom_data[h_high * data_width + w_low] - : 0; - T v4 = (h_high <= height - 1 && w_high <= width - 1) - ? bottom_data[h_high * data_width + w_high] - : 0; - - T w1 = hh * hw; - T w2 = hh * lw; - T w3 = lh * hw; - T w4 = lh * lw; - - return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; -} diff --git a/paddle/fluid/operators/deformable_conv_op.cc b/paddle/fluid/operators/deformable_conv_op.cc index 6e15fd090b8..1b76aca1e66 100644 --- a/paddle/fluid/operators/deformable_conv_op.cc +++ b/paddle/fluid/operators/deformable_conv_op.cc @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/deformable_conv_op.h" #include -#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -108,158 +110,6 @@ $$ class DeformableConvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "deformable_conv"); - OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset", - "deformable_conv)"); - OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "deformable_conv"); - OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", - "deformable_conv"); - OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", - "deformable_conv"); - - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - auto offset_dims = ctx->GetInputDim("Offset"); - auto mask_dims = ctx->GetInputDim("Mask"); - - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - std::vector dilations = - ctx->Attrs().Get>("dilations"); - int groups = ctx->Attrs().Get("groups"); - int deformable_groups = ctx->Attrs().Get("deformable_groups"); - int im2col_step = ctx->Attrs().Get("im2col_step"); - - PADDLE_ENFORCE_EQ( - in_dims.size(), 4, - platform::errors::InvalidArgument( - "Conv input should be 4-D tensor, get %u", in_dims.size())); - PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(), - platform::errors::InvalidArgument( - "Conv input dimension and filter dimension should be " - "the same. The difference is [%d]: [%d]", - in_dims.size(), filter_dims.size())); - PADDLE_ENFORCE_EQ(in_dims.size() - strides.size(), 2U, - platform::errors::InvalidArgument( - "Conv input dimension and strides " - "dimension should be consistent. But received input " - "dimension:[%d], strides dimension:[%d]", - in_dims.size(), strides.size())); - PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), - platform::errors::InvalidArgument( - "Conv paddings dimension and Conv strides dimension " - "should be the same. The difference is [%d]: [%d]", - paddings.size(), strides.size())); - - PADDLE_ENFORCE_EQ( - in_dims[1], filter_dims[1] * groups, - platform::errors::InvalidArgument( - "The number of input channels should be equal to filter " - "channels * groups. The difference is [%d]: [%d]", - in_dims[1], filter_dims[1] * groups)); - PADDLE_ENFORCE_EQ( - filter_dims[0] % groups, 0, - platform::errors::InvalidArgument( - "The number of output channels should be divided by groups. But " - "received output channels:[%d], groups:[%d]", - filter_dims[0], groups)); - PADDLE_ENFORCE_EQ( - filter_dims[0] % deformable_groups, 0, - platform::errors::InvalidArgument( - "The number of output channels should be " - "divided by deformable groups. The difference is [%d]: [%d]", - filter_dims[0] % groups, 0)); - - if (in_dims[0] > im2col_step) { - PADDLE_ENFORCE_EQ( - in_dims[0] % im2col_step, 0U, - platform::errors::InvalidArgument( - "Input batchsize must be smaller than or divide im2col_step. But " - "received Input batchsize:[%d], im2col_step:[%d]", - in_dims[0], im2col_step)); - } - - for (size_t i = 0; i < strides.size(); ++i) { - PADDLE_ENFORCE_GT(strides[i], 0U, platform::errors::InvalidArgument( - "stride %d size incorrect", i)); - } - for (size_t i = 0; i < dilations.size(); ++i) { - PADDLE_ENFORCE_GT(dilations[i], 0U, platform::errors::InvalidArgument( - "dilation %d size incorrect", i)); - } - - std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < strides.size(); ++i) { - if ((!ctx->IsRuntime()) && - (in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) { - output_shape.push_back(-1); - } else { - output_shape.push_back(ConvOutputSize(in_dims[i + 2], - filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); - } - } - - PADDLE_ENFORCE_EQ( - output_shape[1] % deformable_groups, 0U, - platform::errors::InvalidArgument( - "output num_filter must divide deformable group size. But received " - "output num_filter:[%d], deformable group size:[%d]", - output_shape[1], deformable_groups)); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2], - platform::errors::InvalidArgument( - "output height must equal to offset map height. " - "The difference is [%d]: [%d]", - output_shape[2], offset_dims[2])); - PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3], - platform::errors::InvalidArgument( - "output width must equal to offset map width. The " - "difference is [%d]: [%d]", - output_shape[3], offset_dims[3])); - - PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U, - platform::errors::InvalidArgument( - "offset filter must divide deformable group size. " - "But received [%d]: [%d]", - offset_dims[1], filter_dims[2] * filter_dims[3])); - PADDLE_ENFORCE_EQ( - offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), - deformable_groups, - platform::errors::InvalidArgument( - "offset filter must divide deformable group size. But received " - "[%d]: [%d]", - offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), - deformable_groups)); - PADDLE_ENFORCE_EQ(output_shape[2], mask_dims[2], - platform::errors::InvalidArgument( - "output height must equal to mask map height. The " - "difference is [%d] vs [%d]", - output_shape[2], mask_dims[2])); - PADDLE_ENFORCE_EQ(output_shape[3], mask_dims[3], - platform::errors::InvalidArgument( - "output width must equal to mask map width. The " - "difference is [%d] vs [%d]", - output_shape[3], mask_dims[3])); - - PADDLE_ENFORCE_EQ(mask_dims[1] % (filter_dims[2] * filter_dims[3]), 0U, - platform::errors::InvalidArgument( - "mask filter must divide deformable group size. " - "But received [%d]: [%d]", - mask_dims[1], filter_dims[2] * filter_dims[3])); - PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]), - deformable_groups, - platform::errors::InvalidArgument( - "mask filter must divide deformable group size. " - "But received [%d]: [%d]", - mask_dims[1] / (filter_dims[2] * filter_dims[3]), - deformable_groups)); - } - - ctx->SetOutputDim("Output", phi::make_ddim(output_shape)); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -331,13 +181,13 @@ class DeformableConvGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(deformable_conv, DeformableConvInferShapeFunctor, + PD_INFER_META(phi::DeformableConvInferMeta)); + REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp, ops::DeformableConvOpMaker, ops::DeformableConvGradOpMaker, - ops::DeformableConvGradOpMaker); + ops::DeformableConvGradOpMaker, + DeformableConvInferShapeFunctor); REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp); - -REGISTER_OP_CPU_KERNEL(deformable_conv_grad, - ops::DeformableConvGradCPUKernel, - ops::DeformableConvGradCPUKernel); diff --git a/paddle/fluid/operators/deformable_conv_op.cu b/paddle/fluid/operators/deformable_conv_op.cu deleted file mode 100644 index ad10abf9c64..00000000000 --- a/paddle/fluid/operators/deformable_conv_op.cu +++ /dev/null @@ -1,643 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Part of the following code in this file refs to -// https://github.com/msracver/Deformable-ConvNets/blob/master/DCNv2_op/nn/modulated_deformable_im2col.cuh -// -// Copyright (c) 2018 Microsoft -// Licensed under The MIT License [see LICENSE for details] -// \file modulated_deformable_im2col.cuh -// \brief -// \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu - -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/deformable_conv_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaximumNumBlocks = 4096; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaximumNumBlocks); -} - -template -__device__ T DmcnGetGradientWeight(T argmax_h, T argmax_w, const int h, - const int w, const int height, - const int width) { - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || - argmax_w >= width) { - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - T weight = 0; - if (h == argmax_h_low && w == argmax_w_low) - weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); - if (h == argmax_h_low && w == argmax_w_high) - weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); - if (h == argmax_h_high && w == argmax_w_low) - weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); - if (h == argmax_h_high && w == argmax_w_high) - weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); - return weight; -} - -template -__global__ void ModulatedDeformableCol2imGpuKernel( - const int nthreads, const T* data_col, const T* data_offset, - const T* data_mask, const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, const int dilation_h, - const int dilation_w, const int channel_per_deformable_group, - const int batch_size, const int deformable_group, const int height_col, - const int width_col, T* grad_im) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t thread = index; thread < nthreads; thread += offset) { - const int j = (thread / width_col / height_col / batch_size) % kernel_w; - const int i = - (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = - thread / width_col / height_col / batch_size / kernel_w / kernel_h; - - const int deformable_group_index = c / channel_per_deformable_group; - - int w_out = thread % width_col; - int h_out = (thread / width_col) % height_col; - int b = (thread / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - - const T* data_offset_ptr = data_offset + - (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * - width_col; - const T* data_mask_ptr = data_mask + - (b * deformable_group + deformable_group_index) * - kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const int data_mask_hw_ptr = - ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - const T cur_inv_h_data = h_in + i * dilation_h + offset_h; - const T cur_inv_w_data = w_in + j * dilation_w + offset_w; - - const T cur_top_grad = data_col[thread] * mask; - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { - if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && - cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && - abs(cur_inv_w_data - (cur_w + dx)) < 1) { - int cur_bottom_grad_pos = - ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; - T weight = - DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, - cur_w + dx, height, width); - - platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos, - weight * cur_top_grad); - } - } - } - } -} - -template -inline void ModulatedDeformableCol2im( - const platform::DeviceContext& ctx, const T* data_col, const T* data_offset, - const T* data_mask, const std::vector im_shape, - const std::vector col_shape, - const std::vector kernel_shape, const std::vector pad, - const std::vector stride, const std::vector dilation, - const int deformable_group, T* grad_im) { - int channel_per_deformable_group = im_shape[0] / deformable_group; - int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; - - ModulatedDeformableCol2imGpuKernel<<< - blocks, threads, 0, - reinterpret_cast(ctx).stream()>>>( - num_kernels, data_col, data_offset, data_mask, im_shape[0], im_shape[1], - im_shape[2], kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0], - stride[1], dilation[0], dilation[1], channel_per_deformable_group, - col_shape[1], deformable_group, col_shape[2], col_shape[3], grad_im); -} - -template -__device__ T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height, - const int width, const T* im_data, - const int data_width, const int bp_dir) { - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || - argmax_w >= width) { - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - T weight = 0; - - if (bp_dir == 0) { - if (argmax_h_low >= 0 && argmax_w_low >= 0) - weight += -1 * (argmax_w_low + 1 - argmax_w) * - im_data[argmax_h_low * data_width + argmax_w_low]; - if (argmax_h_low >= 0 && argmax_w_high <= width - 1) - weight += -1 * (argmax_w - argmax_w_low) * - im_data[argmax_h_low * data_width + argmax_w_high]; - if (argmax_h_high <= height - 1 && argmax_w_low >= 0) - weight += (argmax_w_low + 1 - argmax_w) * - im_data[argmax_h_high * data_width + argmax_w_low]; - if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - weight += (argmax_w - argmax_w_low) * - im_data[argmax_h_high * data_width + argmax_w_high]; - } else if (bp_dir == 1) { - if (argmax_h_low >= 0 && argmax_w_low >= 0) - weight += -1 * (argmax_h_low + 1 - argmax_h) * - im_data[argmax_h_low * data_width + argmax_w_low]; - if (argmax_h_low >= 0 && argmax_w_high <= width - 1) - weight += (argmax_h_low + 1 - argmax_h) * - im_data[argmax_h_low * data_width + argmax_w_high]; - if (argmax_h_high <= height - 1 && argmax_w_low >= 0) - weight += -1 * (argmax_h - argmax_h_low) * - im_data[argmax_h_high * data_width + argmax_w_low]; - if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - weight += (argmax_h - argmax_h_low) * - im_data[argmax_h_high * data_width + argmax_w_high]; - } - return weight; -} - -template -__device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width, - const int height, const int width, T h, T w) { - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - T lh = h - h_low; - T lw = w - w_low; - T hh = 1 - lh, hw = 1 - lw; - - T v1 = 0; - if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; - T v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - v2 = bottom_data[h_low * data_width + w_high]; - T v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - v3 = bottom_data[h_high * data_width + w_low]; - T v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = bottom_data[h_high * data_width + w_high]; - - T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - -template -__global__ void ModulatedDeformableCol2imCoordGpuKernel( - const int nthreads, const T* data_col, const T* data_im, - const T* data_offset, const T* data_mask, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, - const int offset_channels, const int deformable_group, const int height_col, - const int width_col, T* grad_offset, T* grad_mask) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - T val = 0, mval = 0; - const int w = i % width_col; - const int h = (i / width_col) % height_col; - const int c = (i / width_col / height_col) % offset_channels; - const int b = (i / width_col / height_col) / offset_channels; - - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; - const T* data_col_ptr = data_col + - deformable_group_index * - channel_per_deformable_group * batch_size * - width_col * height_col; - const T* data_im_ptr = data_im + - (b * deformable_group + deformable_group_index) * - channel_per_deformable_group / kernel_h / - kernel_w * height * width; - const T* data_offset_ptr = data_offset + - (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * - width_col; - const T* data_mask_ptr = data_mask + - (b * deformable_group + deformable_group_index) * - kernel_h * kernel_w * height_col * width_col; - - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; - - for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; - col_c += col_step) { - const int col_pos = - (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; - - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = - (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = - (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = - (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + - w_out); - const int data_mask_hw_ptr = - (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - T inv_h = h_in + i * dilation_h + offset_h; - T inv_w = w_in + j * dilation_w + offset_w; - if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { - inv_h = inv_w = -2; - } else { - mval += data_col_ptr[col_pos] * - DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width, - height, width, inv_h, inv_w); - } - const T weight = DmcnGetCoordinateWeight( - inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, - width, bp_dir); - val += weight * data_col_ptr[col_pos] * mask; - cnt += 1; - } - grad_offset[i] = val; - if (offset_c % 2 == 0) - grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * - kernel_w + - offset_c / 2) * - height_col + - h) * - width_col + - w] = mval; - } -} - -template -inline void ModulatedDeformableCol2imCoord( - const platform::DeviceContext& ctx, const T* data_col, const T* data_im, - const T* data_offset, const T* data_mask, - const std::vector im_shape, const std::vector col_shape, - const std::vector kernel_shape, const std::vector paddings, - const std::vector strides, const std::vector dilations, - const int deformable_groups, T* grad_offset, T* grad_mask) { - int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * - col_shape[2] * col_shape[3] * deformable_groups; - int channel_per_deformable_group = col_shape[0] / deformable_groups; - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; - - ModulatedDeformableCol2imCoordGpuKernel<<< - blocks, threads, 0, - reinterpret_cast(ctx).stream()>>>( - num_kernels, data_col, data_im, data_offset, data_mask, im_shape[0], - im_shape[1], im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0], - paddings[1], strides[0], strides[1], dilations[0], dilations[1], - channel_per_deformable_group, col_shape[1], - 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, - deformable_groups, col_shape[2], col_shape[3], grad_offset, grad_mask); -} - -template -__global__ void ModulatedDeformableIm2colGpuKernel( - const int nthreads, const T* data_im, const T* data_offset, - const T* data_mask, const int height, const int width, const int kernel_h, - const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, - const int num_channels, const int deformable_group, const int height_col, - const int width_col, T* data_col) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - T* data_col_ptr = - data_col + - ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T* data_im_ptr = - data_im + (b_col * num_channels + c_im) * height * width; - const T* data_offset_ptr = - data_offset + - (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * - kernel_w * height_col * width_col; - const T* data_mask_ptr = - data_mask + - (b_col * deformable_group + deformable_group_index) * kernel_h * - kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + - w_col; - const int data_mask_hw_ptr = - ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - T val = static_cast(0); - const T h_im = h_in + i * dilation_h + offset_h; - const T w_im = w_in + j * dilation_w + offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { - val = - DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val * mask; - data_col_ptr += batch_size * height_col * width_col; - } - } - } -} - -template -inline void ModulatedDeformableIm2col( - const platform::DeviceContext& ctx, const T* data_im, const T* data_offset, - const T* data_mask, const std::vector im_shape, - const std::vector col_shape, - const std::vector filter_shape, const std::vector paddings, - const std::vector strides, const std::vector dilations, - const int deformable_groups, T* data_col) { - int channel_per_deformable_group = im_shape[0] / deformable_groups; - int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; - - ModulatedDeformableIm2colGpuKernel<<< - blocks, threads, 0, - reinterpret_cast(ctx).stream()>>>( - num_kernels, data_im, data_offset, data_mask, im_shape[1], im_shape[2], - filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0], - strides[1], dilations[0], dilations[1], channel_per_deformable_group, - col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3], - data_col); -} - -template -__global__ void FilterGradAddupGpuKernel(const int nthreads, const int n, - const int height, const int width, - const T* dweight_3d, T* filter_grad) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - filter_grad[i] = filter_grad[i] + dweight_3d[i]; - } -} - -template -class DeformableConvGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* output_grad = - ctx.Input(framework::GradVarName("Output")); - Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); - Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); - Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); - Tensor* mask_grad = ctx.Output(framework::GradVarName("Mask")); - - const Tensor* input = ctx.Input("Input"); - Tensor offset = *ctx.Input("Offset"); - Tensor mask = *ctx.Input("Mask"); - Tensor filter = *ctx.Input("Filter"); - if (!input_grad && !filter_grad && !offset_grad && !mask_grad) return; - - int groups = ctx.Attr("groups"); - int deformable_groups = ctx.Attr("deformable_groups"); - int im2col_step = ctx.Attr("im2col_step"); - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - std::vector dilations = ctx.Attr>("dilations"); - - auto& dev_ctx = ctx.cuda_device_context(); - const int batch_size = static_cast(input->dims()[0]); - - framework::DDim input_shape = - phi::slice_ddim(input->dims(), 1, input->dims().size()); - std::vector input_shape_vec = phi::vectorize(input_shape); - std::vector filter_shape_vec(phi::vectorize(filter.dims())); - std::vector output_shape_vec(phi::vectorize(output_grad->dims())); - - std::vector col_buffer_shape_vec(filter_shape_vec.size()); - col_buffer_shape_vec[0] = - input->dims()[1] * filter.dims()[2] * filter.dims()[3]; - col_buffer_shape_vec[1] = im2col_step; - for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { - col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec)); - std::vector output_buffer_shape_vec(1); - output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * - output_shape_vec[2] * output_shape_vec[3]; - framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec)); - Tensor col_buffer; - Tensor output_buffer; - col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); - output_buffer = - ctx.AllocateTmpTensor(output_shape, dev_ctx); - - output_buffer.ShareDataWith(*output_grad); - - int64_t M = - input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; - int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; - int64_t K = output_shape_vec[1] / groups; - - framework::DDim weight_3d_shape = {groups, K, M}; - framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, - N}; - framework::DDim col_buffer_3d_shape = {groups, M, N}; - framework::DDim filter_grad_shape = {groups, K, M}; - - Tensor weight_3d; - weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); - Tensor out_grad_4d; - out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); - Tensor col_buffer_3d; - col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); - - phi::funcs::SetConstant set_zero; - auto blas = phi::funcs::GetBlas(dev_ctx); - - col_buffer.mutable_data(ctx.GetPlace()); - col_buffer_3d.mutable_data(ctx.GetPlace()); - out_grad_4d.mutable_data(ctx.GetPlace()); - - int input_dim = input->numel() / input->dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask.numel() / mask.dims()[0]; - - if (filter_grad) { - filter_grad->mutable_data(ctx.GetPlace()); - filter_grad->Resize(filter_grad_shape); - set_zero(dev_ctx, filter_grad, static_cast(0)); - } - - if (input_grad) { - input_grad->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); - } - - if (offset_grad && mask_grad) { - offset_grad->mutable_data(ctx.GetPlace()); - mask_grad->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, offset_grad, static_cast(0)); - set_zero(dev_ctx, mask_grad, static_cast(0)); - } - - for (int i = 0; i < batch_size / im2col_step; ++i) { - Tensor out_grad_3d = out_grad_4d.Slice(i, i + 1).Resize( - phi::slice_ddim(out_grad_4d.dims(), 1, out_grad_4d.dims().size())); - for (int g = 0; g < groups; ++g) { - Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size())); - Tensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - - blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0), - &col_buffer_3d_slice, T(0.0)); - } - col_buffer.Resize(col_shape); - - T* col_buffer_ptr = col_buffer.data(); - const T* input_ptr = input->data(); - const T* offset_ptr = offset.data(); - const T* mask_ptr = mask.data(); - - if (mask_grad && offset_grad) { - T* offset_grad_ptr = offset_grad->data(); - T* mask_grad_ptr = mask_grad->data(); - ModulatedDeformableCol2imCoord( - ctx.device_context(), col_buffer_ptr, - input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, - col_buffer_shape_vec, filter_shape_vec, paddings, strides, - dilations, deformable_groups, - offset_grad_ptr + i * im2col_step * input_offset_dim, - mask_grad_ptr + i * im2col_step * input_mask_dim); - } - if (input_grad) { - T* input_grad_ptr = input_grad->data(); - ModulatedDeformableCol2im( - ctx.device_context(), col_buffer_ptr, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, - col_buffer_shape_vec, filter_shape_vec, paddings, strides, - dilations, deformable_groups, - input_grad_ptr + i * im2col_step * input_dim); - input_grad->Resize(input->dims()); - } - - ModulatedDeformableIm2col( - ctx.device_context(), input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, - col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations, - deformable_groups, col_buffer_ptr); - - col_buffer_3d.Resize(col_buffer_3d_shape); - - if (filter_grad) { - Tensor dweight_3d; - dweight_3d = - ctx.AllocateTmpTensor(filter_grad_shape, dev_ctx); - for (int g = 0; g < groups; ++g) { - Tensor out_grad_3d_slice = - out_grad_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - out_grad_3d.dims(), 1, out_grad_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - Tensor dweight_3d_slice = dweight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(dweight_3d.dims(), 1, dweight_3d.dims().size())); - - blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true, - T(1.0), &dweight_3d_slice, T(0.0)); - } - FilterGradAddupGpuKernel< - T><<>>( - dweight_3d.numel(), groups, K, M, dweight_3d.data(), - filter_grad->data()); - } - } - if (filter_grad) { - filter_grad->Resize(filter.dims()); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CUDA = paddle::platform::CUDADeviceContext; - -REGISTER_OP_CUDA_KERNEL(deformable_conv_grad, - ops::DeformableConvGradCUDAKernel, - ops::DeformableConvGradCUDAKernel); diff --git a/paddle/fluid/operators/deformable_conv_op.h b/paddle/fluid/operators/deformable_conv_op.h deleted file mode 100644 index 1176b96987e..00000000000 --- a/paddle/fluid/operators/deformable_conv_op.h +++ /dev/null @@ -1,509 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Part of the following code in this file refs to -// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu -// -// Copyright (c) 2017 Microsoft -// Licensed under The Apache-2.0 License [see LICENSE for details] -// \file deformable_psroi_pooling.cu -// \brief -// \author Yi Li, Guodong Zhang, Jifeng Dai - -#pragma once -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/deformable_conv_func.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using CPUDeviceContext = platform::CPUDeviceContext; - -template -void ModulatedDeformableCol2imCPUKernel( - const int num_kernels, const T* data_col, const T* data_offset, - const T* data_mask, const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, const int dilation_h, - const int dilation_w, const int channel_per_deformable_group, - const int batch_size, const int deformable_group, const int height_col, - const int width_col, T* grad_im) { - for (int thread = 0; thread < num_kernels; thread++) { - const int j = (thread / width_col / height_col / batch_size) % kernel_w; - const int i = - (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = - thread / width_col / height_col / batch_size / kernel_w / kernel_h; - - const int deformable_group_index = c / channel_per_deformable_group; - - int w_out = thread % width_col; - int h_out = (thread / width_col) % height_col; - int b = (thread / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - - const T* data_offset_ptr = data_offset + - (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * - width_col; - const T* data_mask_ptr = data_mask + - (b * deformable_group + deformable_group_index) * - kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const int data_mask_hw_ptr = - ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - const T cur_inv_h_data = h_in + i * dilation_h + offset_h; - const T cur_inv_w_data = w_in + j * dilation_w + offset_w; - - const T cur_top_grad = data_col[thread] * mask; - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { - if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && - cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && - abs(cur_inv_w_data - (cur_w + dx)) < 1) { - int cur_bottom_grad_pos = - ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; - T weight = - DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, - cur_w + dx, height, width); - - *(grad_im + cur_bottom_grad_pos) = - *(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad; - } - } - } - } -} - -template -static inline void ModulatedDeformableCol2imCPU( - const platform::CPUDeviceContext& ctx, const T* data_col, - const T* data_offset, const T* data_mask, - const std::vector im_shape, const std::vector col_shape, - const std::vector kernel_shape, const std::vector pad, - const std::vector stride, const std::vector dilation, - const int deformable_group, T* grad_im) { - int channel_per_deformable_group = im_shape[0] / deformable_group; - int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - - ModulatedDeformableCol2imCPUKernel( - num_kernels, data_col, data_offset, data_mask, im_shape[0], im_shape[1], - im_shape[2], kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0], - stride[1], dilation[0], dilation[1], channel_per_deformable_group, - col_shape[1], deformable_group, col_shape[2], col_shape[3], grad_im); -} - -template -void ModulatedDeformableCol2imCoordCPUKernel( - const int num_kernels, const T* data_col, const T* data_im, - const T* data_offset, const T* data_mask, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, - const int offset_channels, const int deformable_group, const int height_col, - const int width_col, T* grad_offset, T* grad_mask) { - for (int i = 0; i < num_kernels; i++) { - T val = 0, mval = 0; - const int w = i % width_col; - const int h = (i / width_col) % height_col; - const int c = (i / width_col / height_col) % offset_channels; - const int b = (i / width_col / height_col) / offset_channels; - - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; - const T* data_col_ptr = data_col + - deformable_group_index * - channel_per_deformable_group * batch_size * - width_col * height_col; - const T* data_im_ptr = data_im + - (b * deformable_group + deformable_group_index) * - channel_per_deformable_group / kernel_h / - kernel_w * height * width; - const T* data_offset_ptr = data_offset + - (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * - width_col; - const T* data_mask_ptr = data_mask + - (b * deformable_group + deformable_group_index) * - kernel_h * kernel_w * height_col * width_col; - - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; - - for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; - col_c += col_step) { - const int col_pos = - (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; - - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = - (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = - (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = - (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + - w_out); - const int data_mask_hw_ptr = - (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - T inv_h = h_in + i * dilation_h + offset_h; - T inv_w = w_in + j * dilation_w + offset_w; - if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { - inv_h = inv_w = -2; - } else { - mval += data_col_ptr[col_pos] * - DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width, - height, width, inv_h, inv_w); - } - const T weight = DmcnGetCoordinateWeight( - inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, - width, bp_dir); - val += weight * data_col_ptr[col_pos] * mask; - cnt += 1; - } - grad_offset[i] = val; - if (offset_c % 2 == 0) - grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * - kernel_w + - offset_c / 2) * - height_col + - h) * - width_col + - w] = mval; - } -} - -template -static inline void ModulatedDeformableCol2imCoordCPU( - const platform::CPUDeviceContext& ctx, const T* data_col, const T* data_im, - const T* data_offset, const T* data_mask, - const std::vector im_shape, const std::vector col_shape, - const std::vector kernel_shape, const std::vector paddings, - const std::vector strides, const std::vector dilations, - const int deformable_groups, T* grad_offset, T* grad_mask) { - int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * - col_shape[2] * col_shape[3] * deformable_groups; - int channel_per_deformable_group = col_shape[0] / deformable_groups; - - ModulatedDeformableCol2imCoordCPUKernel( - num_kernels, data_col, data_im, data_offset, data_mask, im_shape[0], - im_shape[1], im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0], - paddings[1], strides[0], strides[1], dilations[0], dilations[1], - channel_per_deformable_group, col_shape[1], - 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, - deformable_groups, col_shape[2], col_shape[3], grad_offset, grad_mask); -} - -template -void ModulatedDeformableIm2colCPUKernel( - const int num_kernels, const T* data_im, const T* data_offset, - const T* data_mask, const int height, const int width, const int kernel_h, - const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, - const int num_channels, const int deformable_group, const int height_col, - const int width_col, T* data_col) { - for (int i = 0; i < num_kernels; i++) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - T* data_col_ptr = - data_col + - ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T* data_im_ptr = - data_im + (b_col * num_channels + c_im) * height * width; - const T* data_offset_ptr = - data_offset + - (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * - kernel_w * height_col * width_col; - const T* data_mask_ptr = - data_mask + - (b_col * deformable_group + deformable_group_index) * kernel_h * - kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + - w_col; - const int data_mask_hw_ptr = - ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - T val = static_cast(0); - const T h_im = h_in + i * dilation_h + offset_h; - const T w_im = w_in + j * dilation_w + offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { - val = - DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val * mask; - data_col_ptr += batch_size * height_col * width_col; - } - } - } -} - -template -static inline void ModulatedDeformableIm2colCPU( - const platform::CPUDeviceContext& ctx, const T* data_im, - const T* data_offset, const T* data_mask, - const std::vector im_shape, const std::vector col_shape, - const std::vector filter_shape, const std::vector paddings, - const std::vector strides, const std::vector dilations, - const int deformable_groups, T* data_col) { - int channel_per_deformable_group = im_shape[0] / deformable_groups; - int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - - // get outputs of im2col with offset by bilinear interpolation - ModulatedDeformableIm2colCPUKernel( - num_kernels, data_im, data_offset, data_mask, im_shape[1], im_shape[2], - filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0], - strides[1], dilations[0], dilations[1], channel_per_deformable_group, - col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3], - data_col); -} - -template -void FilterGradAddupCPUKernel(const int nthreads, const int n, const int height, - const int width, const T* dweight_3d, - T* filter_grad) { - for (int i = 0; i < nthreads; i++) { - filter_grad[i] = filter_grad[i] + dweight_3d[i]; - } -} - -template -class DeformableConvGradCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* output_grad = - ctx.Input(framework::GradVarName("Output")); - Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); - Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); - Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); - Tensor* mask_grad = ctx.Output(framework::GradVarName("Mask")); - - const Tensor* input = ctx.Input("Input"); - Tensor offset = *ctx.Input("Offset"); - Tensor mask = *ctx.Input("Mask"); - Tensor filter = *ctx.Input("Filter"); - if (!input_grad && !filter_grad && !offset_grad && !mask_grad) return; - - int groups = ctx.Attr("groups"); - int deformable_groups = ctx.Attr("deformable_groups"); - int im2col_step = ctx.Attr("im2col_step"); - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - std::vector dilations = ctx.Attr>("dilations"); - - auto& dev_ctx = ctx.template device_context(); - const int batch_size = static_cast(input->dims()[0]); - - framework::DDim input_shape = - phi::slice_ddim(input->dims(), 1, input->dims().size()); - std::vector input_shape_vec = phi::vectorize(input_shape); - std::vector filter_shape_vec(phi::vectorize(filter.dims())); - std::vector output_shape_vec(phi::vectorize(output_grad->dims())); - - std::vector col_buffer_shape_vec(filter_shape_vec.size()); - col_buffer_shape_vec[0] = - input->dims()[1] * filter.dims()[2] * filter.dims()[3]; - col_buffer_shape_vec[1] = im2col_step; - for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { - col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec)); - std::vector output_buffer_shape_vec(1); - output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * - output_shape_vec[2] * output_shape_vec[3]; - framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec)); - Tensor col_buffer; - Tensor output_buffer; - col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); - output_buffer = - ctx.AllocateTmpTensor(output_shape, dev_ctx); - - output_buffer.ShareDataWith(*output_grad); - - int64_t M = - input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; - int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; - int64_t K = output_shape_vec[1] / groups; - - framework::DDim weight_3d_shape = {groups, K, M}; - framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, - N}; - framework::DDim col_buffer_3d_shape = {groups, M, N}; - framework::DDim filter_grad_shape = {groups, K, M}; - - Tensor weight_3d; - weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); - Tensor out_grad_4d; - out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); - Tensor col_buffer_3d; - col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); - - phi::funcs::SetConstant set_zero; - auto blas = phi::funcs::GetBlas(dev_ctx); - - col_buffer.mutable_data(ctx.GetPlace()); - col_buffer_3d.mutable_data(ctx.GetPlace()); - out_grad_4d.mutable_data(ctx.GetPlace()); - - int input_dim = input->numel() / input->dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask.numel() / mask.dims()[0]; - - if (filter_grad) { - filter_grad->mutable_data(ctx.GetPlace()); - filter_grad->Resize(filter_grad_shape); - set_zero(dev_ctx, filter_grad, static_cast(0)); - } - - if (input_grad) { - input_grad->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); - } - - if (offset_grad && mask_grad) { - offset_grad->mutable_data(ctx.GetPlace()); - mask_grad->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, offset_grad, static_cast(0)); - set_zero(dev_ctx, mask_grad, static_cast(0)); - } - - for (int i = 0; i < batch_size / im2col_step; ++i) { - Tensor out_grad_3d = out_grad_4d.Slice(i, i + 1).Resize( - phi::slice_ddim(out_grad_4d.dims(), 1, out_grad_4d.dims().size())); - for (int g = 0; g < groups; ++g) { - Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size())); - Tensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - - blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0), - &col_buffer_3d_slice, T(0.0)); - } - col_buffer.Resize(col_shape); - - T* col_buffer_ptr = col_buffer.data(); - const T* input_ptr = input->data(); - const T* offset_ptr = offset.data(); - const T* mask_ptr = mask.data(); - - if (mask_grad && offset_grad) { - T* offset_grad_ptr = offset_grad->data(); - T* mask_grad_ptr = mask_grad->data(); - // get grad of offset and mask - ModulatedDeformableCol2imCoordCPU( - ctx.template device_context(), col_buffer_ptr, - input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, - col_buffer_shape_vec, filter_shape_vec, paddings, strides, - dilations, deformable_groups, - offset_grad_ptr + i * im2col_step * input_offset_dim, - mask_grad_ptr + i * im2col_step * input_mask_dim); - } - if (input_grad) { - T* input_grad_ptr = input_grad->data(); - // get grad of input - ModulatedDeformableCol2imCPU( - ctx.template device_context(), col_buffer_ptr, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, - col_buffer_shape_vec, filter_shape_vec, paddings, strides, - dilations, deformable_groups, - input_grad_ptr + i * im2col_step * input_dim); - input_grad->Resize(input->dims()); - } - - ModulatedDeformableIm2colCPU( - ctx.template device_context(), - input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, - col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations, - deformable_groups, col_buffer_ptr); - - col_buffer_3d.Resize(col_buffer_3d_shape); - - if (filter_grad) { - Tensor dweight_3d; - dweight_3d = ctx.AllocateTmpTensor( - filter_grad_shape, dev_ctx); - for (int g = 0; g < groups; ++g) { - Tensor out_grad_3d_slice = - out_grad_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - out_grad_3d.dims(), 1, out_grad_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - Tensor dweight_3d_slice = dweight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(dweight_3d.dims(), 1, dweight_3d.dims().size())); - - blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true, - T(1.0), &dweight_3d_slice, T(0.0)); - } - // update grad of weights - FilterGradAddupCPUKernel(dweight_3d.numel(), groups, K, M, - dweight_3d.data(), filter_grad->data()); - } - } - if (filter_grad) { - filter_grad->Resize(filter.dims()); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cc b/paddle/fluid/operators/deformable_conv_v1_op.cc index d1245a52743..0ec95cb54ba 100644 --- a/paddle/fluid/operators/deformable_conv_v1_op.cc +++ b/paddle/fluid/operators/deformable_conv_v1_op.cc @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/deformable_conv_v1_op.h" #include -#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -113,128 +115,6 @@ $$ class DeformableConvV1Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", - "deformable_conv_v1"); - OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset", - "deformable_conv_v1"); - OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", - "deformable_conv_v1"); - OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", - "deformable_conv_v1"); - - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - auto offset_dims = ctx->GetInputDim("Offset"); - - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - std::vector dilations = - ctx->Attrs().Get>("dilations"); - int groups = ctx->Attrs().Get("groups"); - int deformable_groups = ctx->Attrs().Get("deformable_groups"); - int im2col_step = ctx->Attrs().Get("im2col_step"); - - PADDLE_ENFORCE_EQ( - in_dims.size(), 4, - platform::errors::InvalidArgument( - "Conv input should be 4-D tensor, get %u", in_dims.size())); - PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(), - platform::errors::InvalidArgument( - "Conv input dimension and filter dimension should be " - "the same. the difference is [%d] vs [%d]", - in_dims.size(), filter_dims.size())); - PADDLE_ENFORCE_EQ( - in_dims.size() - strides.size(), 2U, - platform::errors::InvalidArgument( - "Conv input dimension and strides " - "dimension should be consistent., But received [%d]: [%d]", - in_dims.size(), strides.size())); - PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), - platform::errors::InvalidArgument( - "Conv paddings dimension and Conv strides dimension " - "should be the same. The difference is [%d] vs [%d]", - paddings.size(), strides.size())); - - PADDLE_ENFORCE_EQ( - in_dims[1], filter_dims[1] * groups, - platform::errors::InvalidArgument( - "The number of input channels should be equal to filter " - "channels * groups. The difference is [%d]: [%d]", - in_dims[1], filter_dims[1] * groups)); - PADDLE_ENFORCE_EQ( - filter_dims[0] % groups, 0, - platform::errors::InvalidArgument( - "The number of output channels should be divided by groups. But" - "received output channels: [%d], groups: [%d]", - filter_dims[0], groups)); - PADDLE_ENFORCE_EQ( - filter_dims[0] % deformable_groups, 0, - platform::errors::InvalidArgument( - "The number of output channels should be " - "divided by deformable groups. But received [%d]: [%d]", - filter_dims[0], deformable_groups)); - - if (in_dims[0] > im2col_step) { - PADDLE_ENFORCE_EQ(in_dims[0] % im2col_step, 0U, - platform::errors::InvalidArgument( - "Input batchsize must be smaller than or divide " - "im2col_step, But received [%d]: [%d]", - in_dims[0], im2col_step)); - } - - for (size_t i = 0; i < strides.size(); ++i) { - PADDLE_ENFORCE_GT(strides[i], 0U, platform::errors::InvalidArgument( - "stride %d size incorrect", i)); - } - for (size_t i = 0; i < dilations.size(); ++i) { - PADDLE_ENFORCE_GT(dilations[i], 0U, platform::errors::InvalidArgument( - "dilation %d size incorrect", i)); - } - - std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < strides.size(); ++i) { - if ((!ctx->IsRuntime()) && - (in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) { - output_shape.push_back(-1); - } else { - output_shape.push_back(ConvOutputSize(in_dims[i + 2], - filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); - } - } - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U, - platform::errors::InvalidArgument( - "output num_filter must divide deformable group " - "size. But received [%d]: [%d]", - output_shape[1], deformable_groups)); - PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2], - platform::errors::InvalidArgument( - "output height must equal to offset map height. " - "The difference is [%d]: [%d]", - output_shape[2], offset_dims[2])); - PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3], - platform::errors::InvalidArgument( - "output width must equal to offset map width. The " - "difference is [%d]: [%d]", - output_shape[3], offset_dims[3])); - PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U, - platform::errors::InvalidArgument( - "offset filter must divide deformable group size. " - "But received [%d]: [%d]", - offset_dims[1], filter_dims[2] * filter_dims[3])); - PADDLE_ENFORCE_EQ( - offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), - deformable_groups, - platform::errors::InvalidArgument( - "offset filter must divide deformable group size. But received " - "[%d]: [%d]", - offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), - deformable_groups)); - } - ctx->SetOutputDim("Output", phi::make_ddim(output_shape)); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -300,15 +180,12 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(deformable_conv, DeformableConvV1InferShapeFunctor, + PD_INFER_META(phi::DeformableConvInferMeta)); + REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op, ops::DeformableConvV1OpMaker, ops::DeformableConvV1GradOpMaker, - ops::DeformableConvV1GradOpMaker); + ops::DeformableConvV1GradOpMaker, + DeformableConvV1InferShapeFunctor); REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp); - -REGISTER_OP_CPU_KERNEL(deformable_conv_v1, - ops::DeformableConvV1CPUKernel, - ops::DeformableConvV1CPUKernel); -REGISTER_OP_CPU_KERNEL(deformable_conv_v1_grad, - ops::DeformableConvV1GradCPUKernel, - ops::DeformableConvV1GradCPUKernel); diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cu b/paddle/fluid/operators/deformable_conv_v1_op.cu deleted file mode 100644 index 70e022157e8..00000000000 --- a/paddle/fluid/operators/deformable_conv_v1_op.cu +++ /dev/null @@ -1,604 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Part of the following code in this file refs to -// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu -// -// Copyright (c) 2017 Microsoft -// Licensed under The Apache-2.0 License [see LICENSE for details] -// \file deformable_psroi_pooling.cu -// \brief -// \author Yi Li, Guodong Zhang, Jifeng Dai - -#pragma once -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/deformable_conv_filter.cu.h" -#include "paddle/fluid/operators/deformable_conv_func.h" -#include "paddle/fluid/operators/deformable_conv_v1_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using CUDADeviceContext = paddle::platform::CUDADeviceContext; - -static constexpr int kNumCUDAThread = 512; -static constexpr int kNumMaximumNumBlock = 4096; - -static inline int NumBlock(const int N) { - return std::min((N + kNumCUDAThread - 1) / kNumCUDAThread, - kNumMaximumNumBlock); -} - -template -__global__ void DeformableCol2imCUDAKernel( - const int nthreads, const T* data_col, const T* data_offset, - const int channels, const int height, const int width, const int kernel_h, - const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, - const int deformable_group, const int height_col, const int width_col, - T* grad_im) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t thread = index; thread < nthreads; thread += offset) { - const int j = (thread / width_col / height_col / batch_size) % kernel_w; - const int i = - (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = - thread / width_col / height_col / batch_size / kernel_w / kernel_h; - - const int deformable_group_index = c / channel_per_deformable_group; - - int w_out = thread % width_col; - int h_out = (thread / width_col) % height_col; - int b = (thread / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - - const T* data_offset_ptr = data_offset + - (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * - width_col; - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T cur_inv_h_data = h_in + i * dilation_h + offset_h; - const T cur_inv_w_data = w_in + j * dilation_w + offset_w; - - const T cur_top_grad = data_col[thread]; - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { - if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && - cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && - abs(cur_inv_w_data - (cur_w + dx)) < 1) { - int cur_bottom_grad_pos = - ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; - T weight = - DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, - cur_w + dx, height, width); - - platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos, - weight * cur_top_grad); - } - } - } - } -} - -template -inline void DeformableCol2im(const platform::CUDADeviceContext& ctx, - const T* data_col, const T* data_offset, - const std::vector im_shape, - const std::vector col_shape, - const std::vector kernel_shape, - const std::vector pad, - const std::vector stride, - const std::vector dilation, - const int deformable_group, T* grad_im) { - int channel_per_deformable_group = im_shape[0] / deformable_group; - int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - int blocks = NumBlock(num_kernels); - int threads = kNumCUDAThread; - - DeformableCol2imCUDAKernel<<< - blocks, threads, 0, - reinterpret_cast(ctx).stream()>>>( - num_kernels, data_col, data_offset, im_shape[0], im_shape[1], im_shape[2], - kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0], stride[1], - dilation[0], dilation[1], channel_per_deformable_group, col_shape[1], - deformable_group, col_shape[2], col_shape[3], grad_im); -} - -template -__global__ void DeformableCol2imCoordCUDAKernel( - const int nthreads, const T* data_col, const T* data_im, - const T* data_offset, const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, const int dilation_h, - const int dilation_w, const int channel_per_deformable_group, - const int batch_size, const int offset_channels, const int deformable_group, - const int height_col, const int width_col, T* grad_offset) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - T val = 0, mval = 0; - const int w = i % width_col; - const int h = (i / width_col) % height_col; - const int c = (i / width_col / height_col) % offset_channels; - const int b = (i / width_col / height_col) / offset_channels; - - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; - const T* data_col_ptr = data_col + - deformable_group_index * - channel_per_deformable_group * batch_size * - width_col * height_col; - const T* data_im_ptr = data_im + - (b * deformable_group + deformable_group_index) * - channel_per_deformable_group / kernel_h / - kernel_w * height * width; - const T* data_offset_ptr = data_offset + - (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * - width_col; - - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; - - for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; - col_c += col_step) { - const int col_pos = - (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; - - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = - (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = - (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = - (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + - w_out); - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - T inv_h = h_in + i * dilation_h + offset_h; - T inv_w = w_in + j * dilation_w + offset_w; - if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { - inv_h = inv_w = -2; - } else { - mval += data_col_ptr[col_pos] * - DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width, - height, width, inv_h, inv_w); - } - const T weight = DmcnGetCoordinateWeight( - inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, - width, bp_dir); - val += weight * data_col_ptr[col_pos]; - cnt += 1; - } - grad_offset[i] = val; - } -} - -template -inline void DeformableCol2imCoord( - const platform::CUDADeviceContext& ctx, const T* data_col, const T* data_im, - const T* data_offset, const std::vector im_shape, - const std::vector col_shape, - const std::vector kernel_shape, const std::vector paddings, - const std::vector strides, const std::vector dilations, - const int deformable_groups, T* grad_offset) { - int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * - col_shape[2] * col_shape[3] * deformable_groups; - int channel_per_deformable_group = col_shape[0] / deformable_groups; - int blocks = NumBlock(num_kernels); - int threads = kNumCUDAThread; - - DeformableCol2imCoordCUDAKernel<<< - blocks, threads, 0, - reinterpret_cast(ctx).stream()>>>( - num_kernels, data_col, data_im, data_offset, im_shape[0], im_shape[1], - im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0], paddings[1], - strides[0], strides[1], dilations[0], dilations[1], - channel_per_deformable_group, col_shape[1], - 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, - deformable_groups, col_shape[2], col_shape[3], grad_offset); -} - -template -__global__ void DeformableIm2colCUDAKernel( - const int nthreads, const T* data_im, const T* data_offset, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, - const int num_channels, const int deformable_group, const int height_col, - const int width_col, T* data_col) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - T* data_col_ptr = - data_col + - ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T* data_im_ptr = - data_im + (b_col * num_channels + c_im) * height * width; - const T* data_offset_ptr = - data_offset + - (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * - kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + - w_col; - - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - T val = static_cast(0); - const T h_im = h_in + i * dilation_h + offset_h; - const T w_im = w_in + j * dilation_w + offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { - val = - DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val; - data_col_ptr += batch_size * height_col * width_col; - } - } - } -} - -template -inline void DeformableIm2col(const platform::CUDADeviceContext& ctx, - const T* data_im, const T* data_offset, - const std::vector im_shape, - const std::vector col_shape, - const std::vector filter_shape, - const std::vector paddings, - const std::vector strides, - const std::vector dilations, - const int deformable_groups, T* data_col) { - int channel_per_deformable_group = im_shape[0] / deformable_groups; - int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - - int blocks = NumBlock(num_kernels); - int threads = kNumCUDAThread; - - // get outputs of im2col with offset by bilinear interpolation - DeformableIm2colCUDAKernel<<< - blocks, threads, 0, - reinterpret_cast(ctx).stream()>>>( - num_kernels, data_im, data_offset, im_shape[1], im_shape[2], - filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0], - strides[1], dilations[0], dilations[1], channel_per_deformable_group, - col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3], - data_col); -} - -template -class DeformableConvV1CUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* input = ctx.Input("Input"); - const Tensor offset = *ctx.Input("Offset"); - Tensor filter = *ctx.Input("Filter"); - Tensor* output = ctx.Output("Output"); - output->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.template device_context(); - - const int groups = ctx.Attr("groups"); - const int deformable_groups = ctx.Attr("deformable_groups"); - const int im2col_step = ctx.Attr("im2col_step"); - const std::vector strides = ctx.Attr>("strides"); - const std::vector paddings = ctx.Attr>("paddings"); - const std::vector dilations = ctx.Attr>("dilations"); - - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(phi::vectorize(filter.dims())); - std::vector output_shape_vec(phi::vectorize(output->dims())); - - // col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w} - std::vector col_buffer_shape_vec(filter_shape_vec.size()); - col_buffer_shape_vec[0] = - input->dims()[1] * filter.dims()[2] * filter.dims()[3]; - col_buffer_shape_vec[1] = im2col_step; - for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { - col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec)); - std::vector output_buffer_shape_vec(1); - output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * - output_shape_vec[2] * output_shape_vec[3]; - framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec)); - Tensor col_buffer; - Tensor output_buffer; - col_buffer = - ctx.AllocateTmpTensor(col_shape, dev_ctx); - output_buffer = - ctx.AllocateTmpTensor(output_shape, dev_ctx); - - int64_t M = output_shape_vec[1] / groups; - int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; - int64_t K = - input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups; - - Tensor weight_3d; - weight_3d.ShareDataWith(filter).Resize(phi::make_ddim({groups, M, K})); - Tensor col_buffer_3d; - col_buffer_3d.ShareDataWith(col_buffer) - .Resize(phi::make_ddim({groups, K, N})); - Tensor output_4d; - output_4d.ShareDataWith(output_buffer) - .Resize(phi::make_ddim({batch_size / im2col_step, groups, M, N})); - output_4d.mutable_data(ctx.GetPlace()); - framework::DDim input_shape = - phi::slice_ddim(input->dims(), 1, input->dims().size()); - std::vector input_shape_vec = phi::vectorize(input_shape); - - int input_dim = input->numel() / input->dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - - auto blas = phi::funcs::GetBlas(dev_ctx); - - const T* input_ptr = input->data(); - const T* offset_ptr = offset.data(); - col_buffer.mutable_data(ctx.GetPlace()); - T* col_buffer_ptr = col_buffer.data(); - - for (int i = 0; i < batch_size / im2col_step; ++i) { - DeformableIm2col(dev_ctx, input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - input_shape_vec, col_buffer_shape_vec, filter_shape_vec, - paddings, strides, dilations, deformable_groups, - col_buffer_ptr); - - Tensor output_3d = output_4d.Slice(i, i + 1).Resize( - phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); - // get the product of pixel and weight - for (int g = 0; g < groups; ++g) { - Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - Tensor output_3d_slice = output_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size())); - - blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0), - &output_3d_slice, T(0.0)); - } - } - output->ShareDataWith(output_buffer) - .Resize(phi::make_ddim(output_shape_vec)); - } -}; - -template -class DeformableConvV1GradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* output_grad = - ctx.Input(framework::GradVarName("Output")); - Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); - Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); - Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); - - const Tensor* input = ctx.Input("Input"); - Tensor offset = *ctx.Input("Offset"); - Tensor filter = *ctx.Input("Filter"); - if (!input_grad && !filter_grad && !offset_grad) return; - - int groups = ctx.Attr("groups"); - int deformable_groups = ctx.Attr("deformable_groups"); - int im2col_step = ctx.Attr("im2col_step"); - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - std::vector dilations = ctx.Attr>("dilations"); - - auto& dev_ctx = ctx.template device_context(); - const int batch_size = static_cast(input->dims()[0]); - - framework::DDim input_shape = - phi::slice_ddim(input->dims(), 1, input->dims().size()); - std::vector input_shape_vec = phi::vectorize(input_shape); - std::vector filter_shape_vec(phi::vectorize(filter.dims())); - std::vector output_shape_vec(phi::vectorize(output_grad->dims())); - - std::vector col_buffer_shape_vec(filter_shape_vec.size()); - col_buffer_shape_vec[0] = - input->dims()[1] * filter.dims()[2] * filter.dims()[3]; - col_buffer_shape_vec[1] = im2col_step; - for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { - col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec)); - std::vector output_buffer_shape_vec(1); - output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * - output_shape_vec[2] * output_shape_vec[3]; - framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec)); - Tensor col_buffer; - Tensor output_buffer; - col_buffer = - ctx.AllocateTmpTensor(col_shape, dev_ctx); - output_buffer = - ctx.AllocateTmpTensor(output_shape, dev_ctx); - - output_buffer.ShareDataWith(*output_grad); - - int64_t M = - input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; - int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; - int64_t K = output_shape_vec[1] / groups; - - framework::DDim weight_3d_shape = {groups, K, M}; - framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, - N}; - framework::DDim col_buffer_3d_shape = {groups, M, N}; - framework::DDim filter_grad_shape = {groups, K, M}; - - Tensor weight_3d; - weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); - Tensor out_grad_4d; - out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); - Tensor col_buffer_3d; - col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); - - phi::funcs::SetConstant set_zero; - auto blas = phi::funcs::GetBlas(dev_ctx); - - col_buffer.mutable_data(ctx.GetPlace()); - col_buffer_3d.mutable_data(ctx.GetPlace()); - out_grad_4d.mutable_data(ctx.GetPlace()); - - int input_dim = input->numel() / input->dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - - if (filter_grad) { - filter_grad->mutable_data(ctx.GetPlace()); - filter_grad->Resize(filter_grad_shape); - set_zero(dev_ctx, filter_grad, static_cast(0)); - } - - if (input_grad) { - input_grad->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); - } - - if (offset_grad) { - offset_grad->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, offset_grad, static_cast(0)); - } - - for (int i = 0; i < batch_size / im2col_step; ++i) { - Tensor out_grad_3d = out_grad_4d.Slice(i, i + 1).Resize( - phi::slice_ddim(out_grad_4d.dims(), 1, out_grad_4d.dims().size())); - for (int g = 0; g < groups; ++g) { - Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size())); - Tensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - - blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0), - &col_buffer_3d_slice, T(0.0)); - } - col_buffer.Resize(col_shape); - - T* col_buffer_ptr = col_buffer.data(); - const T* input_ptr = input->data(); - const T* offset_ptr = offset.data(); - - if (offset_grad) { - T* offset_grad_ptr = offset_grad->data(); - // get grad of offset - DeformableCol2imCoord( - dev_ctx, col_buffer_ptr, input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, input_shape_vec, - col_buffer_shape_vec, filter_shape_vec, paddings, strides, - dilations, deformable_groups, - offset_grad_ptr + i * im2col_step * input_offset_dim); - } - if (input_grad) { - T* input_grad_ptr = input_grad->data(); - // get grad of input - DeformableCol2im(dev_ctx, col_buffer_ptr, - offset_ptr + i * im2col_step * input_offset_dim, - input_shape_vec, col_buffer_shape_vec, - filter_shape_vec, paddings, strides, dilations, - deformable_groups, - input_grad_ptr + i * im2col_step * input_dim); - input_grad->Resize(input->dims()); - } - - DeformableIm2col(dev_ctx, input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - input_shape_vec, col_buffer_shape_vec, filter_shape_vec, - paddings, strides, dilations, deformable_groups, - col_buffer_ptr); - - col_buffer_3d.Resize(col_buffer_3d_shape); - - if (filter_grad) { - Tensor dweight_3d; - dweight_3d = ctx.AllocateTmpTensor( - filter_grad_shape, dev_ctx); - for (int g = 0; g < groups; ++g) { - Tensor out_grad_3d_slice = - out_grad_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - out_grad_3d.dims(), 1, out_grad_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - Tensor dweight_3d_slice = dweight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(dweight_3d.dims(), 1, dweight_3d.dims().size())); - - blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true, - T(1.0), &dweight_3d_slice, T(0.0)); - } - FilterGradAddupCUDAKernel<<>>( - dweight_3d.numel(), groups, K, M, dweight_3d.data(), - filter_grad->data()); - } - } - if (filter_grad) { - filter_grad->Resize(filter.dims()); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(deformable_conv_v1, - ops::DeformableConvV1CUDAKernel, - ops::DeformableConvV1CUDAKernel); -REGISTER_OP_CUDA_KERNEL(deformable_conv_v1_grad, - ops::DeformableConvV1GradCUDAKernel, - ops::DeformableConvV1GradCUDAKernel); diff --git a/paddle/fluid/operators/deformable_conv_v1_op.h b/paddle/fluid/operators/deformable_conv_v1_op.h deleted file mode 100644 index 8f4f9709603..00000000000 --- a/paddle/fluid/operators/deformable_conv_v1_op.h +++ /dev/null @@ -1,556 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Part of the following code in this file refs to -// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu -// -// Copyright (c) 2017 Microsoft -// Licensed under The Apache-2.0 License [see LICENSE for details] -// \file deformable_psroi_pooling.cu -// \brief -// \author Yi Li, Guodong Zhang, Jifeng Dai - -#pragma once -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/deformable_conv_func.h" -#include "paddle/fluid/operators/deformable_conv_op.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using CPUDeviceContext = platform::CPUDeviceContext; - -template -void DeformableCol2imCPUKernel( - const int num_kernels, const T* data_col, const T* data_offset, - const int channels, const int height, const int width, const int kernel_h, - const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, - const int deformable_group, const int height_col, const int width_col, - T* grad_im) { - for (int thread = 0; thread < num_kernels; thread++) { - const int j = (thread / width_col / height_col / batch_size) % kernel_w; - const int i = - (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = - thread / width_col / height_col / batch_size / kernel_w / kernel_h; - - const int deformable_group_index = c / channel_per_deformable_group; - - int w_out = thread % width_col; - int h_out = (thread / width_col) % height_col; - int b = (thread / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - - const T* data_offset_ptr = data_offset + - (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * - width_col; - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T cur_inv_h_data = h_in + i * dilation_h + offset_h; - const T cur_inv_w_data = w_in + j * dilation_w + offset_w; - - const T cur_top_grad = data_col[thread]; - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { - if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && - cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && - abs(cur_inv_w_data - (cur_w + dx)) < 1) { - int cur_bottom_grad_pos = - ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; - T weight = - DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, - cur_w + dx, height, width); - - *(grad_im + cur_bottom_grad_pos) = - *(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad; - } - } - } - } -} - -template -inline void DeformableCol2imCPU(const platform::CPUDeviceContext& ctx, - const T* data_col, const T* data_offset, - const std::vector im_shape, - const std::vector col_shape, - const std::vector kernel_shape, - const std::vector pad, - const std::vector stride, - const std::vector dilation, - const int deformable_group, T* grad_im) { - int channel_per_deformable_group = im_shape[0] / deformable_group; - int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - - DeformableCol2imCPUKernel( - num_kernels, data_col, data_offset, im_shape[0], im_shape[1], im_shape[2], - kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0], stride[1], - dilation[0], dilation[1], channel_per_deformable_group, col_shape[1], - deformable_group, col_shape[2], col_shape[3], grad_im); -} - -template -void DeformableCol2imCoordCPUKernel( - const int num_kernels, const T* data_col, const T* data_im, - const T* data_offset, const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, const int dilation_h, - const int dilation_w, const int channel_per_deformable_group, - const int batch_size, const int offset_channels, const int deformable_group, - const int height_col, const int width_col, T* grad_offset) { - for (int i = 0; i < num_kernels; i++) { - T val = 0, mval = 0; - const int w = i % width_col; - const int h = (i / width_col) % height_col; - const int c = (i / width_col / height_col) % offset_channels; - const int b = (i / width_col / height_col) / offset_channels; - - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; - const T* data_col_ptr = data_col + - deformable_group_index * - channel_per_deformable_group * batch_size * - width_col * height_col; - const T* data_im_ptr = data_im + - (b * deformable_group + deformable_group_index) * - channel_per_deformable_group / kernel_h / - kernel_w * height * width; - const T* data_offset_ptr = data_offset + - (b * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * - width_col; - - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; - - for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; - col_c += col_step) { - const int col_pos = - (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; - - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = - (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = - (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = - (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + - w_out); - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - T inv_h = h_in + i * dilation_h + offset_h; - T inv_w = w_in + j * dilation_w + offset_w; - if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { - inv_h = inv_w = -2; - } else { - mval += data_col_ptr[col_pos] * - DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width, - height, width, inv_h, inv_w); - } - const T weight = DmcnGetCoordinateWeight( - inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, - width, bp_dir); - val += weight * data_col_ptr[col_pos]; - cnt += 1; - } - grad_offset[i] = val; - } -} - -template -inline void DeformableCol2imCoordCPU( - const platform::CPUDeviceContext& ctx, const T* data_col, const T* data_im, - const T* data_offset, const std::vector im_shape, - const std::vector col_shape, - const std::vector kernel_shape, const std::vector paddings, - const std::vector strides, const std::vector dilations, - const int deformable_groups, T* grad_offset) { - int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * - col_shape[2] * col_shape[3] * deformable_groups; - int channel_per_deformable_group = col_shape[0] / deformable_groups; - - DeformableCol2imCoordCPUKernel( - num_kernels, data_col, data_im, data_offset, im_shape[0], im_shape[1], - im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0], paddings[1], - strides[0], strides[1], dilations[0], dilations[1], - channel_per_deformable_group, col_shape[1], - 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, - deformable_groups, col_shape[2], col_shape[3], grad_offset); -} - -template -void DeformableIm2colCPUKernel( - const int num_kernels, const T* data_im, const T* data_offset, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, - const int num_channels, const int deformable_group, const int height_col, - const int width_col, T* data_col) { - for (int i = 0; i < num_kernels; i++) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - T* data_col_ptr = - data_col + - ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T* data_im_ptr = - data_im + (b_col * num_channels + c_im) * height * width; - const T* data_offset_ptr = - data_offset + - (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * - kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + - w_col; - - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - T val = static_cast(0); - const T h_im = h_in + i * dilation_h + offset_h; - const T w_im = w_in + j * dilation_w + offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { - val = - DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val; - data_col_ptr += batch_size * height_col * width_col; - } - } - } -} - -template -inline void DeformableIm2colCPU(const platform::CPUDeviceContext& ctx, - const T* data_im, const T* data_offset, - const std::vector im_shape, - const std::vector col_shape, - const std::vector filter_shape, - const std::vector paddings, - const std::vector strides, - const std::vector dilations, - const int deformable_groups, T* data_col) { - int channel_per_deformable_group = im_shape[0] / deformable_groups; - int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - - // get outputs of im2col with offset by bilinear interpolation - DeformableIm2colCPUKernel( - num_kernels, data_im, data_offset, im_shape[1], im_shape[2], - filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0], - strides[1], dilations[0], dilations[1], channel_per_deformable_group, - col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3], - data_col); -} - -template -class DeformableConvV1CPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("Input"); - auto* offset = ctx.Input("Offset"); - Tensor filter = *ctx.Input("Filter"); - Tensor* output = ctx.Output("Output"); - output->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.template device_context(); - - const int groups = ctx.Attr("groups"); - const int deformable_groups = ctx.Attr("deformable_groups"); - const int im2col_step = ctx.Attr("im2col_step"); - const std::vector strides = ctx.Attr>("strides"); - const std::vector paddings = ctx.Attr>("paddings"); - const std::vector dilations = ctx.Attr>("dilations"); - - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(phi::vectorize(filter.dims())); - std::vector output_shape_vec(phi::vectorize(output->dims())); - - // col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w} - std::vector col_buffer_shape_vec(filter_shape_vec.size()); - col_buffer_shape_vec[0] = - input->dims()[1] * filter.dims()[2] * filter.dims()[3]; - col_buffer_shape_vec[1] = im2col_step; - for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { - col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec)); - std::vector output_buffer_shape_vec(1); - output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * - output_shape_vec[2] * output_shape_vec[3]; - framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec)); - Tensor col_buffer; - Tensor output_buffer; - col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); - output_buffer = - ctx.AllocateTmpTensor(output_shape, dev_ctx); - int64_t M = output_shape_vec[1] / groups; - int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; - int64_t K = - input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups; - - Tensor weight_3d; - weight_3d.ShareDataWith(filter).Resize(phi::make_ddim({groups, M, K})); - Tensor col_buffer_3d; - col_buffer_3d.ShareDataWith(col_buffer) - .Resize(phi::make_ddim({groups, K, N})); - Tensor output_4d; - output_4d.ShareDataWith(output_buffer) - .Resize(phi::make_ddim({batch_size / im2col_step, groups, M, N})); - output_4d.mutable_data(ctx.GetPlace()); - framework::DDim input_shape = - phi::slice_ddim(input->dims(), 1, input->dims().size()); - std::vector input_shape_vec = phi::vectorize(input_shape); - int input_dim = input->numel() / input->dims()[0]; - int input_offset_dim = offset->numel() / offset->dims()[0]; - auto blas = phi::funcs::GetBlas(dev_ctx); - const T* input_ptr = input->data(); - const T* offset_ptr = offset->data(); - col_buffer.mutable_data(ctx.GetPlace()); - T* col_buffer_ptr = col_buffer.data(); - for (int i = 0; i < batch_size / im2col_step; ++i) { - DeformableIm2colCPU(dev_ctx, input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - input_shape_vec, col_buffer_shape_vec, - filter_shape_vec, paddings, strides, dilations, - deformable_groups, col_buffer_ptr); - Tensor output_3d = output_4d.Slice(i, i + 1).Resize( - phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); - // get the product of pixel and weight - for (int g = 0; g < groups; ++g) { - Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - Tensor output_3d_slice = output_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size())); - blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0), - &output_3d_slice, T(0.0)); - } - } - output->ShareDataWith(output_buffer) - .Resize(phi::make_ddim(output_shape_vec)); - } -}; - -template -class DeformableConvV1GradCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* output_grad = - ctx.Input(framework::GradVarName("Output")); - Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); - Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); - Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); - - const Tensor* input = ctx.Input("Input"); - Tensor offset = *ctx.Input("Offset"); - Tensor filter = *ctx.Input("Filter"); - if (!input_grad && !filter_grad && !offset_grad) return; - - int groups = ctx.Attr("groups"); - int deformable_groups = ctx.Attr("deformable_groups"); - int im2col_step = ctx.Attr("im2col_step"); - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - std::vector dilations = ctx.Attr>("dilations"); - - auto& dev_ctx = ctx.template device_context(); - const int batch_size = static_cast(input->dims()[0]); - - framework::DDim input_shape = - phi::slice_ddim(input->dims(), 1, input->dims().size()); - std::vector input_shape_vec = phi::vectorize(input_shape); - std::vector filter_shape_vec(phi::vectorize(filter.dims())); - std::vector output_shape_vec(phi::vectorize(output_grad->dims())); - - std::vector col_buffer_shape_vec(filter_shape_vec.size()); - col_buffer_shape_vec[0] = - input->dims()[1] * filter.dims()[2] * filter.dims()[3]; - col_buffer_shape_vec[1] = im2col_step; - for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { - col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec)); - std::vector output_buffer_shape_vec(1); - output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * - output_shape_vec[2] * output_shape_vec[3]; - framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec)); - Tensor col_buffer; - Tensor output_buffer; - col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); - output_buffer = - ctx.AllocateTmpTensor(output_shape, dev_ctx); - - output_buffer.ShareDataWith(*output_grad); - - int64_t M = - input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; - int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; - int64_t K = output_shape_vec[1] / groups; - - framework::DDim weight_3d_shape = {groups, K, M}; - framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, - N}; - framework::DDim col_buffer_3d_shape = {groups, M, N}; - framework::DDim filter_grad_shape = {groups, K, M}; - - Tensor weight_3d; - weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); - Tensor out_grad_4d; - out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); - Tensor col_buffer_3d; - col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); - - phi::funcs::SetConstant set_zero; - auto blas = phi::funcs::GetBlas(dev_ctx); - - col_buffer.mutable_data(ctx.GetPlace()); - col_buffer_3d.mutable_data(ctx.GetPlace()); - out_grad_4d.mutable_data(ctx.GetPlace()); - - int input_dim = input->numel() / input->dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - - if (filter_grad) { - filter_grad->mutable_data(ctx.GetPlace()); - filter_grad->Resize(filter_grad_shape); - set_zero(dev_ctx, filter_grad, static_cast(0)); - } - - if (input_grad) { - input_grad->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); - } - - if (offset_grad) { - offset_grad->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, offset_grad, static_cast(0)); - } - - for (int i = 0; i < batch_size / im2col_step; ++i) { - Tensor out_grad_3d = out_grad_4d.Slice(i, i + 1).Resize( - phi::slice_ddim(out_grad_4d.dims(), 1, out_grad_4d.dims().size())); - for (int g = 0; g < groups; ++g) { - Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size())); - Tensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - - blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0), - &col_buffer_3d_slice, T(0.0)); - } - col_buffer.Resize(col_shape); - - T* col_buffer_ptr = col_buffer.data(); - const T* input_ptr = input->data(); - const T* offset_ptr = offset.data(); - - if (offset_grad) { - T* offset_grad_ptr = offset_grad->data(); - // get grad of offset - DeformableCol2imCoordCPU( - dev_ctx, col_buffer_ptr, input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, input_shape_vec, - col_buffer_shape_vec, filter_shape_vec, paddings, strides, - dilations, deformable_groups, - offset_grad_ptr + i * im2col_step * input_offset_dim); - } - if (input_grad) { - T* input_grad_ptr = input_grad->data(); - // get grad of input - DeformableCol2imCPU(dev_ctx, col_buffer_ptr, - offset_ptr + i * im2col_step * input_offset_dim, - input_shape_vec, col_buffer_shape_vec, - filter_shape_vec, paddings, strides, dilations, - deformable_groups, - input_grad_ptr + i * im2col_step * input_dim); - input_grad->Resize(input->dims()); - } - - DeformableIm2colCPU(dev_ctx, input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - input_shape_vec, col_buffer_shape_vec, - filter_shape_vec, paddings, strides, dilations, - deformable_groups, col_buffer_ptr); - - col_buffer_3d.Resize(col_buffer_3d_shape); - - if (filter_grad) { - Tensor dweight_3d; - dweight_3d = ctx.AllocateTmpTensor( - filter_grad_shape, dev_ctx); - for (int g = 0; g < groups; ++g) { - Tensor out_grad_3d_slice = - out_grad_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - out_grad_3d.dims(), 1, out_grad_3d.dims().size())); - Tensor col_buffer_3d_slice = - col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( - col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - Tensor dweight_3d_slice = dweight_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(dweight_3d.dims(), 1, dweight_3d.dims().size())); - - blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true, - T(1.0), &dweight_3d_slice, T(0.0)); - } - // update grad of weights - FilterGradAddupCPUKernel(dweight_3d.numel(), groups, K, M, - dweight_3d.data(), filter_grad->data()); - } - } - if (filter_grad) { - filter_grad->Resize(filter.dims()); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 7a00f91da2e..6c268dfb6c4 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -655,6 +655,7 @@ void BindImperative(py::module *m_ptr) { } else { act_name = name.cast(); } + VLOG(4) << "Init VarBase :" << act_name; new (&self) imperative::VarBase(act_name); self.SetPersistable(persistable); self.SetType(type); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 3faf42fe1ab..4790fa863f2 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -516,6 +516,215 @@ void ConcatInferMeta(const std::vector& x, out->share_lod(*x.at(0)); } +inline int ConvOutputSize( + int input_size, int filter_size, int dilation, int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + PADDLE_ENFORCE_GT( + output_size, + 0, + phi::errors::InvalidArgument( + "The output's size is expected to be greater than 0. " + "But recieved: output's size is %d. The output's size is computed by " + "((input_size + 2 * padding - (dilation * (filter_size - 1) + 1)) / " + "stride + 1), where input_size is %d, padding is %d, " + "filter_size is %d, dilation is %d, stride is %d.", + output_size, + input_size, + padding, + filter_size, + dilation, + stride)); + + return output_size; +} + +void DeformableConvInferMeta(const MetaTensor& x, + const MetaTensor& offset, + const MetaTensor& filter, + paddle::optional mask, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int deformable_groups, + int groups, + int im2col_step, + MetaTensor* out, + MetaConfig config) { + auto in_dims = x.dims(); + auto offset_dims = offset.dims(); + auto filter_dims = filter.dims(); + + PADDLE_ENFORCE_EQ( + in_dims.size(), + 4, + phi::errors::InvalidArgument("Conv input should be 4-D tensor, get %u", + in_dims.size())); + PADDLE_ENFORCE_EQ(in_dims.size(), + filter_dims.size(), + phi::errors::InvalidArgument( + "Conv input dimension and filter dimension should be " + "the same. The difference is [%d]: [%d]", + in_dims.size(), + filter_dims.size())); + PADDLE_ENFORCE_EQ(in_dims.size() - strides.size(), + 2U, + phi::errors::InvalidArgument( + "Conv input dimension and strides " + "dimension should be consistent. But received input " + "dimension:[%d], strides dimension:[%d]", + in_dims.size(), + strides.size())); + PADDLE_ENFORCE_EQ(paddings.size(), + strides.size(), + phi::errors::InvalidArgument( + "Conv paddings dimension and Conv strides dimension " + "should be the same. The difference is [%d]: [%d]", + paddings.size(), + strides.size())); + + PADDLE_ENFORCE_EQ( + in_dims[1], + filter_dims[1] * groups, + phi::errors::InvalidArgument( + "The number of input channels should be equal to filter " + "channels * groups. The difference is [%d]: [%d]", + in_dims[1], + filter_dims[1] * groups)); + PADDLE_ENFORCE_EQ( + filter_dims[0] % groups, + 0, + phi::errors::InvalidArgument( + "The number of output channels should be divided by groups. But " + "received output channels:[%d], groups:[%d]", + filter_dims[0], + groups)); + PADDLE_ENFORCE_EQ( + filter_dims[0] % deformable_groups, + 0, + phi::errors::InvalidArgument( + "The number of output channels should be " + "divided by deformable groups. The difference is [%d]: [%d]", + filter_dims[0] % groups, + 0)); + + if (in_dims[0] > im2col_step) { + PADDLE_ENFORCE_EQ( + in_dims[0] % im2col_step, + 0U, + phi::errors::InvalidArgument( + "Input batchsize must be smaller than or divide im2col_step. But " + "received Input batchsize:[%d], im2col_step:[%d]", + in_dims[0], + im2col_step)); + } + + for (size_t i = 0; i < strides.size(); ++i) { + PADDLE_ENFORCE_GT( + strides[i], + 0U, + phi::errors::InvalidArgument("stride %d size incorrect", i)); + } + for (size_t i = 0; i < dilations.size(); ++i) { + PADDLE_ENFORCE_GT( + dilations[i], + 0U, + phi::errors::InvalidArgument("dilation %d size incorrect", i)); + } + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + if (!config.is_runtime && + (in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) { + output_shape.push_back(-1); + } else { + output_shape.push_back(ConvOutputSize(in_dims[i + 2], + filter_dims[i + 2], + dilations[i], + paddings[i], + strides[i])); + } + } + + PADDLE_ENFORCE_EQ( + output_shape[1] % deformable_groups, + 0U, + phi::errors::InvalidArgument( + "output num_filter must divide deformable group size. But received " + "output num_filter:[%d], deformable group size:[%d]", + output_shape[1], + deformable_groups)); + + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(output_shape[2], + offset_dims[2], + phi::errors::InvalidArgument( + "output height must equal to offset map height. " + "The difference is [%d]: [%d]", + output_shape[2], + offset_dims[2])); + PADDLE_ENFORCE_EQ(output_shape[3], + offset_dims[3], + phi::errors::InvalidArgument( + "output width must equal to offset map width. The " + "difference is [%d]: [%d]", + output_shape[3], + offset_dims[3])); + + PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), + 0U, + phi::errors::InvalidArgument( + "offset filter must divide deformable group size. " + "But received [%d]: [%d]", + offset_dims[1], + filter_dims[2] * filter_dims[3])); + PADDLE_ENFORCE_EQ( + offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), + deformable_groups, + phi::errors::InvalidArgument( + "offset filter must divide deformable group size. But received " + "[%d]: [%d]", + offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), + deformable_groups)); + + if (mask) { + auto mask_dims = mask->dims(); + PADDLE_ENFORCE_EQ(output_shape[2], + mask_dims[2], + phi::errors::InvalidArgument( + "output height must equal to mask map height. The " + "difference is [%d] vs [%d]", + output_shape[2], + mask_dims[2])); + PADDLE_ENFORCE_EQ(output_shape[3], + mask_dims[3], + phi::errors::InvalidArgument( + "output width must equal to mask map width. The " + "difference is [%d] vs [%d]", + output_shape[3], + mask_dims[3])); + + PADDLE_ENFORCE_EQ(mask_dims[1] % (filter_dims[2] * filter_dims[3]), + 0U, + phi::errors::InvalidArgument( + "mask filter must divide deformable group size. " + "But received [%d]: [%d]", + mask_dims[1], + filter_dims[2] * filter_dims[3])); + PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]), + deformable_groups, + phi::errors::InvalidArgument( + "mask filter must divide deformable group size. " + "But received [%d]: [%d]", + mask_dims[1] / (filter_dims[2] * filter_dims[3]), + deformable_groups)); + } + } + + out->set_dims(phi::make_ddim(output_shape)); + out->set_dtype(x.dtype()); +} + void HierarchicalSigmoidInferMeta(const MetaTensor& x, const MetaTensor& w, const MetaTensor& label, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index e9b5d8c872f..9088f204812 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -120,6 +120,19 @@ void ConcatInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void DeformableConvInferMeta(const MetaTensor& x, + const MetaTensor& offset, + const MetaTensor& filter, + paddle::optional mask, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int deformable_groups, + int groups, + int im2col_step, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void HierarchicalSigmoidInferMeta(const MetaTensor& x, const MetaTensor& w, const MetaTensor& label, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 59540dbaefd..941ede31400 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -27,12 +27,14 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) # Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel +set(MANUAL_BUILD_KERNELS deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel) +kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) +kernel_library(deformable_conv_grad_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) kernel_library(hierarchical_sigmoid_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code) kernel_library(hierarchical_sigmoid_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code) diff --git a/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc b/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc new file mode 100644 index 00000000000..f64b1d3291f --- /dev/null +++ b/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc @@ -0,0 +1,333 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/deformable_conv_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h" + +namespace phi { + +template +inline void ModulatedDeformableCol2imCPUKernel( + const int num_kernels, + const T* data_col, + const T* data_offset, + const T* data_mask, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int deformable_group, + const int height_col, + const int width_col, + T* grad_im) { + for (int thread = 0; thread < num_kernels; thread++) { + const int j = (thread / width_col / height_col / batch_size) % kernel_w; + const int i = + (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + thread / width_col / height_col / batch_size / kernel_w / kernel_h; + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = thread % width_col; + int h_out = (thread / width_col) % height_col; + int b = (thread / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + T cur_top_grad = data_col[thread]; + if (data_mask) { + const T* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + cur_top_grad *= mask; + } + const int cur_h = static_cast(cur_inv_h_data); + const int cur_w = static_cast(cur_inv_w_data); + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = DmcnGetGradientWeight(cur_inv_h_data, + cur_inv_w_data, + cur_h + dy, + cur_w + dx, + height, + width); + + *(grad_im + cur_bottom_grad_pos) = + *(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad; + } + } + } + } +} + +template +void ModulatedDeformableCol2im(const Context& dev_ctx, + const T* data_col, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& kernel_shape, + const std::vector& pad, + const std::vector& stride, + const std::vector& dilation, + const int deformable_group, + T* grad_im) { + int channel_per_deformable_group = im_shape[0] / deformable_group; + int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + ModulatedDeformableCol2imCPUKernel(num_kernels, + data_col, + data_offset, + data_mask, + im_shape[0], + im_shape[1], + im_shape[2], + kernel_shape[2], + kernel_shape[3], + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + channel_per_deformable_group, + col_shape[1], + deformable_group, + col_shape[2], + col_shape[3], + grad_im); +} + +template +void ModulatedDeformableCol2imCoordCPUKernel( + const int num_kernels, + const T* data_col, + const T* data_im, + const T* data_offset, + const T* data_mask, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int offset_channels, + const int deformable_group, + const int height_col, + const int width_col, + T* grad_offset, + T* grad_mask) { + for (int i = 0; i < num_kernels; i++) { + T val = 0, mval = 0; + const int w = i % width_col; + const int h = (i / width_col) % height_col; + const int c = (i / width_col / height_col) % offset_channels; + const int b = (i / width_col / height_col) / offset_channels; + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T* data_col_ptr = data_col + + deformable_group_index * + channel_per_deformable_group * batch_size * + width_col * height_col; + const T* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / + kernel_w * height * width; + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const T* data_mask_ptr = + data_mask + ? data_mask + + (b * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col + : nullptr; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * + funcs::DmcnIm2colBilinear(data_im_ptr + cnt * height * width, + width, + height, + width, + inv_h, + inv_w); + } + const T weight = + DmcnGetCoordinateWeight(inv_h, + inv_w, + height, + width, + data_im_ptr + cnt * height * width, + width, + bp_dir); + if (data_mask_ptr) { + const int data_mask_hw_ptr = + (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const T mask = data_mask_ptr[data_mask_hw_ptr]; + val += weight * data_col_ptr[col_pos] * mask; + } else { + val += weight * data_col_ptr[col_pos]; + } + cnt += 1; + } + grad_offset[i] = val; + if (grad_mask && offset_c % 2 == 0) + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * + kernel_w + + offset_c / 2) * + height_col + + h) * + width_col + + w] = mval; + } +} + +template +void ModulatedDeformableCol2imCoord(const Context& dev_ctx, + const T* data_col, + const T* data_im, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& kernel_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + T* grad_offset, + T* grad_mask) { + int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int channel_per_deformable_group = col_shape[0] / deformable_groups; + + ModulatedDeformableCol2imCoordCPUKernel( + num_kernels, + data_col, + data_im, + data_offset, + data_mask, + im_shape[0], + im_shape[1], + im_shape[2], + kernel_shape[2], + kernel_shape[3], + paddings[0], + paddings[1], + strides[0], + strides[1], + dilations[0], + dilations[1], + channel_per_deformable_group, + col_shape[1], + 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, + deformable_groups, + col_shape[2], + col_shape[3], + grad_offset, + grad_mask); +} + +template +void FilterGradAddup(const Context& dev_ctx, + const int nthreads, + const int n, + const int height, + const int width, + const T* dweight_3d, + T* filter_grad) { + for (int i = 0; i < nthreads; i++) { + filter_grad[i] = filter_grad[i] + dweight_3d[i]; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(deformable_conv_grad, + CPU, + ALL_LAYOUT, + phi::DeformableConvGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/deformable_conv_kernel.cc b/paddle/phi/kernels/cpu/deformable_conv_kernel.cc index 0d61f7be68a..ea973ff53f7 100644 --- a/paddle/phi/kernels/cpu/deformable_conv_kernel.cc +++ b/paddle/phi/kernels/cpu/deformable_conv_kernel.cc @@ -18,126 +18,6 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/deformable_conv_kernel_impl.h" -namespace phi { - -template -inline void ModulatedDeformableIm2colCPUKernel( - const int num_kernels, - const T* data_im, - const T* data_offset, - const T* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, - T* data_col) { - for (int i = 0; i < num_kernels; i++) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - T* data_col_ptr = - data_col + - ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T* data_im_ptr = - data_im + (b_col * num_channels + c_im) * height * width; - const T* data_offset_ptr = - data_offset + - (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * - kernel_w * height_col * width_col; - const T* data_mask_ptr = - data_mask + - (b_col * deformable_group + deformable_group_index) * kernel_h * - kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + - w_col; - const int data_mask_hw_ptr = - ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - T val = static_cast(0); - const T h_im = h_in + i * dilation_h + offset_h; - const T w_im = w_in + j * dilation_w + offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { - val = - DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val * mask; - data_col_ptr += batch_size * height_col * width_col; - } - } - } -} - -template -void ModulatedDeformableIm2col(const Context& dev_ctx, - const T* data_im, - const T* data_offset, - const T* data_mask, - const std::vector& im_shape, - const std::vector& col_shape, - const std::vector& filter_shape, - const std::vector& paddings, - const std::vector& strides, - const std::vector& dilations, - const int deformable_groups, - T* data_col) { - int channel_per_deformable_group = im_shape[0] / deformable_groups; - int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - - // get outputs of im2col with offset by bilinear interpolation - ModulatedDeformableIm2colCPUKernel(num_kernels, - data_im, - data_offset, - data_mask, - im_shape[1], - im_shape[2], - filter_shape[2], - filter_shape[3], - paddings[0], - paddings[1], - strides[0], - strides[1], - dilations[0], - dilations[1], - channel_per_deformable_group, - col_shape[1], - im_shape[0], - deformable_groups, - col_shape[2], - col_shape[3], - data_col); -} - -} // namespace phi - PD_REGISTER_KERNEL(deformable_conv, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/deformable_conv_grad_kernel.h b/paddle/phi/kernels/deformable_conv_grad_kernel.h new file mode 100644 index 00000000000..85786cec4c3 --- /dev/null +++ b/paddle/phi/kernels/deformable_conv_grad_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DeformableConvGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& offset, + const DenseTensor& filter, + paddle::optional mask, + const DenseTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int deformable_groups, + int groups, + int im2col_step, + DenseTensor* dx, + DenseTensor* offset_grad, + DenseTensor* filter_grad, + DenseTensor* mask_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/deformable_conv_kernel.h b/paddle/phi/kernels/deformable_conv_kernel.h index 3886e6801a3..fbbe5f62c6a 100644 --- a/paddle/phi/kernels/deformable_conv_kernel.h +++ b/paddle/phi/kernels/deformable_conv_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" namespace phi { @@ -23,7 +24,7 @@ void DeformableConvKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& offset, const DenseTensor& filter, - const DenseTensor& mask, + paddle::optional mask, const std::vector& strides, const std::vector& paddings, const std::vector& dilations, diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 942eecae168..b1f010cdff1 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(blas) add_subdirectory(lapack) add_subdirectory(detail) +math_library(deformable_conv_functor DEPS dense_tensor) math_library(concat_and_split_functor DEPS dense_tensor) math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) diff --git a/paddle/phi/kernels/funcs/deformable_conv_functor.cc b/paddle/phi/kernels/funcs/deformable_conv_functor.cc new file mode 100644 index 00000000000..ea256e93bba --- /dev/null +++ b/paddle/phi/kernels/funcs/deformable_conv_functor.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/funcs/deformable_conv_functor.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" + +namespace phi { +namespace funcs { + +template +inline void ModulatedDeformableIm2colCPUKernel( + const int num_kernels, + const T* data_im, + const T* data_offset, + const T* data_mask, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + T* data_col) { + for (int i = 0; i < num_kernels; i++) { + const int w_col = i % width_col; + const int h_col = (i / width_col) % height_col; + const int b_col = (i / width_col) / height_col % batch_size; + const int c_im = (i / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = + data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const T* data_mask_ptr = + data_mask + ? data_mask + + (b_col * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col + : nullptr; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + val = + DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + if (data_mask_ptr) { + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + *data_col_ptr *= mask; + } + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +void ModulatedDeformableIm2col(const Context& dev_ctx, + const T* data_im, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + T* data_col) { + int channel_per_deformable_group = im_shape[0] / deformable_groups; + int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + // get outputs of im2col with offset by bilinear interpolation + ModulatedDeformableIm2colCPUKernel(num_kernels, + data_im, + data_offset, + data_mask, + im_shape[1], + im_shape[2], + filter_shape[2], + filter_shape[3], + paddings[0], + paddings[1], + strides[0], + strides[1], + dilations[0], + dilations[1], + channel_per_deformable_group, + col_shape[1], + im_shape[0], + deformable_groups, + col_shape[2], + col_shape[3], + data_col); +} + +template void ModulatedDeformableIm2col( + const phi::CPUContext& dev_ctx, + const float* data_im, + const float* data_offset, + const float* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + float* data_col); + +template void ModulatedDeformableIm2col( + const phi::CPUContext& dev_ctx, + const double* data_im, + const double* data_offset, + const double* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + double* data_col); + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/deformable_conv_functor.cu b/paddle/phi/kernels/funcs/deformable_conv_functor.cu new file mode 100644 index 00000000000..8bfb46c6636 --- /dev/null +++ b/paddle/phi/kernels/funcs/deformable_conv_functor.cu @@ -0,0 +1,185 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/funcs/deformable_conv_functor.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" + +namespace phi { +namespace funcs { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +__global__ void ModulatedDeformableIm2colGpuKernel( + const int nthreads, + const T* data_im, + const T* data_offset, + const T* data_mask, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + T* data_col) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + const int w_col = i % width_col; + const int h_col = (i / width_col) % height_col; + const int b_col = (i / width_col) / height_col % batch_size; + const int c_im = (i / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = + data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const T* data_mask_ptr = + data_mask + ? data_mask + + (b_col * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col + : nullptr; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + val = + DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + if (data_mask_ptr) { + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + *data_col_ptr *= mask; + } + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +void ModulatedDeformableIm2col(const Context& dev_ctx, + const T* data_im, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + T* data_col) { + int channel_per_deformable_group = im_shape[0] / deformable_groups; + int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + int blocks = NumBlocks(num_kernels); + int threads = kNumCUDAThreads; + + ModulatedDeformableIm2colGpuKernel< + T><<>>(num_kernels, + data_im, + data_offset, + data_mask, + im_shape[1], + im_shape[2], + filter_shape[2], + filter_shape[3], + paddings[0], + paddings[1], + strides[0], + strides[1], + dilations[0], + dilations[1], + channel_per_deformable_group, + col_shape[1], + im_shape[0], + deformable_groups, + col_shape[2], + col_shape[3], + data_col); +} + +template void ModulatedDeformableIm2col( + const phi::GPUContext& dev_ctx, + const float* data_im, + const float* data_offset, + const float* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + float* data_col); + +template void ModulatedDeformableIm2col( + const phi::GPUContext& dev_ctx, + const double* data_im, + const double* data_offset, + const double* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + double* data_col); + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/deformable_conv_functor.h b/paddle/phi/kernels/funcs/deformable_conv_functor.h new file mode 100644 index 00000000000..eecda729275 --- /dev/null +++ b/paddle/phi/kernels/funcs/deformable_conv_functor.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace funcs { + +template +HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, + const int data_width, + const int height, + const int width, + T h, + T w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh; + T hw = 1 - lw; + + T v1 = + (h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0; + T v2 = (h_low >= 0 && w_high <= width - 1) + ? bottom_data[h_low * data_width + w_high] + : 0; + T v3 = (h_high <= height - 1 && w_low >= 0) + ? bottom_data[h_high * data_width + w_low] + : 0; + T v4 = (h_high <= height - 1 && w_high <= width - 1) + ? bottom_data[h_high * data_width + w_high] + : 0; + + T w1 = hh * hw; + T w2 = hh * lw; + T w3 = lh * hw; + T w4 = lh * lw; + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +template +void ModulatedDeformableIm2col(const Context& dev_ctx, + const T* data_im, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + T* data_col); + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu b/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu new file mode 100644 index 00000000000..265d123dfea --- /dev/null +++ b/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu @@ -0,0 +1,366 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/deformable_conv_grad_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +__global__ void ModulatedDeformableCol2imGpuKernel( + const int nthreads, + const T* data_col, + const T* data_offset, + const T* data_mask, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int deformable_group, + const int height_col, + const int width_col, + T* grad_im) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t thread = index; thread < nthreads; thread += offset) { + const int j = (thread / width_col / height_col / batch_size) % kernel_w; + const int i = + (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + thread / width_col / height_col / batch_size / kernel_w / kernel_h; + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = thread % width_col; + int h_out = (thread / width_col) % height_col; + int b = (thread / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + T cur_top_grad = data_col[thread]; + if (data_mask) { + const T* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + cur_top_grad *= mask; + } + const int cur_h = static_cast(cur_inv_h_data); + const int cur_w = static_cast(cur_inv_w_data); + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = DmcnGetGradientWeight(cur_inv_h_data, + cur_inv_w_data, + cur_h + dy, + cur_w + dx, + height, + width); + + paddle::platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos, + weight * cur_top_grad); + } + } + } + } +} + +template +void ModulatedDeformableCol2im(const Context& dev_ctx, + const T* data_col, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& kernel_shape, + const std::vector& pad, + const std::vector& stride, + const std::vector& dilation, + const int deformable_group, + T* grad_im) { + int channel_per_deformable_group = im_shape[0] / deformable_group; + int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + int blocks = NumBlocks(num_kernels); + int threads = kNumCUDAThreads; + + ModulatedDeformableCol2imGpuKernel< + T><<>>(num_kernels, + data_col, + data_offset, + data_mask, + im_shape[0], + im_shape[1], + im_shape[2], + kernel_shape[2], + kernel_shape[3], + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + channel_per_deformable_group, + col_shape[1], + deformable_group, + col_shape[2], + col_shape[3], + grad_im); +} + +template +__global__ void ModulatedDeformableCol2imCoordGpuKernel( + const int nthreads, + const T* data_col, + const T* data_im, + const T* data_offset, + const T* data_mask, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int offset_channels, + const int deformable_group, + const int height_col, + const int width_col, + T* grad_offset, + T* grad_mask) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + T val = 0, mval = 0; + const int w = i % width_col; + const int h = (i / width_col) % height_col; + const int c = (i / width_col / height_col) % offset_channels; + const int b = (i / width_col / height_col) / offset_channels; + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T* data_col_ptr = data_col + + deformable_group_index * + channel_per_deformable_group * batch_size * + width_col * height_col; + const T* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / + kernel_w * height * width; + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const T* data_mask_ptr = + data_mask + ? data_mask + + (b * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col + : nullptr; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * + funcs::DmcnIm2colBilinear(data_im_ptr + cnt * height * width, + width, + height, + width, + inv_h, + inv_w); + } + const T weight = + DmcnGetCoordinateWeight(inv_h, + inv_w, + height, + width, + data_im_ptr + cnt * height * width, + width, + bp_dir); + if (data_mask_ptr) { + const int data_mask_hw_ptr = + (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const T mask = data_mask_ptr[data_mask_hw_ptr]; + val += weight * data_col_ptr[col_pos] * mask; + } else { + val += weight * data_col_ptr[col_pos]; + } + cnt += 1; + } + grad_offset[i] = val; + if (grad_mask && offset_c % 2 == 0) + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * + kernel_w + + offset_c / 2) * + height_col + + h) * + width_col + + w] = mval; + } +} + +template +void ModulatedDeformableCol2imCoord(const Context& dev_ctx, + const T* data_col, + const T* data_im, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& kernel_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + T* grad_offset, + T* grad_mask) { + int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int channel_per_deformable_group = col_shape[0] / deformable_groups; + int blocks = NumBlocks(num_kernels); + int threads = kNumCUDAThreads; + + ModulatedDeformableCol2imCoordGpuKernel< + T><<>>( + num_kernels, + data_col, + data_im, + data_offset, + data_mask, + im_shape[0], + im_shape[1], + im_shape[2], + kernel_shape[2], + kernel_shape[3], + paddings[0], + paddings[1], + strides[0], + strides[1], + dilations[0], + dilations[1], + channel_per_deformable_group, + col_shape[1], + 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, + deformable_groups, + col_shape[2], + col_shape[3], + grad_offset, + grad_mask); +} + +template +__global__ void FilterGradAddupGpuKernel(const int nthreads, + const int n, + const int height, + const int width, + const T* dweight_3d, + T* filter_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + filter_grad[i] = filter_grad[i] + dweight_3d[i]; + } +} + +template +void FilterGradAddup(const Context& dev_ctx, + const int nthreads, + const int n, + const int height, + const int width, + const T* dweight_3d, + T* filter_grad) { + FilterGradAddupGpuKernel< + T><<>>( + nthreads, n, height, width, dweight_3d, filter_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(deformable_conv_grad, + GPU, + ALL_LAYOUT, + phi::DeformableConvGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/deformable_conv_kernel.cu b/paddle/phi/kernels/gpu/deformable_conv_kernel.cu index 1db6e1b7cf7..2476dcbafb9 100644 --- a/paddle/phi/kernels/gpu/deformable_conv_kernel.cu +++ b/paddle/phi/kernels/gpu/deformable_conv_kernel.cu @@ -16,142 +16,8 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/impl/deformable_conv_kernel_impl.h" -namespace phi { - -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaximumNumBlocks = 4096; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaximumNumBlocks); -} - -template -__global__ void ModulatedDeformableIm2colGpuKernel( - const int nthreads, - const T* data_im, - const T* data_offset, - const T* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, - T* data_col) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - T* data_col_ptr = - data_col + - ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T* data_im_ptr = - data_im + (b_col * num_channels + c_im) * height * width; - const T* data_offset_ptr = - data_offset + - (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * - kernel_w * height_col * width_col; - const T* data_mask_ptr = - data_mask + - (b_col * deformable_group + deformable_group_index) * kernel_h * - kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + - w_col; - const int data_mask_hw_ptr = - ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - T val = static_cast(0); - const T h_im = h_in + i * dilation_h + offset_h; - const T w_im = w_in + j * dilation_w + offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { - val = - DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val * mask; - data_col_ptr += batch_size * height_col * width_col; - } - } - } -} - -template -void ModulatedDeformableIm2col(const Context& dev_ctx, - const T* data_im, - const T* data_offset, - const T* data_mask, - const std::vector& im_shape, - const std::vector& col_shape, - const std::vector& filter_shape, - const std::vector& paddings, - const std::vector& strides, - const std::vector& dilations, - const int deformable_groups, - T* data_col) { - int channel_per_deformable_group = im_shape[0] / deformable_groups; - int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; - - ModulatedDeformableIm2colGpuKernel< - T><<>>(num_kernels, - data_im, - data_offset, - data_mask, - im_shape[1], - im_shape[2], - filter_shape[2], - filter_shape[3], - paddings[0], - paddings[1], - strides[0], - strides[1], - dilations[0], - dilations[1], - channel_per_deformable_group, - col_shape[1], - im_shape[0], - deformable_groups, - col_shape[2], - col_shape[3], - data_col); -} - -} // namespace phi - PD_REGISTER_KERNEL(deformable_conv, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h new file mode 100644 index 00000000000..8d8e66a02f5 --- /dev/null +++ b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h @@ -0,0 +1,364 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/deformable_conv_functor.h" + +namespace phi { + +template +HOSTDEVICE T DmcnGetGradientWeight(T argmax_h, + T argmax_w, + const int h, + const int w, + const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + weight = (h == argmax_h_low && w == argmax_w_low) + ? (h + 1 - argmax_h) * (w + 1 - argmax_w) + : weight; + weight = (h == argmax_h_low && w == argmax_w_high) + ? (h + 1 - argmax_h) * (argmax_w + 1 - w) + : weight; + weight = (h == argmax_h_high && w == argmax_w_low) + ? (argmax_h + 1 - h) * (w + 1 - argmax_w) + : weight; + weight = (h == argmax_h_high && w == argmax_w_high) + ? (argmax_h + 1 - h) * (argmax_w + 1 - w) + : weight; + + return weight; +} + +template +HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h, + T argmax_w, + const int height, + const int width, + const T* im_data, + const int data_width, + const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + if (bp_dir == 0) { + weight += (argmax_h_low >= 0 && argmax_w_low >= 0) + ? -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low] + : 0; + + weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1) + ? -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high] + : 0; + + weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0) + ? (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low] + : 0; + weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + ? (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high] + : 0; + } else if (bp_dir == 1) { + weight += (argmax_h_low >= 0 && argmax_w_low >= 0) + ? -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low] + : 0; + weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1) + ? (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high] + : 0; + weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0) + ? -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low] + : 0; + weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + ? (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high] + : 0; + } + + return weight; +} + +template +void ModulatedDeformableCol2imCoord(const Context& dev_ctx, + const T* data_col, + const T* data_im, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& kernel_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + T* grad_offset, + T* grad_mask); + +template +void ModulatedDeformableCol2im(const Context& dev_ctx, + const T* data_col, + const T* data_offset, + const T* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& kernel_shape, + const std::vector& pad, + const std::vector& stride, + const std::vector& dilation, + const int deformable_group, + T* grad_im); + +template +void FilterGradAddup(const Context& dev_ctx, + const int nthreads, + const int n, + const int height, + const int width, + const T* dweight_3d, + T* filter_grad); + +template +void DeformableConvGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& offset, + const DenseTensor& filter, + paddle::optional mask, + const DenseTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int deformable_groups, + int groups, + int im2col_step, + DenseTensor* dx, + DenseTensor* offset_grad, + DenseTensor* filter_grad, + DenseTensor* mask_grad) { + const int batch_size = static_cast(x.dims()[0]); + + DDim input_shape = phi::slice_ddim(x.dims(), 1, x.dims().size()); + std::vector input_shape_vec = phi::vectorize(input_shape); + std::vector filter_shape_vec(phi::vectorize(filter.dims())); + std::vector output_shape_vec(phi::vectorize(out_grad.dims())); + + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = x.dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + + DenseTensor col_buffer = Empty(dev_ctx, col_buffer_shape_vec); + DenseTensor output_buffer; + output_buffer.ShareDataWith(out_grad).Resize( + make_ddim(output_buffer_shape_vec)); + + int64_t M = + input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = output_shape_vec[1] / groups; + + DDim weight_3d_shape = {groups, K, M}; + DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, N}; + DDim col_buffer_3d_shape = {groups, M, N}; + DDim filter_grad_shape = {groups, K, M}; + + DenseTensor weight_3d; + weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); + DenseTensor out_grad_4d; + out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); + DenseTensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); + + phi::funcs::SetConstant set_zero; + auto blas = phi::funcs::GetBlas(dev_ctx); + + int input_dim = x.numel() / x.dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + int input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; + + if (filter_grad) { + Full(dev_ctx, + {filter_grad_shape.Get(), filter_grad_shape.size()}, + 0, + filter_grad); + } + + if (dx) { + dev_ctx.template Alloc(dx); + set_zero(dev_ctx, dx, static_cast(0)); + } + + if (offset_grad) { + dev_ctx.template Alloc(offset_grad); + set_zero(dev_ctx, offset_grad, static_cast(0)); + + if (mask_grad) { + dev_ctx.template Alloc(mask_grad); + set_zero(dev_ctx, mask_grad, static_cast(0)); + } + } + + for (int i = 0; i < batch_size / im2col_step; ++i) { + DenseTensor out_grad_3d = out_grad_4d.Slice(i, i + 1).Resize( + phi::slice_ddim(out_grad_4d.dims(), 1, out_grad_4d.dims().size())); + for (int g = 0; g < groups; ++g) { + DenseTensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( + phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size())); + DenseTensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize( + phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + DenseTensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + blas.MatMul(weight_3d_slice, + true, + out_grad_3d_slice, + false, + T(1.0), + &col_buffer_3d_slice, + T(0.0)); + } + col_buffer.Resize(make_ddim(col_buffer_shape_vec)); + + T* col_buffer_ptr = col_buffer.data(); + const T* input_ptr = x.data(); + const T* offset_ptr = offset.data(); + const T* mask_data_ptr = + mask ? mask->data() + i * im2col_step * input_mask_dim : nullptr; + if (offset_grad) { + T* offset_grad_ptr = offset_grad->data(); + T* mask_grad_data_ptr = + mask_grad ? mask_grad->data() + i * im2col_step * input_mask_dim + : nullptr; + // get grad of offset and mask + ModulatedDeformableCol2imCoord( + dev_ctx, + col_buffer_ptr, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_data_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + offset_grad_ptr + i * im2col_step * input_offset_dim, + mask_grad_data_ptr); + } + if (dx) { + T* dx_ptr = dx->data(); + // get grad of input + ModulatedDeformableCol2im(dev_ctx, + col_buffer_ptr, + offset_ptr + i * im2col_step * input_offset_dim, + mask_data_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + dx_ptr + i * im2col_step * input_dim); + dx->Resize(x.dims()); + } + + funcs::ModulatedDeformableIm2col( + dev_ctx, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_data_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + col_buffer_ptr); + + col_buffer_3d.Resize(col_buffer_3d_shape); + + if (filter_grad) { + DenseTensor dweight_3d = Empty( + dev_ctx, {filter_grad_shape.Get(), filter_grad_shape.size()}); + for (int g = 0; g < groups; ++g) { + DenseTensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize( + phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + DenseTensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + DenseTensor dweight_3d_slice = dweight_3d.Slice(g, g + 1).Resize( + phi::slice_ddim(dweight_3d.dims(), 1, dweight_3d.dims().size())); + + blas.MatMul(out_grad_3d_slice, + false, + col_buffer_3d_slice, + true, + T(1.0), + &dweight_3d_slice, + T(0.0)); + } + + // update grad of weights + FilterGradAddup(dev_ctx, + dweight_3d.numel(), + groups, + K, + M, + dweight_3d.data(), + filter_grad->data()); + } + } + if (filter_grad) { + filter_grad->Resize(filter.dims()); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h index d8795808a64..6c0457024dd 100644 --- a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h +++ b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h @@ -18,66 +18,17 @@ #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/deformable_conv_functor.h" +#include "paddle/utils/optional.h" namespace phi { -template -HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, - const int data_width, - const int height, - const int width, - T h, - T w) { - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - T lh = h - h_low; - T lw = w - w_low; - T hh = 1 - lh; - T hw = 1 - lw; - - T v1 = - (h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0; - T v2 = (h_low >= 0 && w_high <= width - 1) - ? bottom_data[h_low * data_width + w_high] - : 0; - T v3 = (h_high <= height - 1 && w_low >= 0) - ? bottom_data[h_high * data_width + w_low] - : 0; - T v4 = (h_high <= height - 1 && w_high <= width - 1) - ? bottom_data[h_high * data_width + w_high] - : 0; - - T w1 = hh * hw; - T w2 = hh * lw; - T w3 = lh * hw; - T w4 = lh * lw; - - return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; -} - -template -void ModulatedDeformableIm2col(const Context& dev_ctx, - const T* data_im, - const T* data_offset, - const T* data_mask, - const std::vector& im_shape, - const std::vector& col_shape, - const std::vector& filter_shape, - const std::vector& paddings, - const std::vector& strides, - const std::vector& dilations, - const int deformable_groups, - T* data_col); - template void DeformableConvKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& offset, const DenseTensor& filter, - const DenseTensor& mask, + paddle::optional mask, const std::vector& strides, const std::vector& paddings, const std::vector& dilations, @@ -125,28 +76,31 @@ void DeformableConvKernel(const Context& dev_ctx, int input_dim = x.numel() / x.dims()[0]; int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask.numel() / mask.dims()[0]; - - auto blas = phi::funcs::GetBlas(dev_ctx); + int input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; const T* input_ptr = x.data(); const T* offset_ptr = offset.data(); - const T* mask_ptr = mask.data(); + const T* mask_ptr = mask ? mask->data() : nullptr; T* col_buffer_ptr = col_buffer.data(); + auto blas = phi::funcs::GetBlas(dev_ctx); + for (int i = 0; i < batch_size / im2col_step; ++i) { - ModulatedDeformableIm2col(dev_ctx, - input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, - input_shape_vec, - col_buffer_shape_vec, - filter_shape_vec, - paddings, - strides, - dilations, - deformable_groups, - col_buffer_ptr); + const T* temp_mask_ptr = + mask_ptr ? mask_ptr + i * im2col_step * input_mask_dim : nullptr; + funcs::ModulatedDeformableIm2col( + dev_ctx, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + temp_mask_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + col_buffer_ptr); DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize( phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); // get the product of pixel and weight diff --git a/paddle/phi/ops/compat/deformable_conv_sig.cc b/paddle/phi/ops/compat/deformable_conv_sig.cc index e2a21673634..a84a0840090 100644 --- a/paddle/phi/ops/compat/deformable_conv_sig.cc +++ b/paddle/phi/ops/compat/deformable_conv_sig.cc @@ -29,6 +29,34 @@ KernelSignature DeformableConvOpArgumentMapping( {"Output"}); } +KernelSignature DeformableConvGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "deformable_conv_grad", + {"Input", "Offset", "Filter", "Mask", GradVarName("Output")}, + {"strides", + "paddings", + "dilations", + "deformable_groups", + "groups", + "im2col_step"}, + {GradVarName("Input"), + GradVarName("Offset"), + GradVarName("Filter"), + GradVarName("Mask")}); +} + } // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(deformable_conv_v1, deformable_conv); +PD_REGISTER_BASE_KERNEL_NAME(deformable_conv_v1_grad, deformable_conv_grad); + PD_REGISTER_ARG_MAPPING_FN(deformable_conv, phi::DeformableConvOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(deformable_conv_grad, + phi::DeformableConvGradOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(deformable_conv_v1, + phi::DeformableConvOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(deformable_conv_v1_grad, + phi::DeformableConvGradOpArgumentMapping); -- GitLab