Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a8e5c9be
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a8e5c9be
编写于
3月 19, 2022
作者:
Z
zyfncg
提交者:
GitHub
3月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move deformable_conv forward kernel to phi (#40700)
上级
c46f2ddb
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
549 addition
and
205 deletion
+549
-205
paddle/fluid/operators/deformable_conv_op.cc
paddle/fluid/operators/deformable_conv_op.cc
+0
-2
paddle/fluid/operators/deformable_conv_op.cu
paddle/fluid/operators/deformable_conv_op.cu
+0
-105
paddle/fluid/operators/deformable_conv_op.h
paddle/fluid/operators/deformable_conv_op.h
+0
-96
paddle/phi/kernels/cpu/deformable_conv_kernel.cc
paddle/phi/kernels/cpu/deformable_conv_kernel.cc
+146
-0
paddle/phi/kernels/cumsum_kernel.h
paddle/phi/kernels/cumsum_kernel.h
+1
-1
paddle/phi/kernels/deformable_conv_kernel.h
paddle/phi/kernels/deformable_conv_kernel.h
+35
-0
paddle/phi/kernels/gpu/deformable_conv_kernel.cu
paddle/phi/kernels/gpu/deformable_conv_kernel.cu
+160
-0
paddle/phi/kernels/impl/deformable_conv_kernel_impl.h
paddle/phi/kernels/impl/deformable_conv_kernel_impl.h
+173
-0
paddle/phi/ops/compat/cumprod_sig.cc
paddle/phi/ops/compat/cumprod_sig.cc
+0
-1
paddle/phi/ops/compat/deformable_conv_sig.cc
paddle/phi/ops/compat/deformable_conv_sig.cc
+34
-0
未找到文件。
paddle/fluid/operators/deformable_conv_op.cc
浏览文件 @
a8e5c9be
...
...
@@ -338,8 +338,6 @@ REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
REGISTER_OPERATOR
(
deformable_conv_grad
,
ops
::
DeformableConvGradOp
);
REGISTER_OP_CPU_KERNEL
(
deformable_conv
,
ops
::
DeformableConvCPUKernel
<
float
>
,
ops
::
DeformableConvCPUKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
deformable_conv_grad
,
ops
::
DeformableConvGradCPUKernel
<
float
>
,
ops
::
DeformableConvGradCPUKernel
<
double
>
);
paddle/fluid/operators/deformable_conv_op.cu
浏览文件 @
a8e5c9be
...
...
@@ -446,108 +446,6 @@ __global__ void FilterGradAddupGpuKernel(const int nthreads, const int n,
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
DeformableConvCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
Tensor
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
const
Tensor
offset
=
*
ctx
.
Input
<
Tensor
>
(
"Offset"
);
const
Tensor
mask
=
*
ctx
.
Input
<
Tensor
>
(
"Mask"
);
Tensor
filter
=
*
ctx
.
Input
<
Tensor
>
(
"Filter"
);
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
const
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
const
int
deformable_groups
=
ctx
.
Attr
<
int
>
(
"deformable_groups"
);
const
int
im2col_step
=
ctx
.
Attr
<
int
>
(
"im2col_step"
);
const
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
const
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
const
std
::
vector
<
int
>
dilations
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
std
::
vector
<
int64_t
>
filter_shape_vec
(
phi
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
phi
::
vectorize
(
output
->
dims
()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std
::
vector
<
int64_t
>
col_buffer_shape_vec
(
filter_shape_vec
.
size
());
col_buffer_shape_vec
[
0
]
=
input
->
dims
()[
1
]
*
filter
.
dims
()[
2
]
*
filter
.
dims
()[
3
];
col_buffer_shape_vec
[
1
]
=
im2col_step
;
for
(
size_t
j
=
0
;
j
<
filter_shape_vec
.
size
()
-
2
;
++
j
)
{
col_buffer_shape_vec
[
j
+
2
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
phi
::
make_ddim
(
col_buffer_shape_vec
));
std
::
vector
<
int64_t
>
output_buffer_shape_vec
(
1
);
output_buffer_shape_vec
[
0
]
=
batch_size
*
output_shape_vec
[
1
]
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
framework
::
DDim
output_shape
(
phi
::
make_ddim
(
output_buffer_shape_vec
));
Tensor
col_buffer
;
Tensor
output_buffer
;
col_buffer
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
col_shape
,
dev_ctx
);
output_buffer
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
output_shape
,
dev_ctx
);
int64_t
M
=
output_shape_vec
[
1
]
/
groups
;
int64_t
N
=
im2col_step
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
int64_t
K
=
input
->
dims
()[
1
]
*
filter_shape_vec
[
2
]
*
filter_shape_vec
[
3
]
/
groups
;
Tensor
weight_3d
;
weight_3d
.
ShareDataWith
(
filter
).
Resize
(
phi
::
make_ddim
({
groups
,
M
,
K
}));
Tensor
col_buffer_3d
;
col_buffer_3d
.
ShareDataWith
(
col_buffer
)
.
Resize
(
phi
::
make_ddim
({
groups
,
K
,
N
}));
Tensor
output_4d
;
output_4d
.
ShareDataWith
(
output_buffer
)
.
Resize
(
phi
::
make_ddim
({
batch_size
/
im2col_step
,
groups
,
M
,
N
}));
output_4d
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
DDim
input_shape
=
phi
::
slice_ddim
(
input
->
dims
(),
1
,
input
->
dims
().
size
());
std
::
vector
<
int64_t
>
input_shape_vec
=
phi
::
vectorize
(
input_shape
);
int
input_dim
=
input
->
numel
()
/
input
->
dims
()[
0
];
int
input_offset_dim
=
offset
.
numel
()
/
offset
.
dims
()[
0
];
int
input_mask_dim
=
mask
.
numel
()
/
mask
.
dims
()[
0
];
auto
blas
=
phi
::
funcs
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
offset_ptr
=
offset
.
data
<
T
>
();
const
T
*
mask_ptr
=
mask
.
data
<
T
>
();
col_buffer
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
col_buffer_ptr
=
col_buffer
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
batch_size
/
im2col_step
;
++
i
)
{
ModulatedDeformableIm2col
(
ctx
.
device_context
(),
input_ptr
+
i
*
im2col_step
*
input_dim
,
offset_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_ptr
+
i
*
im2col_step
*
input_mask_dim
,
input_shape_vec
,
col_buffer_shape_vec
,
filter_shape_vec
,
paddings
,
strides
,
dilations
,
deformable_groups
,
col_buffer_ptr
);
Tensor
output_3d
=
output_4d
.
Slice
(
i
,
i
+
1
).
Resize
(
phi
::
slice_ddim
(
output_4d
.
dims
(),
1
,
output_4d
.
dims
().
size
()));
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
Tensor
weight_3d_slice
=
weight_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
weight_3d
.
dims
(),
1
,
weight_3d
.
dims
().
size
()));
Tensor
col_buffer_3d_slice
=
col_buffer_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
col_buffer_3d
.
dims
(),
1
,
col_buffer_3d
.
dims
().
size
()));
Tensor
output_3d_slice
=
output_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
output_3d
.
dims
(),
1
,
output_3d
.
dims
().
size
()));
blas
.
MatMul
(
weight_3d_slice
,
false
,
col_buffer_3d_slice
,
false
,
T
(
1.0
),
&
output_3d_slice
,
T
(
0.0
));
}
}
output
->
ShareDataWith
(
output_buffer
)
.
Resize
(
phi
::
make_ddim
(
output_shape_vec
));
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
DeformableConvGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -740,9 +638,6 @@ class DeformableConvGradCUDAKernel : public framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
deformable_conv
,
ops
::
DeformableConvCUDAKernel
<
CUDA
,
float
>
,
ops
::
DeformableConvCUDAKernel
<
CUDA
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
deformable_conv_grad
,
ops
::
DeformableConvGradCUDAKernel
<
CUDA
,
float
>
,
ops
::
DeformableConvGradCUDAKernel
<
CUDA
,
double
>
);
paddle/fluid/operators/deformable_conv_op.h
浏览文件 @
a8e5c9be
...
...
@@ -318,102 +318,6 @@ void FilterGradAddupCPUKernel(const int nthreads, const int n, const int height,
}
}
template
<
typename
T
>
class
DeformableConvCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
*
offset
=
ctx
.
Input
<
Tensor
>
(
"Offset"
);
auto
*
mask
=
ctx
.
Input
<
Tensor
>
(
"Mask"
);
Tensor
filter
=
*
ctx
.
Input
<
Tensor
>
(
"Filter"
);
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
CPUDeviceContext
>();
const
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
const
int
deformable_groups
=
ctx
.
Attr
<
int
>
(
"deformable_groups"
);
const
int
im2col_step
=
ctx
.
Attr
<
int
>
(
"im2col_step"
);
const
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
const
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
const
std
::
vector
<
int
>
dilations
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
std
::
vector
<
int64_t
>
filter_shape_vec
(
phi
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
phi
::
vectorize
(
output
->
dims
()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std
::
vector
<
int64_t
>
col_buffer_shape_vec
(
filter_shape_vec
.
size
());
col_buffer_shape_vec
[
0
]
=
input
->
dims
()[
1
]
*
filter
.
dims
()[
2
]
*
filter
.
dims
()[
3
];
col_buffer_shape_vec
[
1
]
=
im2col_step
;
for
(
size_t
j
=
0
;
j
<
filter_shape_vec
.
size
()
-
2
;
++
j
)
{
col_buffer_shape_vec
[
j
+
2
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
phi
::
make_ddim
(
col_buffer_shape_vec
));
std
::
vector
<
int64_t
>
output_buffer_shape_vec
(
1
);
output_buffer_shape_vec
[
0
]
=
batch_size
*
output_shape_vec
[
1
]
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
framework
::
DDim
output_shape
(
phi
::
make_ddim
(
output_buffer_shape_vec
));
Tensor
col_buffer
;
Tensor
output_buffer
;
col_buffer
=
ctx
.
AllocateTmpTensor
<
T
,
CPUDeviceContext
>
(
col_shape
,
dev_ctx
);
output_buffer
=
ctx
.
AllocateTmpTensor
<
T
,
CPUDeviceContext
>
(
output_shape
,
dev_ctx
);
int64_t
M
=
output_shape_vec
[
1
]
/
groups
;
int64_t
N
=
im2col_step
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
int64_t
K
=
input
->
dims
()[
1
]
*
filter_shape_vec
[
2
]
*
filter_shape_vec
[
3
]
/
groups
;
Tensor
weight_3d
;
weight_3d
.
ShareDataWith
(
filter
).
Resize
(
phi
::
make_ddim
({
groups
,
M
,
K
}));
Tensor
col_buffer_3d
;
col_buffer_3d
.
ShareDataWith
(
col_buffer
)
.
Resize
(
phi
::
make_ddim
({
groups
,
K
,
N
}));
Tensor
output_4d
;
output_4d
.
ShareDataWith
(
output_buffer
)
.
Resize
(
phi
::
make_ddim
({
batch_size
/
im2col_step
,
groups
,
M
,
N
}));
output_4d
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
DDim
input_shape
=
phi
::
slice_ddim
(
input
->
dims
(),
1
,
input
->
dims
().
size
());
std
::
vector
<
int64_t
>
input_shape_vec
=
phi
::
vectorize
(
input_shape
);
int
input_dim
=
input
->
numel
()
/
input
->
dims
()[
0
];
int
input_offset_dim
=
offset
->
numel
()
/
offset
->
dims
()[
0
];
int
input_mask_dim
=
mask
->
numel
()
/
mask
->
dims
()[
0
];
auto
blas
=
phi
::
funcs
::
GetBlas
<
CPUDeviceContext
,
T
>
(
dev_ctx
);
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
offset_ptr
=
offset
->
data
<
T
>
();
const
T
*
mask_ptr
=
mask
->
data
<
T
>
();
col_buffer
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
col_buffer_ptr
=
col_buffer
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
batch_size
/
im2col_step
;
++
i
)
{
ModulatedDeformableIm2colCPU
(
dev_ctx
,
input_ptr
+
i
*
im2col_step
*
input_dim
,
offset_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_ptr
+
i
*
im2col_step
*
input_mask_dim
,
input_shape_vec
,
col_buffer_shape_vec
,
filter_shape_vec
,
paddings
,
strides
,
dilations
,
deformable_groups
,
col_buffer_ptr
);
Tensor
output_3d
=
output_4d
.
Slice
(
i
,
i
+
1
).
Resize
(
phi
::
slice_ddim
(
output_4d
.
dims
(),
1
,
output_4d
.
dims
().
size
()));
// get the product of pixel and weight
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
Tensor
weight_3d_slice
=
weight_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
weight_3d
.
dims
(),
1
,
weight_3d
.
dims
().
size
()));
Tensor
col_buffer_3d_slice
=
col_buffer_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
col_buffer_3d
.
dims
(),
1
,
col_buffer_3d
.
dims
().
size
()));
Tensor
output_3d_slice
=
output_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
output_3d
.
dims
(),
1
,
output_3d
.
dims
().
size
()));
blas
.
MatMul
(
weight_3d_slice
,
false
,
col_buffer_3d_slice
,
false
,
T
(
1.0
),
&
output_3d_slice
,
T
(
0.0
));
}
}
output
->
ShareDataWith
(
output_buffer
)
.
Resize
(
phi
::
make_ddim
(
output_shape_vec
));
}
};
template
<
typename
T
>
class
DeformableConvGradCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
paddle/phi/kernels/cpu/deformable_conv_kernel.cc
0 → 100644
浏览文件 @
a8e5c9be
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/deformable_conv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/deformable_conv_kernel_impl.h"
namespace
phi
{
template
<
typename
T
>
inline
void
ModulatedDeformableIm2colCPUKernel
(
const
int
num_kernels
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
data_col
)
{
for
(
int
i
=
0
;
i
<
num_kernels
;
i
++
)
{
const
int
w_col
=
i
%
width_col
;
const
int
h_col
=
(
i
/
width_col
)
%
height_col
;
const
int
b_col
=
(
i
/
width_col
)
/
height_col
%
batch_size
;
const
int
c_im
=
(
i
/
width_col
/
height_col
)
/
batch_size
;
const
int
c_col
=
c_im
*
kernel_h
*
kernel_w
;
const
int
deformable_group_index
=
c_im
/
channel_per_deformable_group
;
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
T
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
for
(
int
i
=
0
;
i
<
kernel_h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_w
;
++
j
)
{
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
T
val
=
static_cast
<
T
>
(
0
);
const
T
h_im
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
T
w_im
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height
&&
w_im
<
width
)
{
val
=
DmcnIm2colBilinear
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
*
mask
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
ModulatedDeformableIm2col
(
const
Context
&
dev_ctx
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>&
im_shape
,
const
std
::
vector
<
int64_t
>&
col_shape
,
const
std
::
vector
<
int64_t
>&
filter_shape
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
,
const
int
deformable_groups
,
T
*
data_col
)
{
int
channel_per_deformable_group
=
im_shape
[
0
]
/
deformable_groups
;
int
num_kernels
=
im_shape
[
0
]
*
col_shape
[
1
]
*
col_shape
[
2
]
*
col_shape
[
3
];
// get outputs of im2col with offset by bilinear interpolation
ModulatedDeformableIm2colCPUKernel
(
num_kernels
,
data_im
,
data_offset
,
data_mask
,
im_shape
[
1
],
im_shape
[
2
],
filter_shape
[
2
],
filter_shape
[
3
],
paddings
[
0
],
paddings
[
1
],
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
channel_per_deformable_group
,
col_shape
[
1
],
im_shape
[
0
],
deformable_groups
,
col_shape
[
2
],
col_shape
[
3
],
data_col
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
deformable_conv
,
CPU
,
ALL_LAYOUT
,
phi
::
DeformableConvKernel
,
float
,
double
)
{}
paddle/phi/kernels/cumsum_kernel.h
浏览文件 @
a8e5c9be
...
...
@@ -18,7 +18,7 @@
namespace
phi
{
template
<
typename
Functor
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
CumsumKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
axis
,
...
...
paddle/phi/kernels/deformable_conv_kernel.h
0 → 100644
浏览文件 @
a8e5c9be
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DeformableConvKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
offset
,
const
DenseTensor
&
filter
,
const
DenseTensor
&
mask
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
int
deformable_groups
,
int
groups
,
int
im2col_step
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/gpu/deformable_conv_kernel.cu
0 → 100644
浏览文件 @
a8e5c9be
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/deformable_conv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/impl/deformable_conv_kernel_impl.h"
namespace
phi
{
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaximumNumBlocks
=
4096
;
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaximumNumBlocks
);
}
template
<
typename
T
>
__global__
void
ModulatedDeformableIm2colGpuKernel
(
const
int
nthreads
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
T
*
data_col
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
const
int
w_col
=
i
%
width_col
;
const
int
h_col
=
(
i
/
width_col
)
%
height_col
;
const
int
b_col
=
(
i
/
width_col
)
/
height_col
%
batch_size
;
const
int
c_im
=
(
i
/
width_col
/
height_col
)
/
batch_size
;
const
int
c_col
=
c_im
*
kernel_h
*
kernel_w
;
const
int
deformable_group_index
=
c_im
/
channel_per_deformable_group
;
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
T
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
T
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
T
*
data_mask_ptr
=
data_mask
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
for
(
int
i
=
0
;
i
<
kernel_h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_w
;
++
j
)
{
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
T
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
T
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
T
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
T
val
=
static_cast
<
T
>
(
0
);
const
T
h_im
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
T
w_im
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height
&&
w_im
<
width
)
{
val
=
DmcnIm2colBilinear
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
*
mask
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
ModulatedDeformableIm2col
(
const
Context
&
dev_ctx
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>&
im_shape
,
const
std
::
vector
<
int64_t
>&
col_shape
,
const
std
::
vector
<
int64_t
>&
filter_shape
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
,
const
int
deformable_groups
,
T
*
data_col
)
{
int
channel_per_deformable_group
=
im_shape
[
0
]
/
deformable_groups
;
int
num_kernels
=
im_shape
[
0
]
*
col_shape
[
1
]
*
col_shape
[
2
]
*
col_shape
[
3
];
int
blocks
=
NumBlocks
(
num_kernels
);
int
threads
=
kNumCUDAThreads
;
ModulatedDeformableIm2colGpuKernel
<
T
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
num_kernels
,
data_im
,
data_offset
,
data_mask
,
im_shape
[
1
],
im_shape
[
2
],
filter_shape
[
2
],
filter_shape
[
3
],
paddings
[
0
],
paddings
[
1
],
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
channel_per_deformable_group
,
col_shape
[
1
],
im_shape
[
0
],
deformable_groups
,
col_shape
[
2
],
col_shape
[
3
],
data_col
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
deformable_conv
,
GPU
,
ALL_LAYOUT
,
phi
::
DeformableConvKernel
,
float
,
double
)
{}
paddle/phi/kernels/impl/deformable_conv_kernel_impl.h
0 → 100644
浏览文件 @
a8e5c9be
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
phi
{
template
<
typename
T
>
HOSTDEVICE
T
DmcnIm2colBilinear
(
const
T
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
T
h
,
T
w
)
{
int
h_low
=
floor
(
h
);
int
w_low
=
floor
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
T
lh
=
h
-
h_low
;
T
lw
=
w
-
w_low
;
T
hh
=
1
-
lh
;
T
hw
=
1
-
lw
;
T
v1
=
(
h_low
>=
0
&&
w_low
>=
0
)
?
bottom_data
[
h_low
*
data_width
+
w_low
]
:
0
;
T
v2
=
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
?
bottom_data
[
h_low
*
data_width
+
w_high
]
:
0
;
T
v3
=
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
?
bottom_data
[
h_high
*
data_width
+
w_low
]
:
0
;
T
v4
=
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
?
bottom_data
[
h_high
*
data_width
+
w_high
]
:
0
;
T
w1
=
hh
*
hw
;
T
w2
=
hh
*
lw
;
T
w3
=
lh
*
hw
;
T
w4
=
lh
*
lw
;
return
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
;
}
template
<
typename
T
,
typename
Context
>
void
ModulatedDeformableIm2col
(
const
Context
&
dev_ctx
,
const
T
*
data_im
,
const
T
*
data_offset
,
const
T
*
data_mask
,
const
std
::
vector
<
int64_t
>&
im_shape
,
const
std
::
vector
<
int64_t
>&
col_shape
,
const
std
::
vector
<
int64_t
>&
filter_shape
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
,
const
int
deformable_groups
,
T
*
data_col
);
template
<
typename
T
,
typename
Context
>
void
DeformableConvKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
offset
,
const
DenseTensor
&
filter
,
const
DenseTensor
&
mask
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
int
deformable_groups
,
int
groups
,
int
im2col_step
,
DenseTensor
*
out
)
{
const
int
batch_size
=
static_cast
<
int
>
(
x
.
dims
()[
0
]);
std
::
vector
<
int64_t
>
filter_shape_vec
(
phi
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
phi
::
vectorize
(
out
->
dims
()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std
::
vector
<
int64_t
>
col_buffer_shape_vec
(
filter_shape_vec
.
size
());
col_buffer_shape_vec
[
0
]
=
x
.
dims
()[
1
]
*
filter
.
dims
()[
2
]
*
filter
.
dims
()[
3
];
col_buffer_shape_vec
[
1
]
=
im2col_step
;
for
(
size_t
j
=
0
;
j
<
filter_shape_vec
.
size
()
-
2
;
++
j
)
{
col_buffer_shape_vec
[
j
+
2
]
=
output_shape_vec
[
j
+
2
];
}
std
::
vector
<
int64_t
>
output_buffer_shape_vec
(
1
);
output_buffer_shape_vec
[
0
]
=
batch_size
*
output_shape_vec
[
1
]
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
DenseTensor
col_buffer
=
Empty
<
T
>
(
dev_ctx
,
col_buffer_shape_vec
);
DenseTensor
output_buffer
=
Empty
<
T
>
(
dev_ctx
,
output_buffer_shape_vec
);
int64_t
M
=
output_shape_vec
[
1
]
/
groups
;
int64_t
N
=
im2col_step
*
output_shape_vec
[
2
]
*
output_shape_vec
[
3
];
int64_t
K
=
x
.
dims
()[
1
]
*
filter_shape_vec
[
2
]
*
filter_shape_vec
[
3
]
/
groups
;
DenseTensor
weight_3d
;
weight_3d
.
ShareDataWith
(
filter
).
Resize
(
phi
::
make_ddim
({
groups
,
M
,
K
}));
DenseTensor
col_buffer_3d
;
col_buffer_3d
.
ShareDataWith
(
col_buffer
)
.
Resize
(
phi
::
make_ddim
({
groups
,
K
,
N
}));
DenseTensor
output_4d
;
output_4d
.
ShareDataWith
(
output_buffer
)
.
Resize
(
phi
::
make_ddim
({
batch_size
/
im2col_step
,
groups
,
M
,
N
}));
DDim
input_shape
=
phi
::
slice_ddim
(
x
.
dims
(),
1
,
x
.
dims
().
size
());
std
::
vector
<
int64_t
>
input_shape_vec
=
phi
::
vectorize
(
input_shape
);
int
input_dim
=
x
.
numel
()
/
x
.
dims
()[
0
];
int
input_offset_dim
=
offset
.
numel
()
/
offset
.
dims
()[
0
];
int
input_mask_dim
=
mask
.
numel
()
/
mask
.
dims
()[
0
];
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
const
T
*
input_ptr
=
x
.
data
<
T
>
();
const
T
*
offset_ptr
=
offset
.
data
<
T
>
();
const
T
*
mask_ptr
=
mask
.
data
<
T
>
();
T
*
col_buffer_ptr
=
col_buffer
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
batch_size
/
im2col_step
;
++
i
)
{
ModulatedDeformableIm2col
(
dev_ctx
,
input_ptr
+
i
*
im2col_step
*
input_dim
,
offset_ptr
+
i
*
im2col_step
*
input_offset_dim
,
mask_ptr
+
i
*
im2col_step
*
input_mask_dim
,
input_shape_vec
,
col_buffer_shape_vec
,
filter_shape_vec
,
paddings
,
strides
,
dilations
,
deformable_groups
,
col_buffer_ptr
);
DenseTensor
output_3d
=
output_4d
.
Slice
(
i
,
i
+
1
).
Resize
(
phi
::
slice_ddim
(
output_4d
.
dims
(),
1
,
output_4d
.
dims
().
size
()));
// get the product of pixel and weight
for
(
int
g
=
0
;
g
<
groups
;
++
g
)
{
DenseTensor
weight_3d_slice
=
weight_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
weight_3d
.
dims
(),
1
,
weight_3d
.
dims
().
size
()));
DenseTensor
col_buffer_3d_slice
=
col_buffer_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
col_buffer_3d
.
dims
(),
1
,
col_buffer_3d
.
dims
().
size
()));
DenseTensor
output_3d_slice
=
output_3d
.
Slice
(
g
,
g
+
1
).
Resize
(
phi
::
slice_ddim
(
output_3d
.
dims
(),
1
,
output_3d
.
dims
().
size
()));
blas
.
MatMul
(
weight_3d_slice
,
false
,
col_buffer_3d_slice
,
false
,
T
(
1.0
),
&
output_3d_slice
,
T
(
0.0
));
}
}
out
->
ShareDataWith
(
output_buffer
).
Resize
(
phi
::
make_ddim
(
output_shape_vec
));
}
}
// namespace phi
paddle/phi/ops/compat/cumprod_sig.cc
浏览文件 @
a8e5c9be
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
...
...
paddle/phi/ops/compat/deformable_conv_sig.cc
0 → 100644
浏览文件 @
a8e5c9be
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
DeformableConvOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"deformable_conv"
,
{
"Input"
,
"Offset"
,
"Filter"
,
"Mask"
},
{
"strides"
,
"paddings"
,
"dilations"
,
"deformable_groups"
,
"groups"
,
"im2col_step"
},
{
"Output"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
deformable_conv
,
phi
::
DeformableConvOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录