未验证 提交 a8e5c9be 编写于 作者: Z zyfncg 提交者: GitHub

move deformable_conv forward kernel to phi (#40700)

上级 c46f2ddb
......@@ -338,8 +338,6 @@ REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv, ops::DeformableConvCPUKernel<float>,
ops::DeformableConvCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCPUKernel<float>,
ops::DeformableConvGradCPUKernel<double>);
......@@ -446,108 +446,6 @@ __global__ void FilterGradAddupGpuKernel(const int nthreads, const int n,
}
}
template <typename DeviceContext, typename T>
class DeformableConvCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor offset = *ctx.Input<Tensor>("Offset");
const Tensor mask = *ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.cuda_device_context();
const int groups = ctx.Attr<int>("groups");
const int deformable_groups = ctx.Attr<int>("deformable_groups");
const int im2col_step = ctx.Attr<int>("im2col_step");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(phi::vectorize(output->dims()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std::vector<int64_t> col_buffer_shape_vec(filter_shape_vec.size());
col_buffer_shape_vec[0] =
input->dims()[1] * filter.dims()[2] * filter.dims()[3];
col_buffer_shape_vec[1] = im2col_step;
for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) {
col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2];
}
framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec));
std::vector<int64_t> output_buffer_shape_vec(1);
output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] *
output_shape_vec[2] * output_shape_vec[3];
framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec));
Tensor col_buffer;
Tensor output_buffer;
col_buffer = ctx.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx);
output_buffer =
ctx.AllocateTmpTensor<T, DeviceContext>(output_shape, dev_ctx);
int64_t M = output_shape_vec[1] / groups;
int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3];
int64_t K =
input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups;
Tensor weight_3d;
weight_3d.ShareDataWith(filter).Resize(phi::make_ddim({groups, M, K}));
Tensor col_buffer_3d;
col_buffer_3d.ShareDataWith(col_buffer)
.Resize(phi::make_ddim({groups, K, N}));
Tensor output_4d;
output_4d.ShareDataWith(output_buffer)
.Resize(phi::make_ddim({batch_size / im2col_step, groups, M, N}));
output_4d.mutable_data<T>(ctx.GetPlace());
framework::DDim input_shape =
phi::slice_ddim(input->dims(), 1, input->dims().size());
std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape);
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask.numel() / mask.dims()[0];
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
const T* input_ptr = input->data<T>();
const T* offset_ptr = offset.data<T>();
const T* mask_ptr = mask.data<T>();
col_buffer.mutable_data<T>(ctx.GetPlace());
T* col_buffer_ptr = col_buffer.data<T>();
for (int i = 0; i < batch_size / im2col_step; ++i) {
ModulatedDeformableIm2col(
ctx.device_context(), input_ptr + i * im2col_step * input_dim,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec,
col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations,
deformable_groups, col_buffer_ptr);
Tensor output_3d = output_4d.Slice(i, i + 1).Resize(
phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size()));
for (int g = 0; g < groups; ++g) {
Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size()));
Tensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
Tensor output_3d_slice = output_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size()));
blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0),
&output_3d_slice, T(0.0));
}
}
output->ShareDataWith(output_buffer)
.Resize(phi::make_ddim(output_shape_vec));
}
};
template <typename DeviceContext, typename T>
class DeformableConvGradCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -740,9 +638,6 @@ class DeformableConvGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(deformable_conv,
ops::DeformableConvCUDAKernel<CUDA, float>,
ops::DeformableConvCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCUDAKernel<CUDA, float>,
ops::DeformableConvGradCUDAKernel<CUDA, double>);
......@@ -318,102 +318,6 @@ void FilterGradAddupCPUKernel(const int nthreads, const int n, const int height,
}
}
template <typename T>
class DeformableConvCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* offset = ctx.Input<Tensor>("Offset");
auto* mask = ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
const int groups = ctx.Attr<int>("groups");
const int deformable_groups = ctx.Attr<int>("deformable_groups");
const int im2col_step = ctx.Attr<int>("im2col_step");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(phi::vectorize(output->dims()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std::vector<int64_t> col_buffer_shape_vec(filter_shape_vec.size());
col_buffer_shape_vec[0] =
input->dims()[1] * filter.dims()[2] * filter.dims()[3];
col_buffer_shape_vec[1] = im2col_step;
for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) {
col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2];
}
framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec));
std::vector<int64_t> output_buffer_shape_vec(1);
output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] *
output_shape_vec[2] * output_shape_vec[3];
framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec));
Tensor col_buffer;
Tensor output_buffer;
col_buffer = ctx.AllocateTmpTensor<T, CPUDeviceContext>(col_shape, dev_ctx);
output_buffer =
ctx.AllocateTmpTensor<T, CPUDeviceContext>(output_shape, dev_ctx);
int64_t M = output_shape_vec[1] / groups;
int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3];
int64_t K =
input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups;
Tensor weight_3d;
weight_3d.ShareDataWith(filter).Resize(phi::make_ddim({groups, M, K}));
Tensor col_buffer_3d;
col_buffer_3d.ShareDataWith(col_buffer)
.Resize(phi::make_ddim({groups, K, N}));
Tensor output_4d;
output_4d.ShareDataWith(output_buffer)
.Resize(phi::make_ddim({batch_size / im2col_step, groups, M, N}));
output_4d.mutable_data<T>(ctx.GetPlace());
framework::DDim input_shape =
phi::slice_ddim(input->dims(), 1, input->dims().size());
std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape);
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset->numel() / offset->dims()[0];
int input_mask_dim = mask->numel() / mask->dims()[0];
auto blas = phi::funcs::GetBlas<CPUDeviceContext, T>(dev_ctx);
const T* input_ptr = input->data<T>();
const T* offset_ptr = offset->data<T>();
const T* mask_ptr = mask->data<T>();
col_buffer.mutable_data<T>(ctx.GetPlace());
T* col_buffer_ptr = col_buffer.data<T>();
for (int i = 0; i < batch_size / im2col_step; ++i) {
ModulatedDeformableIm2colCPU(
dev_ctx, input_ptr + i * im2col_step * input_dim,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec,
col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations,
deformable_groups, col_buffer_ptr);
Tensor output_3d = output_4d.Slice(i, i + 1).Resize(
phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size()));
// get the product of pixel and weight
for (int g = 0; g < groups; ++g) {
Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size()));
Tensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
Tensor output_3d_slice = output_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size()));
blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0),
&output_3d_slice, T(0.0));
}
}
output->ShareDataWith(output_buffer)
.Resize(phi::make_ddim(output_shape_vec));
}
};
template <typename T>
class DeformableConvGradCPUKernel : public framework::OpKernel<T> {
public:
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/deformable_conv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/deformable_conv_kernel_impl.h"
namespace phi {
template <typename T>
inline void ModulatedDeformableIm2colCPUKernel(
const int num_kernels,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
T* data_col) {
for (int i = 0; i < num_kernels; i++) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T* data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T* data_offset_ptr =
data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const T* data_mask_ptr =
data_mask +
(b_col * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T, typename Context>
void ModulatedDeformableIm2col(const Context& dev_ctx,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* data_col) {
int channel_per_deformable_group = im_shape[0] / deformable_groups;
int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
// get outputs of im2col with offset by bilinear interpolation
ModulatedDeformableIm2colCPUKernel(num_kernels,
data_im,
data_offset,
data_mask,
im_shape[1],
im_shape[2],
filter_shape[2],
filter_shape[3],
paddings[0],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
channel_per_deformable_group,
col_shape[1],
im_shape[0],
deformable_groups,
col_shape[2],
col_shape[3],
data_col);
}
} // namespace phi
PD_REGISTER_KERNEL(deformable_conv,
CPU,
ALL_LAYOUT,
phi::DeformableConvKernel,
float,
double) {}
......@@ -18,7 +18,7 @@
namespace phi {
template <typename Functor, typename Context>
template <typename T, typename Context>
void CumsumKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DeformableConvKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& offset,
const DenseTensor& filter,
const DenseTensor& mask,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/deformable_conv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/impl/deformable_conv_kernel_impl.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void ModulatedDeformableIm2colGpuKernel(
const int nthreads,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
T* data_col) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T* data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T* data_offset_ptr =
data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const T* data_mask_ptr =
data_mask +
(b_col * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T, typename Context>
void ModulatedDeformableIm2col(const Context& dev_ctx,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* data_col) {
int channel_per_deformable_group = im_shape[0] / deformable_groups;
int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
int blocks = NumBlocks(num_kernels);
int threads = kNumCUDAThreads;
ModulatedDeformableIm2colGpuKernel<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(num_kernels,
data_im,
data_offset,
data_mask,
im_shape[1],
im_shape[2],
filter_shape[2],
filter_shape[3],
paddings[0],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
channel_per_deformable_group,
col_shape[1],
im_shape[0],
deformable_groups,
col_shape[2],
col_shape[3],
data_col);
}
} // namespace phi
PD_REGISTER_KERNEL(deformable_conv,
GPU,
ALL_LAYOUT,
phi::DeformableConvKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
template <typename T>
HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data,
const int data_width,
const int height,
const int width,
T h,
T w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh;
T hw = 1 - lw;
T v1 =
(h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
T v2 = (h_low >= 0 && w_high <= width - 1)
? bottom_data[h_low * data_width + w_high]
: 0;
T v3 = (h_high <= height - 1 && w_low >= 0)
? bottom_data[h_high * data_width + w_low]
: 0;
T v4 = (h_high <= height - 1 && w_high <= width - 1)
? bottom_data[h_high * data_width + w_high]
: 0;
T w1 = hh * hw;
T w2 = hh * lw;
T w3 = lh * hw;
T w4 = lh * lw;
return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
template <typename T, typename Context>
void ModulatedDeformableIm2col(const Context& dev_ctx,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* data_col);
template <typename T, typename Context>
void DeformableConvKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& offset,
const DenseTensor& filter,
const DenseTensor& mask,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
DenseTensor* out) {
const int batch_size = static_cast<int>(x.dims()[0]);
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(phi::vectorize(out->dims()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std::vector<int64_t> col_buffer_shape_vec(filter_shape_vec.size());
col_buffer_shape_vec[0] = x.dims()[1] * filter.dims()[2] * filter.dims()[3];
col_buffer_shape_vec[1] = im2col_step;
for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) {
col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2];
}
std::vector<int64_t> output_buffer_shape_vec(1);
output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] *
output_shape_vec[2] * output_shape_vec[3];
DenseTensor col_buffer = Empty<T>(dev_ctx, col_buffer_shape_vec);
DenseTensor output_buffer = Empty<T>(dev_ctx, output_buffer_shape_vec);
int64_t M = output_shape_vec[1] / groups;
int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3];
int64_t K = x.dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups;
DenseTensor weight_3d;
weight_3d.ShareDataWith(filter).Resize(phi::make_ddim({groups, M, K}));
DenseTensor col_buffer_3d;
col_buffer_3d.ShareDataWith(col_buffer)
.Resize(phi::make_ddim({groups, K, N}));
DenseTensor output_4d;
output_4d.ShareDataWith(output_buffer)
.Resize(phi::make_ddim({batch_size / im2col_step, groups, M, N}));
DDim input_shape = phi::slice_ddim(x.dims(), 1, x.dims().size());
std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape);
int input_dim = x.numel() / x.dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask.numel() / mask.dims()[0];
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
const T* input_ptr = x.data<T>();
const T* offset_ptr = offset.data<T>();
const T* mask_ptr = mask.data<T>();
T* col_buffer_ptr = col_buffer.data<T>();
for (int i = 0; i < batch_size / im2col_step; ++i) {
ModulatedDeformableIm2col(dev_ctx,
input_ptr + i * im2col_step * input_dim,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
input_shape_vec,
col_buffer_shape_vec,
filter_shape_vec,
paddings,
strides,
dilations,
deformable_groups,
col_buffer_ptr);
DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize(
phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size()));
// get the product of pixel and weight
for (int g = 0; g < groups; ++g) {
DenseTensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size()));
DenseTensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
DenseTensor output_3d_slice = output_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size()));
blas.MatMul(weight_3d_slice,
false,
col_buffer_3d_slice,
false,
T(1.0),
&output_3d_slice,
T(0.0));
}
}
out->ShareDataWith(output_buffer).Resize(phi::make_ddim(output_shape_vec));
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DeformableConvOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("deformable_conv",
{"Input", "Offset", "Filter", "Mask"},
{"strides",
"paddings",
"dilations",
"deformable_groups",
"groups",
"im2col_step"},
{"Output"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(deformable_conv,
phi::DeformableConvOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册