Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
3d6a027c
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3d6a027c
编写于
12月 20, 2022
作者:
S
shangliang Xu
提交者:
GitHub
12月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dev] add ms_deformable_attn cuda op (#7521)
上级
e1a8f660
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1459 addition
and
72 deletion
+1459
-72
ppdet/modeling/architectures/detr.py
ppdet/modeling/architectures/detr.py
+1
-1
ppdet/modeling/post_process.py
ppdet/modeling/post_process.py
+13
-11
ppdet/modeling/transformers/deformable_transformer.py
ppdet/modeling/transformers/deformable_transformer.py
+44
-32
ppdet/modeling/transformers/detr_transformer.py
ppdet/modeling/transformers/detr_transformer.py
+2
-3
ppdet/modeling/transformers/ext_op/README.md
ppdet/modeling/transformers/ext_op/README.md
+84
-0
ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cc
ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cc
+65
-0
ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cu
ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cu
+1073
-0
ppdet/modeling/transformers/ext_op/setup_ms_deformable_attn_op.py
...deling/transformers/ext_op/setup_ms_deformable_attn_op.py
+7
-0
ppdet/modeling/transformers/ext_op/test_ms_deformable_attn_op.py
...odeling/transformers/ext_op/test_ms_deformable_attn_op.py
+140
-0
ppdet/modeling/transformers/position_encoding.py
ppdet/modeling/transformers/position_encoding.py
+9
-14
ppdet/modeling/transformers/utils.py
ppdet/modeling/transformers/utils.py
+21
-11
未找到文件。
ppdet/modeling/architectures/detr.py
浏览文件 @
3d6a027c
...
...
@@ -84,7 +84,7 @@ class DETR(BaseArch):
preds
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
return
bbox
,
bbox_num
def
get_loss
(
self
,
):
def
get_loss
(
self
):
losses
=
self
.
_forward
()
losses
.
update
({
'loss'
:
...
...
ppdet/modeling/post_process.py
浏览文件 @
3d6a027c
...
...
@@ -492,19 +492,21 @@ class DETRBBoxPostProcess(object):
if
scores
.
shape
[
1
]
>
self
.
num_top_queries
:
scores
,
index
=
paddle
.
topk
(
scores
,
self
.
num_top_queries
,
axis
=-
1
)
labels
=
paddle
.
stack
(
[
paddle
.
gather
(
l
,
i
)
for
l
,
i
in
zip
(
labels
,
index
)])
bbox_pred
=
paddle
.
stack
(
[
paddle
.
gather
(
b
,
i
)
for
b
,
i
in
zip
(
bbox_pred
,
index
)])
batch_ind
=
paddle
.
arange
(
end
=
scores
.
shape
[
0
]).
unsqueeze
(
-
1
).
tile
(
[
1
,
self
.
num_top_queries
])
index
=
paddle
.
stack
([
batch_ind
,
index
],
axis
=-
1
)
labels
=
paddle
.
gather_nd
(
labels
,
index
)
bbox_pred
=
paddle
.
gather_nd
(
bbox_pred
,
index
)
else
:
scores
,
index
=
paddle
.
topk
(
scores
.
reshape
([
logits
.
shape
[
0
],
-
1
]),
self
.
num_top_queries
,
axis
=-
1
)
labels
=
index
%
logits
.
shape
[
2
]
index
=
index
//
logits
.
shape
[
2
]
bbox_pred
=
paddle
.
stack
(
[
paddle
.
gather
(
b
,
i
)
for
b
,
i
in
zip
(
bbox_pred
,
index
)]
)
scores
.
flatten
(
1
),
self
.
num_top_queries
,
axis
=-
1
)
labels
=
index
%
self
.
num_classes
index
=
index
//
self
.
num_classes
batch_ind
=
paddle
.
arange
(
end
=
scores
.
shape
[
0
]).
unsqueeze
(
-
1
).
tile
(
[
1
,
self
.
num_top_queries
])
index
=
paddle
.
stack
([
batch_ind
,
index
],
axis
=-
1
)
bbox_pred
=
paddle
.
gather_nd
(
bbox_pred
,
index
)
bbox_pred
=
paddle
.
concat
(
[
...
...
ppdet/modeling/transformers/deformable_transformer.py
浏览文件 @
3d6a027c
...
...
@@ -28,7 +28,7 @@ from paddle import ParamAttr
from
ppdet.core.workspace
import
register
from
..layers
import
MultiHeadAttention
from
.position_encoding
import
PositionEmbedding
from
.utils
import
_get_clones
,
deformable_attention_core_func
from
.utils
import
_get_clones
,
get_valid_ratio
from
..initializer
import
linear_init_
,
constant_
,
xavier_uniform_
,
normal_
__all__
=
[
'DeformableTransformer'
]
...
...
@@ -63,6 +63,13 @@ class MSDeformableAttention(nn.Layer):
self
.
attention_weights
=
nn
.
Linear
(
embed_dim
,
self
.
total_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
output_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
try
:
# use cuda op
from
deformable_detr_ops
import
ms_deformable_attn
except
:
# use paddle func
from
.utils
import
deformable_attention_core_func
as
ms_deformable_attn
self
.
ms_deformable_attn_core
=
ms_deformable_attn
self
.
_reset_parameters
()
...
...
@@ -95,6 +102,7 @@ class MSDeformableAttention(nn.Layer):
reference_points
,
value
,
value_spatial_shapes
,
value_level_start_index
,
value_mask
=
None
):
"""
Args:
...
...
@@ -103,6 +111,7 @@ class MSDeformableAttention(nn.Layer):
bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C]
value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns:
...
...
@@ -131,8 +140,9 @@ class MSDeformableAttention(nn.Layer):
bs
,
Len_q
,
1
,
self
.
num_levels
,
1
,
2
])
+
sampling_offsets
/
offset_normalizer
output
=
deformable_attention_core_func
(
value
,
value_spatial_shapes
,
sampling_locations
,
attention_weights
)
output
=
self
.
ms_deformable_attn_core
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
output
=
self
.
output_proj
(
output
)
return
output
...
...
@@ -185,12 +195,13 @@ class DeformableTransformerEncoderLayer(nn.Layer):
src
,
reference_points
,
spatial_shapes
,
level_start_index
,
src_mask
=
None
,
pos_embed
=
None
):
# self attention
src2
=
self
.
self_attn
(
self
.
with_pos_embed
(
src
,
pos_embed
),
reference_points
,
src
,
spatial_shapes
,
src_mask
)
spatial_shapes
,
level_start_index
,
src_mask
)
src
=
src
+
self
.
dropout1
(
src2
)
src
=
self
.
norm1
(
src
)
# ffn
...
...
@@ -206,13 +217,12 @@ class DeformableTransformerEncoder(nn.Layer):
self
.
num_layers
=
num_layers
@
staticmethod
def
get_reference_points
(
spatial_shapes
,
valid_ratios
):
def
get_reference_points
(
spatial_shapes
,
valid_ratios
,
offset
=
0.5
):
valid_ratios
=
valid_ratios
.
unsqueeze
(
1
)
reference_points
=
[]
for
i
,
(
H
,
W
)
in
enumerate
(
spatial_shapes
.
tolist
()
):
for
i
,
(
H
,
W
)
in
enumerate
(
spatial_shapes
):
ref_y
,
ref_x
=
paddle
.
meshgrid
(
paddle
.
linspace
(
0.5
,
H
-
0.5
,
H
),
paddle
.
linspace
(
0.5
,
W
-
0.5
,
W
))
paddle
.
arange
(
end
=
H
)
+
offset
,
paddle
.
arange
(
end
=
W
)
+
offset
)
ref_y
=
ref_y
.
flatten
().
unsqueeze
(
0
)
/
(
valid_ratios
[:,
:,
i
,
1
]
*
H
)
ref_x
=
ref_x
.
flatten
().
unsqueeze
(
0
)
/
(
valid_ratios
[:,
:,
i
,
0
]
*
...
...
@@ -225,6 +235,7 @@ class DeformableTransformerEncoder(nn.Layer):
def
forward
(
self
,
src
,
spatial_shapes
,
level_start_index
,
src_mask
=
None
,
pos_embed
=
None
,
valid_ratios
=
None
):
...
...
@@ -235,8 +246,8 @@ class DeformableTransformerEncoder(nn.Layer):
reference_points
=
self
.
get_reference_points
(
spatial_shapes
,
valid_ratios
)
for
layer
in
self
.
layers
:
output
=
layer
(
output
,
reference_points
,
spatial_shapes
,
src_mask
,
pos_embed
)
output
=
layer
(
output
,
reference_points
,
spatial_shapes
,
level_start_index
,
src_mask
,
pos_embed
)
return
output
...
...
@@ -296,6 +307,7 @@ class DeformableTransformerDecoderLayer(nn.Layer):
reference_points
,
memory
,
memory_spatial_shapes
,
memory_level_start_index
,
memory_mask
=
None
,
query_pos_embed
=
None
):
# self attention
...
...
@@ -307,7 +319,7 @@ class DeformableTransformerDecoderLayer(nn.Layer):
# cross attention
tgt2
=
self
.
cross_attn
(
self
.
with_pos_embed
(
tgt
,
query_pos_embed
),
reference_points
,
memory
,
memory_spatial_shapes
,
memory_mask
)
memory_spatial_shapes
,
memory_
level_start_index
,
memory_
mask
)
tgt
=
tgt
+
self
.
dropout2
(
tgt2
)
tgt
=
self
.
norm2
(
tgt
)
...
...
@@ -329,13 +341,15 @@ class DeformableTransformerDecoder(nn.Layer):
reference_points
,
memory
,
memory_spatial_shapes
,
memory_level_start_index
,
memory_mask
=
None
,
query_pos_embed
=
None
):
output
=
tgt
intermediate
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
output
=
layer
(
output
,
reference_points
,
memory
,
memory_spatial_shapes
,
memory_mask
,
query_pos_embed
)
memory_spatial_shapes
,
memory_level_start_index
,
memory_mask
,
query_pos_embed
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
...
...
@@ -447,14 +461,7 @@ class DeformableTransformer(nn.Layer):
def
from_config
(
cls
,
cfg
,
input_shape
):
return
{
'backbone_num_channels'
:
[
i
.
channels
for
i
in
input_shape
],
}
def
_get_valid_ratio
(
self
,
mask
):
_
,
H
,
W
=
mask
.
shape
valid_ratio_h
=
paddle
.
sum
(
mask
[:,
:,
0
],
1
)
/
H
valid_ratio_w
=
paddle
.
sum
(
mask
[:,
0
,
:],
1
)
/
W
valid_ratio
=
paddle
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
return
valid_ratio
def
forward
(
self
,
src_feats
,
src_mask
=
None
):
def
forward
(
self
,
src_feats
,
src_mask
=
None
,
*
args
,
**
kwargs
):
srcs
=
[]
for
i
in
range
(
len
(
src_feats
)):
srcs
.
append
(
self
.
input_proj
[
i
](
src_feats
[
i
]))
...
...
@@ -471,33 +478,38 @@ class DeformableTransformer(nn.Layer):
spatial_shapes
=
[]
valid_ratios
=
[]
for
level
,
src
in
enumerate
(
srcs
):
bs
,
c
,
h
,
w
=
src
.
shape
spatial_shapes
.
append
(
[
h
,
w
]
)
bs
,
_
,
h
,
w
=
paddle
.
shape
(
src
)
spatial_shapes
.
append
(
paddle
.
concat
([
h
,
w
])
)
src
=
src
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
src_flatten
.
append
(
src
)
if
src_mask
is
not
None
:
mask
=
F
.
interpolate
(
src_mask
.
unsqueeze
(
0
),
size
=
(
h
,
w
))[
0
]
else
:
mask
=
paddle
.
ones
([
bs
,
h
,
w
])
valid_ratios
.
append
(
self
.
_get_valid_ratio
(
mask
))
pos_embed
=
self
.
position_embedding
(
mask
).
flatten
(
2
).
transpose
(
[
0
,
2
,
1
])
lvl_pos_embed
=
pos_embed
+
self
.
level_embed
.
weight
[
level
].
reshape
(
[
1
,
1
,
-
1
])
valid_ratios
.
append
(
get_valid_ratio
(
mask
))
pos_embed
=
self
.
position_embedding
(
mask
).
flatten
(
1
,
2
)
lvl_pos_embed
=
pos_embed
+
self
.
level_embed
.
weight
[
level
]
lvl_pos_embed_flatten
.
append
(
lvl_pos_embed
)
mask
=
mask
.
flatten
(
1
)
mask_flatten
.
append
(
mask
)
src_flatten
=
paddle
.
concat
(
src_flatten
,
1
)
mask_flatten
=
paddle
.
concat
(
mask_flatten
,
1
)
mask_flatten
=
None
if
src_mask
is
None
else
paddle
.
concat
(
mask_flatten
,
1
)
lvl_pos_embed_flatten
=
paddle
.
concat
(
lvl_pos_embed_flatten
,
1
)
# [l, 2]
spatial_shapes
=
paddle
.
to_tensor
(
spatial_shapes
,
dtype
=
'int64'
)
spatial_shapes
=
paddle
.
to_tensor
(
paddle
.
stack
(
spatial_shapes
).
astype
(
'int64'
))
# [l], 每一个level的起始index
level_start_index
=
paddle
.
concat
([
paddle
.
zeros
(
[
1
],
dtype
=
'int64'
),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]
])
# [b, l, 2]
valid_ratios
=
paddle
.
stack
(
valid_ratios
,
1
)
# encoder
memory
=
self
.
encoder
(
src_flatten
,
spatial_shapes
,
mask_flatten
,
lvl_pos_embed_flatten
,
valid_ratios
)
memory
=
self
.
encoder
(
src_flatten
,
spatial_shapes
,
level_start_index
,
mask_flatten
,
lvl_pos_embed_flatten
,
valid_ratios
)
# prepare input for decoder
bs
,
_
,
c
=
memory
.
shape
...
...
@@ -509,6 +521,6 @@ class DeformableTransformer(nn.Layer):
# decoder
hs
=
self
.
decoder
(
tgt
,
reference_points_input
,
memory
,
spatial_shapes
,
mask_flatten
,
query_embed
)
level_start_index
,
mask_flatten
,
query_embed
)
return
(
hs
,
memory
,
reference_points
)
ppdet/modeling/transformers/detr_transformer.py
浏览文件 @
3d6a027c
...
...
@@ -295,7 +295,7 @@ class DETRTransformer(nn.Layer):
def
_convert_attention_mask
(
self
,
mask
):
return
(
mask
-
1.0
)
*
1e9
def
forward
(
self
,
src
,
src_mask
=
None
):
def
forward
(
self
,
src
,
src_mask
=
None
,
*
args
,
**
kwargs
):
r
"""
Applies a Transformer model on the inputs.
...
...
@@ -325,8 +325,7 @@ class DETRTransformer(nn.Layer):
src_mask
=
F
.
interpolate
(
src_mask
.
unsqueeze
(
0
),
size
=
(
h
,
w
))[
0
]
else
:
src_mask
=
paddle
.
ones
([
bs
,
h
,
w
])
pos_embed
=
self
.
position_embedding
(
src_mask
).
flatten
(
2
).
transpose
(
[
0
,
2
,
1
])
pos_embed
=
self
.
position_embedding
(
src_mask
).
flatten
(
1
,
2
)
if
self
.
training
:
src_mask
=
self
.
_convert_attention_mask
(
src_mask
)
...
...
ppdet/modeling/transformers/ext_op/README.md
0 → 100644
浏览文件 @
3d6a027c
# Multi-scale deformable attention自定义OP编译
该自定义OP是参考
[
自定义外部算子
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html
)
。
## 1. 环境依赖
-
Paddle >= 2.3.2
-
gcc 8.2
## 2. 安装
请在当前路径下进行编译安装
```
cd PaddleDetection/ppdet/modeling/transformers/ext_op/
python setup_ms_deformable_attn_op.py install
```
编译完成后即可使用,以下为
`ms_deformable_attn`
的使用示例
```
# 引入自定义op
from deformable_detr_ops import ms_deformable_attn
# 构造fake input tensor
bs, n_heads, c = 2, 8, 8
query_length, n_levels, n_points = 2, 2, 2
spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
level_start_index = paddle.concat((paddle.to_tensor(
[0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
value_length = sum([(H * W).item() for H, W in spatial_shapes])
def get_test_tensors(channels):
value = paddle.rand(
[bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
sampling_locations = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points, 2],
dtype=paddle.float32)
attention_weights = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points],
dtype=paddle.float32) + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
-2, keepdim=True)
return [value, sampling_locations, attention_weights]
value, sampling_locations, attention_weights = get_test_tensors(c)
output = ms_deformable_attn(value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights)
```
## 3. 单元测试
可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示:
```
python test_ms_deformable_attn_op.py
```
运行成功后,打印如下:
```
*True check_forward_equal_with_paddle_float: max_abs_err 6.98e-10 max_rel_err 2.03e-07
*tensor1 True check_gradient_numerical(D=30)
*tensor2 True check_gradient_numerical(D=30)
*tensor3 True check_gradient_numerical(D=30)
*tensor1 True check_gradient_numerical(D=32)
*tensor2 True check_gradient_numerical(D=32)
*tensor3 True check_gradient_numerical(D=32)
*tensor1 True check_gradient_numerical(D=64)
*tensor2 True check_gradient_numerical(D=64)
*tensor3 True check_gradient_numerical(D=64)
*tensor1 True check_gradient_numerical(D=71)
*tensor2 True check_gradient_numerical(D=71)
*tensor3 True check_gradient_numerical(D=71)
*tensor1 True check_gradient_numerical(D=128)
*tensor2 True check_gradient_numerical(D=128)
*tensor3 True check_gradient_numerical(D=128)
*tensor1 True check_gradient_numerical(D=1024)
*tensor2 True check_gradient_numerical(D=1024)
*tensor3 True check_gradient_numerical(D=1024)
*tensor1 True check_gradient_numerical(D=1025)
*tensor2 True check_gradient_numerical(D=1025)
*tensor3 True check_gradient_numerical(D=1025)
*tensor1 True check_gradient_numerical(D=2048)
*tensor2 True check_gradient_numerical(D=2048)
*tensor3 True check_gradient_numerical(D=2048)
*tensor1 True check_gradient_numerical(D=3096)
*tensor2 True check_gradient_numerical(D=3096)
*tensor3 True check_gradient_numerical(D=3096)
```
ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cc
0 → 100644
浏览文件 @
3d6a027c
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/extension.h"
#include <vector>
// declare GPU implementation
std
::
vector
<
paddle
::
Tensor
>
MSDeformableAttnCUDAForward
(
const
paddle
::
Tensor
&
value
,
const
paddle
::
Tensor
&
value_spatial_shapes
,
const
paddle
::
Tensor
&
value_level_start_index
,
const
paddle
::
Tensor
&
sampling_locations
,
const
paddle
::
Tensor
&
attention_weights
);
std
::
vector
<
paddle
::
Tensor
>
MSDeformableAttnCUDABackward
(
const
paddle
::
Tensor
&
value
,
const
paddle
::
Tensor
&
value_spatial_shapes
,
const
paddle
::
Tensor
&
value_level_start_index
,
const
paddle
::
Tensor
&
sampling_locations
,
const
paddle
::
Tensor
&
attention_weights
,
const
paddle
::
Tensor
&
grad_out
);
//// CPU not implemented
std
::
vector
<
std
::
vector
<
int64_t
>>
MSDeformableAttnInferShape
(
std
::
vector
<
int64_t
>
value_shape
,
std
::
vector
<
int64_t
>
value_spatial_shapes_shape
,
std
::
vector
<
int64_t
>
value_level_start_index_shape
,
std
::
vector
<
int64_t
>
sampling_locations_shape
,
std
::
vector
<
int64_t
>
attention_weights_shape
)
{
return
{{
value_shape
[
0
],
sampling_locations_shape
[
1
],
value_shape
[
2
]
*
value_shape
[
3
]}};
}
std
::
vector
<
paddle
::
DataType
>
MSDeformableAttnInferDtype
(
paddle
::
DataType
value_dtype
,
paddle
::
DataType
value_spatial_shapes_dtype
,
paddle
::
DataType
value_level_start_index_dtype
,
paddle
::
DataType
sampling_locations_dtype
,
paddle
::
DataType
attention_weights_dtype
)
{
return
{
value_dtype
};
}
PD_BUILD_OP
(
ms_deformable_attn
)
.
Inputs
({
"Value"
,
"SpatialShapes"
,
"LevelIndex"
,
"SamplingLocations"
,
"AttentionWeights"
})
.
Outputs
({
"Out"
})
.
SetKernelFn
(
PD_KERNEL
(
MSDeformableAttnCUDAForward
))
.
SetInferShapeFn
(
PD_INFER_SHAPE
(
MSDeformableAttnInferShape
))
.
SetInferDtypeFn
(
PD_INFER_DTYPE
(
MSDeformableAttnInferDtype
));
PD_BUILD_GRAD_OP
(
ms_deformable_attn
)
.
Inputs
({
"Value"
,
"SpatialShapes"
,
"LevelIndex"
,
"SamplingLocations"
,
"AttentionWeights"
,
paddle
::
Grad
(
"Out"
)})
.
Outputs
({
paddle
::
Grad
(
"Value"
),
paddle
::
Grad
(
"SpatialShapes"
),
paddle
::
Grad
(
"LevelIndex"
),
paddle
::
Grad
(
"SamplingLocations"
),
paddle
::
Grad
(
"AttentionWeights"
)})
.
SetKernelFn
(
PD_KERNEL
(
MSDeformableAttnCUDABackward
));
ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cu
0 → 100644
浏览文件 @
3d6a027c
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/extension.h"
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const
int
CUDA_NUM_THREADS
=
1024
;
inline
int
GET_BLOCKS
(
const
int
N
,
const
int
num_threads
)
{
return
(
N
+
num_threads
-
1
)
/
num_threads
;
}
// forward bilinear
template
<
typename
data_t
>
__device__
data_t
deformable_attn_bilinear_forward
(
const
data_t
*&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
data_t
&
h
,
const
data_t
&
w
,
const
int
&
m
,
const
int
&
c
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
data_t
lh
=
h
-
h_low
;
const
data_t
lw
=
w
-
w_low
;
const
data_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
data_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
}
data_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
}
data_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
}
data_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
}
const
data_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
data_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
// forward kernel
template
<
typename
data_t
>
__global__
void
deformable_attn_cuda_kernel_forward
(
const
int
n
,
const
data_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
data_t
*
data_sampling_loc
,
const
data_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
value_length
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
query_length
,
const
int
num_points
,
data_t
*
output_data_ptr
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
query_length
;
_temp
/=
query_length
;
const
int
b_col
=
_temp
;
data_t
*
data_ptr
=
output_data_ptr
+
index
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_points
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
value_length
*
qid_stride
;
data_t
col
=
0
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
data_t
*
data_value_ptr
=
data_value
+
(
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
);
for
(
int
p_col
=
0
;
p_col
<
num_points
;
++
p_col
)
{
const
data_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
data_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
data_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
data_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
data_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
col
+=
deformable_attn_bilinear_forward
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
)
*
weight
;
}
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
}
}
*
data_ptr
=
col
;
}
}
#define CHECK_INPUT_GPU(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
// forward
std
::
vector
<
paddle
::
Tensor
>
MSDeformableAttnCUDAForward
(
const
paddle
::
Tensor
&
value
,
const
paddle
::
Tensor
&
value_spatial_shapes
,
const
paddle
::
Tensor
&
value_level_start_index
,
const
paddle
::
Tensor
&
sampling_locations
,
const
paddle
::
Tensor
&
attention_weights
)
{
CHECK_INPUT_GPU
(
value
);
CHECK_INPUT_GPU
(
value_spatial_shapes
);
CHECK_INPUT_GPU
(
value_level_start_index
);
CHECK_INPUT_GPU
(
sampling_locations
);
CHECK_INPUT_GPU
(
attention_weights
);
const
int
batch_size
=
value
.
shape
()[
0
];
const
int
value_length
=
value
.
shape
()[
1
];
const
int
num_heads
=
value
.
shape
()[
2
];
const
int
channels
=
value
.
shape
()[
3
];
const
int
num_levels
=
value_spatial_shapes
.
shape
()[
0
];
const
int
query_length
=
sampling_locations
.
shape
()[
1
];
const
int
num_points
=
sampling_locations
.
shape
()[
4
];
auto
output
=
paddle
::
full
({
batch_size
,
query_length
,
num_heads
*
channels
},
0
,
value
.
dtype
(),
paddle
::
GPUPlace
());
const
int
num_kernels
=
batch_size
*
query_length
*
num_heads
*
channels
;
deformable_attn_cuda_kernel_forward
<
float
>
<<<
GET_BLOCKS
(
num_kernels
,
CUDA_NUM_THREADS
),
CUDA_NUM_THREADS
,
0
,
value
.
stream
()
>>>
(
num_kernels
,
value
.
data
<
float
>
(),
value_spatial_shapes
.
data
<
int64_t
>
(),
value_level_start_index
.
data
<
int64_t
>
(),
sampling_locations
.
data
<
float
>
(),
attention_weights
.
data
<
float
>
(),
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
output
.
data
<
float
>
());
return
{
output
};
}
// backward bilinear
template
<
typename
data_t
>
__device__
void
deformable_attn_bilinear_backward
(
const
data_t
*&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
data_t
&
h
,
const
data_t
&
w
,
const
int
&
m
,
const
int
&
c
,
const
data_t
&
top_grad
,
const
data_t
&
attn_weight
,
data_t
*&
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
data_t
lh
=
h
-
h_low
;
const
data_t
lw
=
w
-
w_low
;
const
data_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
const
data_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
data_t
top_grad_value
=
top_grad
*
attn_weight
;
data_t
grad_h_weight
=
0
,
grad_w_weight
=
0
;
data_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
grad_h_weight
-=
hw
*
v1
;
grad_w_weight
-=
hh
*
v1
;
atomicAdd
(
grad_value
+
ptr1
,
w1
*
top_grad_value
);
}
data_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
grad_h_weight
-=
lw
*
v2
;
grad_w_weight
+=
hh
*
v2
;
atomicAdd
(
grad_value
+
ptr2
,
w2
*
top_grad_value
);
}
data_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
grad_h_weight
+=
hw
*
v3
;
grad_w_weight
-=
lh
*
v3
;
atomicAdd
(
grad_value
+
ptr3
,
w3
*
top_grad_value
);
}
data_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
grad_h_weight
+=
lw
*
v4
;
grad_w_weight
+=
lh
*
v4
;
atomicAdd
(
grad_value
+
ptr4
,
w4
*
top_grad_value
);
}
const
data_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
*
grad_attn_weight
=
top_grad
*
val
;
*
grad_sampling_loc
=
width
*
grad_w_weight
*
top_grad_value
;
*
(
grad_sampling_loc
+
1
)
=
height
*
grad_h_weight
*
top_grad_value
;
}
template
<
typename
data_t
>
__device__
void
deformable_attn_bilinear_backward_gm
(
const
data_t
*&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
data_t
&
h
,
const
data_t
&
w
,
const
int
&
m
,
const
int
&
c
,
const
data_t
&
top_grad
,
const
data_t
&
attn_weight
,
data_t
*&
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
data_t
lh
=
h
-
h_low
;
const
data_t
lw
=
w
-
w_low
;
const
data_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
const
data_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
data_t
top_grad_value
=
top_grad
*
attn_weight
;
data_t
grad_h_weight
=
0
,
grad_w_weight
=
0
;
data_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
grad_h_weight
-=
hw
*
v1
;
grad_w_weight
-=
hh
*
v1
;
atomicAdd
(
grad_value
+
ptr1
,
w1
*
top_grad_value
);
}
data_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
grad_h_weight
-=
lw
*
v2
;
grad_w_weight
+=
hh
*
v2
;
atomicAdd
(
grad_value
+
ptr2
,
w2
*
top_grad_value
);
}
data_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
grad_h_weight
+=
hw
*
v3
;
grad_w_weight
-=
lh
*
v3
;
atomicAdd
(
grad_value
+
ptr3
,
w3
*
top_grad_value
);
}
data_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
grad_h_weight
+=
lw
*
v4
;
grad_w_weight
+=
lh
*
v4
;
atomicAdd
(
grad_value
+
ptr4
,
w4
*
top_grad_value
);
}
const
data_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
atomicAdd
(
grad_attn_weight
,
top_grad
*
val
);
atomicAdd
(
grad_sampling_loc
,
width
*
grad_w_weight
*
top_grad_value
);
atomicAdd
(
grad_sampling_loc
+
1
,
height
*
grad_h_weight
*
top_grad_value
);
}
// backward kernels
// channels > 1024
template
<
typename
data_t
>
__global__
void
deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks
(
const
int
n
,
const
data_t
*
grad_col
,
const
data_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
data_t
*
data_sampling_loc
,
const
data_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
value_length
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
query_length
,
const
int
num_points
,
data_t
*
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
data_t
*
cache_grad_sampling_loc
=
(
data_t
*
)
_s
;
data_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
query_length
;
_temp
/=
query_length
;
const
int
b_col
=
_temp
;
const
data_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_points
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
value_length
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
data_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
data_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_points
;
++
p_col
)
{
const
data_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
data_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
data_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
data_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
data_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
deformable_attn_bilinear_backward
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
x
/
2
,
spre
=
blockDim
.
x
;
s
>
0
;
s
>>=
1
,
spre
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
if
(
tid
+
(
s
<<
1
)
<
spre
)
{
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
+
(
s
<<
1
)];
}
}
__syncthreads
();
}
if
(
tid
==
0
)
{
atomicAdd
(
grad_sampling_loc
,
cache_grad_sampling_loc
[
0
]);
atomicAdd
(
grad_sampling_loc
+
1
,
cache_grad_sampling_loc
[
1
]);
atomicAdd
(
grad_attn_weight
,
cache_grad_attn_weight
[
0
]);
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
data_t
>
__global__
void
deformable_attn_cuda_kernel_backward_gm
(
const
int
n
,
const
data_t
*
grad_col
,
const
data_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
data_t
*
data_sampling_loc
,
const
data_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
value_length
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
query_length
,
const
int
num_points
,
data_t
*
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
query_length
;
_temp
/=
query_length
;
const
int
b_col
=
_temp
;
const
data_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_points
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
value_length
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
data_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
data_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_points
;
++
p_col
)
{
const
data_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
data_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
data_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
data_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
data_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
deformable_attn_bilinear_backward_gm
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
grad_sampling_loc
,
grad_attn_weight
);
}
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
// channels <= 1024
template
<
typename
data_t
,
unsigned
int
blockSize
>
__global__
void
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1
(
const
int
n
,
const
data_t
*
grad_col
,
const
data_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
data_t
*
data_sampling_loc
,
const
data_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
value_length
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
query_length
,
const
int
num_points
,
data_t
*
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
__shared__
data_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
data_t
cache_grad_attn_weight
[
blockSize
];
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
query_length
;
_temp
/=
query_length
;
const
int
b_col
=
_temp
;
const
data_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_points
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
value_length
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
data_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
data_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_points
;
++
p_col
)
{
const
data_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
data_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
data_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
data_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
data_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
deformable_attn_bilinear_backward
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
if
(
tid
==
0
)
{
data_t
_grad_w
=
cache_grad_sampling_loc
[
0
],
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_a
=
cache_grad_attn_weight
[
0
];
int
sid
=
2
;
for
(
unsigned
int
tid
=
1
;
tid
<
blockSize
;
++
tid
)
{
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_a
+=
cache_grad_attn_weight
[
tid
];
sid
+=
2
;
}
*
grad_sampling_loc
=
_grad_w
;
*
(
grad_sampling_loc
+
1
)
=
_grad_h
;
*
grad_attn_weight
=
_grad_a
;
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
data_t
,
unsigned
int
blockSize
>
__global__
void
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2
(
const
int
n
,
const
data_t
*
grad_col
,
const
data_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
data_t
*
data_sampling_loc
,
const
data_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
value_length
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
query_length
,
const
int
num_points
,
data_t
*
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
__shared__
data_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
data_t
cache_grad_attn_weight
[
blockSize
];
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
query_length
;
_temp
/=
query_length
;
const
int
b_col
=
_temp
;
const
data_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_points
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
value_length
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
data_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
data_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_points
;
++
p_col
)
{
const
data_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
data_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
data_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
data_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
data_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
deformable_attn_bilinear_backward
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockSize
/
2
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
*
grad_sampling_loc
=
cache_grad_sampling_loc
[
0
];
*
(
grad_sampling_loc
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
grad_attn_weight
=
cache_grad_attn_weight
[
0
];
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
data_t
>
__global__
void
deformable_attn_cuda_kernel_backward_shm_reduce_v1
(
const
int
n
,
const
data_t
*
grad_col
,
const
data_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
data_t
*
data_sampling_loc
,
const
data_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
value_length
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
query_length
,
const
int
num_points
,
data_t
*
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
data_t
*
cache_grad_sampling_loc
=
(
data_t
*
)
_s
;
data_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
query_length
;
_temp
/=
query_length
;
const
int
b_col
=
_temp
;
const
data_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_points
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
value_length
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
data_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
data_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_points
;
++
p_col
)
{
const
data_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
data_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
data_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
data_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
data_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
deformable_attn_bilinear_backward
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
if
(
tid
==
0
)
{
data_t
_grad_w
=
cache_grad_sampling_loc
[
0
],
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_a
=
cache_grad_attn_weight
[
0
];
int
sid
=
2
;
for
(
unsigned
int
tid
=
1
;
tid
<
blockDim
.
x
;
++
tid
)
{
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_a
+=
cache_grad_attn_weight
[
tid
];
sid
+=
2
;
}
*
grad_sampling_loc
=
_grad_w
;
*
(
grad_sampling_loc
+
1
)
=
_grad_h
;
*
grad_attn_weight
=
_grad_a
;
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
data_t
>
__global__
void
deformable_attn_cuda_kernel_backward_shm_reduce_v2
(
const
int
n
,
const
data_t
*
grad_col
,
const
data_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
data_t
*
data_sampling_loc
,
const
data_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
value_length
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
query_length
,
const
int
num_points
,
data_t
*
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
data_t
*
cache_grad_sampling_loc
=
(
data_t
*
)
_s
;
data_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
query_length
;
_temp
/=
query_length
;
const
int
b_col
=
_temp
;
const
data_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_points
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
value_length
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
data_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
data_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_points
;
++
p_col
)
{
const
data_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
data_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
data_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
data_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
data_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
deformable_attn_bilinear_backward
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
x
/
2
,
spre
=
blockDim
.
x
;
s
>
0
;
s
>>=
1
,
spre
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
if
(
tid
+
(
s
<<
1
)
<
spre
)
{
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
+
(
s
<<
1
)];
}
}
__syncthreads
();
}
if
(
tid
==
0
)
{
*
grad_sampling_loc
=
cache_grad_sampling_loc
[
0
];
*
(
grad_sampling_loc
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
grad_attn_weight
=
cache_grad_attn_weight
[
0
];
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
// backward branch
template
<
typename
data_t
>
void
deformable_attn_cuda_backward
(
cudaStream_t
stream
,
const
data_t
*
grad_out
,
const
data_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
data_t
*
data_sampling_loc
,
const
data_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
value_length
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
query_length
,
const
int
num_points
,
data_t
*
grad_value
,
data_t
*
grad_sampling_loc
,
data_t
*
grad_attn_weight
)
{
const
int
num_threads
=
(
channels
>
CUDA_NUM_THREADS
)
?
CUDA_NUM_THREADS
:
channels
;
const
int
num_kernels
=
batch_size
*
query_length
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
query_length
*
num_heads
*
channels
;
if
(
channels
>
1024
)
{
if
((
channels
&
1023
)
==
0
)
{
deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks
<
data_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
data_t
),
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
else
{
deformable_attn_cuda_kernel_backward_gm
<
data_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
}
else
{
switch
(
channels
)
{
case
1
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1
<
data_t
,
1
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
2
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1
<
data_t
,
2
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
4
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1
<
data_t
,
4
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
8
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1
<
data_t
,
8
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
16
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1
<
data_t
,
16
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
32
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1
<
data_t
,
32
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
64
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2
<
data_t
,
64
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
128
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2
<
data_t
,
128
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
256
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2
<
data_t
,
256
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
512
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2
<
data_t
,
512
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
1024
:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2
<
data_t
,
1024
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
default:
if
(
channels
<
64
)
{
deformable_attn_cuda_kernel_backward_shm_reduce_v1
<
data_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
data_t
),
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
else
{
deformable_attn_cuda_kernel_backward_shm_reduce_v2
<
data_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
data_t
),
stream
>>>
(
num_kernels
,
grad_out
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
}
}
}
// backward
std
::
vector
<
paddle
::
Tensor
>
MSDeformableAttnCUDABackward
(
const
paddle
::
Tensor
&
value
,
const
paddle
::
Tensor
&
value_spatial_shapes
,
const
paddle
::
Tensor
&
value_level_start_index
,
const
paddle
::
Tensor
&
sampling_locations
,
const
paddle
::
Tensor
&
attention_weights
,
const
paddle
::
Tensor
&
grad_out
)
{
CHECK_INPUT_GPU
(
value
);
CHECK_INPUT_GPU
(
value_spatial_shapes
);
CHECK_INPUT_GPU
(
value_level_start_index
);
CHECK_INPUT_GPU
(
sampling_locations
);
CHECK_INPUT_GPU
(
attention_weights
);
CHECK_INPUT_GPU
(
grad_out
);
const
int
batch_size
=
value
.
shape
()[
0
];
const
int
value_length
=
value
.
shape
()[
1
];
const
int
num_heads
=
value
.
shape
()[
2
];
const
int
channels
=
value
.
shape
()[
3
];
const
int
num_levels
=
value_spatial_shapes
.
shape
()[
0
];
const
int
query_length
=
sampling_locations
.
shape
()[
1
];
const
int
num_points
=
sampling_locations
.
shape
()[
4
];
auto
grad_value
=
paddle
::
full
(
value
.
shape
(),
0
,
value
.
dtype
(),
paddle
::
GPUPlace
());
auto
grad_spatial_shapes
=
paddle
::
full
(
value
.
shape
(),
0
,
value
.
dtype
(),
paddle
::
GPUPlace
());
auto
grad_level_start_index
=
paddle
::
full
(
value
.
shape
(),
0
,
value
.
dtype
(),
paddle
::
GPUPlace
());
auto
grad_sampling_locations
=
paddle
::
full
(
sampling_locations
.
shape
(),
0
,
sampling_locations
.
dtype
(),
paddle
::
GPUPlace
());
auto
grad_attention_weights
=
paddle
::
full
(
attention_weights
.
shape
(),
0
,
attention_weights
.
dtype
(),
paddle
::
GPUPlace
());
deformable_attn_cuda_backward
<
float
>
(
value
.
stream
(),
grad_out
.
data
<
float
>
(),
value
.
data
<
float
>
(),
value_spatial_shapes
.
data
<
int64_t
>
(),
value_level_start_index
.
data
<
int64_t
>
(),
sampling_locations
.
data
<
float
>
(),
attention_weights
.
data
<
float
>
(),
batch_size
,
value_length
,
num_heads
,
channels
,
num_levels
,
query_length
,
num_points
,
grad_value
.
data
<
float
>
(),
grad_sampling_locations
.
data
<
float
>
(),
grad_attention_weights
.
data
<
float
>
());
return
{
grad_value
,
grad_spatial_shapes
,
grad_level_start_index
,
grad_sampling_locations
,
grad_attention_weights
};
}
ppdet/modeling/transformers/ext_op/setup_ms_deformable_attn_op.py
0 → 100644
浏览文件 @
3d6a027c
from
paddle.utils.cpp_extension
import
CUDAExtension
,
setup
if
__name__
==
"__main__"
:
setup
(
name
=
'deformable_detr_ops'
,
ext_modules
=
CUDAExtension
(
sources
=
[
'ms_deformable_attn_op.cc'
,
'ms_deformable_attn_op.cu'
]))
ppdet/modeling/transformers/ext_op/test_ms_deformable_attn_op.py
0 → 100644
浏览文件 @
3d6a027c
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
os
import
sys
import
random
import
numpy
as
np
import
paddle
# add python path of PadleDetection to sys.path
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
]
*
5
)))
if
parent_path
not
in
sys
.
path
:
sys
.
path
.
append
(
parent_path
)
from
ppdet.modeling.transformers.utils
import
deformable_attention_core_func
ms_deform_attn_core_paddle
=
deformable_attention_core_func
try
:
gpu_index
=
int
(
sys
.
argv
[
1
])
except
:
gpu_index
=
0
print
(
f
'Use gpu
{
gpu_index
}
to test...'
)
paddle
.
set_device
(
f
'gpu:
{
gpu_index
}
'
)
try
:
from
deformable_detr_ops
import
ms_deformable_attn
except
Exception
as
e
:
print
(
'import deformable_detr_ops error'
,
e
)
sys
.
exit
(
-
1
)
paddle
.
seed
(
1
)
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
bs
,
n_heads
,
c
=
2
,
8
,
8
query_length
,
n_levels
,
n_points
=
2
,
2
,
2
spatial_shapes
=
paddle
.
to_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
paddle
.
int64
)
level_start_index
=
paddle
.
concat
((
paddle
.
to_tensor
(
[
0
],
dtype
=
paddle
.
int64
),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
value_length
=
sum
([(
H
*
W
).
item
()
for
H
,
W
in
spatial_shapes
])
def
get_test_tensors
(
channels
):
value
=
paddle
.
rand
(
[
bs
,
value_length
,
n_heads
,
channels
],
dtype
=
paddle
.
float32
)
*
0.01
sampling_locations
=
paddle
.
rand
(
[
bs
,
query_length
,
n_heads
,
n_levels
,
n_points
,
2
],
dtype
=
paddle
.
float32
)
attention_weights
=
paddle
.
rand
(
[
bs
,
query_length
,
n_heads
,
n_levels
,
n_points
],
dtype
=
paddle
.
float32
)
+
1e-5
attention_weights
/=
attention_weights
.
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
2
,
keepdim
=
True
)
return
[
value
,
sampling_locations
,
attention_weights
]
@
paddle
.
no_grad
()
def
check_forward_equal_with_paddle_float
():
value
,
sampling_locations
,
attention_weights
=
get_test_tensors
(
c
)
output_paddle
=
ms_deform_attn_core_paddle
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
).
detach
().
cpu
()
output_cuda
=
ms_deformable_attn
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
).
detach
().
cpu
()
fwdok
=
paddle
.
allclose
(
output_cuda
,
output_paddle
,
rtol
=
1e-2
,
atol
=
1e-3
).
item
()
max_abs_err
=
(
output_cuda
-
output_paddle
).
abs
().
max
().
item
()
max_rel_err
=
(
(
output_cuda
-
output_paddle
).
abs
()
/
output_paddle
.
abs
()).
max
().
item
()
print
(
f
'*
{
fwdok
}
check_forward_equal_with_paddle_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
def
check_gradient_numerical
(
channels
=
4
):
value_paddle
,
sampling_locations_paddle
,
attention_weights_paddle
=
get_test_tensors
(
channels
)
value_paddle
.
stop_gradient
=
False
sampling_locations_paddle
.
stop_gradient
=
False
attention_weights_paddle
.
stop_gradient
=
False
value_cuda
=
value_paddle
.
detach
().
clone
()
sampling_locations_cuda
=
sampling_locations_paddle
.
detach
().
clone
()
attention_weights_cuda
=
attention_weights_paddle
.
detach
().
clone
()
value_cuda
.
stop_gradient
=
False
sampling_locations_cuda
.
stop_gradient
=
False
attention_weights_cuda
.
stop_gradient
=
False
output_paddle
=
ms_deform_attn_core_paddle
(
value_paddle
,
spatial_shapes
,
level_start_index
,
sampling_locations_paddle
,
attention_weights_paddle
)
output_paddle
.
sum
().
backward
()
output_cuda
=
ms_deformable_attn
(
value_cuda
,
spatial_shapes
,
level_start_index
,
sampling_locations_cuda
,
attention_weights_cuda
)
output_cuda
.
sum
().
backward
()
res
=
paddle
.
allclose
(
value_paddle
.
grad
,
value_cuda
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
).
item
()
print
(
f
'*tensor1
{
res
}
check_gradient_numerical(D=
{
channels
}
)'
)
res
=
paddle
.
allclose
(
sampling_locations_paddle
.
grad
,
sampling_locations_cuda
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
).
item
()
print
(
f
'*tensor2
{
res
}
check_gradient_numerical(D=
{
channels
}
)'
)
res
=
paddle
.
allclose
(
attention_weights_paddle
.
grad
,
attention_weights_cuda
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
).
item
()
print
(
f
'*tensor3
{
res
}
check_gradient_numerical(D=
{
channels
}
)'
)
if
__name__
==
'__main__'
:
check_forward_equal_with_paddle_float
()
for
channels
in
[
30
,
32
,
64
,
71
,
128
,
1024
,
1025
,
2048
,
3096
]:
check_gradient_numerical
(
channels
)
ppdet/modeling/transformers/position_encoding.py
浏览文件 @
3d6a027c
...
...
@@ -33,37 +33,34 @@ class PositionEmbedding(nn.Layer):
num_pos_feats
=
128
,
temperature
=
10000
,
normalize
=
True
,
scale
=
None
,
scale
=
2
*
math
.
pi
,
embed_type
=
'sine'
,
num_embeddings
=
50
,
offset
=
0.
):
offset
=
0.
,
eps
=
1e-6
):
super
(
PositionEmbedding
,
self
).
__init__
()
assert
embed_type
in
[
'sine'
,
'learned'
]
self
.
embed_type
=
embed_type
self
.
offset
=
offset
self
.
eps
=
1e-6
self
.
eps
=
eps
if
self
.
embed_type
==
'sine'
:
self
.
num_pos_feats
=
num_pos_feats
self
.
temperature
=
temperature
self
.
normalize
=
normalize
if
scale
is
not
None
and
normalize
is
False
:
raise
ValueError
(
"normalize should be True if scale is passed"
)
if
scale
is
None
:
scale
=
2
*
math
.
pi
self
.
scale
=
scale
elif
self
.
embed_type
==
'learned'
:
self
.
row_embed
=
nn
.
Embedding
(
num_embeddings
,
num_pos_feats
)
self
.
col_embed
=
nn
.
Embedding
(
num_embeddings
,
num_pos_feats
)
else
:
raise
ValueError
(
f
"
not supported
{
self
.
embed_type
}
"
)
raise
ValueError
(
f
"
{
self
.
embed_type
}
is not supported.
"
)
def
forward
(
self
,
mask
):
"""
Args:
mask (Tensor): [B, H, W]
Returns:
pos (Tensor): [B,
C, H, W
]
pos (Tensor): [B,
H, W, C
]
"""
if
self
.
embed_type
==
'sine'
:
y_embed
=
mask
.
cumsum
(
1
)
...
...
@@ -86,20 +83,18 @@ class PositionEmbedding(nn.Layer):
pos_y
=
paddle
.
stack
(
(
pos_y
[:,
:,
:,
0
::
2
].
sin
(),
pos_y
[:,
:,
:,
1
::
2
].
cos
()),
axis
=
4
).
flatten
(
3
)
pos
=
paddle
.
concat
((
pos_y
,
pos_x
),
axis
=
3
).
transpose
([
0
,
3
,
1
,
2
])
return
pos
return
paddle
.
concat
((
pos_y
,
pos_x
),
axis
=
3
)
elif
self
.
embed_type
==
'learned'
:
h
,
w
=
mask
.
shape
[
-
2
:]
i
=
paddle
.
arange
(
w
)
j
=
paddle
.
arange
(
h
)
x_emb
=
self
.
col_embed
(
i
)
y_emb
=
self
.
row_embed
(
j
)
pos
=
paddle
.
concat
(
return
paddle
.
concat
(
[
x_emb
.
unsqueeze
(
0
).
tile
([
h
,
1
,
1
]),
y_emb
.
unsqueeze
(
1
).
tile
([
1
,
w
,
1
]),
],
axis
=-
1
).
transpose
([
2
,
0
,
1
]).
unsqueeze
(
0
)
return
pos
axis
=-
1
).
unsqueeze
(
0
)
else
:
raise
ValueError
(
f
"not supported
{
self
.
embed_type
}
"
)
ppdet/modeling/transformers/utils.py
浏览文件 @
3d6a027c
...
...
@@ -38,15 +38,14 @@ def _get_clones(module, N):
def
bbox_cxcywh_to_xyxy
(
x
):
x_c
,
y_c
,
w
,
h
=
x
.
split
(
4
,
axis
=-
1
)
b
=
[(
x_c
-
0.5
*
w
),
(
y_c
-
0.5
*
h
),
(
x_c
+
0.5
*
w
),
(
y_c
+
0.5
*
h
)]
return
paddle
.
concat
(
b
,
axis
=-
1
)
cxcy
,
wh
=
paddle
.
split
(
x
,
2
,
axis
=-
1
)
return
paddle
.
concat
([
cxcy
-
0.5
*
wh
,
cxcy
+
0.5
*
wh
],
axis
=-
1
)
def
bbox_xyxy_to_cxcywh
(
x
):
x
0
,
y0
,
x1
,
y1
=
x
.
split
(
4
,
axis
=-
1
)
b
=
[(
x0
+
x1
)
/
2
,
(
y0
+
y1
)
/
2
,
(
x1
-
x0
),
(
y1
-
y0
)]
return
paddle
.
concat
(
b
,
axis
=-
1
)
x
1
,
y1
,
x2
,
y2
=
x
.
split
(
4
,
axis
=-
1
)
return
paddle
.
concat
(
[(
x1
+
x2
)
/
2
,
(
y1
+
y2
)
/
2
,
(
x2
-
x1
),
(
y2
-
y1
)]
,
axis
=-
1
)
def
sigmoid_focal_loss
(
logit
,
label
,
normalizer
=
1.0
,
alpha
=
0.25
,
gamma
=
2.0
):
...
...
@@ -67,24 +66,27 @@ def inverse_sigmoid(x, eps=1e-6):
def
deformable_attention_core_func
(
value
,
value_spatial_shapes
,
sampling_locations
,
attention_weights
):
value_level_start_index
,
sampling_locations
,
attention_weights
):
"""
Args:
value (Tensor): [bs, value_length, n_head, c]
value_spatial_shapes (Tensor): [n_levels, 2]
value_level_start_index (Tensor): [n_levels]
sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs
,
Len_v
,
n_head
,
c
=
value
.
shape
_
,
Len_q
,
n_head
,
n_levels
,
n_points
,
_
=
sampling_locations
.
shape
bs
,
_
,
n_head
,
c
=
value
.
shape
_
,
Len_q
,
_
,
n_levels
,
n_points
,
_
=
sampling_locations
.
shape
value_list
=
value
.
split
(
value_spatial_shapes
.
prod
(
1
).
tolist
(),
axis
=
1
)
value_list
=
value
.
split
(
value_spatial_shapes
.
prod
(
1
).
split
(
n_levels
),
axis
=
1
)
sampling_grids
=
2
*
sampling_locations
-
1
sampling_value_list
=
[]
for
level
,
(
h
,
w
)
in
enumerate
(
value_spatial_shapes
.
tolist
()
):
for
level
,
(
h
,
w
)
in
enumerate
(
value_spatial_shapes
):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_
=
value_list
[
level
].
flatten
(
2
).
transpose
(
[
0
,
2
,
1
]).
reshape
([
bs
*
n_head
,
c
,
h
,
w
])
...
...
@@ -107,3 +109,11 @@ def deformable_attention_core_func(value, value_spatial_shapes,
attention_weights
).
sum
(
-
1
).
reshape
([
bs
,
n_head
*
c
,
Len_q
])
return
output
.
transpose
([
0
,
2
,
1
])
def
get_valid_ratio
(
mask
):
_
,
H
,
W
=
paddle
.
shape
(
mask
)
valid_ratio_h
=
paddle
.
sum
(
mask
[:,
:,
0
],
1
)
/
H
valid_ratio_w
=
paddle
.
sum
(
mask
[:,
0
,
:],
1
)
/
W
# [b, 2]
return
paddle
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录