未验证 提交 3d6a027c 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] add ms_deformable_attn cuda op (#7521)

上级 e1a8f660
...@@ -84,7 +84,7 @@ class DETR(BaseArch): ...@@ -84,7 +84,7 @@ class DETR(BaseArch):
preds, self.inputs['im_shape'], self.inputs['scale_factor']) preds, self.inputs['im_shape'], self.inputs['scale_factor'])
return bbox, bbox_num return bbox, bbox_num
def get_loss(self, ): def get_loss(self):
losses = self._forward() losses = self._forward()
losses.update({ losses.update({
'loss': 'loss':
......
...@@ -492,19 +492,21 @@ class DETRBBoxPostProcess(object): ...@@ -492,19 +492,21 @@ class DETRBBoxPostProcess(object):
if scores.shape[1] > self.num_top_queries: if scores.shape[1] > self.num_top_queries:
scores, index = paddle.topk( scores, index = paddle.topk(
scores, self.num_top_queries, axis=-1) scores, self.num_top_queries, axis=-1)
labels = paddle.stack( batch_ind = paddle.arange(
[paddle.gather(l, i) for l, i in zip(labels, index)]) end=scores.shape[0]).unsqueeze(-1).tile(
bbox_pred = paddle.stack( [1, self.num_top_queries])
[paddle.gather(b, i) for b, i in zip(bbox_pred, index)]) index = paddle.stack([batch_ind, index], axis=-1)
labels = paddle.gather_nd(labels, index)
bbox_pred = paddle.gather_nd(bbox_pred, index)
else: else:
scores, index = paddle.topk( scores, index = paddle.topk(
scores.reshape([logits.shape[0], -1]), scores.flatten(1), self.num_top_queries, axis=-1)
self.num_top_queries, labels = index % self.num_classes
axis=-1) index = index // self.num_classes
labels = index % logits.shape[2] batch_ind = paddle.arange(end=scores.shape[0]).unsqueeze(-1).tile(
index = index // logits.shape[2] [1, self.num_top_queries])
bbox_pred = paddle.stack( index = paddle.stack([batch_ind, index], axis=-1)
[paddle.gather(b, i) for b, i in zip(bbox_pred, index)]) bbox_pred = paddle.gather_nd(bbox_pred, index)
bbox_pred = paddle.concat( bbox_pred = paddle.concat(
[ [
......
...@@ -28,7 +28,7 @@ from paddle import ParamAttr ...@@ -28,7 +28,7 @@ from paddle import ParamAttr
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ..layers import MultiHeadAttention from ..layers import MultiHeadAttention
from .position_encoding import PositionEmbedding from .position_encoding import PositionEmbedding
from .utils import _get_clones, deformable_attention_core_func from .utils import _get_clones, get_valid_ratio
from ..initializer import linear_init_, constant_, xavier_uniform_, normal_ from ..initializer import linear_init_, constant_, xavier_uniform_, normal_
__all__ = ['DeformableTransformer'] __all__ = ['DeformableTransformer']
...@@ -63,6 +63,13 @@ class MSDeformableAttention(nn.Layer): ...@@ -63,6 +63,13 @@ class MSDeformableAttention(nn.Layer):
self.attention_weights = nn.Linear(embed_dim, self.total_points) self.attention_weights = nn.Linear(embed_dim, self.total_points)
self.value_proj = nn.Linear(embed_dim, embed_dim) self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim) self.output_proj = nn.Linear(embed_dim, embed_dim)
try:
# use cuda op
from deformable_detr_ops import ms_deformable_attn
except:
# use paddle func
from .utils import deformable_attention_core_func as ms_deformable_attn
self.ms_deformable_attn_core = ms_deformable_attn
self._reset_parameters() self._reset_parameters()
...@@ -95,6 +102,7 @@ class MSDeformableAttention(nn.Layer): ...@@ -95,6 +102,7 @@ class MSDeformableAttention(nn.Layer):
reference_points, reference_points,
value, value,
value_spatial_shapes, value_spatial_shapes,
value_level_start_index,
value_mask=None): value_mask=None):
""" """
Args: Args:
...@@ -103,6 +111,7 @@ class MSDeformableAttention(nn.Layer): ...@@ -103,6 +111,7 @@ class MSDeformableAttention(nn.Layer):
bottom-right (1, 1), including padding area bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C] value (Tensor): [bs, value_length, C]
value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns: Returns:
...@@ -131,8 +140,9 @@ class MSDeformableAttention(nn.Layer): ...@@ -131,8 +140,9 @@ class MSDeformableAttention(nn.Layer):
bs, Len_q, 1, self.num_levels, 1, 2 bs, Len_q, 1, self.num_levels, 1, 2
]) + sampling_offsets / offset_normalizer ]) + sampling_offsets / offset_normalizer
output = deformable_attention_core_func( output = self.ms_deformable_attn_core(
value, value_spatial_shapes, sampling_locations, attention_weights) value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights)
output = self.output_proj(output) output = self.output_proj(output)
return output return output
...@@ -185,12 +195,13 @@ class DeformableTransformerEncoderLayer(nn.Layer): ...@@ -185,12 +195,13 @@ class DeformableTransformerEncoderLayer(nn.Layer):
src, src,
reference_points, reference_points,
spatial_shapes, spatial_shapes,
level_start_index,
src_mask=None, src_mask=None,
pos_embed=None): pos_embed=None):
# self attention # self attention
src2 = self.self_attn( src2 = self.self_attn(
self.with_pos_embed(src, pos_embed), reference_points, src, self.with_pos_embed(src, pos_embed), reference_points, src,
spatial_shapes, src_mask) spatial_shapes, level_start_index, src_mask)
src = src + self.dropout1(src2) src = src + self.dropout1(src2)
src = self.norm1(src) src = self.norm1(src)
# ffn # ffn
...@@ -206,13 +217,12 @@ class DeformableTransformerEncoder(nn.Layer): ...@@ -206,13 +217,12 @@ class DeformableTransformerEncoder(nn.Layer):
self.num_layers = num_layers self.num_layers = num_layers
@staticmethod @staticmethod
def get_reference_points(spatial_shapes, valid_ratios): def get_reference_points(spatial_shapes, valid_ratios, offset=0.5):
valid_ratios = valid_ratios.unsqueeze(1) valid_ratios = valid_ratios.unsqueeze(1)
reference_points = [] reference_points = []
for i, (H, W) in enumerate(spatial_shapes.tolist()): for i, (H, W) in enumerate(spatial_shapes):
ref_y, ref_x = paddle.meshgrid( ref_y, ref_x = paddle.meshgrid(
paddle.linspace(0.5, H - 0.5, H), paddle.arange(end=H) + offset, paddle.arange(end=W) + offset)
paddle.linspace(0.5, W - 0.5, W))
ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] * ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] *
H) H)
ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] * ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] *
...@@ -225,6 +235,7 @@ class DeformableTransformerEncoder(nn.Layer): ...@@ -225,6 +235,7 @@ class DeformableTransformerEncoder(nn.Layer):
def forward(self, def forward(self,
src, src,
spatial_shapes, spatial_shapes,
level_start_index,
src_mask=None, src_mask=None,
pos_embed=None, pos_embed=None,
valid_ratios=None): valid_ratios=None):
...@@ -235,8 +246,8 @@ class DeformableTransformerEncoder(nn.Layer): ...@@ -235,8 +246,8 @@ class DeformableTransformerEncoder(nn.Layer):
reference_points = self.get_reference_points(spatial_shapes, reference_points = self.get_reference_points(spatial_shapes,
valid_ratios) valid_ratios)
for layer in self.layers: for layer in self.layers:
output = layer(output, reference_points, spatial_shapes, src_mask, output = layer(output, reference_points, spatial_shapes,
pos_embed) level_start_index, src_mask, pos_embed)
return output return output
...@@ -296,6 +307,7 @@ class DeformableTransformerDecoderLayer(nn.Layer): ...@@ -296,6 +307,7 @@ class DeformableTransformerDecoderLayer(nn.Layer):
reference_points, reference_points,
memory, memory,
memory_spatial_shapes, memory_spatial_shapes,
memory_level_start_index,
memory_mask=None, memory_mask=None,
query_pos_embed=None): query_pos_embed=None):
# self attention # self attention
...@@ -307,7 +319,7 @@ class DeformableTransformerDecoderLayer(nn.Layer): ...@@ -307,7 +319,7 @@ class DeformableTransformerDecoderLayer(nn.Layer):
# cross attention # cross attention
tgt2 = self.cross_attn( tgt2 = self.cross_attn(
self.with_pos_embed(tgt, query_pos_embed), reference_points, memory, self.with_pos_embed(tgt, query_pos_embed), reference_points, memory,
memory_spatial_shapes, memory_mask) memory_spatial_shapes, memory_level_start_index, memory_mask)
tgt = tgt + self.dropout2(tgt2) tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
...@@ -329,13 +341,15 @@ class DeformableTransformerDecoder(nn.Layer): ...@@ -329,13 +341,15 @@ class DeformableTransformerDecoder(nn.Layer):
reference_points, reference_points,
memory, memory,
memory_spatial_shapes, memory_spatial_shapes,
memory_level_start_index,
memory_mask=None, memory_mask=None,
query_pos_embed=None): query_pos_embed=None):
output = tgt output = tgt
intermediate = [] intermediate = []
for lid, layer in enumerate(self.layers): for lid, layer in enumerate(self.layers):
output = layer(output, reference_points, memory, output = layer(output, reference_points, memory,
memory_spatial_shapes, memory_mask, query_pos_embed) memory_spatial_shapes, memory_level_start_index,
memory_mask, query_pos_embed)
if self.return_intermediate: if self.return_intermediate:
intermediate.append(output) intermediate.append(output)
...@@ -447,14 +461,7 @@ class DeformableTransformer(nn.Layer): ...@@ -447,14 +461,7 @@ class DeformableTransformer(nn.Layer):
def from_config(cls, cfg, input_shape): def from_config(cls, cfg, input_shape):
return {'backbone_num_channels': [i.channels for i in input_shape], } return {'backbone_num_channels': [i.channels for i in input_shape], }
def _get_valid_ratio(self, mask): def forward(self, src_feats, src_mask=None, *args, **kwargs):
_, H, W = mask.shape
valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H
valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
valid_ratio = paddle.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(self, src_feats, src_mask=None):
srcs = [] srcs = []
for i in range(len(src_feats)): for i in range(len(src_feats)):
srcs.append(self.input_proj[i](src_feats[i])) srcs.append(self.input_proj[i](src_feats[i]))
...@@ -471,33 +478,38 @@ class DeformableTransformer(nn.Layer): ...@@ -471,33 +478,38 @@ class DeformableTransformer(nn.Layer):
spatial_shapes = [] spatial_shapes = []
valid_ratios = [] valid_ratios = []
for level, src in enumerate(srcs): for level, src in enumerate(srcs):
bs, c, h, w = src.shape bs, _, h, w = paddle.shape(src)
spatial_shapes.append([h, w]) spatial_shapes.append(paddle.concat([h, w]))
src = src.flatten(2).transpose([0, 2, 1]) src = src.flatten(2).transpose([0, 2, 1])
src_flatten.append(src) src_flatten.append(src)
if src_mask is not None: if src_mask is not None:
mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0] mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0]
else: else:
mask = paddle.ones([bs, h, w]) mask = paddle.ones([bs, h, w])
valid_ratios.append(self._get_valid_ratio(mask)) valid_ratios.append(get_valid_ratio(mask))
pos_embed = self.position_embedding(mask).flatten(2).transpose( pos_embed = self.position_embedding(mask).flatten(1, 2)
[0, 2, 1]) lvl_pos_embed = pos_embed + self.level_embed.weight[level]
lvl_pos_embed = pos_embed + self.level_embed.weight[level].reshape(
[1, 1, -1])
lvl_pos_embed_flatten.append(lvl_pos_embed) lvl_pos_embed_flatten.append(lvl_pos_embed)
mask = mask.flatten(1) mask = mask.flatten(1)
mask_flatten.append(mask) mask_flatten.append(mask)
src_flatten = paddle.concat(src_flatten, 1) src_flatten = paddle.concat(src_flatten, 1)
mask_flatten = paddle.concat(mask_flatten, 1) mask_flatten = None if src_mask is None else paddle.concat(mask_flatten,
1)
lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1) lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1)
# [l, 2] # [l, 2]
spatial_shapes = paddle.to_tensor(spatial_shapes, dtype='int64') spatial_shapes = paddle.to_tensor(
paddle.stack(spatial_shapes).astype('int64'))
# [l], 每一个level的起始index
level_start_index = paddle.concat([
paddle.zeros(
[1], dtype='int64'), spatial_shapes.prod(1).cumsum(0)[:-1]
])
# [b, l, 2] # [b, l, 2]
valid_ratios = paddle.stack(valid_ratios, 1) valid_ratios = paddle.stack(valid_ratios, 1)
# encoder # encoder
memory = self.encoder(src_flatten, spatial_shapes, mask_flatten, memory = self.encoder(src_flatten, spatial_shapes, level_start_index,
lvl_pos_embed_flatten, valid_ratios) mask_flatten, lvl_pos_embed_flatten, valid_ratios)
# prepare input for decoder # prepare input for decoder
bs, _, c = memory.shape bs, _, c = memory.shape
...@@ -509,6 +521,6 @@ class DeformableTransformer(nn.Layer): ...@@ -509,6 +521,6 @@ class DeformableTransformer(nn.Layer):
# decoder # decoder
hs = self.decoder(tgt, reference_points_input, memory, spatial_shapes, hs = self.decoder(tgt, reference_points_input, memory, spatial_shapes,
mask_flatten, query_embed) level_start_index, mask_flatten, query_embed)
return (hs, memory, reference_points) return (hs, memory, reference_points)
...@@ -295,7 +295,7 @@ class DETRTransformer(nn.Layer): ...@@ -295,7 +295,7 @@ class DETRTransformer(nn.Layer):
def _convert_attention_mask(self, mask): def _convert_attention_mask(self, mask):
return (mask - 1.0) * 1e9 return (mask - 1.0) * 1e9
def forward(self, src, src_mask=None): def forward(self, src, src_mask=None, *args, **kwargs):
r""" r"""
Applies a Transformer model on the inputs. Applies a Transformer model on the inputs.
...@@ -325,8 +325,7 @@ class DETRTransformer(nn.Layer): ...@@ -325,8 +325,7 @@ class DETRTransformer(nn.Layer):
src_mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0] src_mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0]
else: else:
src_mask = paddle.ones([bs, h, w]) src_mask = paddle.ones([bs, h, w])
pos_embed = self.position_embedding(src_mask).flatten(2).transpose( pos_embed = self.position_embedding(src_mask).flatten(1, 2)
[0, 2, 1])
if self.training: if self.training:
src_mask = self._convert_attention_mask(src_mask) src_mask = self._convert_attention_mask(src_mask)
......
# Multi-scale deformable attention自定义OP编译
该自定义OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html)
## 1. 环境依赖
- Paddle >= 2.3.2
- gcc 8.2
## 2. 安装
请在当前路径下进行编译安装
```
cd PaddleDetection/ppdet/modeling/transformers/ext_op/
python setup_ms_deformable_attn_op.py install
```
编译完成后即可使用,以下为`ms_deformable_attn`的使用示例
```
# 引入自定义op
from deformable_detr_ops import ms_deformable_attn
# 构造fake input tensor
bs, n_heads, c = 2, 8, 8
query_length, n_levels, n_points = 2, 2, 2
spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
level_start_index = paddle.concat((paddle.to_tensor(
[0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
value_length = sum([(H * W).item() for H, W in spatial_shapes])
def get_test_tensors(channels):
value = paddle.rand(
[bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
sampling_locations = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points, 2],
dtype=paddle.float32)
attention_weights = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points],
dtype=paddle.float32) + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
-2, keepdim=True)
return [value, sampling_locations, attention_weights]
value, sampling_locations, attention_weights = get_test_tensors(c)
output = ms_deformable_attn(value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights)
```
## 3. 单元测试
可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示:
```
python test_ms_deformable_attn_op.py
```
运行成功后,打印如下:
```
*True check_forward_equal_with_paddle_float: max_abs_err 6.98e-10 max_rel_err 2.03e-07
*tensor1 True check_gradient_numerical(D=30)
*tensor2 True check_gradient_numerical(D=30)
*tensor3 True check_gradient_numerical(D=30)
*tensor1 True check_gradient_numerical(D=32)
*tensor2 True check_gradient_numerical(D=32)
*tensor3 True check_gradient_numerical(D=32)
*tensor1 True check_gradient_numerical(D=64)
*tensor2 True check_gradient_numerical(D=64)
*tensor3 True check_gradient_numerical(D=64)
*tensor1 True check_gradient_numerical(D=71)
*tensor2 True check_gradient_numerical(D=71)
*tensor3 True check_gradient_numerical(D=71)
*tensor1 True check_gradient_numerical(D=128)
*tensor2 True check_gradient_numerical(D=128)
*tensor3 True check_gradient_numerical(D=128)
*tensor1 True check_gradient_numerical(D=1024)
*tensor2 True check_gradient_numerical(D=1024)
*tensor3 True check_gradient_numerical(D=1024)
*tensor1 True check_gradient_numerical(D=1025)
*tensor2 True check_gradient_numerical(D=1025)
*tensor3 True check_gradient_numerical(D=1025)
*tensor1 True check_gradient_numerical(D=2048)
*tensor2 True check_gradient_numerical(D=2048)
*tensor3 True check_gradient_numerical(D=2048)
*tensor1 True check_gradient_numerical(D=3096)
*tensor2 True check_gradient_numerical(D=3096)
*tensor3 True check_gradient_numerical(D=3096)
```
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/extension.h"
#include <vector>
// declare GPU implementation
std::vector<paddle::Tensor>
MSDeformableAttnCUDAForward(const paddle::Tensor &value,
const paddle::Tensor &value_spatial_shapes,
const paddle::Tensor &value_level_start_index,
const paddle::Tensor &sampling_locations,
const paddle::Tensor &attention_weights);
std::vector<paddle::Tensor> MSDeformableAttnCUDABackward(
const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes,
const paddle::Tensor &value_level_start_index,
const paddle::Tensor &sampling_locations,
const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out);
//// CPU not implemented
std::vector<std::vector<int64_t>>
MSDeformableAttnInferShape(std::vector<int64_t> value_shape,
std::vector<int64_t> value_spatial_shapes_shape,
std::vector<int64_t> value_level_start_index_shape,
std::vector<int64_t> sampling_locations_shape,
std::vector<int64_t> attention_weights_shape) {
return {{value_shape[0], sampling_locations_shape[1],
value_shape[2] * value_shape[3]}};
}
std::vector<paddle::DataType>
MSDeformableAttnInferDtype(paddle::DataType value_dtype,
paddle::DataType value_spatial_shapes_dtype,
paddle::DataType value_level_start_index_dtype,
paddle::DataType sampling_locations_dtype,
paddle::DataType attention_weights_dtype) {
return {value_dtype};
}
PD_BUILD_OP(ms_deformable_attn)
.Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations",
"AttentionWeights"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(MSDeformableAttnCUDAForward))
.SetInferShapeFn(PD_INFER_SHAPE(MSDeformableAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MSDeformableAttnInferDtype));
PD_BUILD_GRAD_OP(ms_deformable_attn)
.Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations",
"AttentionWeights", paddle::Grad("Out")})
.Outputs({paddle::Grad("Value"), paddle::Grad("SpatialShapes"),
paddle::Grad("LevelIndex"), paddle::Grad("SamplingLocations"),
paddle::Grad("AttentionWeights")})
.SetKernelFn(PD_KERNEL(MSDeformableAttnCUDABackward));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/extension.h"
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
// forward bilinear
template <typename data_t>
__device__ data_t deformable_attn_bilinear_forward(
const data_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const data_t &h, const data_t &w,
const int &m, const int &c) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const data_t lh = h - h_low;
const data_t lw = w - w_low;
const data_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
data_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
data_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
data_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
data_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
// forward kernel
template <typename data_t>
__global__ void deformable_attn_cuda_kernel_forward(
const int n, const data_t *data_value, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const data_t *data_sampling_loc,
const data_t *data_attn_weight, const int batch_size,
const int value_length, const int num_heads, const int channels,
const int num_levels, const int query_length, const int num_points,
data_t *output_data_ptr) {
CUDA_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % query_length;
_temp /= query_length;
const int b_col = _temp;
data_t *data_ptr = output_data_ptr + index;
int data_weight_ptr = sampling_index * num_levels * num_points;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
data_t col = 0;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const data_t *data_value_ptr = data_value + (data_value_ptr_init_offset +
level_start_id * qid_stride);
for (int p_col = 0; p_col < num_points; ++p_col) {
const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const data_t weight = data_attn_weight[data_weight_ptr];
const data_t h_im = loc_h * spatial_h - 0.5;
const data_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
col += deformable_attn_bilinear_forward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels,
h_im, w_im, m_col, c_col) *
weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_ptr = col;
}
}
#define CHECK_INPUT_GPU(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
// forward
std::vector<paddle::Tensor>
MSDeformableAttnCUDAForward(const paddle::Tensor &value,
const paddle::Tensor &value_spatial_shapes,
const paddle::Tensor &value_level_start_index,
const paddle::Tensor &sampling_locations,
const paddle::Tensor &attention_weights) {
CHECK_INPUT_GPU(value);
CHECK_INPUT_GPU(value_spatial_shapes);
CHECK_INPUT_GPU(value_level_start_index);
CHECK_INPUT_GPU(sampling_locations);
CHECK_INPUT_GPU(attention_weights);
const int batch_size = value.shape()[0];
const int value_length = value.shape()[1];
const int num_heads = value.shape()[2];
const int channels = value.shape()[3];
const int num_levels = value_spatial_shapes.shape()[0];
const int query_length = sampling_locations.shape()[1];
const int num_points = sampling_locations.shape()[4];
auto output = paddle::full({batch_size, query_length, num_heads * channels},
0, value.dtype(), paddle::GPUPlace());
const int num_kernels = batch_size * query_length * num_heads * channels;
deformable_attn_cuda_kernel_forward<float>
<<<GET_BLOCKS(num_kernels, CUDA_NUM_THREADS), CUDA_NUM_THREADS, 0,
value.stream()>>>(num_kernels, value.data<float>(),
value_spatial_shapes.data<int64_t>(),
value_level_start_index.data<int64_t>(),
sampling_locations.data<float>(),
attention_weights.data<float>(), batch_size,
value_length, num_heads, channels, num_levels,
query_length, num_points, output.data<float>());
return {output};
}
// backward bilinear
template <typename data_t>
__device__ void deformable_attn_bilinear_backward(
const data_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const data_t &h, const data_t &w,
const int &m, const int &c, const data_t &top_grad,
const data_t &attn_weight, data_t *&grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const data_t lh = h - h_low;
const data_t lw = w - w_low;
const data_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const data_t top_grad_value = top_grad * attn_weight;
data_t grad_h_weight = 0, grad_w_weight = 0;
data_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value + ptr1, w1 * top_grad_value);
}
data_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value + ptr2, w2 * top_grad_value);
}
data_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value + ptr3, w3 * top_grad_value);
}
data_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value + ptr4, w4 * top_grad_value);
}
const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template <typename data_t>
__device__ void deformable_attn_bilinear_backward_gm(
const data_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const data_t &h, const data_t &w,
const int &m, const int &c, const data_t &top_grad,
const data_t &attn_weight, data_t *&grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const data_t lh = h - h_low;
const data_t lw = w - w_low;
const data_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const data_t top_grad_value = top_grad * attn_weight;
data_t grad_h_weight = 0, grad_w_weight = 0;
data_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value + ptr1, w1 * top_grad_value);
}
data_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value + ptr2, w2 * top_grad_value);
}
data_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value + ptr3, w3 * top_grad_value);
}
data_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value + ptr4, w4 * top_grad_value);
}
const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
// backward kernels
// channels > 1024
template <typename data_t>
__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks(
const int n, const data_t *grad_col, const data_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const data_t *data_sampling_loc, const data_t *data_attn_weight,
const int batch_size, const int value_length, const int num_heads,
const int channels, const int num_levels, const int query_length,
const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
data_t *cache_grad_sampling_loc = (data_t *)_s;
data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % query_length;
_temp /= query_length;
const int b_col = _temp;
const data_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_points;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const data_t *data_value_ptr = data_value + value_ptr_offset;
data_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_points; ++p_col) {
const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const data_t weight = data_attn_weight[data_weight_ptr];
const data_t h_im = loc_h * spatial_h - 0.5;
const data_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
deformable_attn_bilinear_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_attn_weight[tid] +=
cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] +=
cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0) {
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename data_t>
__global__ void deformable_attn_cuda_kernel_backward_gm(
const int n, const data_t *grad_col, const data_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const data_t *data_sampling_loc, const data_t *data_attn_weight,
const int batch_size, const int value_length, const int num_heads,
const int channels, const int num_levels, const int query_length,
const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % query_length;
_temp /= query_length;
const int b_col = _temp;
const data_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_points;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const data_t *data_value_ptr = data_value + value_ptr_offset;
data_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_points; ++p_col) {
const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const data_t weight = data_attn_weight[data_weight_ptr];
const data_t h_im = loc_h * spatial_h - 0.5;
const data_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
deformable_attn_bilinear_backward_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
// channels <= 1024
template <typename data_t, unsigned int blockSize>
__global__ void
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1(
const int n, const data_t *grad_col, const data_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const data_t *data_sampling_loc, const data_t *data_attn_weight,
const int batch_size, const int value_length, const int num_heads,
const int channels, const int num_levels, const int query_length,
const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
__shared__ data_t cache_grad_sampling_loc[blockSize * 2];
__shared__ data_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % query_length;
_temp /= query_length;
const int b_col = _temp;
const data_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_points;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const data_t *data_value_ptr = data_value + value_ptr_offset;
data_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_points; ++p_col) {
const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const data_t weight = data_attn_weight[data_weight_ptr];
const data_t h_im = loc_h * spatial_h - 0.5;
const data_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
deformable_attn_bilinear_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
data_t _grad_w = cache_grad_sampling_loc[0],
_grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockSize; ++tid) {
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename data_t, unsigned int blockSize>
__global__ void
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2(
const int n, const data_t *grad_col, const data_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const data_t *data_sampling_loc, const data_t *data_attn_weight,
const int batch_size, const int value_length, const int num_heads,
const int channels, const int num_levels, const int query_length,
const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
__shared__ data_t cache_grad_sampling_loc[blockSize * 2];
__shared__ data_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % query_length;
_temp /= query_length;
const int b_col = _temp;
const data_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_points;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const data_t *data_value_ptr = data_value + value_ptr_offset;
data_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_points; ++p_col) {
const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const data_t weight = data_attn_weight[data_weight_ptr];
const data_t h_im = loc_h * spatial_h - 0.5;
const data_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
deformable_attn_bilinear_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename data_t>
__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v1(
const int n, const data_t *grad_col, const data_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const data_t *data_sampling_loc, const data_t *data_attn_weight,
const int batch_size, const int value_length, const int num_heads,
const int channels, const int num_levels, const int query_length,
const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
data_t *cache_grad_sampling_loc = (data_t *)_s;
data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % query_length;
_temp /= query_length;
const int b_col = _temp;
const data_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_points;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const data_t *data_value_ptr = data_value + value_ptr_offset;
data_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_points; ++p_col) {
const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const data_t weight = data_attn_weight[data_weight_ptr];
const data_t h_im = loc_h * spatial_h - 0.5;
const data_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
deformable_attn_bilinear_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
data_t _grad_w = cache_grad_sampling_loc[0],
_grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename data_t>
__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v2(
const int n, const data_t *grad_col, const data_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const data_t *data_sampling_loc, const data_t *data_attn_weight,
const int batch_size, const int value_length, const int num_heads,
const int channels, const int num_levels, const int query_length,
const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
data_t *cache_grad_sampling_loc = (data_t *)_s;
data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % query_length;
_temp /= query_length;
const int b_col = _temp;
const data_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_points;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const data_t *data_value_ptr = data_value + value_ptr_offset;
data_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_points; ++p_col) {
const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const data_t weight = data_attn_weight[data_weight_ptr];
const data_t h_im = loc_h * spatial_h - 0.5;
const data_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
deformable_attn_bilinear_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_attn_weight[tid] +=
cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] +=
cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
// backward branch
template <typename data_t>
void deformable_attn_cuda_backward(
cudaStream_t stream, const data_t *grad_out, const data_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const data_t *data_sampling_loc, const data_t *data_attn_weight,
const int batch_size, const int value_length, const int num_heads,
const int channels, const int num_levels, const int query_length,
const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
data_t *grad_attn_weight) {
const int num_threads =
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
const int num_kernels = batch_size * query_length * num_heads * channels;
const int num_actual_kernels =
batch_size * query_length * num_heads * channels;
if (channels > 1024) {
if ((channels & 1023) == 0) {
deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks<data_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(data_t), stream>>>(
num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, value_length, num_heads, channels, num_levels,
query_length, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
deformable_attn_cuda_kernel_backward_gm<data_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
}
} else {
switch (channels) {
case 1:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 2:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 4:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 8:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 16:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 32:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 64:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 128:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 256:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 512:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
case 1024:
deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points,
grad_value, grad_sampling_loc, grad_attn_weight);
break;
default:
if (channels < 64) {
deformable_attn_cuda_kernel_backward_shm_reduce_v1<data_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(data_t), stream>>>(
num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, value_length, num_heads, channels, num_levels,
query_length, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
deformable_attn_cuda_kernel_backward_shm_reduce_v2<data_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(data_t), stream>>>(
num_kernels, grad_out, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, value_length, num_heads, channels, num_levels,
query_length, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
}
}
}
// backward
std::vector<paddle::Tensor> MSDeformableAttnCUDABackward(
const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes,
const paddle::Tensor &value_level_start_index,
const paddle::Tensor &sampling_locations,
const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out) {
CHECK_INPUT_GPU(value);
CHECK_INPUT_GPU(value_spatial_shapes);
CHECK_INPUT_GPU(value_level_start_index);
CHECK_INPUT_GPU(sampling_locations);
CHECK_INPUT_GPU(attention_weights);
CHECK_INPUT_GPU(grad_out);
const int batch_size = value.shape()[0];
const int value_length = value.shape()[1];
const int num_heads = value.shape()[2];
const int channels = value.shape()[3];
const int num_levels = value_spatial_shapes.shape()[0];
const int query_length = sampling_locations.shape()[1];
const int num_points = sampling_locations.shape()[4];
auto grad_value =
paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
auto grad_spatial_shapes =
paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
auto grad_level_start_index =
paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
auto grad_sampling_locations =
paddle::full(sampling_locations.shape(), 0, sampling_locations.dtype(),
paddle::GPUPlace());
auto grad_attention_weights =
paddle::full(attention_weights.shape(), 0, attention_weights.dtype(),
paddle::GPUPlace());
deformable_attn_cuda_backward<float>(
value.stream(), grad_out.data<float>(), value.data<float>(),
value_spatial_shapes.data<int64_t>(),
value_level_start_index.data<int64_t>(), sampling_locations.data<float>(),
attention_weights.data<float>(), batch_size, value_length, num_heads,
channels, num_levels, query_length, num_points, grad_value.data<float>(),
grad_sampling_locations.data<float>(),
grad_attention_weights.data<float>());
return {grad_value, grad_spatial_shapes, grad_level_start_index,
grad_sampling_locations, grad_attention_weights};
}
from paddle.utils.cpp_extension import CUDAExtension, setup
if __name__ == "__main__":
setup(
name='deformable_detr_ops',
ext_modules=CUDAExtension(
sources=['ms_deformable_attn_op.cc', 'ms_deformable_attn_op.cu']))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import sys
import random
import numpy as np
import paddle
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 5)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from ppdet.modeling.transformers.utils import deformable_attention_core_func
ms_deform_attn_core_paddle = deformable_attention_core_func
try:
gpu_index = int(sys.argv[1])
except:
gpu_index = 0
print(f'Use gpu {gpu_index} to test...')
paddle.set_device(f'gpu:{gpu_index}')
try:
from deformable_detr_ops import ms_deformable_attn
except Exception as e:
print('import deformable_detr_ops error', e)
sys.exit(-1)
paddle.seed(1)
random.seed(1)
np.random.seed(1)
bs, n_heads, c = 2, 8, 8
query_length, n_levels, n_points = 2, 2, 2
spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
level_start_index = paddle.concat((paddle.to_tensor(
[0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
value_length = sum([(H * W).item() for H, W in spatial_shapes])
def get_test_tensors(channels):
value = paddle.rand(
[bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
sampling_locations = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points, 2],
dtype=paddle.float32)
attention_weights = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points],
dtype=paddle.float32) + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
-2, keepdim=True)
return [value, sampling_locations, attention_weights]
@paddle.no_grad()
def check_forward_equal_with_paddle_float():
value, sampling_locations, attention_weights = get_test_tensors(c)
output_paddle = ms_deform_attn_core_paddle(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights).detach().cpu()
output_cuda = ms_deformable_attn(value, spatial_shapes, level_start_index,
sampling_locations,
attention_weights).detach().cpu()
fwdok = paddle.allclose(
output_cuda, output_paddle, rtol=1e-2, atol=1e-3).item()
max_abs_err = (output_cuda - output_paddle).abs().max().item()
max_rel_err = (
(output_cuda - output_paddle).abs() / output_paddle.abs()).max().item()
print(
f'*{fwdok} check_forward_equal_with_paddle_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}'
)
def check_gradient_numerical(channels=4):
value_paddle, sampling_locations_paddle, attention_weights_paddle = get_test_tensors(
channels)
value_paddle.stop_gradient = False
sampling_locations_paddle.stop_gradient = False
attention_weights_paddle.stop_gradient = False
value_cuda = value_paddle.detach().clone()
sampling_locations_cuda = sampling_locations_paddle.detach().clone()
attention_weights_cuda = attention_weights_paddle.detach().clone()
value_cuda.stop_gradient = False
sampling_locations_cuda.stop_gradient = False
attention_weights_cuda.stop_gradient = False
output_paddle = ms_deform_attn_core_paddle(
value_paddle, spatial_shapes, level_start_index,
sampling_locations_paddle, attention_weights_paddle)
output_paddle.sum().backward()
output_cuda = ms_deformable_attn(value_cuda, spatial_shapes,
level_start_index, sampling_locations_cuda,
attention_weights_cuda)
output_cuda.sum().backward()
res = paddle.allclose(
value_paddle.grad, value_cuda.grad, rtol=1e-2, atol=1e-3).item()
print(f'*tensor1 {res} check_gradient_numerical(D={channels})')
res = paddle.allclose(
sampling_locations_paddle.grad,
sampling_locations_cuda.grad,
rtol=1e-2,
atol=1e-3).item()
print(f'*tensor2 {res} check_gradient_numerical(D={channels})')
res = paddle.allclose(
attention_weights_paddle.grad,
attention_weights_cuda.grad,
rtol=1e-2,
atol=1e-3).item()
print(f'*tensor3 {res} check_gradient_numerical(D={channels})')
if __name__ == '__main__':
check_forward_equal_with_paddle_float()
for channels in [30, 32, 64, 71, 128, 1024, 1025, 2048, 3096]:
check_gradient_numerical(channels)
...@@ -33,37 +33,34 @@ class PositionEmbedding(nn.Layer): ...@@ -33,37 +33,34 @@ class PositionEmbedding(nn.Layer):
num_pos_feats=128, num_pos_feats=128,
temperature=10000, temperature=10000,
normalize=True, normalize=True,
scale=None, scale=2 * math.pi,
embed_type='sine', embed_type='sine',
num_embeddings=50, num_embeddings=50,
offset=0.): offset=0.,
eps=1e-6):
super(PositionEmbedding, self).__init__() super(PositionEmbedding, self).__init__()
assert embed_type in ['sine', 'learned'] assert embed_type in ['sine', 'learned']
self.embed_type = embed_type self.embed_type = embed_type
self.offset = offset self.offset = offset
self.eps = 1e-6 self.eps = eps
if self.embed_type == 'sine': if self.embed_type == 'sine':
self.num_pos_feats = num_pos_feats self.num_pos_feats = num_pos_feats
self.temperature = temperature self.temperature = temperature
self.normalize = normalize self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale self.scale = scale
elif self.embed_type == 'learned': elif self.embed_type == 'learned':
self.row_embed = nn.Embedding(num_embeddings, num_pos_feats) self.row_embed = nn.Embedding(num_embeddings, num_pos_feats)
self.col_embed = nn.Embedding(num_embeddings, num_pos_feats) self.col_embed = nn.Embedding(num_embeddings, num_pos_feats)
else: else:
raise ValueError(f"not supported {self.embed_type}") raise ValueError(f"{self.embed_type} is not supported.")
def forward(self, mask): def forward(self, mask):
""" """
Args: Args:
mask (Tensor): [B, H, W] mask (Tensor): [B, H, W]
Returns: Returns:
pos (Tensor): [B, C, H, W] pos (Tensor): [B, H, W, C]
""" """
if self.embed_type == 'sine': if self.embed_type == 'sine':
y_embed = mask.cumsum(1) y_embed = mask.cumsum(1)
...@@ -86,20 +83,18 @@ class PositionEmbedding(nn.Layer): ...@@ -86,20 +83,18 @@ class PositionEmbedding(nn.Layer):
pos_y = paddle.stack( pos_y = paddle.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
axis=4).flatten(3) axis=4).flatten(3)
pos = paddle.concat((pos_y, pos_x), axis=3).transpose([0, 3, 1, 2]) return paddle.concat((pos_y, pos_x), axis=3)
return pos
elif self.embed_type == 'learned': elif self.embed_type == 'learned':
h, w = mask.shape[-2:] h, w = mask.shape[-2:]
i = paddle.arange(w) i = paddle.arange(w)
j = paddle.arange(h) j = paddle.arange(h)
x_emb = self.col_embed(i) x_emb = self.col_embed(i)
y_emb = self.row_embed(j) y_emb = self.row_embed(j)
pos = paddle.concat( return paddle.concat(
[ [
x_emb.unsqueeze(0).tile([h, 1, 1]), x_emb.unsqueeze(0).tile([h, 1, 1]),
y_emb.unsqueeze(1).tile([1, w, 1]), y_emb.unsqueeze(1).tile([1, w, 1]),
], ],
axis=-1).transpose([2, 0, 1]).unsqueeze(0) axis=-1).unsqueeze(0)
return pos
else: else:
raise ValueError(f"not supported {self.embed_type}") raise ValueError(f"not supported {self.embed_type}")
...@@ -38,15 +38,14 @@ def _get_clones(module, N): ...@@ -38,15 +38,14 @@ def _get_clones(module, N):
def bbox_cxcywh_to_xyxy(x): def bbox_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.split(4, axis=-1) cxcy, wh = paddle.split(x, 2, axis=-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1)
return paddle.concat(b, axis=-1)
def bbox_xyxy_to_cxcywh(x): def bbox_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.split(4, axis=-1) x1, y1, x2, y2 = x.split(4, axis=-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return paddle.concat(
return paddle.concat(b, axis=-1) [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1)
def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0): def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
...@@ -67,24 +66,27 @@ def inverse_sigmoid(x, eps=1e-6): ...@@ -67,24 +66,27 @@ def inverse_sigmoid(x, eps=1e-6):
def deformable_attention_core_func(value, value_spatial_shapes, def deformable_attention_core_func(value, value_spatial_shapes,
sampling_locations, attention_weights): value_level_start_index, sampling_locations,
attention_weights):
""" """
Args: Args:
value (Tensor): [bs, value_length, n_head, c] value (Tensor): [bs, value_length, n_head, c]
value_spatial_shapes (Tensor): [n_levels, 2] value_spatial_shapes (Tensor): [n_levels, 2]
value_level_start_index (Tensor): [n_levels]
sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2] sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points] attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
Returns: Returns:
output (Tensor): [bs, Length_{query}, C] output (Tensor): [bs, Length_{query}, C]
""" """
bs, Len_v, n_head, c = value.shape bs, _, n_head, c = value.shape
_, Len_q, n_head, n_levels, n_points, _ = sampling_locations.shape _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
value_list = value.split(value_spatial_shapes.prod(1).tolist(), axis=1) value_list = value.split(
value_spatial_shapes.prod(1).split(n_levels), axis=1)
sampling_grids = 2 * sampling_locations - 1 sampling_grids = 2 * sampling_locations - 1
sampling_value_list = [] sampling_value_list = []
for level, (h, w) in enumerate(value_spatial_shapes.tolist()): for level, (h, w) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[level].flatten(2).transpose( value_l_ = value_list[level].flatten(2).transpose(
[0, 2, 1]).reshape([bs * n_head, c, h, w]) [0, 2, 1]).reshape([bs * n_head, c, h, w])
...@@ -107,3 +109,11 @@ def deformable_attention_core_func(value, value_spatial_shapes, ...@@ -107,3 +109,11 @@ def deformable_attention_core_func(value, value_spatial_shapes,
attention_weights).sum(-1).reshape([bs, n_head * c, Len_q]) attention_weights).sum(-1).reshape([bs, n_head * c, Len_q])
return output.transpose([0, 2, 1]) return output.transpose([0, 2, 1])
def get_valid_ratio(mask):
_, H, W = paddle.shape(mask)
valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H
valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
# [b, 2]
return paddle.stack([valid_ratio_w, valid_ratio_h], -1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册