未验证 提交 00efd1d8 编写于 作者: C chengjuntao 提交者: GitHub

add deformable conv v1 op and cpu version of deformable conv v2 (#18500)

* add deformable conv v1 op, test=develop
上级 40c66f8d
......@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -285,7 +285,7 @@ paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=
paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'c03490ffaa1b78258747157c313db4cd'))
paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', 'b1e1487760295e1ff55307b880a99e18'))
paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'fa2f457a81714430c5677c2d68744728'))
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e'))
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'modulated', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, True, None)), ('document', '335193ac57d41d7199f8d26d30c069b1'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', 'b6ea7d4ddeacae85e37d1e47d5262948'))
......
......@@ -55,7 +55,7 @@ if (NOT WITH_MKL)
endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
sync_batch_norm_op deformable_conv_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
sync_batch_norm_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above
......@@ -73,8 +73,6 @@ if (WITH_GPU)
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
endif()
op_library(deformable_conv_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(deformable_conv);\n")
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
//
// Copyright (c) 2017 Microsoft
// Licensed under The Apache-2.0 License [see LICENSE for details]
// \file deformable_psroi_pooling.cu
// \brief
// \author Yi Li, Guodong Zhang, Jifeng Dai
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
template <typename T>
__global__ void FilterGradAddupCUDAKernel(const int nthreads, const int n,
const int height, const int width,
const T* dweight_3d, T* filter_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
filter_grad[i] = filter_grad[i] + dweight_3d[i];
}
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
//
// Copyright (c) 2017 Microsoft
// Licensed under The Apache-2.0 License [see LICENSE for details]
// \file deformable_psroi_pooling.cu
// \brief
// \author Yi Li, Guodong Zhang, Jifeng Dai
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
template <typename T>
HOSTDEVICE T DmcnGetGradientWeight(T argmax_h, T argmax_w, const int h,
const int w, const int height,
const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
weight = (h == argmax_h_low && w == argmax_w_low)
? (h + 1 - argmax_h) * (w + 1 - argmax_w)
: weight;
weight = (h == argmax_h_low && w == argmax_w_high)
? (h + 1 - argmax_h) * (argmax_w + 1 - w)
: weight;
weight = (h == argmax_h_high && w == argmax_w_low)
? (argmax_h + 1 - h) * (w + 1 - argmax_w)
: weight;
weight = (h == argmax_h_high && w == argmax_w_high)
? (argmax_h + 1 - h) * (argmax_w + 1 - w)
: weight;
return weight;
}
template <typename T>
HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height,
const int width, const T* im_data,
const int data_width, const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (bp_dir == 0) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low]
: 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high]
: 0;
} else if (bp_dir == 1) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low]
: 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high]
: 0;
}
return weight;
}
template <typename T>
HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, const int data_width,
const int height, const int width, T h, T w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh;
T hw = 1 - lw;
T v1 =
(h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
T v2 = (h_low >= 0 && w_high <= width - 1)
? bottom_data[h_low * data_width + w_high]
: 0;
T v3 = (h_high <= height - 1 && w_low >= 0)
? bottom_data[h_high * data_width + w_low]
: 0;
T v4 = (h_high <= height - 1 && w_high <= width - 1)
? bottom_data[h_high * data_width + w_high]
: 0;
T w1 = hh * hw;
T w2 = hh * lw;
T w3 = lh * hw;
T w4 = lh * lw;
return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/deformable_conv_op.h"
#include <memory>
#include "paddle/fluid/operators/conv_op.h"
namespace paddle {
......@@ -197,7 +199,6 @@ class DeformableConvOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]),
deformable_groups,
"mask filter must divide deformable group size.");
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
......@@ -274,5 +275,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
ops::DeformableConvOpMaker,
ops::DeformableConvGradOpDescMaker);
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv, ops::DeformableConvCPUKernel<float>,
ops::DeformableConvCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCPUKernel<float>,
ops::DeformableConvGradCPUKernel<double>);
......@@ -24,6 +24,7 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/deformable_conv_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/deformable_conv_v1_op.h"
#include <memory>
#include "paddle/fluid/operators/conv_op.h"
namespace paddle {
namespace operators {
class DeformableConvV1OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor) The input of deformable conv op. "
"The shape of input is "
"[N, channel_in, H, W]");
AddInput("Offset",
"(Tensor) The input offset. "
"The shape of the offset is "
"[N, deformable_groups * kernel_w * kernel_h * 2, H, W");
AddInput("Filter",
"(Tensor) The Input Filter "
"The shape of the wight is "
"[num_filters, channel_in, kernel_h, kernel_w.");
AddOutput("Output",
"(Tensor) The output. "
"The shape of the output tensor is "
"[N, num_filters, out_height, out_width]].");
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0,0}), the "
"paddings(h_pad, w_pad) of "
"convolution operator. ")
.SetDefault({0, 0});
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<int>(
"groups",
"(int default:1), the groups number of the convolution operator. "
"According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
"when group=2, the first half of the filters is only connected to the "
"first half of the input channels, while the second half of the "
"filters "
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<int>("deformable_groups",
"(int default:1), the number of the deformable groups.")
.SetDefault(1);
AddAttr<int>("im2col_step",
"im2col maximum number of image per computation")
.SetDefault(64);
AddComment(R"DOC(
**Deformable Convolution v1 Operator**
Deformable Convolution is a new method based Convolution which feature has offset
in spatial location.
1. Get offset of each pixel in feature map with convolution layers which number
of channels should be double of weight size.
2. Add offset to pixel to get new location and the new value which are computed
directly through bilinear interpolation with four nearest pixel.
3. Get the product of pixel and weight as result
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
$$
y(p) = \\sum_{k=1}^{K}{w_k * x(p + p_k + \\Delta p_k)}
$$
Where $$\\Delta p_k$$ is the learnable offset for the k-th location, respectively.
Refer to 'https://arxiv.org/abs/1703.06211 '<https://arxiv.org/abs/1703.06211>
Example:
Input:
Input shape: $(N, C_{in}, H_{in}, W_{in})$
Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
Offset shape: $(N, 2 * deformable_groups, * H_f * W_f, H_{out}, W_{out})$
Output:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
where $H_{out}, W_{out}$ must be equal to $H_{in}, W_{in}$ respectively.
Where
$$
H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC");
}
};
class DeformableConvV1Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input(Input) of DeformableConvOp "
"should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput("Offset"), true,
"Input(Offset) of DeformableConvOp "
"should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true,
"Input(Filter) of DeformableConvOp "
"should not be null");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true,
"Output(Output) of DeformableConvOp "
"should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
auto offset_dims = ctx->GetInputDim("Offset");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
int groups = ctx->Attrs().Get<int>("groups");
int deformable_groups = ctx->Attrs().Get<int>("deformable_groups");
int im2col_step = ctx->Attrs().Get<int>("im2col_step");
PADDLE_ENFORCE_EQ(in_dims.size(), 4,
"Conv input should be 4-D tensor, get %u",
in_dims.size());
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
"Conv input dimension and filter dimension should be the same.");
PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size(), 2U,
"Conv input dimension and strides dimension should be consistent.");
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
"Conv paddings dimension and Conv strides dimension "
"should be the same.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
"The number of output channels should be divided by groups.");
PADDLE_ENFORCE_EQ(filter_dims[0] % deformable_groups, 0,
"The number of output channels should be "
"divided by deformable groups.");
if (in_dims[0] > im2col_step) {
PADDLE_ENFORCE_EQ(
in_dims[0] % im2col_step, 0U,
"Input batchsize must be smaller than or divide im2col_step");
}
for (size_t i = 0; i < strides.size(); ++i) {
PADDLE_ENFORCE_GT(strides[i], 0U, "stride %d size incorrect", i);
}
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_GT(dilations[i], 0U, "dilation %d size incorrect", i);
}
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
}
PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U,
"output num_filter must divide deformable group size.");
PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
"output height must equal to offset map height.");
PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
"output width must equal to offset map width.");
PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
"offset filter must divide deformable group size.");
PADDLE_ENFORCE_EQ(offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups,
"offset filter must divide deformable group size.");
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.device_context());
}
};
class DeformableConvV1GradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("deformable_conv_v1_grad");
op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter"));
op->SetInput("Offset", Input("Offset"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Offset"), InputGrad("Offset"));
op->SetAttrMap(Attrs());
return op;
}
};
class DeformableConvV1GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
auto offset_dims = ctx->GetInputDim("Offset");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Output")), true,
"the gradient of output(Out) must not be null");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
if (ctx->HasOutput(framework::GradVarName("Offset"))) {
ctx->SetOutputDim(framework::GradVarName("Offset"), offset_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op,
ops::DeformableConvV1OpMaker,
ops::DeformableConvV1GradOpDescMaker);
REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1,
ops::DeformableConvV1CPUKernel<float>);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1_grad,
ops::DeformableConvV1GradCPUKernel<float>);
此差异已折叠。
此差异已折叠。
......@@ -13196,20 +13196,30 @@ def deformable_conv(input,
im2col_step=None,
param_attr=None,
bias_attr=None,
modulated=True,
name=None):
"""
**Deformable Convolution Layer**
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
Deformable Convolution v2:
.. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k}
Deformable Convolution v1:
Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, respectively.
Refer to `Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168v2>`_ .
.. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k)}
Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location,
which :math:`\Delta m_k` is one in deformable convolution v1. Please refer to `Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168v2>`_ and `Deformable Convolutional Networks <https://arxiv.org/abs/1703.06211>`_.
Example:
- Input:
......@@ -13235,7 +13245,7 @@ def deformable_conv(input,
Args:
input (Variable): The input image with [N, C, H, W] format.
offset (Variable): The input coord offset of deformable convolution layer.
offset (Variable): The input coordinate offset of deformable convolution layer.
Mask (Variable): The input mask of deformable covolution layer.
num_filters(int): The number of filter. It is as same as the output
image channel.
......@@ -13274,6 +13284,8 @@ def deformable_conv(input,
to the output units. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
modulated (bool): Make sure which version should be used between v1 and v2, where v2 is \
used while True. Default: True.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
......@@ -13285,12 +13297,22 @@ def deformable_conv(input,
Examples:
.. code-block:: python
#deformable conv v2:
import paddle.fluid as fluid
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32')
mask = fluid.layers.data(name='mask', shape=[9, 32, 32], dtype='float32')
out = fluid.layers.deformable_conv(input=data, offset=offset, mask=mask,
num_filters=2, filter_size=3, padding=1)
num_filters=2, filter_size=3, padding=1, modulated=True)
#deformable conv v1:
import paddle.fluid as fluid
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32')
out = fluid.layers.deformable_conv(input=data, offset=offset, mask=None,
num_filters=2, filter_size=3, padding=1, modulated=False)
"""
num_channels = input.shape[1]
......@@ -13303,8 +13325,6 @@ def deformable_conv(input,
raise TypeError("Input of deformable_conv must be Variable")
if not isinstance(offset, Variable):
raise TypeError("Input Offset of deformable_conv must be Variable")
if not isinstance(mask, Variable):
raise TypeError("Input Mask of deformable_conv must be Variable")
if groups is None:
num_filter_channels = num_channels
......@@ -13334,23 +13354,42 @@ def deformable_conv(input,
pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='deformable_conv',
inputs={
'Input': input,
'Filter': filter_param,
'Offset': offset,
'Mask': mask,
},
outputs={"Output": pre_bias},
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'deformable_groups': deformable_groups,
'im2col_step': im2col_step,
})
if modulated:
helper.append_op(
type='deformable_conv',
inputs={
'Input': input,
'Filter': filter_param,
'Offset': offset,
'Mask': mask,
},
outputs={"Output": pre_bias},
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'deformable_groups': deformable_groups,
'im2col_step': im2col_step,
})
else:
helper.append_op(
type='deformable_conv_v1',
inputs={
'Input': input,
'Filter': filter_param,
'Offset': offset,
},
outputs={"Output": pre_bias},
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'deformable_groups': deformable_groups,
'im2col_step': im2col_step,
})
output = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
return output
......
......@@ -145,48 +145,35 @@ class TestModulatedDeformableConvOp(OpTest):
}
self.outputs = {'Output': output}
def has_cuda(self):
return core.is_compiled_with_cuda()
def test_check_output(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
self.check_output(atol=1e-5)
def test_check_grad(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, {'Input', 'Offset', 'Mask', 'Filter'},
'Output',
max_relative_error=0.05)
self.check_grad(
{'Input', 'Offset', 'Mask', 'Filter'},
'Output',
max_relative_error=0.05)
def test_check_grad_no_filter(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Filter']))
self.check_grad(
['Input', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Input']))
self.check_grad(
['Filter', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Input']))
def test_check_grad_no_offset_no_mask(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input', 'Filter'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Offset', 'Mask']))
self.check_grad(
['Input', 'Filter'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Offset', 'Mask']))
def init_test_case(self):
self.pad = [1, 1]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
def dmc_bilinear(data_im, height, width, h, w):
h_low = int(np.floor(h))
w_low = int(np.floor(w))
h_high = h_low + 1
w_high = w_low + 1
lh = h - h_low
lw = w - w_low
hh = 1 - lh
hw = 1 - lw
v1 = 0
if h_low >= 0 and w_low >= 0:
v1 = data_im[h_low, w_low]
v2 = 0
if h_low >= 0 and w_high <= width - 1:
v2 = data_im[h_low, w_high]
v3 = 0
if h_high <= height - 1 and w_low >= 0:
v3 = data_im[h_high, w_low]
v4 = 0
if h_high <= height - 1 and w_high <= width - 1:
v4 = data_im[h_high, w_high]
w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
def dconv_im2col_gemm(input, offset, filter, group, conv_param):
in_n, in_c, in_h, in_w = input.shape
out_c, f_c, f_h, f_w = filter.shape
assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w)
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
stride, pad, dilation = conv_param['stride'], conv_param['pad'],\
conv_param['dilation']
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
assert out_h == in_h
assert out_w == in_w
col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w))
for n in range(in_n):
for c in range(in_c):
for h in range(out_h):
for w in range(out_w):
for kh in range(f_h):
for kw in range(f_w):
offset_h_table = \
offset[n, ::2, h, w].reshape(f_h, f_w)
offset_w_table = \
offset[n, 1::2, h, w].reshape(f_h, f_w)
offset_h = offset_h_table[kh, kw]
offset_w = offset_w_table[kh, kw]
val = 0
im_h = h * stride[0] + kh * dilation[0] \
+ offset_h - pad[0]
im_w = w * stride[0] + kw * dilation[0] \
+ offset_w - pad[1]
if im_h > -1 and im_w > -1 and \
im_h < in_h and im_w < in_h:
val = dmc_bilinear(input[n, c], in_h, in_w,
im_h, im_w)
val_out = val
col_buffer[n, c * f_h * f_w + kh * f_w + kw, h *
in_w + w] = val_out
out = np.zeros((in_n, group, int(out_c // group), out_h * out_w))
weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w)
col_buffer = col_buffer.reshape(
(in_n, group, int(in_c // group * f_h * f_w), in_h * in_w))
for n in range(in_n):
for g in range(group):
out[n, g] = np.matmul(weight[g], col_buffer[n, g])
out = out.reshape(in_n, out_c, out_h, out_w)
return out
class TestModulatedDeformableConvOp(OpTest):
def setUp(self):
self.op_type = "deformable_conv_v1"
self.dtype = np.float32
self.init_group()
self.init_dilation()
self.init_test_case()
conv_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
offset = 10 * np.random.random(self.offset_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = dconv_im2col_gemm(input, offset, filter, self.groups,
conv_param)
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Offset': OpTest.np_dtype_to_fluid_dtype(offset),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'deformable_groups': self.deformable_groups,
'im2col_step': self.im2col_step,
'dilations': self.dilations,
}
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['Input', 'Offset', 'Filter'], 'Output', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['Input', 'Offset'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Filter']))
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 4, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [4, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self):
self.groups = 1
class TestWithStride(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [3, 3]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
class TestWithDilation(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [1, 1]
self.input_size = [2, 3, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [2, 2]
class TestWith1x1(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
class TestWithGroup(TestModulatedDeformableConvOp):
def init_group(self):
self.groups = 2
if __name__ == '__main__':
unittest.main()
......@@ -2276,32 +2276,31 @@ class TestBook(LayerTest):
print(str(program))
def test_deformable_conv(self):
if core.is_compiled_with_cuda():
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='input',
append_batch_size=False,
shape=[2, 3, 32, 32],
dtype="float32")
offset = layers.data(
name='offset',
append_batch_size=False,
shape=[2, 18, 32, 32],
dtype="float32")
mask = layers.data(
name='mask',
append_batch_size=False,
shape=[2, 9, 32, 32],
dtype="float32")
out = layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=2,
filter_size=3,
padding=1)
return (out)
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='input',
append_batch_size=False,
shape=[2, 3, 32, 32],
dtype="float32")
offset = layers.data(
name='offset',
append_batch_size=False,
shape=[2, 18, 32, 32],
dtype="float32")
mask = layers.data(
name='mask',
append_batch_size=False,
shape=[2, 9, 32, 32],
dtype="float32")
out = layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=2,
filter_size=3,
padding=1)
return (out)
def test_unfold(self):
with self.static_graph():
......@@ -2338,6 +2337,29 @@ class TestBook(LayerTest):
trans_std=0.1)
return (out)
def test_deformable_conv_v1(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='input',
append_batch_size=False,
shape=[2, 3, 32, 32],
dtype="float32")
offset = layers.data(
name='offset',
append_batch_size=False,
shape=[2, 18, 32, 32],
dtype="float32")
out = layers.deformable_conv(
input=input,
offset=offset,
mask=None,
num_filters=2,
filter_size=3,
padding=1,
modulated=False)
return (out)
def test_retinanet_target_assign(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册