Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
27337af0
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
27337af0
编写于
9月 24, 2020
作者:
Z
zhangwen31
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[host][kernel]feat: add deformable_conv v2 host kernel
上级
80452148
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
808 addition
and
0 deletion
+808
-0
lite/kernels/host/CMakeLists.txt
lite/kernels/host/CMakeLists.txt
+1
-0
lite/kernels/host/deformable_conv_compute.cc
lite/kernels/host/deformable_conv_compute.cc
+187
-0
lite/kernels/host/deformable_conv_compute.h
lite/kernels/host/deformable_conv_compute.h
+33
-0
lite/kernels/host/deformable_conv_op.h
lite/kernels/host/deformable_conv_op.h
+587
-0
未找到文件。
lite/kernels/host/CMakeLists.txt
浏览文件 @
27337af0
...
...
@@ -10,6 +10,7 @@ add_kernel(expand_compute_host Host basic SRCS expand_compute.cc DEPS ${lite_ker
add_kernel
(
expand_as_compute_host Host basic SRCS expand_as_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
fill_constant_compute_host Host basic SRCS fill_constant_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
fill_constant_batch_size_like_compute_host Host basic SRCS fill_constant_batch_size_like_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
deformable_conv_compute_host Host basic SRCS deformable_conv_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
shape_compute_host Host extra SRCS shape_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
is_empty_compute_host Host extra SRCS is_empty_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS
${
lite_kernel_deps
}
)
...
...
lite/kernels/host/deformable_conv_compute.cc
0 → 100644
浏览文件 @
27337af0
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/deformable_conv_compute.h"
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/host/deformable_conv_op.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
host
{
// todo: use blas if necessary
/**
* naive row majored mat mul
*/
template
<
class
T
>
void
MatMul
(
const
Tensor
&
mat_a
,
const
Tensor
&
mat_b
,
T
alpha
,
Tensor
*
mat_out
,
T
beta
)
{
auto
dim_a
=
mat_a
.
dims
();
auto
dim_b
=
mat_b
.
dims
();
auto
dim_out
=
mat_out
->
dims
();
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
dim_a
[
1
];
auto
*
pA
=
mat_a
.
data
<
T
>
();
auto
*
pB
=
mat_b
.
data
<
T
>
();
auto
*
pC
=
mat_out
->
mutable_data
<
T
>
();
for
(
int
i
=
0
;
i
<
M
;
++
i
)
{
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
T
sum
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
sum
+=
pA
[
i
*
K
+
k
]
*
pB
[
k
*
N
+
j
];
}
pC
[
i
*
N
+
j
]
=
sum
*
alpha
+
beta
;
}
}
}
/**
* @note this function is modified from paddle fluid
* paddle commit id: f4c750d721a1226738bea382f6c0cf725cca8481
*
* check "paddle/fluid/operators/deformable_conv_op.h"
* if necessary
*/
template
<
>
void
DeformableConvComputeHost
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>::
Run
()
{
const
auto
&
param
=
this
->
Param
<
operators
::
DeformableConvParam
>
();
// this implementation only support v2
// to support v1, you could follow
// "paddle/fluid/operators/deformable_conv_v1_op.h"
const
auto
*
input
=
param
.
x
;
const
auto
*
offset
=
param
.
offset
;
const
auto
*
mask
=
param
.
mask
;
const
auto
&
filter
=
*
param
.
conv_param
.
filter
;
auto
*
output
=
param
.
output
;
const
int
groups
=
param
.
conv_param
.
groups
;
const
int
deformable_groups
=
param
.
deformable_groups
;
const
int
im2col_step
=
param
.
im2col_step
;
const
std
::
vector
<
int
>&
strides
=
param
.
conv_param
.
strides
;
const
std
::
vector
<
int
>&
paddings
=
*
param
.
conv_param
.
paddings
;
const
std
::
vector
<
int
>&
dilations
=
*
param
.
conv_param
.
dilations
;
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
std
::
vector
<
int64_t
>
filter_shape_vec
(
filter
.
dims
().
Vectorize
());
std
::
vector
<
int64_t
>
output_shape_vec
(
output
->
dims
().
Vectorize
());
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std
::
vector
<
int64_t
>
col_buffer_shape_vec
(
filter_shape_vec
.
size
());
col_buffer_shape_vec
[
0
]
=
input
->
dims
()[
1
]
*
filter
.
dims
()[
2
]
*
filter
.
dims
()[
3
];
col_buffer_shape_vec
[
1
]
=
im2col_step
;
for
(
size_t
j
=
0
;
j
<
filter_shape_vec
.
size
()
-
2
;
++
j
)
{
col_buffer_shape_vec
[
j
+
2
]
=
output_shape_vec
[
j
+
2
];
}
DDim
col_shape
(
col_buffer_shape_vec
);
std
::
vector
<
int64_t
>
output_buffer_shape_vec
(
1
);
output_buffer_shape_vec
[
0
]
=
batch_size
*
output_shape_vec
[
1
]
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
DDim
output_shape
(
output_buffer_shape_vec
);
Tensor
col_buffer
;
Tensor
output_buffer
;
col_buffer
.
Resize
(
col_shape
);
col_buffer
.
mutable_data
<
float
>
();
output_buffer
.
Resize
(
output_shape
);
output_buffer
.
mutable_data
<
float
>
();
int64_t
M
=
output_shape_vec
[
1
]
/
groups
;
int64_t
N
=
im2col_step
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
int64_t
K
=
input
->
dims
()[
1
]
*
filter_shape_vec
[
2
]
*
filter_shape_vec
[
3
]
/
groups
;
Tensor
weight_3d
;
weight_3d
.
ShareDataWith
(
filter
);
weight_3d
.
Resize
(
DDim
({
groups
,
M
,
K
}));
Tensor
col_buffer_3d
;
col_buffer_3d
.
ShareDataWith
(
col_buffer
);
col_buffer_3d
.
Resize
(
DDim
({
groups
,
K
,
N
}));
Tensor
output_4d
;
output_4d
.
ShareDataWith
(
output_buffer
);
output_4d
.
Resize
(
DDim
({
batch_size
/
im2col_step
,
groups
,
M
,
N
}));
output_4d
.
mutable_data
<
float
>
();
DDim
input_shape
=
input
->
dims
().
Slice
(
1
,
input
->
dims
().
size
());
std
::
vector
<
int64_t
>
input_shape_vec
=
input_shape
.
Vectorize
();
int
input_dim
=
input
->
numel
()
/
input
->
dims
()[
0
];
int
input_offset_dim
=
offset
->
numel
()
/
offset
->
dims
()[
0
];
int
input_mask_dim
=
mask
->
numel
()
/
mask
->
dims
()[
0
];
const
float
*
input_ptr
=
input
->
data
<
float
>
();
const
float
*
offset_ptr
=
offset
->
data
<
float
>
();
const
float
*
mask_ptr
=
mask
->
data
<
float
>
();
col_buffer
.
mutable_data
<
float
>
();
float
*
col_buffer_ptr
=
col_buffer
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
batch_size
/
im2col_step
;
++
i
)
{
ModulatedDeformableIm2colCPU
<
float
>
(
input_ptr
+
i
*
im2col_step
*
input_dim
,
offset_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_ptr
+
i
*
im2col_step
*
input_mask_dim
,
input_shape_vec
,
col_buffer_shape_vec
,
filter_shape_vec
,
paddings
,
strides
,
dilations
,
deformable_groups
,
col_buffer_ptr
);
Tensor
output_3d
=
output_4d
.
Slice
<
float
>
(
i
,
i
+
1
);
output_3d
.
Resize
(
DDim
(
output_4d
.
dims
()).
Slice
(
1
,
output_4d
.
dims
().
size
()));
// get the product of pixel and weight
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
Tensor
weight_3d_slice
=
weight_3d
.
Slice
<
float
>
(
g
,
g
+
1
);
weight_3d_slice
.
Resize
(
DDim
(
weight_3d
.
dims
()).
Slice
(
1
,
weight_3d
.
dims
().
size
()));
Tensor
col_buffer_3d_slice
=
col_buffer_3d
.
Slice
<
float
>
(
g
,
g
+
1
);
col_buffer_3d_slice
.
Resize
(
DDim
(
col_buffer_3d
.
dims
()).
Slice
(
1
,
col_buffer_3d
.
dims
().
size
()));
Tensor
output_3d_slice
=
output_3d
.
Slice
<
float
>
(
g
,
g
+
1
);
output_3d_slice
.
Resize
(
DDim
(
output_3d
.
dims
()).
Slice
(
1
,
output_3d
.
dims
().
size
()));
MatMul
<
float
>
(
weight_3d_slice
,
col_buffer_3d_slice
,
1.0
f
,
&
output_3d_slice
,
0.0
f
);
}
}
output
->
ShareDataWith
(
output_buffer
);
output
->
Resize
(
DDim
(
output_shape_vec
));
}
}
// namespace host
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
using
DeformableConvFp32Host
=
paddle
::
lite
::
kernels
::
host
::
DeformableConvComputeHost
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
REGISTER_LITE_KERNEL
(
deformable_conv
,
kHost
,
kFloat
,
kNCHW
,
DeformableConvFp32Host
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
BindInput
(
"Mask"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
BindInput
(
"Offset"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
Finalize
();
lite/kernels/host/deformable_conv_compute.h
0 → 100644
浏览文件 @
27337af0
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
host
{
template
<
PrecisionType
Ptype
,
PrecisionType
OutType
>
class
DeformableConvComputeHost
:
public
KernelLite
<
TARGET
(
kHost
),
Ptype
>
{
public:
void
Run
()
override
;
~
DeformableConvComputeHost
()
=
default
;
};
}
// namespace host
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/host/deformable_conv_op.h
0 → 100644
浏览文件 @
27337af0
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
//
// Copyright (c) 2017 Microsoft
// Licensed under The Apache-2.0 License [see LICENSE for details]
// \file deformable_psroi_pooling.cu
// \brief
// \author Yi Li, Guodong Zhang, Jifeng Dai
/**
* @note: all code in this file are copied from paddle fluid
* paddle commit id: f4c750d721a1226738bea382f6c0cf725cca8481
*
* check "paddle/fluid/operators/deformable_conv_op.h"
* and "paddle/fluid/operators/deformable_conv_func.h"
* if necessary
*/
#pragma once
#include <math.h>
#include <algorithm>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
host
{
template
<
typename
T
>
HOSTDEVICE
T
DmcnGetGradientWeight
(
T
argmax_h
,
T
argmax_w
,
const
int
h
,
const
int
w
,
const
int
height
,
const
int
width
)
{
if
(
argmax_h
<=
-
1
||
argmax_h
>=
height
||
argmax_w
<=
-
1
||
argmax_w
>=
width
)
{
return
0
;
}
int
argmax_h_low
=
floor
(
argmax_h
);
int
argmax_w_low
=
floor
(
argmax_w
);
int
argmax_h_high
=
argmax_h_low
+
1
;
int
argmax_w_high
=
argmax_w_low
+
1
;
T
weight
=
0
;
weight
=
(
h
==
argmax_h_low
&&
w
==
argmax_w_low
)
?
(
h
+
1
-
argmax_h
)
*
(
w
+
1
-
argmax_w
)
:
weight
;
weight
=
(
h
==
argmax_h_low
&&
w
==
argmax_w_high
)
?
(
h
+
1
-
argmax_h
)
*
(
argmax_w
+
1
-
w
)
:
weight
;
weight
=
(
h
==
argmax_h_high
&&
w
==
argmax_w_low
)
?
(
argmax_h
+
1
-
h
)
*
(
w
+
1
-
argmax_w
)
:
weight
;
weight
=
(
h
==
argmax_h_high
&&
w
==
argmax_w_high
)
?
(
argmax_h
+
1
-
h
)
*
(
argmax_w
+
1
-
w
)
:
weight
;
return
weight
;
}
template
<
typename
T
>
HOSTDEVICE
T
DmcnGetCoordinateWeight
(
T
argmax_h
,
T
argmax_w
,
const
int
height
,
const
int
width
,
const
T
*
im_data
,
const
int
data_width
,
const
int
bp_dir
)
{
if
(
argmax_h
<=
-
1
||
argmax_h
>=
height
||
argmax_w
<=
-
1
||
argmax_w
>=
width
)
{
return
0
;
}
int
argmax_h_low
=
floor
(
argmax_h
);
int
argmax_w_low
=
floor
(
argmax_w
);
int
argmax_h_high
=
argmax_h_low
+
1
;
int
argmax_w_high
=
argmax_w_low
+
1
;
T
weight
=
0
;
if
(
bp_dir
==
0
)
{
weight
+=
(
argmax_h_low
>=
0
&&
argmax_w_low
>=
0
)
?
-
1
*
(
argmax_w_low
+
1
-
argmax_w
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_low
]
:
0
;
weight
+=
(
argmax_h_low
>=
0
&&
argmax_w_high
<=
width
-
1
)
?
-
1
*
(
argmax_w
-
argmax_w_low
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_high
]
:
0
;
weight
+=
(
argmax_h_high
<=
height
-
1
&&
argmax_w_low
>=
0
)
?
(
argmax_w_low
+
1
-
argmax_w
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_low
]
:
0
;
weight
+=
(
argmax_h_high
<=
height
-
1
&&
argmax_w_high
<=
width
-
1
)
?
(
argmax_w
-
argmax_w_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_high
]
:
0
;
}
else
if
(
bp_dir
==
1
)
{
weight
+=
(
argmax_h_low
>=
0
&&
argmax_w_low
>=
0
)
?
-
1
*
(
argmax_h_low
+
1
-
argmax_h
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_low
]
:
0
;
weight
+=
(
argmax_h_low
>=
0
&&
argmax_w_high
<=
width
-
1
)
?
(
argmax_h_low
+
1
-
argmax_h
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_high
]
:
0
;
weight
+=
(
argmax_h_high
<=
height
-
1
&&
argmax_w_low
>=
0
)
?
-
1
*
(
argmax_h
-
argmax_h_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_low
]
:
0
;
weight
+=
(
argmax_h_high
<=
height
-
1
&&
argmax_w_high
<=
width
-
1
)
?
(
argmax_h
-
argmax_h_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_high
]
:
0
;
}
return
weight
;
}
template
<
typename
T
>
HOSTDEVICE
T
DmcnIm2colBilinear
(
const
T
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
T
h
,
T
w
)
{
int
h_low
=
floor
(
h
);
int
w_low
=
floor
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
T
lh
=
h
-
h_low
;
T
lw
=
w
-
w_low
;
T
hh
=
1
-
lh
;
T
hw
=
1
-
lw
;
T
v1
=
(
h_low
>=
0
&&
w_low
>=
0
)
?
bottom_data
[
h_low
*
data_width
+
w_low
]
:
0
;
T
v2
=
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
?
bottom_data
[
h_low
*
data_width
+
w_high
]
:
0
;
T
v3
=
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
?
bottom_data
[
h_high
*
data_width
+
w_low
]
:
0
;
T
v4
=
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
?
bottom_data
[
h_high
*
data_width
+
w_high
]
:
0
;
T
w1
=
hh
*
hw
;
T
w2
=
hh
*
lw
;
T
w3
=
lh
*
hw
;
T
w4
=
lh
*
lw
;
return
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
;
}
template
<
typename
T
>
void
ModulatedDeformableCol2imCPUKernel
(
const
int
num_kernels
,
const
T
*
data_col
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
grad_im
)
{
for
(
int
thread
=
0
;
thread
<
num_kernels
;
thread
++
)
{
const
int
j
=
(
thread
/
width_col
/
height_col
/
batch_size
)
%
kernel_w
;
const
int
i
=
(
thread
/
width_col
/
height_col
/
batch_size
/
kernel_w
)
%
kernel_h
;
const
int
c
=
thread
/
width_col
/
height_col
/
batch_size
/
kernel_w
/
kernel_h
;
const
int
deformable_group_index
=
c
/
channel_per_deformable_group
;
int
w_out
=
thread
%
width_col
;
int
h_out
=
(
thread
/
width_col
)
%
height_col
;
int
b
=
(
thread
/
width_col
/
height_col
)
%
batch_size
;
int
w_in
=
w_out
*
stride_w
-
pad_w
;
int
h_in
=
h_out
*
stride_h
-
pad_h
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
const
T
cur_inv_h_data
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
T
cur_inv_w_data
=
w_in
+
j
*
dilation_w
+
offset_w
;
const
T
cur_top_grad
=
data_col
[
thread
]
*
mask
;
const
int
cur_h
=
static_cast
<
int
>
(
cur_inv_h_data
);
const
int
cur_w
=
static_cast
<
int
>
(
cur_inv_w_data
);
for
(
int
dy
=
-
2
;
dy
<=
2
;
dy
++
)
{
for
(
int
dx
=
-
2
;
dx
<=
2
;
dx
++
)
{
if
(
cur_h
+
dy
>=
0
&&
cur_h
+
dy
<
height
&&
cur_w
+
dx
>=
0
&&
cur_w
+
dx
<
width
&&
abs
(
cur_inv_h_data
-
(
cur_h
+
dy
))
<
1
&&
abs
(
cur_inv_w_data
-
(
cur_w
+
dx
))
<
1
)
{
int
cur_bottom_grad_pos
=
((
b
*
channels
+
c
)
*
height
+
cur_h
+
dy
)
*
width
+
cur_w
+
dx
;
T
weight
=
DmcnGetGradientWeight
(
cur_inv_h_data
,
cur_inv_w_data
,
cur_h
+
dy
,
cur_w
+
dx
,
height
,
width
);
*
(
grad_im
+
cur_bottom_grad_pos
)
=
*
(
grad_im
+
cur_bottom_grad_pos
)
+
weight
*
cur_top_grad
;
}
}
}
}
}
template
<
typename
T
>
static
inline
void
ModulatedDeformableCol2imCPU
(
const
T
*
data_col
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>
im_shape
,
const
std
::
vector
<
int64_t
>
col_shape
,
const
std
::
vector
<
int64_t
>
kernel_shape
,
const
std
::
vector
<
int
>
pad
,
const
std
::
vector
<
int
>
stride
,
const
std
::
vector
<
int
>
dilation
,
const
int
deformable_group
,
T
*
grad_im
)
{
int
channel_per_deformable_group
=
im_shape
[
0
]
/
deformable_group
;
int
num_kernels
=
col_shape
[
0
]
*
col_shape
[
1
]
*
col_shape
[
2
]
*
col_shape
[
3
];
ModulatedDeformableCol2imCPUKernel
(
num_kernels
,
data_col
,
data_offset
,
data_mask
,
im_shape
[
0
],
im_shape
[
1
],
im_shape
[
2
],
kernel_shape
[
2
],
kernel_shape
[
3
],
pad
[
0
],
pad
[
1
],
stride
[
0
],
stride
[
1
],
dilation
[
0
],
dilation
[
1
],
channel_per_deformable_group
,
col_shape
[
1
],
deformable_group
,
col_shape
[
2
],
col_shape
[
3
],
grad_im
);
}
template
<
typename
T
>
void
ModulatedDeformableCol2imCoordCPUKernel
(
const
int
num_kernels
,
const
T
*
data_col
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
offset_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
grad_offset
,
T
*
grad_mask
)
{
for
(
int
i
=
0
;
i
<
num_kernels
;
i
++
)
{
T
val
=
0
,
mval
=
0
;
const
int
w
=
i
%
width_col
;
const
int
h
=
(
i
/
width_col
)
%
height_col
;
const
int
c
=
(
i
/
width_col
/
height_col
)
%
offset_channels
;
const
int
b
=
(
i
/
width_col
/
height_col
)
/
offset_channels
;
const
int
deformable_group_index
=
c
/
(
2
*
kernel_h
*
kernel_w
);
const
int
col_step
=
kernel_h
*
kernel_w
;
int
cnt
=
0
;
const
T
*
data_col_ptr
=
data_col
+
deformable_group_index
*
channel_per_deformable_group
*
batch_size
*
width_col
*
height_col
;
const
T
*
data_im_ptr
=
data_im
+
(
b
*
deformable_group
+
deformable_group_index
)
*
channel_per_deformable_group
/
kernel_h
/
kernel_w
*
height
*
width
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
int
offset_c
=
c
-
deformable_group_index
*
2
*
kernel_h
*
kernel_w
;
for
(
int
col_c
=
offset_c
/
2
;
col_c
<
channel_per_deformable_group
;
col_c
+=
col_step
)
{
const
int
col_pos
=
(((
col_c
*
batch_size
+
b
)
*
height_col
)
+
h
)
*
width_col
+
w
;
const
int
bp_dir
=
offset_c
%
2
;
int
j
=
(
col_pos
/
width_col
/
height_col
/
batch_size
)
%
kernel_w
;
int
i
=
(
col_pos
/
width_col
/
height_col
/
batch_size
/
kernel_w
)
%
kernel_h
;
int
w_out
=
col_pos
%
width_col
;
int
h_out
=
(
col_pos
/
width_col
)
%
height_col
;
int
w_in
=
w_out
*
stride_w
-
pad_w
;
int
h_in
=
h_out
*
stride_h
-
pad_h
;
const
int
data_offset_h_ptr
=
(((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
int
data_offset_w_ptr
=
(((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
int
data_mask_hw_ptr
=
(((
i
*
kernel_w
+
j
)
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
T
inv_h
=
h_in
+
i
*
dilation_h
+
offset_h
;
T
inv_w
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
inv_h
<=
-
1
||
inv_w
<=
-
1
||
inv_h
>=
height
||
inv_w
>=
width
)
{
inv_h
=
inv_w
=
-
2
;
}
else
{
mval
+=
data_col_ptr
[
col_pos
]
*
DmcnIm2colBilinear
(
data_im_ptr
+
cnt
*
height
*
width
,
width
,
height
,
width
,
inv_h
,
inv_w
);
}
const
T
weight
=
DmcnGetCoordinateWeight
(
inv_h
,
inv_w
,
height
,
width
,
data_im_ptr
+
cnt
*
height
*
width
,
width
,
bp_dir
);
val
+=
weight
*
data_col_ptr
[
col_pos
]
*
mask
;
cnt
+=
1
;
}
grad_offset
[
i
]
=
val
;
if
(
offset_c
%
2
==
0
)
grad_mask
[(((
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
+
offset_c
/
2
)
*
height_col
+
h
)
*
width_col
+
w
]
=
mval
;
}
}
template
<
typename
T
>
static
inline
void
ModulatedDeformableCol2imCoordCPU
(
const
T
*
data_col
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>
im_shape
,
const
std
::
vector
<
int64_t
>
col_shape
,
const
std
::
vector
<
int64_t
>
kernel_shape
,
const
std
::
vector
<
int
>
paddings
,
const
std
::
vector
<
int
>
strides
,
const
std
::
vector
<
int
>
dilations
,
const
int
deformable_groups
,
T
*
grad_offset
,
T
*
grad_mask
)
{
int
num_kernels
=
2
*
kernel_shape
[
2
]
*
kernel_shape
[
3
]
*
col_shape
[
1
]
*
col_shape
[
2
]
*
col_shape
[
3
]
*
deformable_groups
;
int
channel_per_deformable_group
=
col_shape
[
0
]
/
deformable_groups
;
ModulatedDeformableCol2imCoordCPUKernel
(
num_kernels
,
data_col
,
data_im
,
data_offset
,
data_mask
,
im_shape
[
0
],
im_shape
[
1
],
im_shape
[
2
],
kernel_shape
[
2
],
kernel_shape
[
3
],
paddings
[
0
],
paddings
[
1
],
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
channel_per_deformable_group
,
col_shape
[
1
],
2
*
kernel_shape
[
2
]
*
kernel_shape
[
3
]
*
deformable_groups
,
deformable_groups
,
col_shape
[
2
],
col_shape
[
3
],
grad_offset
,
grad_mask
);
}
template
<
typename
T
>
void
ModulatedDeformableIm2colCPUKernel
(
const
int
num_kernels
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
data_col
)
{
for
(
int
i
=
0
;
i
<
num_kernels
;
i
++
)
{
const
int
w_col
=
i
%
width_col
;
const
int
h_col
=
(
i
/
width_col
)
%
height_col
;
const
int
b_col
=
(
i
/
width_col
)
/
height_col
%
batch_size
;
const
int
c_im
=
(
i
/
width_col
/
height_col
)
/
batch_size
;
const
int
c_col
=
c_im
*
kernel_h
*
kernel_w
;
const
int
deformable_group_index
=
c_im
/
channel_per_deformable_group
;
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
T
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
for
(
int
i
=
0
;
i
<
kernel_h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_w
;
++
j
)
{
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
T
val
=
static_cast
<
T
>
(
0
);
const
T
h_im
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
T
w_im
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height
&&
w_im
<
width
)
{
val
=
DmcnIm2colBilinear
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
*
mask
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
}
}
}
}
template
<
typename
T
>
static
inline
void
ModulatedDeformableIm2colCPU
(
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>
im_shape
,
const
std
::
vector
<
int64_t
>
col_shape
,
const
std
::
vector
<
int64_t
>
filter_shape
,
const
std
::
vector
<
int
>
paddings
,
const
std
::
vector
<
int
>
strides
,
const
std
::
vector
<
int
>
dilations
,
const
int
deformable_groups
,
T
*
data_col
)
{
int
channel_per_deformable_group
=
im_shape
[
0
]
/
deformable_groups
;
int
num_kernels
=
im_shape
[
0
]
*
col_shape
[
1
]
*
col_shape
[
2
]
*
col_shape
[
3
];
// get outputs of im2col with offset by bilinear interpolation
ModulatedDeformableIm2colCPUKernel
(
num_kernels
,
data_im
,
data_offset
,
data_mask
,
im_shape
[
1
],
im_shape
[
2
],
filter_shape
[
2
],
filter_shape
[
3
],
paddings
[
0
],
paddings
[
1
],
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
channel_per_deformable_group
,
col_shape
[
1
],
im_shape
[
0
],
deformable_groups
,
col_shape
[
2
],
col_shape
[
3
],
data_col
);
}
template
<
typename
T
>
void
FilterGradAddupCPUKernel
(
const
int
nthreads
,
const
int
n
,
const
int
height
,
const
int
width
,
const
T
*
dweight_3d
,
T
*
filter_grad
)
{
for
(
int
i
=
0
;
i
<
nthreads
;
i
++
)
{
filter_grad
[
i
]
=
filter_grad
[
i
]
+
dweight_3d
[
i
];
}
}
}
// namespace host
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录