未验证 提交 7e3752bb 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Move deformable_conv and deformable_conv_v1 to phi (#40794)

* move deformable_conv_grad to phi

* move infershape of deformable_conv to phi

* adjust some code format

* move deformable_conv_v1 to phi
上级 778008d7
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
//
// Copyright (c) 2017 Microsoft
// Licensed under The Apache-2.0 License [see LICENSE for details]
// \file deformable_psroi_pooling.cu
// \brief
// \author Yi Li, Guodong Zhang, Jifeng Dai
#pragma once
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
template <typename T>
HOSTDEVICE T DmcnGetGradientWeight(T argmax_h, T argmax_w, const int h,
const int w, const int height,
const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
weight = (h == argmax_h_low && w == argmax_w_low)
? (h + 1 - argmax_h) * (w + 1 - argmax_w)
: weight;
weight = (h == argmax_h_low && w == argmax_w_high)
? (h + 1 - argmax_h) * (argmax_w + 1 - w)
: weight;
weight = (h == argmax_h_high && w == argmax_w_low)
? (argmax_h + 1 - h) * (w + 1 - argmax_w)
: weight;
weight = (h == argmax_h_high && w == argmax_w_high)
? (argmax_h + 1 - h) * (argmax_w + 1 - w)
: weight;
return weight;
}
template <typename T>
HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height,
const int width, const T* im_data,
const int data_width, const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (bp_dir == 0) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low]
: 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high]
: 0;
} else if (bp_dir == 1) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low]
: 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high]
: 0;
}
return weight;
}
template <typename T>
HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, const int data_width,
const int height, const int width, T h, T w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh;
T hw = 1 - lw;
T v1 =
(h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
T v2 = (h_low >= 0 && w_high <= width - 1)
? bottom_data[h_low * data_width + w_high]
: 0;
T v3 = (h_high <= height - 1 && w_low >= 0)
? bottom_data[h_high * data_width + w_low]
: 0;
T v4 = (h_high <= height - 1 && w_high <= width - 1)
? bottom_data[h_high * data_width + w_high]
: 0;
T w1 = hh * hw;
T w2 = hh * lw;
T w3 = lh * hw;
T w4 = lh * lw;
return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
......@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/deformable_conv_op.h"
#include <memory>
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -108,158 +110,6 @@ $$
class DeformableConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "deformable_conv");
OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset",
"deformable_conv)");
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "deformable_conv");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter",
"deformable_conv");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output",
"deformable_conv");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
auto offset_dims = ctx->GetInputDim("Offset");
auto mask_dims = ctx->GetInputDim("Mask");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
int groups = ctx->Attrs().Get<int>("groups");
int deformable_groups = ctx->Attrs().Get<int>("deformable_groups");
int im2col_step = ctx->Attrs().Get<int>("im2col_step");
PADDLE_ENFORCE_EQ(
in_dims.size(), 4,
platform::errors::InvalidArgument(
"Conv input should be 4-D tensor, get %u", in_dims.size()));
PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(),
platform::errors::InvalidArgument(
"Conv input dimension and filter dimension should be "
"the same. The difference is [%d]: [%d]",
in_dims.size(), filter_dims.size()));
PADDLE_ENFORCE_EQ(in_dims.size() - strides.size(), 2U,
platform::errors::InvalidArgument(
"Conv input dimension and strides "
"dimension should be consistent. But received input "
"dimension:[%d], strides dimension:[%d]",
in_dims.size(), strides.size()));
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
platform::errors::InvalidArgument(
"Conv paddings dimension and Conv strides dimension "
"should be the same. The difference is [%d]: [%d]",
paddings.size(), strides.size()));
PADDLE_ENFORCE_EQ(
in_dims[1], filter_dims[1] * groups,
platform::errors::InvalidArgument(
"The number of input channels should be equal to filter "
"channels * groups. The difference is [%d]: [%d]",
in_dims[1], filter_dims[1] * groups));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
platform::errors::InvalidArgument(
"The number of output channels should be divided by groups. But "
"received output channels:[%d], groups:[%d]",
filter_dims[0], groups));
PADDLE_ENFORCE_EQ(
filter_dims[0] % deformable_groups, 0,
platform::errors::InvalidArgument(
"The number of output channels should be "
"divided by deformable groups. The difference is [%d]: [%d]",
filter_dims[0] % groups, 0));
if (in_dims[0] > im2col_step) {
PADDLE_ENFORCE_EQ(
in_dims[0] % im2col_step, 0U,
platform::errors::InvalidArgument(
"Input batchsize must be smaller than or divide im2col_step. But "
"received Input batchsize:[%d], im2col_step:[%d]",
in_dims[0], im2col_step));
}
for (size_t i = 0; i < strides.size(); ++i) {
PADDLE_ENFORCE_GT(strides[i], 0U, platform::errors::InvalidArgument(
"stride %d size incorrect", i));
}
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_GT(dilations[i], 0U, platform::errors::InvalidArgument(
"dilation %d size incorrect", i));
}
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
if ((!ctx->IsRuntime()) &&
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2],
filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
}
PADDLE_ENFORCE_EQ(
output_shape[1] % deformable_groups, 0U,
platform::errors::InvalidArgument(
"output num_filter must divide deformable group size. But received "
"output num_filter:[%d], deformable group size:[%d]",
output_shape[1], deformable_groups));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
platform::errors::InvalidArgument(
"output height must equal to offset map height. "
"The difference is [%d]: [%d]",
output_shape[2], offset_dims[2]));
PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
platform::errors::InvalidArgument(
"output width must equal to offset map width. The "
"difference is [%d]: [%d]",
output_shape[3], offset_dims[3]));
PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
platform::errors::InvalidArgument(
"offset filter must divide deformable group size. "
"But received [%d]: [%d]",
offset_dims[1], filter_dims[2] * filter_dims[3]));
PADDLE_ENFORCE_EQ(
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups,
platform::errors::InvalidArgument(
"offset filter must divide deformable group size. But received "
"[%d]: [%d]",
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups));
PADDLE_ENFORCE_EQ(output_shape[2], mask_dims[2],
platform::errors::InvalidArgument(
"output height must equal to mask map height. The "
"difference is [%d] vs [%d]",
output_shape[2], mask_dims[2]));
PADDLE_ENFORCE_EQ(output_shape[3], mask_dims[3],
platform::errors::InvalidArgument(
"output width must equal to mask map width. The "
"difference is [%d] vs [%d]",
output_shape[3], mask_dims[3]));
PADDLE_ENFORCE_EQ(mask_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
platform::errors::InvalidArgument(
"mask filter must divide deformable group size. "
"But received [%d]: [%d]",
mask_dims[1], filter_dims[2] * filter_dims[3]));
PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]),
deformable_groups,
platform::errors::InvalidArgument(
"mask filter must divide deformable group size. "
"But received [%d]: [%d]",
mask_dims[1] / (filter_dims[2] * filter_dims[3]),
deformable_groups));
}
ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -331,13 +181,13 @@ class DeformableConvGradOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(deformable_conv, DeformableConvInferShapeFunctor,
PD_INFER_META(phi::DeformableConvInferMeta));
REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
ops::DeformableConvOpMaker,
ops::DeformableConvGradOpMaker<paddle::framework::OpDesc>,
ops::DeformableConvGradOpMaker<paddle::imperative::OpBase>);
ops::DeformableConvGradOpMaker<paddle::imperative::OpBase>,
DeformableConvInferShapeFunctor);
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCPUKernel<float>,
ops::DeformableConvGradCPUKernel<double>);
此差异已折叠。
此差异已折叠。
......@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/deformable_conv_v1_op.h"
#include <memory>
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -113,128 +115,6 @@ $$
class DeformableConvV1Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
"deformable_conv_v1");
OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset",
"deformable_conv_v1");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter",
"deformable_conv_v1");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output",
"deformable_conv_v1");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
auto offset_dims = ctx->GetInputDim("Offset");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
int groups = ctx->Attrs().Get<int>("groups");
int deformable_groups = ctx->Attrs().Get<int>("deformable_groups");
int im2col_step = ctx->Attrs().Get<int>("im2col_step");
PADDLE_ENFORCE_EQ(
in_dims.size(), 4,
platform::errors::InvalidArgument(
"Conv input should be 4-D tensor, get %u", in_dims.size()));
PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(),
platform::errors::InvalidArgument(
"Conv input dimension and filter dimension should be "
"the same. the difference is [%d] vs [%d]",
in_dims.size(), filter_dims.size()));
PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size(), 2U,
platform::errors::InvalidArgument(
"Conv input dimension and strides "
"dimension should be consistent., But received [%d]: [%d]",
in_dims.size(), strides.size()));
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
platform::errors::InvalidArgument(
"Conv paddings dimension and Conv strides dimension "
"should be the same. The difference is [%d] vs [%d]",
paddings.size(), strides.size()));
PADDLE_ENFORCE_EQ(
in_dims[1], filter_dims[1] * groups,
platform::errors::InvalidArgument(
"The number of input channels should be equal to filter "
"channels * groups. The difference is [%d]: [%d]",
in_dims[1], filter_dims[1] * groups));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
platform::errors::InvalidArgument(
"The number of output channels should be divided by groups. But"
"received output channels: [%d], groups: [%d]",
filter_dims[0], groups));
PADDLE_ENFORCE_EQ(
filter_dims[0] % deformable_groups, 0,
platform::errors::InvalidArgument(
"The number of output channels should be "
"divided by deformable groups. But received [%d]: [%d]",
filter_dims[0], deformable_groups));
if (in_dims[0] > im2col_step) {
PADDLE_ENFORCE_EQ(in_dims[0] % im2col_step, 0U,
platform::errors::InvalidArgument(
"Input batchsize must be smaller than or divide "
"im2col_step, But received [%d]: [%d]",
in_dims[0], im2col_step));
}
for (size_t i = 0; i < strides.size(); ++i) {
PADDLE_ENFORCE_GT(strides[i], 0U, platform::errors::InvalidArgument(
"stride %d size incorrect", i));
}
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_GT(dilations[i], 0U, platform::errors::InvalidArgument(
"dilation %d size incorrect", i));
}
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
if ((!ctx->IsRuntime()) &&
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2],
filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U,
platform::errors::InvalidArgument(
"output num_filter must divide deformable group "
"size. But received [%d]: [%d]",
output_shape[1], deformable_groups));
PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
platform::errors::InvalidArgument(
"output height must equal to offset map height. "
"The difference is [%d]: [%d]",
output_shape[2], offset_dims[2]));
PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
platform::errors::InvalidArgument(
"output width must equal to offset map width. The "
"difference is [%d]: [%d]",
output_shape[3], offset_dims[3]));
PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
platform::errors::InvalidArgument(
"offset filter must divide deformable group size. "
"But received [%d]: [%d]",
offset_dims[1], filter_dims[2] * filter_dims[3]));
PADDLE_ENFORCE_EQ(
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups,
platform::errors::InvalidArgument(
"offset filter must divide deformable group size. But received "
"[%d]: [%d]",
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups));
}
ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -300,15 +180,12 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(deformable_conv, DeformableConvV1InferShapeFunctor,
PD_INFER_META(phi::DeformableConvInferMeta));
REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op,
ops::DeformableConvV1OpMaker,
ops::DeformableConvV1GradOpMaker<paddle::framework::OpDesc>,
ops::DeformableConvV1GradOpMaker<paddle::imperative::OpBase>);
ops::DeformableConvV1GradOpMaker<paddle::imperative::OpBase>,
DeformableConvV1InferShapeFunctor);
REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1,
ops::DeformableConvV1CPUKernel<float>,
ops::DeformableConvV1CPUKernel<double>);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1_grad,
ops::DeformableConvV1GradCPUKernel<float>,
ops::DeformableConvV1GradCPUKernel<double>);
......@@ -655,6 +655,7 @@ void BindImperative(py::module *m_ptr) {
} else {
act_name = name.cast<std::string>();
}
VLOG(4) << "Init VarBase :" << act_name;
new (&self) imperative::VarBase(act_name);
self.SetPersistable(persistable);
self.SetType(type);
......
......@@ -516,6 +516,215 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0));
}
inline int ConvOutputSize(
int input_size, int filter_size, int dilation, int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
PADDLE_ENFORCE_GT(
output_size,
0,
phi::errors::InvalidArgument(
"The output's size is expected to be greater than 0. "
"But recieved: output's size is %d. The output's size is computed by "
"((input_size + 2 * padding - (dilation * (filter_size - 1) + 1)) / "
"stride + 1), where input_size is %d, padding is %d, "
"filter_size is %d, dilation is %d, stride is %d.",
output_size,
input_size,
padding,
filter_size,
dilation,
stride));
return output_size;
}
void DeformableConvInferMeta(const MetaTensor& x,
const MetaTensor& offset,
const MetaTensor& filter,
paddle::optional<const MetaTensor&> mask,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
MetaTensor* out,
MetaConfig config) {
auto in_dims = x.dims();
auto offset_dims = offset.dims();
auto filter_dims = filter.dims();
PADDLE_ENFORCE_EQ(
in_dims.size(),
4,
phi::errors::InvalidArgument("Conv input should be 4-D tensor, get %u",
in_dims.size()));
PADDLE_ENFORCE_EQ(in_dims.size(),
filter_dims.size(),
phi::errors::InvalidArgument(
"Conv input dimension and filter dimension should be "
"the same. The difference is [%d]: [%d]",
in_dims.size(),
filter_dims.size()));
PADDLE_ENFORCE_EQ(in_dims.size() - strides.size(),
2U,
phi::errors::InvalidArgument(
"Conv input dimension and strides "
"dimension should be consistent. But received input "
"dimension:[%d], strides dimension:[%d]",
in_dims.size(),
strides.size()));
PADDLE_ENFORCE_EQ(paddings.size(),
strides.size(),
phi::errors::InvalidArgument(
"Conv paddings dimension and Conv strides dimension "
"should be the same. The difference is [%d]: [%d]",
paddings.size(),
strides.size()));
PADDLE_ENFORCE_EQ(
in_dims[1],
filter_dims[1] * groups,
phi::errors::InvalidArgument(
"The number of input channels should be equal to filter "
"channels * groups. The difference is [%d]: [%d]",
in_dims[1],
filter_dims[1] * groups));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups,
0,
phi::errors::InvalidArgument(
"The number of output channels should be divided by groups. But "
"received output channels:[%d], groups:[%d]",
filter_dims[0],
groups));
PADDLE_ENFORCE_EQ(
filter_dims[0] % deformable_groups,
0,
phi::errors::InvalidArgument(
"The number of output channels should be "
"divided by deformable groups. The difference is [%d]: [%d]",
filter_dims[0] % groups,
0));
if (in_dims[0] > im2col_step) {
PADDLE_ENFORCE_EQ(
in_dims[0] % im2col_step,
0U,
phi::errors::InvalidArgument(
"Input batchsize must be smaller than or divide im2col_step. But "
"received Input batchsize:[%d], im2col_step:[%d]",
in_dims[0],
im2col_step));
}
for (size_t i = 0; i < strides.size(); ++i) {
PADDLE_ENFORCE_GT(
strides[i],
0U,
phi::errors::InvalidArgument("stride %d size incorrect", i));
}
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_GT(
dilations[i],
0U,
phi::errors::InvalidArgument("dilation %d size incorrect", i));
}
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
if (!config.is_runtime &&
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2],
filter_dims[i + 2],
dilations[i],
paddings[i],
strides[i]));
}
}
PADDLE_ENFORCE_EQ(
output_shape[1] % deformable_groups,
0U,
phi::errors::InvalidArgument(
"output num_filter must divide deformable group size. But received "
"output num_filter:[%d], deformable group size:[%d]",
output_shape[1],
deformable_groups));
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(output_shape[2],
offset_dims[2],
phi::errors::InvalidArgument(
"output height must equal to offset map height. "
"The difference is [%d]: [%d]",
output_shape[2],
offset_dims[2]));
PADDLE_ENFORCE_EQ(output_shape[3],
offset_dims[3],
phi::errors::InvalidArgument(
"output width must equal to offset map width. The "
"difference is [%d]: [%d]",
output_shape[3],
offset_dims[3]));
PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]),
0U,
phi::errors::InvalidArgument(
"offset filter must divide deformable group size. "
"But received [%d]: [%d]",
offset_dims[1],
filter_dims[2] * filter_dims[3]));
PADDLE_ENFORCE_EQ(
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups,
phi::errors::InvalidArgument(
"offset filter must divide deformable group size. But received "
"[%d]: [%d]",
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups));
if (mask) {
auto mask_dims = mask->dims();
PADDLE_ENFORCE_EQ(output_shape[2],
mask_dims[2],
phi::errors::InvalidArgument(
"output height must equal to mask map height. The "
"difference is [%d] vs [%d]",
output_shape[2],
mask_dims[2]));
PADDLE_ENFORCE_EQ(output_shape[3],
mask_dims[3],
phi::errors::InvalidArgument(
"output width must equal to mask map width. The "
"difference is [%d] vs [%d]",
output_shape[3],
mask_dims[3]));
PADDLE_ENFORCE_EQ(mask_dims[1] % (filter_dims[2] * filter_dims[3]),
0U,
phi::errors::InvalidArgument(
"mask filter must divide deformable group size. "
"But received [%d]: [%d]",
mask_dims[1],
filter_dims[2] * filter_dims[3]));
PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]),
deformable_groups,
phi::errors::InvalidArgument(
"mask filter must divide deformable group size. "
"But received [%d]: [%d]",
mask_dims[1] / (filter_dims[2] * filter_dims[3]),
deformable_groups));
}
}
out->set_dims(phi::make_ddim(output_shape));
out->set_dtype(x.dtype());
}
void HierarchicalSigmoidInferMeta(const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& label,
......
......@@ -120,6 +120,19 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void DeformableConvInferMeta(const MetaTensor& x,
const MetaTensor& offset,
const MetaTensor& filter,
paddle::optional<const MetaTensor&> mask,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
MetaTensor* out,
MetaConfig config = MetaConfig());
void HierarchicalSigmoidInferMeta(const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& label,
......
......@@ -27,12 +27,14 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel
set(MANUAL_BUILD_KERNELS deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel
hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel)
kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
kernel_library(deformable_conv_grad_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(hierarchical_sigmoid_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code)
kernel_library(hierarchical_sigmoid_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/deformable_conv_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h"
namespace phi {
template <typename T>
inline void ModulatedDeformableCol2imCPUKernel(
const int num_kernels,
const T* data_col,
const T* data_offset,
const T* data_mask,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int deformable_group,
const int height_col,
const int width_col,
T* grad_im) {
for (int thread = 0; thread < num_kernels; thread++) {
const int j = (thread / width_col / height_col / batch_size) % kernel_w;
const int i =
(thread / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
thread / width_col / height_col / batch_size / kernel_w / kernel_h;
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = thread % width_col;
int h_out = (thread / width_col) % height_col;
int b = (thread / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const T* data_offset_ptr = data_offset +
(b * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col *
width_col;
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
T cur_top_grad = data_col[thread];
if (data_mask) {
const T* data_mask_ptr = data_mask +
(b * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col;
const T mask = data_mask_ptr[data_mask_hw_ptr];
cur_top_grad *= mask;
}
const int cur_h = static_cast<int>(cur_inv_h_data);
const int cur_w = static_cast<int>(cur_inv_w_data);
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
T weight = DmcnGetGradientWeight(cur_inv_h_data,
cur_inv_w_data,
cur_h + dy,
cur_w + dx,
height,
width);
*(grad_im + cur_bottom_grad_pos) =
*(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad;
}
}
}
}
}
template <typename T, typename Context>
void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& kernel_shape,
const std::vector<int>& pad,
const std::vector<int>& stride,
const std::vector<int>& dilation,
const int deformable_group,
T* grad_im) {
int channel_per_deformable_group = im_shape[0] / deformable_group;
int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
ModulatedDeformableCol2imCPUKernel(num_kernels,
data_col,
data_offset,
data_mask,
im_shape[0],
im_shape[1],
im_shape[2],
kernel_shape[2],
kernel_shape[3],
pad[0],
pad[1],
stride[0],
stride[1],
dilation[0],
dilation[1],
channel_per_deformable_group,
col_shape[1],
deformable_group,
col_shape[2],
col_shape[3],
grad_im);
}
template <typename T>
void ModulatedDeformableCol2imCoordCPUKernel(
const int num_kernels,
const T* data_col,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int offset_channels,
const int deformable_group,
const int height_col,
const int width_col,
T* grad_offset,
T* grad_mask) {
for (int i = 0; i < num_kernels; i++) {
T val = 0, mval = 0;
const int w = i % width_col;
const int h = (i / width_col) % height_col;
const int c = (i / width_col / height_col) % offset_channels;
const int b = (i / width_col / height_col) / offset_channels;
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const T* data_col_ptr = data_col +
deformable_group_index *
channel_per_deformable_group * batch_size *
width_col * height_col;
const T* data_im_ptr = data_im +
(b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h /
kernel_w * height * width;
const T* data_offset_ptr = data_offset +
(b * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col *
width_col;
const T* data_mask_ptr =
data_mask
? data_mask +
(b * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col
: nullptr;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = offset_c / 2; col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T inv_h = h_in + i * dilation_h + offset_h;
T inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
inv_h = inv_w = -2;
} else {
mval += data_col_ptr[col_pos] *
funcs::DmcnIm2colBilinear(data_im_ptr + cnt * height * width,
width,
height,
width,
inv_h,
inv_w);
}
const T weight =
DmcnGetCoordinateWeight(inv_h,
inv_w,
height,
width,
data_im_ptr + cnt * height * width,
width,
bp_dir);
if (data_mask_ptr) {
const int data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const T mask = data_mask_ptr[data_mask_hw_ptr];
val += weight * data_col_ptr[col_pos] * mask;
} else {
val += weight * data_col_ptr[col_pos];
}
cnt += 1;
}
grad_offset[i] = val;
if (grad_mask && offset_c % 2 == 0)
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
kernel_w +
offset_c / 2) *
height_col +
h) *
width_col +
w] = mval;
}
}
template <typename T, typename Context>
void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
const T* data_col,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& kernel_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* grad_offset,
T* grad_mask) {
int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] *
col_shape[2] * col_shape[3] * deformable_groups;
int channel_per_deformable_group = col_shape[0] / deformable_groups;
ModulatedDeformableCol2imCoordCPUKernel(
num_kernels,
data_col,
data_im,
data_offset,
data_mask,
im_shape[0],
im_shape[1],
im_shape[2],
kernel_shape[2],
kernel_shape[3],
paddings[0],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
channel_per_deformable_group,
col_shape[1],
2 * kernel_shape[2] * kernel_shape[3] * deformable_groups,
deformable_groups,
col_shape[2],
col_shape[3],
grad_offset,
grad_mask);
}
template <typename T, typename Context>
void FilterGradAddup(const Context& dev_ctx,
const int nthreads,
const int n,
const int height,
const int width,
const T* dweight_3d,
T* filter_grad) {
for (int i = 0; i < nthreads; i++) {
filter_grad[i] = filter_grad[i] + dweight_3d[i];
}
}
} // namespace phi
PD_REGISTER_KERNEL(deformable_conv_grad,
CPU,
ALL_LAYOUT,
phi::DeformableConvGradKernel,
float,
double) {}
......@@ -18,126 +18,6 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/deformable_conv_kernel_impl.h"
namespace phi {
template <typename T>
inline void ModulatedDeformableIm2colCPUKernel(
const int num_kernels,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
T* data_col) {
for (int i = 0; i < num_kernels; i++) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T* data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T* data_offset_ptr =
data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const T* data_mask_ptr =
data_mask +
(b_col * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T, typename Context>
void ModulatedDeformableIm2col(const Context& dev_ctx,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* data_col) {
int channel_per_deformable_group = im_shape[0] / deformable_groups;
int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
// get outputs of im2col with offset by bilinear interpolation
ModulatedDeformableIm2colCPUKernel(num_kernels,
data_im,
data_offset,
data_mask,
im_shape[1],
im_shape[2],
filter_shape[2],
filter_shape[3],
paddings[0],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
channel_per_deformable_group,
col_shape[1],
im_shape[0],
deformable_groups,
col_shape[2],
col_shape[3],
data_col);
}
} // namespace phi
PD_REGISTER_KERNEL(deformable_conv,
CPU,
ALL_LAYOUT,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DeformableConvGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& offset,
const DenseTensor& filter,
paddle::optional<const DenseTensor&> mask,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
DenseTensor* dx,
DenseTensor* offset_grad,
DenseTensor* filter_grad,
DenseTensor* mask_grad);
} // namespace phi
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
......@@ -23,7 +24,7 @@ void DeformableConvKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& offset,
const DenseTensor& filter,
const DenseTensor& mask,
paddle::optional<const DenseTensor&> mask,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
......
......@@ -3,6 +3,7 @@ add_subdirectory(blas)
add_subdirectory(lapack)
add_subdirectory(detail)
math_library(deformable_conv_functor DEPS dense_tensor)
math_library(concat_and_split_functor DEPS dense_tensor)
math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace phi {
namespace funcs {
template <typename T>
inline void ModulatedDeformableIm2colCPUKernel(
const int num_kernels,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
T* data_col) {
for (int i = 0; i < num_kernels; i++) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T* data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T* data_offset_ptr =
data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const T* data_mask_ptr =
data_mask
? data_mask +
(b_col * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col
: nullptr;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
if (data_mask_ptr) {
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T mask = data_mask_ptr[data_mask_hw_ptr];
*data_col_ptr *= mask;
}
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T, typename Context>
void ModulatedDeformableIm2col(const Context& dev_ctx,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* data_col) {
int channel_per_deformable_group = im_shape[0] / deformable_groups;
int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
// get outputs of im2col with offset by bilinear interpolation
ModulatedDeformableIm2colCPUKernel(num_kernels,
data_im,
data_offset,
data_mask,
im_shape[1],
im_shape[2],
filter_shape[2],
filter_shape[3],
paddings[0],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
channel_per_deformable_group,
col_shape[1],
im_shape[0],
deformable_groups,
col_shape[2],
col_shape[3],
data_col);
}
template void ModulatedDeformableIm2col(
const phi::CPUContext& dev_ctx,
const float* data_im,
const float* data_offset,
const float* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
float* data_col);
template void ModulatedDeformableIm2col(
const phi::CPUContext& dev_ctx,
const double* data_im,
const double* data_offset,
const double* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
double* data_col);
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace funcs {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void ModulatedDeformableIm2colGpuKernel(
const int nthreads,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
T* data_col) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T* data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T* data_offset_ptr =
data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const T* data_mask_ptr =
data_mask
? data_mask +
(b_col * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col
: nullptr;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
if (data_mask_ptr) {
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T mask = data_mask_ptr[data_mask_hw_ptr];
*data_col_ptr *= mask;
}
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T, typename Context>
void ModulatedDeformableIm2col(const Context& dev_ctx,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* data_col) {
int channel_per_deformable_group = im_shape[0] / deformable_groups;
int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
int blocks = NumBlocks(num_kernels);
int threads = kNumCUDAThreads;
ModulatedDeformableIm2colGpuKernel<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(num_kernels,
data_im,
data_offset,
data_mask,
im_shape[1],
im_shape[2],
filter_shape[2],
filter_shape[3],
paddings[0],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
channel_per_deformable_group,
col_shape[1],
im_shape[0],
deformable_groups,
col_shape[2],
col_shape[3],
data_col);
}
template void ModulatedDeformableIm2col(
const phi::GPUContext& dev_ctx,
const float* data_im,
const float* data_offset,
const float* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
float* data_col);
template void ModulatedDeformableIm2col(
const phi::GPUContext& dev_ctx,
const double* data_im,
const double* data_offset,
const double* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
double* data_col);
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace funcs {
template <typename T>
HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data,
const int data_width,
const int height,
const int width,
T h,
T w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh;
T hw = 1 - lw;
T v1 =
(h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
T v2 = (h_low >= 0 && w_high <= width - 1)
? bottom_data[h_low * data_width + w_high]
: 0;
T v3 = (h_high <= height - 1 && w_low >= 0)
? bottom_data[h_high * data_width + w_low]
: 0;
T v4 = (h_high <= height - 1 && w_high <= width - 1)
? bottom_data[h_high * data_width + w_high]
: 0;
T w1 = hh * hw;
T w2 = hh * lw;
T w3 = lh * hw;
T w4 = lh * lw;
return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
template <typename T, typename Context>
void ModulatedDeformableIm2col(const Context& dev_ctx,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* data_col);
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/deformable_conv_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void ModulatedDeformableCol2imGpuKernel(
const int nthreads,
const T* data_col,
const T* data_offset,
const T* data_mask,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int deformable_group,
const int height_col,
const int width_col,
T* grad_im) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t thread = index; thread < nthreads; thread += offset) {
const int j = (thread / width_col / height_col / batch_size) % kernel_w;
const int i =
(thread / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
thread / width_col / height_col / batch_size / kernel_w / kernel_h;
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = thread % width_col;
int h_out = (thread / width_col) % height_col;
int b = (thread / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const T* data_offset_ptr = data_offset +
(b * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col *
width_col;
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
T cur_top_grad = data_col[thread];
if (data_mask) {
const T* data_mask_ptr = data_mask +
(b * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col;
const T mask = data_mask_ptr[data_mask_hw_ptr];
cur_top_grad *= mask;
}
const int cur_h = static_cast<int>(cur_inv_h_data);
const int cur_w = static_cast<int>(cur_inv_w_data);
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
T weight = DmcnGetGradientWeight(cur_inv_h_data,
cur_inv_w_data,
cur_h + dy,
cur_w + dx,
height,
width);
paddle::platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
weight * cur_top_grad);
}
}
}
}
}
template <typename T, typename Context>
void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& kernel_shape,
const std::vector<int>& pad,
const std::vector<int>& stride,
const std::vector<int>& dilation,
const int deformable_group,
T* grad_im) {
int channel_per_deformable_group = im_shape[0] / deformable_group;
int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
int blocks = NumBlocks(num_kernels);
int threads = kNumCUDAThreads;
ModulatedDeformableCol2imGpuKernel<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(num_kernels,
data_col,
data_offset,
data_mask,
im_shape[0],
im_shape[1],
im_shape[2],
kernel_shape[2],
kernel_shape[3],
pad[0],
pad[1],
stride[0],
stride[1],
dilation[0],
dilation[1],
channel_per_deformable_group,
col_shape[1],
deformable_group,
col_shape[2],
col_shape[3],
grad_im);
}
template <typename T>
__global__ void ModulatedDeformableCol2imCoordGpuKernel(
const int nthreads,
const T* data_col,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int offset_channels,
const int deformable_group,
const int height_col,
const int width_col,
T* grad_offset,
T* grad_mask) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
T val = 0, mval = 0;
const int w = i % width_col;
const int h = (i / width_col) % height_col;
const int c = (i / width_col / height_col) % offset_channels;
const int b = (i / width_col / height_col) / offset_channels;
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const T* data_col_ptr = data_col +
deformable_group_index *
channel_per_deformable_group * batch_size *
width_col * height_col;
const T* data_im_ptr = data_im +
(b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h /
kernel_w * height * width;
const T* data_offset_ptr = data_offset +
(b * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col *
width_col;
const T* data_mask_ptr =
data_mask
? data_mask +
(b * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col
: nullptr;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = offset_c / 2; col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T inv_h = h_in + i * dilation_h + offset_h;
T inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
inv_h = inv_w = -2;
} else {
mval += data_col_ptr[col_pos] *
funcs::DmcnIm2colBilinear(data_im_ptr + cnt * height * width,
width,
height,
width,
inv_h,
inv_w);
}
const T weight =
DmcnGetCoordinateWeight(inv_h,
inv_w,
height,
width,
data_im_ptr + cnt * height * width,
width,
bp_dir);
if (data_mask_ptr) {
const int data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const T mask = data_mask_ptr[data_mask_hw_ptr];
val += weight * data_col_ptr[col_pos] * mask;
} else {
val += weight * data_col_ptr[col_pos];
}
cnt += 1;
}
grad_offset[i] = val;
if (grad_mask && offset_c % 2 == 0)
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
kernel_w +
offset_c / 2) *
height_col +
h) *
width_col +
w] = mval;
}
}
template <typename T, typename Context>
void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
const T* data_col,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& kernel_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* grad_offset,
T* grad_mask) {
int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] *
col_shape[2] * col_shape[3] * deformable_groups;
int channel_per_deformable_group = col_shape[0] / deformable_groups;
int blocks = NumBlocks(num_kernels);
int threads = kNumCUDAThreads;
ModulatedDeformableCol2imCoordGpuKernel<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
num_kernels,
data_col,
data_im,
data_offset,
data_mask,
im_shape[0],
im_shape[1],
im_shape[2],
kernel_shape[2],
kernel_shape[3],
paddings[0],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
channel_per_deformable_group,
col_shape[1],
2 * kernel_shape[2] * kernel_shape[3] * deformable_groups,
deformable_groups,
col_shape[2],
col_shape[3],
grad_offset,
grad_mask);
}
template <typename T>
__global__ void FilterGradAddupGpuKernel(const int nthreads,
const int n,
const int height,
const int width,
const T* dweight_3d,
T* filter_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
filter_grad[i] = filter_grad[i] + dweight_3d[i];
}
}
template <typename T, typename Context>
void FilterGradAddup(const Context& dev_ctx,
const int nthreads,
const int n,
const int height,
const int width,
const T* dweight_3d,
T* filter_grad) {
FilterGradAddupGpuKernel<
T><<<NumBlocks(nthreads), kNumCUDAThreads, 0, dev_ctx.stream()>>>(
nthreads, n, height, width, dweight_3d, filter_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(deformable_conv_grad,
GPU,
ALL_LAYOUT,
phi::DeformableConvGradKernel,
float,
double) {}
......@@ -16,142 +16,8 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/impl/deformable_conv_kernel_impl.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}
template <typename T>
__global__ void ModulatedDeformableIm2colGpuKernel(
const int nthreads,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
T* data_col) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T* data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T* data_offset_ptr =
data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const T* data_mask_ptr =
data_mask +
(b_col * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T, typename Context>
void ModulatedDeformableIm2col(const Context& dev_ctx,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* data_col) {
int channel_per_deformable_group = im_shape[0] / deformable_groups;
int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
int blocks = NumBlocks(num_kernels);
int threads = kNumCUDAThreads;
ModulatedDeformableIm2colGpuKernel<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(num_kernels,
data_im,
data_offset,
data_mask,
im_shape[1],
im_shape[2],
filter_shape[2],
filter_shape[3],
paddings[0],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
channel_per_deformable_group,
col_shape[1],
im_shape[0],
deformable_groups,
col_shape[2],
col_shape[3],
data_col);
}
} // namespace phi
PD_REGISTER_KERNEL(deformable_conv,
GPU,
ALL_LAYOUT,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
namespace phi {
template <typename T>
HOSTDEVICE T DmcnGetGradientWeight(T argmax_h,
T argmax_w,
const int h,
const int w,
const int height,
const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
weight = (h == argmax_h_low && w == argmax_w_low)
? (h + 1 - argmax_h) * (w + 1 - argmax_w)
: weight;
weight = (h == argmax_h_low && w == argmax_w_high)
? (h + 1 - argmax_h) * (argmax_w + 1 - w)
: weight;
weight = (h == argmax_h_high && w == argmax_w_low)
? (argmax_h + 1 - h) * (w + 1 - argmax_w)
: weight;
weight = (h == argmax_h_high && w == argmax_w_high)
? (argmax_h + 1 - h) * (argmax_w + 1 - w)
: weight;
return weight;
}
template <typename T>
HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h,
T argmax_w,
const int height,
const int width,
const T* im_data,
const int data_width,
const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (bp_dir == 0) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low]
: 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high]
: 0;
} else if (bp_dir == 1) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low]
: 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high]
: 0;
}
return weight;
}
template <typename T, typename Context>
void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
const T* data_col,
const T* data_im,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& kernel_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
T* grad_offset,
T* grad_mask);
template <typename T, typename Context>
void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col,
const T* data_offset,
const T* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& kernel_shape,
const std::vector<int>& pad,
const std::vector<int>& stride,
const std::vector<int>& dilation,
const int deformable_group,
T* grad_im);
template <typename T, typename Context>
void FilterGradAddup(const Context& dev_ctx,
const int nthreads,
const int n,
const int height,
const int width,
const T* dweight_3d,
T* filter_grad);
template <typename T, typename Context>
void DeformableConvGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& offset,
const DenseTensor& filter,
paddle::optional<const DenseTensor&> mask,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
DenseTensor* dx,
DenseTensor* offset_grad,
DenseTensor* filter_grad,
DenseTensor* mask_grad) {
const int batch_size = static_cast<int>(x.dims()[0]);
DDim input_shape = phi::slice_ddim(x.dims(), 1, x.dims().size());
std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape);
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(phi::vectorize(out_grad.dims()));
std::vector<int64_t> col_buffer_shape_vec(filter_shape_vec.size());
col_buffer_shape_vec[0] = x.dims()[1] * filter.dims()[2] * filter.dims()[3];
col_buffer_shape_vec[1] = im2col_step;
for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) {
col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2];
}
std::vector<int64_t> output_buffer_shape_vec(1);
output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] *
output_shape_vec[2] * output_shape_vec[3];
DenseTensor col_buffer = Empty<T>(dev_ctx, col_buffer_shape_vec);
DenseTensor output_buffer;
output_buffer.ShareDataWith(out_grad).Resize(
make_ddim(output_buffer_shape_vec));
int64_t M =
input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3];
int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3];
int64_t K = output_shape_vec[1] / groups;
DDim weight_3d_shape = {groups, K, M};
DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, N};
DDim col_buffer_3d_shape = {groups, M, N};
DDim filter_grad_shape = {groups, K, M};
DenseTensor weight_3d;
weight_3d.ShareDataWith(filter).Resize(weight_3d_shape);
DenseTensor out_grad_4d;
out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape);
DenseTensor col_buffer_3d;
col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape);
phi::funcs::SetConstant<Context, T> set_zero;
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
int input_dim = x.numel() / x.dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0;
if (filter_grad) {
Full<T>(dev_ctx,
{filter_grad_shape.Get(), filter_grad_shape.size()},
0,
filter_grad);
}
if (dx) {
dev_ctx.template Alloc<T>(dx);
set_zero(dev_ctx, dx, static_cast<T>(0));
}
if (offset_grad) {
dev_ctx.template Alloc<T>(offset_grad);
set_zero(dev_ctx, offset_grad, static_cast<T>(0));
if (mask_grad) {
dev_ctx.template Alloc<T>(mask_grad);
set_zero(dev_ctx, mask_grad, static_cast<T>(0));
}
}
for (int i = 0; i < batch_size / im2col_step; ++i) {
DenseTensor out_grad_3d = out_grad_4d.Slice(i, i + 1).Resize(
phi::slice_ddim(out_grad_4d.dims(), 1, out_grad_4d.dims().size()));
for (int g = 0; g < groups; ++g) {
DenseTensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size()));
DenseTensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size()));
DenseTensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
blas.MatMul(weight_3d_slice,
true,
out_grad_3d_slice,
false,
T(1.0),
&col_buffer_3d_slice,
T(0.0));
}
col_buffer.Resize(make_ddim(col_buffer_shape_vec));
T* col_buffer_ptr = col_buffer.data<T>();
const T* input_ptr = x.data<T>();
const T* offset_ptr = offset.data<T>();
const T* mask_data_ptr =
mask ? mask->data<T>() + i * im2col_step * input_mask_dim : nullptr;
if (offset_grad) {
T* offset_grad_ptr = offset_grad->data<T>();
T* mask_grad_data_ptr =
mask_grad ? mask_grad->data<T>() + i * im2col_step * input_mask_dim
: nullptr;
// get grad of offset and mask
ModulatedDeformableCol2imCoord(
dev_ctx,
col_buffer_ptr,
input_ptr + i * im2col_step * input_dim,
offset_ptr + i * im2col_step * input_offset_dim,
mask_data_ptr,
input_shape_vec,
col_buffer_shape_vec,
filter_shape_vec,
paddings,
strides,
dilations,
deformable_groups,
offset_grad_ptr + i * im2col_step * input_offset_dim,
mask_grad_data_ptr);
}
if (dx) {
T* dx_ptr = dx->data<T>();
// get grad of input
ModulatedDeformableCol2im(dev_ctx,
col_buffer_ptr,
offset_ptr + i * im2col_step * input_offset_dim,
mask_data_ptr,
input_shape_vec,
col_buffer_shape_vec,
filter_shape_vec,
paddings,
strides,
dilations,
deformable_groups,
dx_ptr + i * im2col_step * input_dim);
dx->Resize(x.dims());
}
funcs::ModulatedDeformableIm2col(
dev_ctx,
input_ptr + i * im2col_step * input_dim,
offset_ptr + i * im2col_step * input_offset_dim,
mask_data_ptr,
input_shape_vec,
col_buffer_shape_vec,
filter_shape_vec,
paddings,
strides,
dilations,
deformable_groups,
col_buffer_ptr);
col_buffer_3d.Resize(col_buffer_3d_shape);
if (filter_grad) {
DenseTensor dweight_3d = Empty<T>(
dev_ctx, {filter_grad_shape.Get(), filter_grad_shape.size()});
for (int g = 0; g < groups; ++g) {
DenseTensor out_grad_3d_slice = out_grad_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(out_grad_3d.dims(), 1, out_grad_3d.dims().size()));
DenseTensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
DenseTensor dweight_3d_slice = dweight_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(dweight_3d.dims(), 1, dweight_3d.dims().size()));
blas.MatMul(out_grad_3d_slice,
false,
col_buffer_3d_slice,
true,
T(1.0),
&dweight_3d_slice,
T(0.0));
}
// update grad of weights
FilterGradAddup<T>(dev_ctx,
dweight_3d.numel(),
groups,
K,
M,
dweight_3d.data<T>(),
filter_grad->data<T>());
}
}
if (filter_grad) {
filter_grad->Resize(filter.dims());
}
}
} // namespace phi
......@@ -29,6 +29,34 @@ KernelSignature DeformableConvOpArgumentMapping(
{"Output"});
}
KernelSignature DeformableConvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"deformable_conv_grad",
{"Input", "Offset", "Filter", "Mask", GradVarName("Output")},
{"strides",
"paddings",
"dilations",
"deformable_groups",
"groups",
"im2col_step"},
{GradVarName("Input"),
GradVarName("Offset"),
GradVarName("Filter"),
GradVarName("Mask")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(deformable_conv_v1, deformable_conv);
PD_REGISTER_BASE_KERNEL_NAME(deformable_conv_v1_grad, deformable_conv_grad);
PD_REGISTER_ARG_MAPPING_FN(deformable_conv,
phi::DeformableConvOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(deformable_conv_grad,
phi::DeformableConvGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(deformable_conv_v1,
phi::DeformableConvOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(deformable_conv_v1_grad,
phi::DeformableConvGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册