Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
bba57cdd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bba57cdd
编写于
5月 30, 2019
作者:
B
Bai Yifan
提交者:
GitHub
5月 30, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add deformable conv v2 op,test=develop (#17145)
* unit commits, test=develop * update API.spec, test=develop
上级
aca53535
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
1522 addition
and
2 deletion
+1522
-2
cmake/operators.cmake
cmake/operators.cmake
+1
-1
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+4
-1
paddle/fluid/operators/deformable_conv_op.cc
paddle/fluid/operators/deformable_conv_op.cc
+278
-0
paddle/fluid/operators/deformable_conv_op.cu
paddle/fluid/operators/deformable_conv_op.cu
+743
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+173
-0
python/paddle/fluid/tests/unittests/test_deformable_conv_op.py
...n/paddle/fluid/tests/unittests/test_deformable_conv_op.py
+294
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+28
-0
未找到文件。
cmake/operators.cmake
浏览文件 @
bba57cdd
...
...
@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach
(
manual_pybind_op
"compare_op"
"logical_op"
"nccl_op"
"tensor_array_read_write_op"
"tensorrt_engine_op"
"conv_fusion_op"
"fusion_transpose_flatten_concat_op"
"fusion_conv_inception_op"
"sync_batch_norm_op"
"dgc_op"
)
"fusion_transpose_flatten_concat_op"
"fusion_conv_inception_op"
"sync_batch_norm_op"
"d
eformable_conv_op"
"d
gc_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
set
(
pybind_flag 1
)
endif
()
...
...
paddle/fluid/API.spec
浏览文件 @
bba57cdd
...
...
@@ -237,6 +237,7 @@ paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=
paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', '94e2819b7c9715ea71b62e9c78f36b29'))
paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6'))
paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'ccf6bb7912afd2818d24bc45461e807a'))
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', 'c896b66265a60bd3c5510f66e6e02919'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', 'adf285346e23316097f7789b572491e9'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'dce69a78638da8f7ad80b1fc00ed2029'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6'))
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
bba57cdd
...
...
@@ -47,7 +47,8 @@ if (WITH_DISTRIBUTE)
SET
(
OP_PREFETCH_DEPS
${
OP_PREFETCH_DEPS
}
parameter_prefetch
)
endif
()
register_operators
(
EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op sync_batch_norm_op DEPS
${
OP_HEADER_DEPS
}
${
OP_PREFETCH_DEPS
}
)
register_operators
(
EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
sync_batch_norm_op deformable_conv_op DEPS
${
OP_HEADER_DEPS
}
${
OP_PREFETCH_DEPS
}
)
if
(
WITH_GPU
)
# warpctc_op needs cudnn 7 above
...
...
@@ -65,6 +66,8 @@ if (WITH_GPU)
op_library
(
sync_batch_norm_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(sync_batch_norm);
\n
"
)
endif
()
op_library
(
deformable_conv_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(deformable_conv);
\n
"
)
else
()
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
endif
()
...
...
paddle/fluid/operators/deformable_conv_op.cc
0 → 100644
浏览文件 @
bba57cdd
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/conv_op.h"
namespace
paddle
{
namespace
operators
{
class
DeformableConvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input"
,
"(Tensor) The input of deformable conv op. "
"The shape of input is "
"[N, channel_in, H, W]"
);
AddInput
(
"Offset"
,
"(Tensor) The input offset. "
"The shape of the offset is "
"[N, deformable_groups * kernel_w * kernel_h * 2, H, W"
);
AddInput
(
"Mask"
,
"(Tensor) The input mask. "
"The shape of the mask is "
"[N, deformable_groups * kernel_w * kernel_h, H, W]."
);
AddInput
(
"Filter"
,
"(Tensor) The Input Filter "
"The shape of the wight is "
"[num_filters, channel_in, kernel_h, kernel_w."
);
AddOutput
(
"Output"
,
"(Tensor) The output. "
"The shape of the output tensor is "
"[N, num_filters, out_height, out_width]]."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator."
)
.
SetDefault
({
1
,
1
});
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector<int> default:{0,0}), the "
"paddings(h_pad, w_pad) of "
"convolution operator. "
)
.
SetDefault
({
0
,
0
});
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator."
)
.
SetDefault
({
1
,
1
});
AddAttr
<
int
>
(
"groups"
,
"(int default:1), the groups number of the convolution operator. "
"According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
"when group=2, the first half of the filters is only connected to the "
"first half of the input channels, while the second half of the "
"filters "
"is only connected to the second half of the input channels."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
"deformable_groups"
,
"(int default:1), the number of the deformable groups."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
"im2col_step"
,
"im2col maximum number of image per computation"
)
.
SetDefault
(
64
);
AddComment
(
R"DOC(
**Deformable Convolution Operator**
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
$$
y(p) = \\sum_{k=1}^{K}{w_k * x(p + p_k + \\Delta p_k) * \\Delta m_k}
$$
Where $$\\Delta p_k$$ and $$\Delta m_k$$ are the learnable offset and modulation scalar for the k-th location, respectively.
Refer to 'Deformable ConvNets v2: More Deformable, Better Results
'<https://arxiv.org/abs/1811.11168v2>
Example:
Input:
Input shape: $(N, C_{in}, H_{in}, W_{in})$
Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
Offset shape: $(N, 2 * deformable_groups, * H_f * W_f, H_{out}, W_{out})$
Mask shape: $(N, deformable_groups * H_f * W_f, H_{out}, W_{out})$
Output:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
where $H_{out}, W_{out}$ must be equal to $H_{in}, W_{in}$ respectively.
Where
$$
H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC"
);
}
};
class
DeformableConvOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of DeformableConvOp "
"should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Offset"
),
"Input(Offset) of DeformableConvOp "
"should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Mask"
),
"Input(Mask) of DeformableConvOp "
"should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Filter"
),
"Input(Filter) of DeformableConvOp "
"should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Output"
),
"Output(Output) of DeformableConvOp "
"should not be null."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
auto
offset_dims
=
ctx
->
GetInputDim
(
"Offset"
);
auto
mask_dims
=
ctx
->
GetInputDim
(
"Mask"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
int
deformable_groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"deformable_groups"
);
int
im2col_step
=
ctx
->
Attrs
().
Get
<
int
>
(
"im2col_step"
);
PADDLE_ENFORCE
(
in_dims
.
size
()
==
4
,
"Conv input should be 4-D tensor, get %u"
,
in_dims
.
size
());
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
filter_dims
.
size
(),
"Conv input dimension and filter dimension should be the same."
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
()
-
strides
.
size
(),
2U
,
"Conv input dimension and strides dimension should be consistent."
);
PADDLE_ENFORCE_EQ
(
paddings
.
size
(),
strides
.
size
(),
"Conv paddings dimension and Conv strides dimension "
"should be the same."
);
PADDLE_ENFORCE_EQ
(
in_dims
[
1
],
filter_dims
[
1
]
*
groups
,
"The number of input channels should be equal to filter "
"channels * groups."
);
PADDLE_ENFORCE_EQ
(
filter_dims
[
0
]
%
groups
,
0
,
"The number of output channels should be divided by groups."
);
PADDLE_ENFORCE_EQ
(
filter_dims
[
0
]
%
deformable_groups
,
0
,
"The number of output channels should be "
"divided by deformable groups."
);
if
(
in_dims
[
0
]
>
im2col_step
)
{
PADDLE_ENFORCE_EQ
(
in_dims
[
0
]
%
im2col_step
,
0U
,
"Input batchsize must be smaller than or divide im2col_step"
);
}
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
PADDLE_ENFORCE_GT
(
strides
[
i
],
0U
,
"stride %d size incorrect"
,
i
);
}
for
(
size_t
i
=
0
;
i
<
dilations
.
size
();
++
i
)
{
PADDLE_ENFORCE_GT
(
dilations
[
i
],
0U
,
"dilation %d size incorrect"
,
i
);
}
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
],
filter_dims
[
0
]});
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
output_shape
.
push_back
(
ConvOutputSize
(
in_dims
[
i
+
2
],
filter_dims
[
i
+
2
],
dilations
[
i
],
paddings
[
i
],
strides
[
i
]));
}
PADDLE_ENFORCE_EQ
(
output_shape
[
1
]
%
deformable_groups
,
0U
,
"output num_filter must divide deformable group size."
);
PADDLE_ENFORCE_EQ
(
output_shape
[
2
],
offset_dims
[
2
],
"output height must equal to offset map height."
);
PADDLE_ENFORCE_EQ
(
output_shape
[
3
],
offset_dims
[
3
],
"output width must equal to offset map width."
);
PADDLE_ENFORCE_EQ
(
offset_dims
[
1
]
%
(
filter_dims
[
2
]
*
filter_dims
[
3
]),
0U
,
"offset filter must divide deformable group size."
);
PADDLE_ENFORCE_EQ
(
offset_dims
[
1
]
/
(
2
*
filter_dims
[
2
]
*
filter_dims
[
3
]),
deformable_groups
,
"offset filter must divide deformable group size."
);
PADDLE_ENFORCE_EQ
(
output_shape
[
2
],
mask_dims
[
2
],
"output height must equal to mask map height."
);
PADDLE_ENFORCE_EQ
(
output_shape
[
3
],
mask_dims
[
3
],
"output width must equal to mask map width."
);
PADDLE_ENFORCE_EQ
(
mask_dims
[
1
]
%
(
filter_dims
[
2
]
*
filter_dims
[
3
]),
0U
,
"mask filter must divide deformable group size."
);
PADDLE_ENFORCE_EQ
(
mask_dims
[
1
]
/
(
filter_dims
[
2
]
*
filter_dims
[
3
]),
deformable_groups
,
"mask filter must divide deformable group size."
);
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
}
};
class
DeformableConvGradOpDescMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
op
(
new
framework
::
OpDesc
());
op
->
SetType
(
"deformable_conv_grad"
);
op
->
SetInput
(
"Input"
,
Input
(
"Input"
));
op
->
SetInput
(
"Filter"
,
Input
(
"Filter"
));
op
->
SetInput
(
"Offset"
,
Input
(
"Offset"
));
op
->
SetInput
(
"Mask"
,
Input
(
"Mask"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Output"
),
OutputGrad
(
"Output"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Input"
),
InputGrad
(
"Input"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Filter"
),
InputGrad
(
"Filter"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Offset"
),
InputGrad
(
"Offset"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Mask"
),
InputGrad
(
"Mask"
));
op
->
SetAttrMap
(
Attrs
());
return
op
;
}
};
class
DeformableConvGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
auto
offset_dims
=
ctx
->
GetInputDim
(
"Offset"
);
auto
mask_dims
=
ctx
->
GetInputDim
(
"Mask"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Output"
)),
"the gradient of output(Out) must not be null"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Input"
),
in_dims
);
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Filter"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Filter"
),
filter_dims
);
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Offset"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Offset"
),
offset_dims
);
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Mask"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Mask"
),
mask_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
deformable_conv
,
ops
::
DeformableConvOp
,
ops
::
DeformableConvOpMaker
,
ops
::
DeformableConvGradOpDescMaker
);
REGISTER_OPERATOR
(
deformable_conv_grad
,
ops
::
DeformableConvGradOp
);
paddle/fluid/operators/deformable_conv_op.cu
0 → 100644
浏览文件 @
bba57cdd
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaximumNumBlocks
=
4096
;
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaximumNumBlocks
);
}
template
<
typename
T
>
__device__
T
DmcnGetGradientWeight
(
T
argmax_h
,
T
argmax_w
,
const
int
h
,
const
int
w
,
const
int
height
,
const
int
width
)
{
if
(
argmax_h
<=
-
1
||
argmax_h
>=
height
||
argmax_w
<=
-
1
||
argmax_w
>=
width
)
{
return
0
;
}
int
argmax_h_low
=
floor
(
argmax_h
);
int
argmax_w_low
=
floor
(
argmax_w
);
int
argmax_h_high
=
argmax_h_low
+
1
;
int
argmax_w_high
=
argmax_w_low
+
1
;
T
weight
=
0
;
if
(
h
==
argmax_h_low
&&
w
==
argmax_w_low
)
weight
=
(
h
+
1
-
argmax_h
)
*
(
w
+
1
-
argmax_w
);
if
(
h
==
argmax_h_low
&&
w
==
argmax_w_high
)
weight
=
(
h
+
1
-
argmax_h
)
*
(
argmax_w
+
1
-
w
);
if
(
h
==
argmax_h_high
&&
w
==
argmax_w_low
)
weight
=
(
argmax_h
+
1
-
h
)
*
(
w
+
1
-
argmax_w
);
if
(
h
==
argmax_h_high
&&
w
==
argmax_w_high
)
weight
=
(
argmax_h
+
1
-
h
)
*
(
argmax_w
+
1
-
w
);
return
weight
;
}
template
<
typename
T
>
__global__
void
ModulatedDeformableCol2imGpuKernel
(
const
int
nthreads
,
const
T
*
data_col
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
grad_im
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
thread
=
index
;
thread
<
nthreads
;
thread
+=
offset
)
{
const
int
j
=
(
thread
/
width_col
/
height_col
/
batch_size
)
%
kernel_w
;
const
int
i
=
(
thread
/
width_col
/
height_col
/
batch_size
/
kernel_w
)
%
kernel_h
;
const
int
c
=
thread
/
width_col
/
height_col
/
batch_size
/
kernel_w
/
kernel_h
;
const
int
deformable_group_index
=
c
/
channel_per_deformable_group
;
int
w_out
=
thread
%
width_col
;
int
h_out
=
(
thread
/
width_col
)
%
height_col
;
int
b
=
(
thread
/
width_col
/
height_col
)
%
batch_size
;
int
w_in
=
w_out
*
stride_w
-
pad_w
;
int
h_in
=
h_out
*
stride_h
-
pad_h
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
const
T
cur_inv_h_data
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
T
cur_inv_w_data
=
w_in
+
j
*
dilation_w
+
offset_w
;
const
T
cur_top_grad
=
data_col
[
thread
]
*
mask
;
const
int
cur_h
=
static_cast
<
int
>
(
cur_inv_h_data
);
const
int
cur_w
=
static_cast
<
int
>
(
cur_inv_w_data
);
for
(
int
dy
=
-
2
;
dy
<=
2
;
dy
++
)
{
for
(
int
dx
=
-
2
;
dx
<=
2
;
dx
++
)
{
if
(
cur_h
+
dy
>=
0
&&
cur_h
+
dy
<
height
&&
cur_w
+
dx
>=
0
&&
cur_w
+
dx
<
width
&&
abs
(
cur_inv_h_data
-
(
cur_h
+
dy
))
<
1
&&
abs
(
cur_inv_w_data
-
(
cur_w
+
dx
))
<
1
)
{
int
cur_bottom_grad_pos
=
((
b
*
channels
+
c
)
*
height
+
cur_h
+
dy
)
*
width
+
cur_w
+
dx
;
T
weight
=
DmcnGetGradientWeight
(
cur_inv_h_data
,
cur_inv_w_data
,
cur_h
+
dy
,
cur_w
+
dx
,
height
,
width
);
atomicAdd
(
grad_im
+
cur_bottom_grad_pos
,
weight
*
cur_top_grad
);
}
}
}
}
}
template
<
typename
T
>
inline
void
ModulatedDeformableCol2im
(
const
platform
::
DeviceContext
&
ctx
,
const
T
*
data_col
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>
im_shape
,
const
std
::
vector
<
int64_t
>
col_shape
,
const
std
::
vector
<
int64_t
>
kernel_shape
,
const
std
::
vector
<
int
>
pad
,
const
std
::
vector
<
int
>
stride
,
const
std
::
vector
<
int
>
dilation
,
const
int
deformable_group
,
T
*
grad_im
)
{
int
channel_per_deformable_group
=
im_shape
[
0
]
/
deformable_group
;
int
num_kernels
=
col_shape
[
0
]
*
col_shape
[
1
]
*
col_shape
[
2
]
*
col_shape
[
3
];
int
blocks
=
NumBlocks
(
num_kernels
);
int
threads
=
kNumCUDAThreads
;
ModulatedDeformableCol2imGpuKernel
<
T
><<<
blocks
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
num_kernels
,
data_col
,
data_offset
,
data_mask
,
im_shape
[
0
],
im_shape
[
1
],
im_shape
[
2
],
kernel_shape
[
2
],
kernel_shape
[
3
],
pad
[
0
],
pad
[
1
],
stride
[
0
],
stride
[
1
],
dilation
[
0
],
dilation
[
1
],
channel_per_deformable_group
,
col_shape
[
1
],
deformable_group
,
col_shape
[
2
],
col_shape
[
3
],
grad_im
);
}
template
<
typename
T
>
__device__
T
DmcnGetCoordinateWeight
(
T
argmax_h
,
T
argmax_w
,
const
int
height
,
const
int
width
,
const
T
*
im_data
,
const
int
data_width
,
const
int
bp_dir
)
{
if
(
argmax_h
<=
-
1
||
argmax_h
>=
height
||
argmax_w
<=
-
1
||
argmax_w
>=
width
)
{
return
0
;
}
int
argmax_h_low
=
floor
(
argmax_h
);
int
argmax_w_low
=
floor
(
argmax_w
);
int
argmax_h_high
=
argmax_h_low
+
1
;
int
argmax_w_high
=
argmax_w_low
+
1
;
T
weight
=
0
;
if
(
bp_dir
==
0
)
{
if
(
argmax_h_low
>=
0
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_w_low
+
1
-
argmax_w
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_low
];
if
(
argmax_h_low
>=
0
&&
argmax_w_high
<=
width
-
1
)
weight
+=
-
1
*
(
argmax_w
-
argmax_w_low
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_high
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_low
>=
0
)
weight
+=
(
argmax_w_low
+
1
-
argmax_w
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_low
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_w
-
argmax_w_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_high
];
}
else
if
(
bp_dir
==
1
)
{
if
(
argmax_h_low
>=
0
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_h_low
+
1
-
argmax_h
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_low
];
if
(
argmax_h_low
>=
0
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_h_low
+
1
-
argmax_h
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_high
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_h
-
argmax_h_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_low
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_h
-
argmax_h_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_high
];
}
return
weight
;
}
template
<
typename
T
>
__global__
void
ModulatedDeformableCol2imCoordGpuKernel
(
const
int
nthreads
,
const
T
*
data_col
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
offset_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
grad_offset
,
T
*
grad_mask
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
T
val
=
0
,
mval
=
0
;
const
int
w
=
i
%
width_col
;
const
int
h
=
(
i
/
width_col
)
%
height_col
;
const
int
c
=
(
i
/
width_col
/
height_col
)
%
offset_channels
;
const
int
b
=
(
i
/
width_col
/
height_col
)
/
offset_channels
;
const
int
deformable_group_index
=
c
/
(
2
*
kernel_h
*
kernel_w
);
const
int
col_step
=
kernel_h
*
kernel_w
;
int
cnt
=
0
;
const
T
*
data_col_ptr
=
data_col
+
deformable_group_index
*
channel_per_deformable_group
*
batch_size
*
width_col
*
height_col
;
const
T
*
data_im_ptr
=
data_im
+
(
b
*
deformable_group
+
deformable_group_index
)
*
channel_per_deformable_group
/
kernel_h
/
kernel_w
*
height
*
width
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
int
offset_c
=
c
-
deformable_group_index
*
2
*
kernel_h
*
kernel_w
;
for
(
int
col_c
=
offset_c
/
2
;
col_c
<
channel_per_deformable_group
;
col_c
+=
col_step
)
{
const
int
col_pos
=
(((
col_c
*
batch_size
+
b
)
*
height_col
)
+
h
)
*
width_col
+
w
;
const
int
bp_dir
=
offset_c
%
2
;
int
j
=
(
col_pos
/
width_col
/
height_col
/
batch_size
)
%
kernel_w
;
int
i
=
(
col_pos
/
width_col
/
height_col
/
batch_size
/
kernel_w
)
%
kernel_h
;
int
w_out
=
col_pos
%
width_col
;
int
h_out
=
(
col_pos
/
width_col
)
%
height_col
;
int
w_in
=
w_out
*
stride_w
-
pad_w
;
int
h_in
=
h_out
*
stride_h
-
pad_h
;
const
int
data_offset_h_ptr
=
(((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
int
data_offset_w_ptr
=
(((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
int
data_mask_hw_ptr
=
(((
i
*
kernel_w
+
j
)
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
T
inv_h
=
h_in
+
i
*
dilation_h
+
offset_h
;
T
inv_w
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
inv_h
<=
-
1
||
inv_w
<=
-
1
||
inv_h
>=
height
||
inv_w
>=
width
)
{
inv_h
=
inv_w
=
-
2
;
}
else
{
mval
+=
data_col_ptr
[
col_pos
]
*
DmcnIm2colBilinear
(
data_im_ptr
+
cnt
*
height
*
width
,
width
,
height
,
width
,
inv_h
,
inv_w
);
}
const
T
weight
=
DmcnGetCoordinateWeight
(
inv_h
,
inv_w
,
height
,
width
,
data_im_ptr
+
cnt
*
height
*
width
,
width
,
bp_dir
);
val
+=
weight
*
data_col_ptr
[
col_pos
]
*
mask
;
cnt
+=
1
;
}
grad_offset
[
i
]
=
val
;
if
(
offset_c
%
2
==
0
)
grad_mask
[(((
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
+
offset_c
/
2
)
*
height_col
+
h
)
*
width_col
+
w
]
=
mval
;
}
}
template
<
typename
T
>
inline
void
ModulatedDeformableCol2imCoord
(
const
platform
::
DeviceContext
&
ctx
,
const
T
*
data_col
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>
im_shape
,
const
std
::
vector
<
int64_t
>
col_shape
,
const
std
::
vector
<
int64_t
>
kernel_shape
,
const
std
::
vector
<
int
>
paddings
,
const
std
::
vector
<
int
>
strides
,
const
std
::
vector
<
int
>
dilations
,
const
int
deformable_groups
,
T
*
grad_offset
,
T
*
grad_mask
)
{
int
num_kernels
=
2
*
kernel_shape
[
2
]
*
kernel_shape
[
3
]
*
col_shape
[
1
]
*
col_shape
[
2
]
*
col_shape
[
3
]
*
deformable_groups
;
int
channel_per_deformable_group
=
col_shape
[
0
]
/
deformable_groups
;
int
blocks
=
NumBlocks
(
num_kernels
);
int
threads
=
kNumCUDAThreads
;
ModulatedDeformableCol2imCoordGpuKernel
<
T
><<<
blocks
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
num_kernels
,
data_col
,
data_im
,
data_offset
,
data_mask
,
im_shape
[
0
],
im_shape
[
1
],
im_shape
[
2
],
kernel_shape
[
2
],
kernel_shape
[
3
],
paddings
[
0
],
paddings
[
1
],
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
channel_per_deformable_group
,
col_shape
[
1
],
2
*
kernel_shape
[
2
]
*
kernel_shape
[
3
]
*
deformable_groups
,
deformable_groups
,
col_shape
[
2
],
col_shape
[
3
],
grad_offset
,
grad_mask
);
}
template
<
typename
T
>
__device__
T
DmcnIm2colBilinear
(
const
T
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
T
h
,
T
w
)
{
int
h_low
=
floor
(
h
);
int
w_low
=
floor
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
T
lh
=
h
-
h_low
;
T
lw
=
w
-
w_low
;
T
hh
=
1
-
lh
,
hw
=
1
-
lw
;
T
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
v1
=
bottom_data
[
h_low
*
data_width
+
w_low
];
T
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
v2
=
bottom_data
[
h_low
*
data_width
+
w_high
];
T
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
v3
=
bottom_data
[
h_high
*
data_width
+
w_low
];
T
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
v4
=
bottom_data
[
h_high
*
data_width
+
w_high
];
T
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
T
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
typename
T
>
__global__
void
ModulatedDeformableIm2colGpuKernel
(
const
int
nthreads
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
data_col
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
const
int
w_col
=
i
%
width_col
;
const
int
h_col
=
(
i
/
width_col
)
%
height_col
;
const
int
b_col
=
(
i
/
width_col
)
/
height_col
%
batch_size
;
const
int
c_im
=
(
i
/
width_col
/
height_col
)
/
batch_size
;
const
int
c_col
=
c_im
*
kernel_h
*
kernel_w
;
const
int
deformable_group_index
=
c_im
/
channel_per_deformable_group
;
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
T
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
for
(
int
i
=
0
;
i
<
kernel_h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_w
;
++
j
)
{
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
T
val
=
static_cast
<
T
>
(
0
);
const
T
h_im
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
T
w_im
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height
&&
w_im
<
width
)
{
val
=
DmcnIm2colBilinear
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
*
mask
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
}
}
}
}
template
<
typename
T
>
inline
void
ModulatedDeformableIm2col
(
const
platform
::
DeviceContext
&
ctx
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>
im_shape
,
const
std
::
vector
<
int64_t
>
col_shape
,
const
std
::
vector
<
int64_t
>
filter_shape
,
const
std
::
vector
<
int
>
paddings
,
const
std
::
vector
<
int
>
strides
,
const
std
::
vector
<
int
>
dilations
,
const
int
deformable_groups
,
T
*
data_col
)
{
int
channel_per_deformable_group
=
im_shape
[
0
]
/
deformable_groups
;
int
num_kernels
=
im_shape
[
0
]
*
col_shape
[
1
]
*
col_shape
[
2
]
*
col_shape
[
3
];
int
blocks
=
NumBlocks
(
num_kernels
);
int
threads
=
kNumCUDAThreads
;
ModulatedDeformableIm2colGpuKernel
<
T
><<<
blocks
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
num_kernels
,
data_im
,
data_offset
,
data_mask
,
im_shape
[
1
],
im_shape
[
2
],
filter_shape
[
2
],
filter_shape
[
3
],
paddings
[
0
],
paddings
[
1
],
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
channel_per_deformable_group
,
col_shape
[
1
],
im_shape
[
0
],
deformable_groups
,
col_shape
[
2
],
col_shape
[
3
],
data_col
);
}
template
<
typename
T
>
__global__
void
FilterGradAddupGpuKernel
(
const
int
nthreads
,
const
int
n
,
const
int
height
,
const
int
width
,
const
T
*
dweight_3d
,
T
*
filter_grad
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
filter_grad
[
i
]
=
filter_grad
[
i
]
+
dweight_3d
[
i
];
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
DeformableConvCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
Tensor
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
const
Tensor
offset
=
*
ctx
.
Input
<
Tensor
>
(
"Offset"
);
const
Tensor
mask
=
*
ctx
.
Input
<
Tensor
>
(
"Mask"
);
Tensor
filter
=
*
ctx
.
Input
<
Tensor
>
(
"Filter"
);
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
const
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
const
int
deformable_groups
=
ctx
.
Attr
<
int
>
(
"deformable_groups"
);
const
int
im2col_step
=
ctx
.
Attr
<
int
>
(
"im2col_step"
);
const
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
const
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
const
std
::
vector
<
int
>
dilations
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output
->
dims
()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std
::
vector
<
int64_t
>
col_buffer_shape_vec
(
filter_shape_vec
.
size
());
col_buffer_shape_vec
[
0
]
=
input
->
dims
()[
1
]
*
filter
.
dims
()[
2
]
*
filter
.
dims
()[
3
];
col_buffer_shape_vec
[
1
]
=
im2col_step
;
for
(
size_t
j
=
0
;
j
<
filter_shape_vec
.
size
()
-
2
;
++
j
)
{
col_buffer_shape_vec
[
j
+
2
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_buffer_shape_vec
));
std
::
vector
<
int64_t
>
output_buffer_shape_vec
(
1
);
output_buffer_shape_vec
[
0
]
=
batch_size
*
output_shape_vec
[
1
]
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
framework
::
DDim
output_shape
(
framework
::
make_ddim
(
output_buffer_shape_vec
));
Tensor
col_buffer
;
Tensor
output_buffer
;
col_buffer
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
col_shape
,
dev_ctx
);
output_buffer
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
output_shape
,
dev_ctx
);
int64_t
M
=
output_shape_vec
[
1
]
/
groups
;
int64_t
N
=
im2col_step
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
int64_t
K
=
input
->
dims
()[
1
]
*
filter_shape_vec
[
2
]
*
filter_shape_vec
[
3
]
/
groups
;
Tensor
weight_3d
;
weight_3d
.
ShareDataWith
(
filter
).
Resize
(
framework
::
make_ddim
({
groups
,
M
,
K
}));
Tensor
col_buffer_3d
;
col_buffer_3d
.
ShareDataWith
(
col_buffer
)
.
Resize
(
framework
::
make_ddim
({
groups
,
K
,
N
}));
Tensor
output_4d
;
output_4d
.
ShareDataWith
(
output_buffer
)
.
Resize
(
framework
::
make_ddim
({
batch_size
/
im2col_step
,
groups
,
M
,
N
}));
output_4d
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
DDim
input_shape
=
framework
::
slice_ddim
(
input
->
dims
(),
1
,
input
->
dims
().
size
());
std
::
vector
<
int64_t
>
input_shape_vec
=
framework
::
vectorize
(
input_shape
);
int
input_dim
=
input
->
numel
()
/
input
->
dims
()[
0
];
int
input_offset_dim
=
offset
.
numel
()
/
offset
.
dims
()[
0
];
int
input_mask_dim
=
mask
.
numel
()
/
mask
.
dims
()[
0
];
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
offset_ptr
=
offset
.
data
<
T
>
();
const
T
*
mask_ptr
=
mask
.
data
<
T
>
();
col_buffer
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
col_buffer_ptr
=
col_buffer
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
batch_size
/
im2col_step
;
++
i
)
{
ModulatedDeformableIm2col
(
ctx
.
device_context
(),
input_ptr
+
i
*
im2col_step
*
input_dim
,
offset_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_ptr
+
i
*
im2col_step
*
input_mask_dim
,
input_shape_vec
,
col_buffer_shape_vec
,
filter_shape_vec
,
paddings
,
strides
,
dilations
,
deformable_groups
,
col_buffer_ptr
);
Tensor
output_3d
=
output_4d
.
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
slice_ddim
(
output_4d
.
dims
(),
1
,
output_4d
.
dims
().
size
()));
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
Tensor
weight_3d_slice
=
weight_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
weight_3d
.
dims
(),
1
,
weight_3d
.
dims
().
size
()));
Tensor
col_buffer_3d_slice
=
col_buffer_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
col_buffer_3d
.
dims
(),
1
,
col_buffer_3d
.
dims
().
size
()));
Tensor
output_3d_slice
=
output_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
output_3d
.
dims
(),
1
,
output_3d
.
dims
().
size
()));
blas
.
MatMul
(
weight_3d_slice
,
false
,
col_buffer_3d_slice
,
false
,
T
(
1.0
),
&
output_3d_slice
,
T
(
0.0
));
}
}
output
->
ShareDataWith
(
output_buffer
)
.
Resize
(
framework
::
make_ddim
(
output_shape_vec
));
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
DeformableConvGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
Tensor
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
Tensor
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
Tensor
*
filter_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Filter"
));
Tensor
*
offset_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Offset"
));
Tensor
*
mask_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Mask"
));
const
Tensor
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
Tensor
offset
=
*
ctx
.
Input
<
Tensor
>
(
"Offset"
);
Tensor
mask
=
*
ctx
.
Input
<
Tensor
>
(
"Mask"
);
Tensor
filter
=
*
ctx
.
Input
<
Tensor
>
(
"Filter"
);
if
(
!
input_grad
&&
!
filter_grad
&&
!
offset_grad
&&
!
mask_grad
)
return
;
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
int
deformable_groups
=
ctx
.
Attr
<
int
>
(
"deformable_groups"
);
int
im2col_step
=
ctx
.
Attr
<
int
>
(
"im2col_step"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
framework
::
DDim
input_shape
=
framework
::
slice_ddim
(
input
->
dims
(),
1
,
input
->
dims
().
size
());
std
::
vector
<
int64_t
>
input_shape_vec
=
framework
::
vectorize
(
input_shape
);
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output_grad
->
dims
()));
std
::
vector
<
int64_t
>
col_buffer_shape_vec
(
filter_shape_vec
.
size
());
col_buffer_shape_vec
[
0
]
=
input
->
dims
()[
1
]
*
filter
.
dims
()[
2
]
*
filter
.
dims
()[
3
];
col_buffer_shape_vec
[
1
]
=
im2col_step
;
for
(
size_t
j
=
0
;
j
<
filter_shape_vec
.
size
()
-
2
;
++
j
)
{
col_buffer_shape_vec
[
j
+
2
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_buffer_shape_vec
));
std
::
vector
<
int64_t
>
output_buffer_shape_vec
(
1
);
output_buffer_shape_vec
[
0
]
=
batch_size
*
output_shape_vec
[
1
]
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
framework
::
DDim
output_shape
(
framework
::
make_ddim
(
output_buffer_shape_vec
));
Tensor
col_buffer
;
Tensor
output_buffer
;
col_buffer
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
col_shape
,
dev_ctx
);
output_buffer
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
output_shape
,
dev_ctx
);
output_buffer
.
ShareDataWith
(
*
output_grad
);
int64_t
M
=
input_shape_vec
[
0
]
/
groups
*
filter_shape_vec
[
2
]
*
filter_shape_vec
[
3
];
int64_t
N
=
im2col_step
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
int64_t
K
=
output_shape_vec
[
1
]
/
groups
;
framework
::
DDim
weight_3d_shape
=
{
groups
,
K
,
M
};
framework
::
DDim
out_grad_4d_shape
=
{
batch_size
/
im2col_step
,
groups
,
K
,
N
};
framework
::
DDim
col_buffer_3d_shape
=
{
groups
,
M
,
N
};
framework
::
DDim
filter_grad_shape
=
{
groups
,
K
,
M
};
Tensor
weight_3d
;
weight_3d
.
ShareDataWith
(
filter
).
Resize
(
weight_3d_shape
);
Tensor
out_grad_4d
;
out_grad_4d
.
ShareDataWith
(
output_buffer
).
Resize
(
out_grad_4d_shape
);
Tensor
col_buffer_3d
;
col_buffer_3d
.
ShareDataWith
(
col_buffer
).
Resize
(
col_buffer_3d_shape
);
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
col_buffer
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
col_buffer_3d
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out_grad_4d
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
input_dim
=
input
->
numel
()
/
input
->
dims
()[
0
];
int
input_offset_dim
=
offset
.
numel
()
/
offset
.
dims
()[
0
];
int
input_mask_dim
=
mask
.
numel
()
/
mask
.
dims
()[
0
];
if
(
filter_grad
)
{
filter_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
filter_grad
->
Resize
(
filter_grad_shape
);
set_zero
(
dev_ctx
,
filter_grad
,
static_cast
<
T
>
(
0
));
}
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
input_grad
,
static_cast
<
T
>
(
0
));
}
if
(
offset_grad
&&
mask_grad
)
{
offset_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mask_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
offset_grad
,
static_cast
<
T
>
(
0
));
set_zero
(
dev_ctx
,
mask_grad
,
static_cast
<
T
>
(
0
));
}
for
(
int
i
=
0
;
i
<
batch_size
/
im2col_step
;
++
i
)
{
Tensor
out_grad_3d
=
out_grad_4d
.
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
slice_ddim
(
out_grad_4d
.
dims
(),
1
,
out_grad_4d
.
dims
().
size
()));
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
Tensor
weight_3d_slice
=
weight_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
weight_3d
.
dims
(),
1
,
weight_3d
.
dims
().
size
()));
Tensor
out_grad_3d_slice
=
out_grad_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
out_grad_3d
.
dims
(),
1
,
out_grad_3d
.
dims
().
size
()));
Tensor
col_buffer_3d_slice
=
col_buffer_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
col_buffer_3d
.
dims
(),
1
,
col_buffer_3d
.
dims
().
size
()));
blas
.
MatMul
(
weight_3d_slice
,
true
,
out_grad_3d_slice
,
false
,
T
(
1.0
),
&
col_buffer_3d_slice
,
T
(
0.0
));
}
col_buffer
.
Resize
(
col_shape
);
T
*
col_buffer_ptr
=
col_buffer
.
data
<
T
>
();
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
offset_ptr
=
offset
.
data
<
T
>
();
const
T
*
mask_ptr
=
mask
.
data
<
T
>
();
if
(
mask_grad
&&
offset_grad
)
{
T
*
offset_grad_ptr
=
offset_grad
->
data
<
T
>
();
T
*
mask_grad_ptr
=
mask_grad
->
data
<
T
>
();
ModulatedDeformableCol2imCoord
(
ctx
.
device_context
(),
col_buffer_ptr
,
input_ptr
+
i
*
im2col_step
*
input_dim
,
offset_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_ptr
+
i
*
im2col_step
*
input_mask_dim
,
input_shape_vec
,
col_buffer_shape_vec
,
filter_shape_vec
,
paddings
,
strides
,
dilations
,
deformable_groups
,
offset_grad_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_grad_ptr
+
i
*
im2col_step
*
input_mask_dim
);
}
if
(
input_grad
)
{
T
*
input_grad_ptr
=
input_grad
->
data
<
T
>
();
ModulatedDeformableCol2im
(
ctx
.
device_context
(),
col_buffer_ptr
,
offset_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_ptr
+
i
*
im2col_step
*
input_mask_dim
,
input_shape_vec
,
col_buffer_shape_vec
,
filter_shape_vec
,
paddings
,
strides
,
dilations
,
deformable_groups
,
input_grad_ptr
+
i
*
im2col_step
*
input_dim
);
input_grad
->
Resize
(
input
->
dims
());
}
ModulatedDeformableIm2col
(
ctx
.
device_context
(),
input_ptr
+
i
*
im2col_step
*
input_dim
,
offset_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_ptr
+
i
*
im2col_step
*
input_mask_dim
,
input_shape_vec
,
col_buffer_shape_vec
,
filter_shape_vec
,
paddings
,
strides
,
dilations
,
deformable_groups
,
col_buffer_ptr
);
col_buffer_3d
.
Resize
(
col_buffer_3d_shape
);
if
(
filter_grad
)
{
Tensor
dweight_3d
;
dweight_3d
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
filter_grad_shape
,
dev_ctx
);
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
Tensor
out_grad_3d_slice
=
out_grad_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
out_grad_3d
.
dims
(),
1
,
out_grad_3d
.
dims
().
size
()));
Tensor
col_buffer_3d_slice
=
col_buffer_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
col_buffer_3d
.
dims
(),
1
,
col_buffer_3d
.
dims
().
size
()));
Tensor
dweight_3d_slice
=
dweight_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
framework
::
slice_ddim
(
dweight_3d
.
dims
(),
1
,
dweight_3d
.
dims
().
size
()));
blas
.
MatMul
(
out_grad_3d_slice
,
false
,
col_buffer_3d_slice
,
true
,
T
(
1.0
),
&
dweight_3d_slice
,
T
(
0.0
));
}
FilterGradAddupGpuKernel
<
T
><<<
NumBlocks
(
dweight_3d
.
numel
()),
kNumCUDAThreads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
dweight_3d
.
numel
(),
groups
,
K
,
M
,
dweight_3d
.
data
<
T
>
(),
filter_grad
->
data
<
T
>
());
}
}
if
(
filter_grad
)
{
filter_grad
->
Resize
(
filter
.
dims
());
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
deformable_conv
,
ops
::
DeformableConvCUDAKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
deformable_conv_grad
,
ops
::
DeformableConvGradCUDAKernel
<
CUDA
,
float
>
);
python/paddle/fluid/layers/nn.py
浏览文件 @
bba57cdd
...
...
@@ -202,6 +202,7 @@ __all__ = [
'continuous_value_model'
,
'where'
,
'sign'
,
'deformable_conv'
,
]
kIgnoreIndex
=
-
100
...
...
@@ -11745,3 +11746,175 @@ def sign(x):
helper
.
append_op
(
type
=
'sign'
,
inputs
=
{
'X'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
]})
return
out
def
deformable_conv
(
input
,
offset
,
mask
,
num_filters
,
filter_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
None
,
deformable_groups
=
None
,
im2col_step
=
None
,
param_attr
=
None
,
bias_attr
=
None
,
name
=
None
):
"""
**Deformable Convolution Layer**
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
.. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k}
Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, respectively.
Refer to `Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168v2>`_ .
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
Offset shape: :math:`(N, 2 * deformable\_groups * H_f * H_w, H_{in}, W_{in})`
Mask shape: :math:`(N, deformable\_groups * H_f * H_w, H_{in}, W_{in})`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&=
\\
frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1
\\\\
W_{out}&=
\\
frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
input (Variable): The input image with [N, C, H, W] format.
offset (Variable): The input coord offset of deformable convolution layer.
Mask (Variable): The input mask of deformable covolution layer.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups (int): The groups number of the deformable conv layer. According to
grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1.
deformable_groups (int): The number of deformable group partitions.
Default: deformable_groups = 1.
im2col_step (int): Maximum number of images per im2col computation;
The total batch size should be divisable by this value or smaller
than this value; if you face out of memory problem, you can try
to use a smaller value here.
Default: im2col_step = 64.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of deformable conv. If it is set to None or one attribute of ParamAttr,
deformable conv will create ParamAttr as param_attr.
If the Initializer of the param_attr is not set, the parameter is
initialized with :math:`Normal(0.0, std)`, and the
:math:`std` is :math:`(
\\
frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of
deformable conv layer. If it is set to False, no bias will be added
to the output units. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: The tensor variable storing the deformable convolution
\
result.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32')
mask = fluid.layers.data(name='mask', shape=[9, 32, 32], dtype='float32')
out = fluid.layers.deformable_conv(input=data, offset=offset, mask=mask,
num_filters=2, filter_size=3, padding=1)
"""
num_channels
=
input
.
shape
[
1
]
assert
param_attr
is
not
False
,
"param_attr should not be False here."
helper
=
LayerHelper
(
'deformable_conv'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
if
not
isinstance
(
input
,
Variable
):
raise
TypeError
(
"Input of deformable_conv must be Variable"
)
if
not
isinstance
(
offset
,
Variable
):
raise
TypeError
(
"Input Offset of deformable_conv must be Variable"
)
if
not
isinstance
(
mask
,
Variable
):
raise
TypeError
(
"Input Mask of deformable_conv must be Variable"
)
if
groups
is
None
:
num_filter_channels
=
num_channels
else
:
if
num_channels
%
groups
!=
0
:
raise
ValueError
(
"num_channels must be divisible by groups."
)
num_filter_channels
=
num_channels
//
groups
filter_size
=
utils
.
convert_to_list
(
filter_size
,
2
,
'filter_size'
)
stride
=
utils
.
convert_to_list
(
stride
,
2
,
'stride'
)
padding
=
utils
.
convert_to_list
(
padding
,
2
,
'padding'
)
dilation
=
utils
.
convert_to_list
(
dilation
,
2
,
'dilation'
)
input_shape
=
input
.
shape
filter_shape
=
[
num_filters
,
int
(
num_filter_channels
)]
+
filter_size
def
_get_default_param_initializer
():
filter_elem_num
=
filter_size
[
0
]
*
filter_size
[
1
]
*
num_channels
std
=
(
2.0
/
filter_elem_num
)
**
0.5
return
Normal
(
0.0
,
std
,
0
)
filter_param
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
filter_shape
,
dtype
=
dtype
,
default_initializer
=
_get_default_param_initializer
())
pre_bias
=
helper
.
create_variable_for_type_inference
(
dtype
)
helper
.
append_op
(
type
=
'deformable_conv'
,
inputs
=
{
'Input'
:
input
,
'Filter'
:
filter_param
,
'Offset'
:
offset
,
'Mask'
:
mask
,
},
outputs
=
{
"Output"
:
pre_bias
},
attrs
=
{
'strides'
:
stride
,
'paddings'
:
padding
,
'dilations'
:
dilation
,
'groups'
:
groups
,
'deformable_groups'
:
deformable_groups
,
'im2col_step'
:
im2col_step
,
})
output
=
helper
.
append_bias_op
(
pre_bias
,
dim_start
=
1
,
dim_end
=
2
)
return
output
python/paddle/fluid/tests/unittests/test_deformable_conv_op.py
0 → 100644
浏览文件 @
bba57cdd
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
def
dmc_bilinear
(
data_im
,
height
,
width
,
h
,
w
):
h_low
=
int
(
np
.
floor
(
h
))
w_low
=
int
(
np
.
floor
(
w
))
h_high
=
h_low
+
1
w_high
=
w_low
+
1
lh
=
h
-
h_low
lw
=
w
-
w_low
hh
=
1
-
lh
hw
=
1
-
lw
v1
=
0
if
h_low
>=
0
and
w_low
>=
0
:
v1
=
data_im
[
h_low
,
w_low
]
v2
=
0
if
h_low
>=
0
and
w_high
<=
width
-
1
:
v2
=
data_im
[
h_low
,
w_high
]
v3
=
0
if
h_high
<=
height
-
1
and
w_low
>=
0
:
v3
=
data_im
[
h_high
,
w_low
]
v4
=
0
if
h_high
<=
height
-
1
and
w_high
<=
width
-
1
:
v4
=
data_im
[
h_high
,
w_high
]
w1
,
w2
,
w3
,
w4
=
hh
*
hw
,
hh
*
lw
,
lh
*
hw
,
lh
*
lw
val
=
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
return
val
def
dconv_im2col_gemm
(
input
,
offset
,
mask
,
filter
,
group
,
conv_param
):
in_n
,
in_c
,
in_h
,
in_w
=
input
.
shape
out_c
,
f_c
,
f_h
,
f_w
=
filter
.
shape
assert
offset
.
shape
==
(
in_n
,
2
*
f_h
*
f_w
,
in_h
,
in_w
)
assert
mask
.
shape
==
(
in_n
,
f_h
*
f_w
,
in_h
,
in_w
)
assert
f_c
*
group
==
in_c
assert
np
.
mod
(
out_c
,
group
)
==
0
stride
,
pad
,
dilation
=
conv_param
[
'stride'
],
conv_param
[
'pad'
],
\
conv_param
[
'dilation'
]
out_h
=
1
+
(
in_h
+
2
*
pad
[
0
]
-
(
dilation
[
0
]
*
(
f_h
-
1
)
+
1
))
//
stride
[
0
]
out_w
=
1
+
(
in_w
+
2
*
pad
[
1
]
-
(
dilation
[
1
]
*
(
f_w
-
1
)
+
1
))
//
stride
[
1
]
assert
out_h
==
in_h
assert
out_w
==
in_w
col_buffer
=
np
.
zeros
((
in_n
,
in_c
*
f_h
*
f_w
,
in_h
*
in_w
))
for
n
in
range
(
in_n
):
for
c
in
range
(
in_c
):
for
h
in
range
(
out_h
):
for
w
in
range
(
out_w
):
for
kh
in
range
(
f_h
):
for
kw
in
range
(
f_w
):
offset_h_table
=
\
offset
[
n
,
::
2
,
h
,
w
].
reshape
(
f_h
,
f_w
)
offset_w_table
=
\
offset
[
n
,
1
::
2
,
h
,
w
].
reshape
(
f_h
,
f_w
)
mask_table
=
\
mask
[
n
,
:,
h
,
w
].
reshape
(
f_h
,
f_w
)
offset_h
=
offset_h_table
[
kh
,
kw
]
offset_w
=
offset_w_table
[
kh
,
kw
]
val
=
0
im_h
=
h
*
stride
[
0
]
+
kh
*
dilation
[
0
]
\
+
offset_h
-
pad
[
0
]
im_w
=
w
*
stride
[
0
]
+
kw
*
dilation
[
0
]
\
+
offset_w
-
pad
[
1
]
if
im_h
>
-
1
and
im_w
>
-
1
and
\
im_h
<
in_h
and
im_w
<
in_h
:
val
=
dmc_bilinear
(
input
[
n
,
c
],
in_h
,
in_w
,
im_h
,
im_w
)
val_out
=
val
*
mask_table
[
kh
,
kw
]
col_buffer
[
n
,
c
*
f_h
*
f_w
+
kh
*
f_w
+
kw
,
h
*
in_w
+
w
]
=
val_out
out
=
np
.
zeros
((
in_n
,
group
,
int
(
out_c
//
group
),
out_h
*
out_w
))
weight
=
filter
.
reshape
(
group
,
int
(
out_c
//
group
),
f_c
*
f_h
*
f_w
)
col_buffer
=
col_buffer
.
reshape
(
(
in_n
,
group
,
int
(
in_c
//
group
*
f_h
*
f_w
),
in_h
*
in_w
))
for
n
in
range
(
in_n
):
for
g
in
range
(
group
):
out
[
n
,
g
]
=
np
.
matmul
(
weight
[
g
],
col_buffer
[
n
,
g
])
out
=
out
.
reshape
(
in_n
,
out_c
,
out_h
,
out_w
)
return
out
class
TestModulatedDeformableConvOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"deformable_conv"
self
.
dtype
=
np
.
float32
self
.
init_group
()
self
.
init_dilation
()
self
.
init_test_case
()
conv_param
=
{
'stride'
:
self
.
stride
,
'pad'
:
self
.
pad
,
'dilation'
:
self
.
dilations
}
input
=
np
.
random
.
random
(
self
.
input_size
).
astype
(
self
.
dtype
)
offset
=
10
*
np
.
random
.
random
(
self
.
offset_size
).
astype
(
self
.
dtype
)
mask
=
10
*
np
.
random
.
random
(
self
.
mask_size
).
astype
(
self
.
dtype
)
filter
=
np
.
random
.
random
(
self
.
filter_size
).
astype
(
self
.
dtype
)
output
=
dconv_im2col_gemm
(
input
,
offset
,
mask
,
filter
,
self
.
groups
,
conv_param
)
output
=
output
.
astype
(
self
.
dtype
)
self
.
inputs
=
{
'Input'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
),
'Offset'
:
OpTest
.
np_dtype_to_fluid_dtype
(
offset
),
'Mask'
:
OpTest
.
np_dtype_to_fluid_dtype
(
mask
),
'Filter'
:
OpTest
.
np_dtype_to_fluid_dtype
(
filter
)
}
self
.
attrs
=
{
'strides'
:
self
.
stride
,
'paddings'
:
self
.
pad
,
'groups'
:
self
.
groups
,
'deformable_groups'
:
self
.
deformable_groups
,
'im2col_step'
:
self
.
im2col_step
,
'dilations'
:
self
.
dilations
,
}
self
.
outputs
=
{
'Output'
:
output
}
def
has_cuda
(
self
):
return
core
.
is_compiled_with_cuda
()
def
test_check_output
(
self
):
if
self
.
has_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
,
atol
=
1e-5
)
def
test_check_grad
(
self
):
if
self
.
has_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
{
'Input'
,
'Offset'
,
'Mask'
,
'Filter'
},
'Output'
,
max_relative_error
=
0.05
)
def
test_check_grad_no_filter
(
self
):
if
self
.
has_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'Input'
,
'Offset'
,
'Mask'
],
'Output'
,
max_relative_error
=
0.1
,
no_grad_set
=
set
([
'Filter'
]))
def
test_check_grad_no_input
(
self
):
if
self
.
has_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'Filter'
,
'Offset'
,
'Mask'
],
'Output'
,
max_relative_error
=
0.1
,
no_grad_set
=
set
([
'Input'
]))
def
test_check_grad_no_offset_no_mask
(
self
):
if
self
.
has_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'Input'
,
'Filter'
],
'Output'
,
max_relative_error
=
0.1
,
no_grad_set
=
set
([
'Offset'
,
'Mask'
]))
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
]
self
.
stride
=
[
1
,
1
]
self
.
dilations
=
[
1
,
1
]
self
.
input_size
=
[
2
,
4
,
4
,
4
]
# NCHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
//
self
.
groups
self
.
filter_size
=
[
4
,
f_c
,
3
,
3
]
self
.
im2col_step
=
1
self
.
deformable_groups
=
1
offset_c
=
2
*
self
.
deformable_groups
*
self
.
filter_size
[
2
]
*
self
.
filter_size
[
3
]
mask_c
=
self
.
deformable_groups
*
self
.
filter_size
[
2
]
*
self
.
filter_size
[
3
]
self
.
offset_size
=
[
self
.
input_size
[
0
],
offset_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
self
.
mask_size
=
[
self
.
input_size
[
0
],
mask_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
def
init_dilation
(
self
):
self
.
dilations
=
[
1
,
1
]
def
init_group
(
self
):
self
.
groups
=
1
class
TestWithStride
(
TestModulatedDeformableConvOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
3
,
3
]
self
.
stride
=
[
2
,
2
]
self
.
input_size
=
[
2
,
3
,
5
,
5
]
# NCHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
//
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
3
,
3
]
self
.
im2col_step
=
1
self
.
deformable_groups
=
1
offset_c
=
2
*
self
.
deformable_groups
*
self
.
filter_size
[
2
]
*
self
.
filter_size
[
3
]
mask_c
=
self
.
deformable_groups
*
self
.
filter_size
[
2
]
*
self
.
filter_size
[
3
]
self
.
offset_size
=
[
self
.
input_size
[
0
],
offset_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
self
.
mask_size
=
[
self
.
input_size
[
0
],
mask_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
class
TestWithDilation
(
TestModulatedDeformableConvOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
2
,
2
]
self
.
stride
=
[
1
,
1
]
self
.
input_size
=
[
2
,
3
,
4
,
4
]
# NCHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
//
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
3
,
3
]
self
.
im2col_step
=
1
self
.
deformable_groups
=
1
offset_c
=
2
*
self
.
deformable_groups
*
self
.
filter_size
[
2
]
*
self
.
filter_size
[
3
]
mask_c
=
self
.
deformable_groups
*
self
.
filter_size
[
2
]
*
self
.
filter_size
[
3
]
self
.
offset_size
=
[
self
.
input_size
[
0
],
offset_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
self
.
mask_size
=
[
self
.
input_size
[
0
],
mask_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
def
init_dilation
(
self
):
self
.
dilations
=
[
2
,
2
]
class
TestWith1x1
(
TestModulatedDeformableConvOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
0
,
0
]
self
.
stride
=
[
1
,
1
]
self
.
input_size
=
[
2
,
3
,
5
,
5
]
# NCHW
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
//
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
1
,
1
]
self
.
im2col_step
=
1
self
.
deformable_groups
=
1
offset_c
=
2
*
self
.
deformable_groups
*
self
.
filter_size
[
2
]
*
self
.
filter_size
[
3
]
mask_c
=
self
.
deformable_groups
*
self
.
filter_size
[
2
]
*
self
.
filter_size
[
3
]
self
.
offset_size
=
[
self
.
input_size
[
0
],
offset_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
self
.
mask_size
=
[
self
.
input_size
[
0
],
mask_c
,
self
.
input_size
[
2
],
self
.
input_size
[
3
]
]
class
TestWithGroup
(
TestModulatedDeformableConvOp
):
def
init_group
(
self
):
self
.
groups
=
2
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
bba57cdd
...
...
@@ -1957,6 +1957,34 @@ class TestBook(LayerTest):
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
def
test_deformable_conv
(
self
):
if
core
.
is_compiled_with_cuda
():
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
input
=
layers
.
data
(
name
=
'input'
,
append_batch_size
=
False
,
shape
=
[
2
,
3
,
32
,
32
],
dtype
=
"float32"
)
offset
=
layers
.
data
(
name
=
'offset'
,
append_batch_size
=
False
,
shape
=
[
2
,
18
,
32
,
32
],
dtype
=
"float32"
)
mask
=
layers
.
data
(
name
=
'mask'
,
append_batch_size
=
False
,
shape
=
[
2
,
9
,
32
,
32
],
dtype
=
"float32"
)
out
=
layers
.
deformable_conv
(
input
=
input
,
offset
=
offset
,
mask
=
mask
,
num_filters
=
2
,
filter_size
=
3
,
padding
=
1
)
return
(
out
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录