未验证 提交 bba57cdd 编写于 作者: B Bai Yifan 提交者: GitHub

Add deformable conv v2 op,test=develop (#17145)

* unit commits, test=develop

* update API.spec, test=develop
上级 aca53535
......@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -237,6 +237,7 @@ paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=
paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', '94e2819b7c9715ea71b62e9c78f36b29'))
paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6'))
paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'ccf6bb7912afd2818d24bc45461e807a'))
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', 'c896b66265a60bd3c5510f66e6e02919'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', 'adf285346e23316097f7789b572491e9'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'dce69a78638da8f7ad80b1fc00ed2029'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6'))
......
......@@ -47,7 +47,8 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
sync_batch_norm_op deformable_conv_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above
......@@ -65,6 +66,8 @@ if (WITH_GPU)
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
endif()
op_library(deformable_conv_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(deformable_conv);\n")
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/conv_op.h"
namespace paddle {
namespace operators {
class DeformableConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor) The input of deformable conv op. "
"The shape of input is "
"[N, channel_in, H, W]");
AddInput("Offset",
"(Tensor) The input offset. "
"The shape of the offset is "
"[N, deformable_groups * kernel_w * kernel_h * 2, H, W");
AddInput("Mask",
"(Tensor) The input mask. "
"The shape of the mask is "
"[N, deformable_groups * kernel_w * kernel_h, H, W].");
AddInput("Filter",
"(Tensor) The Input Filter "
"The shape of the wight is "
"[num_filters, channel_in, kernel_h, kernel_w.");
AddOutput("Output",
"(Tensor) The output. "
"The shape of the output tensor is "
"[N, num_filters, out_height, out_width]].");
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0,0}), the "
"paddings(h_pad, w_pad) of "
"convolution operator. ")
.SetDefault({0, 0});
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<int>(
"groups",
"(int default:1), the groups number of the convolution operator. "
"According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
"when group=2, the first half of the filters is only connected to the "
"first half of the input channels, while the second half of the "
"filters "
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<int>("deformable_groups",
"(int default:1), the number of the deformable groups.")
.SetDefault(1);
AddAttr<int>("im2col_step",
"im2col maximum number of image per computation")
.SetDefault(64);
AddComment(R"DOC(
**Deformable Convolution Operator**
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
$$
y(p) = \\sum_{k=1}^{K}{w_k * x(p + p_k + \\Delta p_k) * \\Delta m_k}
$$
Where $$\\Delta p_k$$ and $$\Delta m_k$$ are the learnable offset and modulation scalar for the k-th location, respectively.
Refer to 'Deformable ConvNets v2: More Deformable, Better Results
'<https://arxiv.org/abs/1811.11168v2>
Example:
Input:
Input shape: $(N, C_{in}, H_{in}, W_{in})$
Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
Offset shape: $(N, 2 * deformable_groups, * H_f * W_f, H_{out}, W_{out})$
Mask shape: $(N, deformable_groups * H_f * W_f, H_{out}, W_{out})$
Output:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
where $H_{out}, W_{out}$ must be equal to $H_{in}, W_{in}$ respectively.
Where
$$
H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC");
}
};
class DeformableConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of DeformableConvOp "
"should not be null");
PADDLE_ENFORCE(ctx->HasInput("Offset"),
"Input(Offset) of DeformableConvOp "
"should not be null");
PADDLE_ENFORCE(ctx->HasInput("Mask"),
"Input(Mask) of DeformableConvOp "
"should not be null");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of DeformableConvOp "
"should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of DeformableConvOp "
"should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
auto offset_dims = ctx->GetInputDim("Offset");
auto mask_dims = ctx->GetInputDim("Mask");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
int groups = ctx->Attrs().Get<int>("groups");
int deformable_groups = ctx->Attrs().Get<int>("deformable_groups");
int im2col_step = ctx->Attrs().Get<int>("im2col_step");
PADDLE_ENFORCE(in_dims.size() == 4,
"Conv input should be 4-D tensor, get %u", in_dims.size());
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
"Conv input dimension and filter dimension should be the same.");
PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size(), 2U,
"Conv input dimension and strides dimension should be consistent.");
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
"Conv paddings dimension and Conv strides dimension "
"should be the same.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
"The number of output channels should be divided by groups.");
PADDLE_ENFORCE_EQ(filter_dims[0] % deformable_groups, 0,
"The number of output channels should be "
"divided by deformable groups.");
if (in_dims[0] > im2col_step) {
PADDLE_ENFORCE_EQ(
in_dims[0] % im2col_step, 0U,
"Input batchsize must be smaller than or divide im2col_step");
}
for (size_t i = 0; i < strides.size(); ++i) {
PADDLE_ENFORCE_GT(strides[i], 0U, "stride %d size incorrect", i);
}
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_GT(dilations[i], 0U, "dilation %d size incorrect", i);
}
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
}
PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U,
"output num_filter must divide deformable group size.");
PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
"output height must equal to offset map height.");
PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
"output width must equal to offset map width.");
PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
"offset filter must divide deformable group size.");
PADDLE_ENFORCE_EQ(offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups,
"offset filter must divide deformable group size.");
PADDLE_ENFORCE_EQ(output_shape[2], mask_dims[2],
"output height must equal to mask map height.");
PADDLE_ENFORCE_EQ(output_shape[3], mask_dims[3],
"output width must equal to mask map width.");
PADDLE_ENFORCE_EQ(mask_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
"mask filter must divide deformable group size.");
PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]),
deformable_groups,
"mask filter must divide deformable group size.");
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.device_context());
}
};
class DeformableConvGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("deformable_conv_grad");
op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter"));
op->SetInput("Offset", Input("Offset"));
op->SetInput("Mask", Input("Mask"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Offset"), InputGrad("Offset"));
op->SetOutput(framework::GradVarName("Mask"), InputGrad("Mask"));
op->SetAttrMap(Attrs());
return op;
}
};
class DeformableConvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
auto offset_dims = ctx->GetInputDim("Offset");
auto mask_dims = ctx->GetInputDim("Mask");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")),
"the gradient of output(Out) must not be null");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
if (ctx->HasOutput(framework::GradVarName("Offset"))) {
ctx->SetOutputDim(framework::GradVarName("Offset"), offset_dims);
}
if (ctx->HasOutput(framework::GradVarName("Mask"))) {
ctx->SetOutputDim(framework::GradVarName("Mask"), mask_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
ops::DeformableConvOpMaker,
ops::DeformableConvGradOpDescMaker);
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
此差异已折叠。
......@@ -202,6 +202,7 @@ __all__ = [
'continuous_value_model',
'where',
'sign',
'deformable_conv',
]
kIgnoreIndex = -100
......@@ -11745,3 +11746,175 @@ def sign(x):
helper.append_op(type='sign', inputs={'X': [x]}, outputs={'Out': [out]})
return out
def deformable_conv(input,
offset,
mask,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
deformable_groups=None,
im2col_step=None,
param_attr=None,
bias_attr=None,
name=None):
"""
**Deformable Convolution Layer**
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
.. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k}
Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, respectively.
Refer to `Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168v2>`_ .
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
Offset shape: :math:`(N, 2 * deformable\_groups * H_f * H_w, H_{in}, W_{in})`
Mask shape: :math:`(N, deformable\_groups * H_f * H_w, H_{in}, W_{in})`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
input (Variable): The input image with [N, C, H, W] format.
offset (Variable): The input coord offset of deformable convolution layer.
Mask (Variable): The input mask of deformable covolution layer.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups (int): The groups number of the deformable conv layer. According to
grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1.
deformable_groups (int): The number of deformable group partitions.
Default: deformable_groups = 1.
im2col_step (int): Maximum number of images per im2col computation;
The total batch size should be divisable by this value or smaller
than this value; if you face out of memory problem, you can try
to use a smaller value here.
Default: im2col_step = 64.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of deformable conv. If it is set to None or one attribute of ParamAttr,
deformable conv will create ParamAttr as param_attr.
If the Initializer of the param_attr is not set, the parameter is
initialized with :math:`Normal(0.0, std)`, and the
:math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of
deformable conv layer. If it is set to False, no bias will be added
to the output units. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: The tensor variable storing the deformable convolution \
result.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32')
mask = fluid.layers.data(name='mask', shape=[9, 32, 32], dtype='float32')
out = fluid.layers.deformable_conv(input=data, offset=offset, mask=mask,
num_filters=2, filter_size=3, padding=1)
"""
num_channels = input.shape[1]
assert param_attr is not False, "param_attr should not be False here."
helper = LayerHelper('deformable_conv', **locals())
dtype = helper.input_dtype()
if not isinstance(input, Variable):
raise TypeError("Input of deformable_conv must be Variable")
if not isinstance(offset, Variable):
raise TypeError("Input Offset of deformable_conv must be Variable")
if not isinstance(mask, Variable):
raise TypeError("Input Mask of deformable_conv must be Variable")
if groups is None:
num_filter_channels = num_channels
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels // groups
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
input_shape = input.shape
filter_shape = [num_filters, int(num_filter_channels)] + filter_size
def _get_default_param_initializer():
filter_elem_num = filter_size[0] * filter_size[1] * num_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype,
default_initializer=_get_default_param_initializer())
pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='deformable_conv',
inputs={
'Input': input,
'Filter': filter_param,
'Offset': offset,
'Mask': mask,
},
outputs={"Output": pre_bias},
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'deformable_groups': deformable_groups,
'im2col_step': im2col_step,
})
output = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
return output
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
def dmc_bilinear(data_im, height, width, h, w):
h_low = int(np.floor(h))
w_low = int(np.floor(w))
h_high = h_low + 1
w_high = w_low + 1
lh = h - h_low
lw = w - w_low
hh = 1 - lh
hw = 1 - lw
v1 = 0
if h_low >= 0 and w_low >= 0:
v1 = data_im[h_low, w_low]
v2 = 0
if h_low >= 0 and w_high <= width - 1:
v2 = data_im[h_low, w_high]
v3 = 0
if h_high <= height - 1 and w_low >= 0:
v3 = data_im[h_high, w_low]
v4 = 0
if h_high <= height - 1 and w_high <= width - 1:
v4 = data_im[h_high, w_high]
w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
in_n, in_c, in_h, in_w = input.shape
out_c, f_c, f_h, f_w = filter.shape
assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w)
assert mask.shape == (in_n, f_h * f_w, in_h, in_w)
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
stride, pad, dilation = conv_param['stride'], conv_param['pad'],\
conv_param['dilation']
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
assert out_h == in_h
assert out_w == in_w
col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w))
for n in range(in_n):
for c in range(in_c):
for h in range(out_h):
for w in range(out_w):
for kh in range(f_h):
for kw in range(f_w):
offset_h_table = \
offset[n, ::2, h, w].reshape(f_h, f_w)
offset_w_table = \
offset[n, 1::2, h, w].reshape(f_h, f_w)
mask_table = \
mask[n, :, h, w].reshape(f_h, f_w)
offset_h = offset_h_table[kh, kw]
offset_w = offset_w_table[kh, kw]
val = 0
im_h = h * stride[0] + kh * dilation[0] \
+ offset_h - pad[0]
im_w = w * stride[0] + kw * dilation[0] \
+ offset_w - pad[1]
if im_h > -1 and im_w > -1 and \
im_h < in_h and im_w < in_h:
val = dmc_bilinear(input[n, c], in_h, in_w,
im_h, im_w)
val_out = val * mask_table[kh, kw]
col_buffer[n, c * f_h * f_w + kh * f_w + kw, h *
in_w + w] = val_out
out = np.zeros((in_n, group, int(out_c // group), out_h * out_w))
weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w)
col_buffer = col_buffer.reshape(
(in_n, group, int(in_c // group * f_h * f_w), in_h * in_w))
for n in range(in_n):
for g in range(group):
out[n, g] = np.matmul(weight[g], col_buffer[n, g])
out = out.reshape(in_n, out_c, out_h, out_w)
return out
class TestModulatedDeformableConvOp(OpTest):
def setUp(self):
self.op_type = "deformable_conv"
self.dtype = np.float32
self.init_group()
self.init_dilation()
self.init_test_case()
conv_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
offset = 10 * np.random.random(self.offset_size).astype(self.dtype)
mask = 10 * np.random.random(self.mask_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = dconv_im2col_gemm(input, offset, mask, filter, self.groups,
conv_param)
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Offset': OpTest.np_dtype_to_fluid_dtype(offset),
'Mask': OpTest.np_dtype_to_fluid_dtype(mask),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'deformable_groups': self.deformable_groups,
'im2col_step': self.im2col_step,
'dilations': self.dilations,
}
self.outputs = {'Output': output}
def has_cuda(self):
return core.is_compiled_with_cuda()
def test_check_output(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
def test_check_grad(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, {'Input', 'Offset', 'Mask', 'Filter'},
'Output',
max_relative_error=0.05)
def test_check_grad_no_filter(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Input']))
def test_check_grad_no_offset_no_mask(self):
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input', 'Filter'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Offset', 'Mask']))
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 4, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [4, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self):
self.groups = 1
class TestWithStride(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [3, 3]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
class TestWithDilation(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [1, 1]
self.input_size = [2, 3, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [2, 2]
class TestWith1x1(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
class TestWithGroup(TestModulatedDeformableConvOp):
def init_group(self):
self.groups = 2
if __name__ == '__main__':
unittest.main()
......@@ -1957,6 +1957,34 @@ class TestBook(LayerTest):
self.assertIsNotNone(out)
print(str(program))
def test_deformable_conv(self):
if core.is_compiled_with_cuda():
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='input',
append_batch_size=False,
shape=[2, 3, 32, 32],
dtype="float32")
offset = layers.data(
name='offset',
append_batch_size=False,
shape=[2, 18, 32, 32],
dtype="float32")
mask = layers.data(
name='mask',
append_batch_size=False,
shape=[2, 9, 32, 32],
dtype="float32")
out = layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=2,
filter_size=3,
padding=1)
return (out)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册