From 3d6a027c2f9e00bd7595f45a02e88f9de12a500c Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Tue, 20 Dec 2022 11:16:03 +0800 Subject: [PATCH] [dev] add ms_deformable_attn cuda op (#7521) --- ppdet/modeling/architectures/detr.py | 2 +- ppdet/modeling/post_process.py | 24 +- .../transformers/deformable_transformer.py | 76 +- .../modeling/transformers/detr_transformer.py | 5 +- ppdet/modeling/transformers/ext_op/README.md | 84 ++ .../ext_op/ms_deformable_attn_op.cc | 65 + .../ext_op/ms_deformable_attn_op.cu | 1073 +++++++++++++++++ .../ext_op/setup_ms_deformable_attn_op.py | 7 + .../ext_op/test_ms_deformable_attn_op.py | 140 +++ .../transformers/position_encoding.py | 23 +- ppdet/modeling/transformers/utils.py | 32 +- 11 files changed, 1459 insertions(+), 72 deletions(-) create mode 100644 ppdet/modeling/transformers/ext_op/README.md create mode 100644 ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cc create mode 100644 ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cu create mode 100644 ppdet/modeling/transformers/ext_op/setup_ms_deformable_attn_op.py create mode 100644 ppdet/modeling/transformers/ext_op/test_ms_deformable_attn_op.py diff --git a/ppdet/modeling/architectures/detr.py b/ppdet/modeling/architectures/detr.py index bb7a7c736..3fa32f5ca 100644 --- a/ppdet/modeling/architectures/detr.py +++ b/ppdet/modeling/architectures/detr.py @@ -84,7 +84,7 @@ class DETR(BaseArch): preds, self.inputs['im_shape'], self.inputs['scale_factor']) return bbox, bbox_num - def get_loss(self, ): + def get_loss(self): losses = self._forward() losses.update({ 'loss': diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 39a5ec0be..4f7d5f278 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -492,19 +492,21 @@ class DETRBBoxPostProcess(object): if scores.shape[1] > self.num_top_queries: scores, index = paddle.topk( scores, self.num_top_queries, axis=-1) - labels = paddle.stack( - [paddle.gather(l, i) for l, i in zip(labels, index)]) - bbox_pred = paddle.stack( - [paddle.gather(b, i) for b, i in zip(bbox_pred, index)]) + batch_ind = paddle.arange( + end=scores.shape[0]).unsqueeze(-1).tile( + [1, self.num_top_queries]) + index = paddle.stack([batch_ind, index], axis=-1) + labels = paddle.gather_nd(labels, index) + bbox_pred = paddle.gather_nd(bbox_pred, index) else: scores, index = paddle.topk( - scores.reshape([logits.shape[0], -1]), - self.num_top_queries, - axis=-1) - labels = index % logits.shape[2] - index = index // logits.shape[2] - bbox_pred = paddle.stack( - [paddle.gather(b, i) for b, i in zip(bbox_pred, index)]) + scores.flatten(1), self.num_top_queries, axis=-1) + labels = index % self.num_classes + index = index // self.num_classes + batch_ind = paddle.arange(end=scores.shape[0]).unsqueeze(-1).tile( + [1, self.num_top_queries]) + index = paddle.stack([batch_ind, index], axis=-1) + bbox_pred = paddle.gather_nd(bbox_pred, index) bbox_pred = paddle.concat( [ diff --git a/ppdet/modeling/transformers/deformable_transformer.py b/ppdet/modeling/transformers/deformable_transformer.py index db07e0327..3ed777a17 100644 --- a/ppdet/modeling/transformers/deformable_transformer.py +++ b/ppdet/modeling/transformers/deformable_transformer.py @@ -28,7 +28,7 @@ from paddle import ParamAttr from ppdet.core.workspace import register from ..layers import MultiHeadAttention from .position_encoding import PositionEmbedding -from .utils import _get_clones, deformable_attention_core_func +from .utils import _get_clones, get_valid_ratio from ..initializer import linear_init_, constant_, xavier_uniform_, normal_ __all__ = ['DeformableTransformer'] @@ -63,6 +63,13 @@ class MSDeformableAttention(nn.Layer): self.attention_weights = nn.Linear(embed_dim, self.total_points) self.value_proj = nn.Linear(embed_dim, embed_dim) self.output_proj = nn.Linear(embed_dim, embed_dim) + try: + # use cuda op + from deformable_detr_ops import ms_deformable_attn + except: + # use paddle func + from .utils import deformable_attention_core_func as ms_deformable_attn + self.ms_deformable_attn_core = ms_deformable_attn self._reset_parameters() @@ -95,6 +102,7 @@ class MSDeformableAttention(nn.Layer): reference_points, value, value_spatial_shapes, + value_level_start_index, value_mask=None): """ Args: @@ -103,6 +111,7 @@ class MSDeformableAttention(nn.Layer): bottom-right (1, 1), including padding area value (Tensor): [bs, value_length, C] value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...] value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements Returns: @@ -131,8 +140,9 @@ class MSDeformableAttention(nn.Layer): bs, Len_q, 1, self.num_levels, 1, 2 ]) + sampling_offsets / offset_normalizer - output = deformable_attention_core_func( - value, value_spatial_shapes, sampling_locations, attention_weights) + output = self.ms_deformable_attn_core( + value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights) output = self.output_proj(output) return output @@ -185,12 +195,13 @@ class DeformableTransformerEncoderLayer(nn.Layer): src, reference_points, spatial_shapes, + level_start_index, src_mask=None, pos_embed=None): # self attention src2 = self.self_attn( self.with_pos_embed(src, pos_embed), reference_points, src, - spatial_shapes, src_mask) + spatial_shapes, level_start_index, src_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn @@ -206,13 +217,12 @@ class DeformableTransformerEncoder(nn.Layer): self.num_layers = num_layers @staticmethod - def get_reference_points(spatial_shapes, valid_ratios): + def get_reference_points(spatial_shapes, valid_ratios, offset=0.5): valid_ratios = valid_ratios.unsqueeze(1) reference_points = [] - for i, (H, W) in enumerate(spatial_shapes.tolist()): + for i, (H, W) in enumerate(spatial_shapes): ref_y, ref_x = paddle.meshgrid( - paddle.linspace(0.5, H - 0.5, H), - paddle.linspace(0.5, W - 0.5, W)) + paddle.arange(end=H) + offset, paddle.arange(end=W) + offset) ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] * H) ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] * @@ -225,6 +235,7 @@ class DeformableTransformerEncoder(nn.Layer): def forward(self, src, spatial_shapes, + level_start_index, src_mask=None, pos_embed=None, valid_ratios=None): @@ -235,8 +246,8 @@ class DeformableTransformerEncoder(nn.Layer): reference_points = self.get_reference_points(spatial_shapes, valid_ratios) for layer in self.layers: - output = layer(output, reference_points, spatial_shapes, src_mask, - pos_embed) + output = layer(output, reference_points, spatial_shapes, + level_start_index, src_mask, pos_embed) return output @@ -296,6 +307,7 @@ class DeformableTransformerDecoderLayer(nn.Layer): reference_points, memory, memory_spatial_shapes, + memory_level_start_index, memory_mask=None, query_pos_embed=None): # self attention @@ -307,7 +319,7 @@ class DeformableTransformerDecoderLayer(nn.Layer): # cross attention tgt2 = self.cross_attn( self.with_pos_embed(tgt, query_pos_embed), reference_points, memory, - memory_spatial_shapes, memory_mask) + memory_spatial_shapes, memory_level_start_index, memory_mask) tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) @@ -329,13 +341,15 @@ class DeformableTransformerDecoder(nn.Layer): reference_points, memory, memory_spatial_shapes, + memory_level_start_index, memory_mask=None, query_pos_embed=None): output = tgt intermediate = [] for lid, layer in enumerate(self.layers): output = layer(output, reference_points, memory, - memory_spatial_shapes, memory_mask, query_pos_embed) + memory_spatial_shapes, memory_level_start_index, + memory_mask, query_pos_embed) if self.return_intermediate: intermediate.append(output) @@ -447,14 +461,7 @@ class DeformableTransformer(nn.Layer): def from_config(cls, cfg, input_shape): return {'backbone_num_channels': [i.channels for i in input_shape], } - def _get_valid_ratio(self, mask): - _, H, W = mask.shape - valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H - valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W - valid_ratio = paddle.stack([valid_ratio_w, valid_ratio_h], -1) - return valid_ratio - - def forward(self, src_feats, src_mask=None): + def forward(self, src_feats, src_mask=None, *args, **kwargs): srcs = [] for i in range(len(src_feats)): srcs.append(self.input_proj[i](src_feats[i])) @@ -471,33 +478,38 @@ class DeformableTransformer(nn.Layer): spatial_shapes = [] valid_ratios = [] for level, src in enumerate(srcs): - bs, c, h, w = src.shape - spatial_shapes.append([h, w]) + bs, _, h, w = paddle.shape(src) + spatial_shapes.append(paddle.concat([h, w])) src = src.flatten(2).transpose([0, 2, 1]) src_flatten.append(src) if src_mask is not None: mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0] else: mask = paddle.ones([bs, h, w]) - valid_ratios.append(self._get_valid_ratio(mask)) - pos_embed = self.position_embedding(mask).flatten(2).transpose( - [0, 2, 1]) - lvl_pos_embed = pos_embed + self.level_embed.weight[level].reshape( - [1, 1, -1]) + valid_ratios.append(get_valid_ratio(mask)) + pos_embed = self.position_embedding(mask).flatten(1, 2) + lvl_pos_embed = pos_embed + self.level_embed.weight[level] lvl_pos_embed_flatten.append(lvl_pos_embed) mask = mask.flatten(1) mask_flatten.append(mask) src_flatten = paddle.concat(src_flatten, 1) - mask_flatten = paddle.concat(mask_flatten, 1) + mask_flatten = None if src_mask is None else paddle.concat(mask_flatten, + 1) lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1) # [l, 2] - spatial_shapes = paddle.to_tensor(spatial_shapes, dtype='int64') + spatial_shapes = paddle.to_tensor( + paddle.stack(spatial_shapes).astype('int64')) + # [l], 每一个level的起始index + level_start_index = paddle.concat([ + paddle.zeros( + [1], dtype='int64'), spatial_shapes.prod(1).cumsum(0)[:-1] + ]) # [b, l, 2] valid_ratios = paddle.stack(valid_ratios, 1) # encoder - memory = self.encoder(src_flatten, spatial_shapes, mask_flatten, - lvl_pos_embed_flatten, valid_ratios) + memory = self.encoder(src_flatten, spatial_shapes, level_start_index, + mask_flatten, lvl_pos_embed_flatten, valid_ratios) # prepare input for decoder bs, _, c = memory.shape @@ -509,6 +521,6 @@ class DeformableTransformer(nn.Layer): # decoder hs = self.decoder(tgt, reference_points_input, memory, spatial_shapes, - mask_flatten, query_embed) + level_start_index, mask_flatten, query_embed) return (hs, memory, reference_points) diff --git a/ppdet/modeling/transformers/detr_transformer.py b/ppdet/modeling/transformers/detr_transformer.py index a6f6a9363..ccbdb0a3d 100644 --- a/ppdet/modeling/transformers/detr_transformer.py +++ b/ppdet/modeling/transformers/detr_transformer.py @@ -295,7 +295,7 @@ class DETRTransformer(nn.Layer): def _convert_attention_mask(self, mask): return (mask - 1.0) * 1e9 - def forward(self, src, src_mask=None): + def forward(self, src, src_mask=None, *args, **kwargs): r""" Applies a Transformer model on the inputs. @@ -325,8 +325,7 @@ class DETRTransformer(nn.Layer): src_mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0] else: src_mask = paddle.ones([bs, h, w]) - pos_embed = self.position_embedding(src_mask).flatten(2).transpose( - [0, 2, 1]) + pos_embed = self.position_embedding(src_mask).flatten(1, 2) if self.training: src_mask = self._convert_attention_mask(src_mask) diff --git a/ppdet/modeling/transformers/ext_op/README.md b/ppdet/modeling/transformers/ext_op/README.md new file mode 100644 index 000000000..88f359913 --- /dev/null +++ b/ppdet/modeling/transformers/ext_op/README.md @@ -0,0 +1,84 @@ +# Multi-scale deformable attention自定义OP编译 +该自定义OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html) 。 + +## 1. 环境依赖 +- Paddle >= 2.3.2 +- gcc 8.2 + +## 2. 安装 +请在当前路径下进行编译安装 +``` +cd PaddleDetection/ppdet/modeling/transformers/ext_op/ +python setup_ms_deformable_attn_op.py install +``` + +编译完成后即可使用,以下为`ms_deformable_attn`的使用示例 +``` +# 引入自定义op +from deformable_detr_ops import ms_deformable_attn +# 构造fake input tensor +bs, n_heads, c = 2, 8, 8 +query_length, n_levels, n_points = 2, 2, 2 +spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64) +level_start_index = paddle.concat((paddle.to_tensor( + [0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1])) +value_length = sum([(H * W).item() for H, W in spatial_shapes]) + +def get_test_tensors(channels): + value = paddle.rand( + [bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01 + sampling_locations = paddle.rand( + [bs, query_length, n_heads, n_levels, n_points, 2], + dtype=paddle.float32) + attention_weights = paddle.rand( + [bs, query_length, n_heads, n_levels, n_points], + dtype=paddle.float32) + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum( + -2, keepdim=True) + return [value, sampling_locations, attention_weights] + +value, sampling_locations, attention_weights = get_test_tensors(c) + +output = ms_deformable_attn(value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights) +``` + +## 3. 单元测试 +可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示: +``` +python test_ms_deformable_attn_op.py +``` +运行成功后,打印如下: +``` +*True check_forward_equal_with_paddle_float: max_abs_err 6.98e-10 max_rel_err 2.03e-07 +*tensor1 True check_gradient_numerical(D=30) +*tensor2 True check_gradient_numerical(D=30) +*tensor3 True check_gradient_numerical(D=30) +*tensor1 True check_gradient_numerical(D=32) +*tensor2 True check_gradient_numerical(D=32) +*tensor3 True check_gradient_numerical(D=32) +*tensor1 True check_gradient_numerical(D=64) +*tensor2 True check_gradient_numerical(D=64) +*tensor3 True check_gradient_numerical(D=64) +*tensor1 True check_gradient_numerical(D=71) +*tensor2 True check_gradient_numerical(D=71) +*tensor3 True check_gradient_numerical(D=71) +*tensor1 True check_gradient_numerical(D=128) +*tensor2 True check_gradient_numerical(D=128) +*tensor3 True check_gradient_numerical(D=128) +*tensor1 True check_gradient_numerical(D=1024) +*tensor2 True check_gradient_numerical(D=1024) +*tensor3 True check_gradient_numerical(D=1024) +*tensor1 True check_gradient_numerical(D=1025) +*tensor2 True check_gradient_numerical(D=1025) +*tensor3 True check_gradient_numerical(D=1025) +*tensor1 True check_gradient_numerical(D=2048) +*tensor2 True check_gradient_numerical(D=2048) +*tensor3 True check_gradient_numerical(D=2048) +*tensor1 True check_gradient_numerical(D=3096) +*tensor2 True check_gradient_numerical(D=3096) +*tensor3 True check_gradient_numerical(D=3096) +``` diff --git a/ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cc b/ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cc new file mode 100644 index 000000000..d1758adbc --- /dev/null +++ b/ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/extension.h" + +#include + +// declare GPU implementation +std::vector +MSDeformableAttnCUDAForward(const paddle::Tensor &value, + const paddle::Tensor &value_spatial_shapes, + const paddle::Tensor &value_level_start_index, + const paddle::Tensor &sampling_locations, + const paddle::Tensor &attention_weights); + +std::vector MSDeformableAttnCUDABackward( + const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes, + const paddle::Tensor &value_level_start_index, + const paddle::Tensor &sampling_locations, + const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out); + +//// CPU not implemented + +std::vector> +MSDeformableAttnInferShape(std::vector value_shape, + std::vector value_spatial_shapes_shape, + std::vector value_level_start_index_shape, + std::vector sampling_locations_shape, + std::vector attention_weights_shape) { + return {{value_shape[0], sampling_locations_shape[1], + value_shape[2] * value_shape[3]}}; +} + +std::vector +MSDeformableAttnInferDtype(paddle::DataType value_dtype, + paddle::DataType value_spatial_shapes_dtype, + paddle::DataType value_level_start_index_dtype, + paddle::DataType sampling_locations_dtype, + paddle::DataType attention_weights_dtype) { + return {value_dtype}; +} + +PD_BUILD_OP(ms_deformable_attn) + .Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations", + "AttentionWeights"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(MSDeformableAttnCUDAForward)) + .SetInferShapeFn(PD_INFER_SHAPE(MSDeformableAttnInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MSDeformableAttnInferDtype)); + +PD_BUILD_GRAD_OP(ms_deformable_attn) + .Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations", + "AttentionWeights", paddle::Grad("Out")}) + .Outputs({paddle::Grad("Value"), paddle::Grad("SpatialShapes"), + paddle::Grad("LevelIndex"), paddle::Grad("SamplingLocations"), + paddle::Grad("AttentionWeights")}) + .SetKernelFn(PD_KERNEL(MSDeformableAttnCUDABackward)); diff --git a/ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cu b/ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cu new file mode 100644 index 000000000..d5a8d1618 --- /dev/null +++ b/ppdet/modeling/transformers/ext_op/ms_deformable_attn_op.cu @@ -0,0 +1,1073 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/extension.h" + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) { + return (N + num_threads - 1) / num_threads; +} + +// forward bilinear +template +__device__ data_t deformable_attn_bilinear_forward( + const data_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const data_t &h, const data_t &w, + const int &m, const int &c) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const data_t lh = h - h_low; + const data_t lw = w - w_low; + const data_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + data_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + data_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + data_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + data_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +// forward kernel +template +__global__ void deformable_attn_cuda_kernel_forward( + const int n, const data_t *data_value, const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, const data_t *data_sampling_loc, + const data_t *data_attn_weight, const int batch_size, + const int value_length, const int num_heads, const int channels, + const int num_levels, const int query_length, const int num_points, + data_t *output_data_ptr) { + CUDA_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % query_length; + _temp /= query_length; + const int b_col = _temp; + + data_t *data_ptr = output_data_ptr + index; + int data_weight_ptr = sampling_index * num_levels * num_points; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * value_length * qid_stride; + data_t col = 0; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const data_t *data_value_ptr = data_value + (data_value_ptr_init_offset + + level_start_id * qid_stride); + for (int p_col = 0; p_col < num_points; ++p_col) { + const data_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const data_t weight = data_attn_weight[data_weight_ptr]; + + const data_t h_im = loc_h * spatial_h - 0.5; + const data_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + col += deformable_attn_bilinear_forward( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, + h_im, w_im, m_col, c_col) * + weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_ptr = col; + } +} + +#define CHECK_INPUT_GPU(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") +// forward +std::vector +MSDeformableAttnCUDAForward(const paddle::Tensor &value, + const paddle::Tensor &value_spatial_shapes, + const paddle::Tensor &value_level_start_index, + const paddle::Tensor &sampling_locations, + const paddle::Tensor &attention_weights) { + + CHECK_INPUT_GPU(value); + CHECK_INPUT_GPU(value_spatial_shapes); + CHECK_INPUT_GPU(value_level_start_index); + CHECK_INPUT_GPU(sampling_locations); + CHECK_INPUT_GPU(attention_weights); + + const int batch_size = value.shape()[0]; + const int value_length = value.shape()[1]; + const int num_heads = value.shape()[2]; + const int channels = value.shape()[3]; + + const int num_levels = value_spatial_shapes.shape()[0]; + const int query_length = sampling_locations.shape()[1]; + const int num_points = sampling_locations.shape()[4]; + + auto output = paddle::full({batch_size, query_length, num_heads * channels}, + 0, value.dtype(), paddle::GPUPlace()); + + const int num_kernels = batch_size * query_length * num_heads * channels; + deformable_attn_cuda_kernel_forward + <<>>(num_kernels, value.data(), + value_spatial_shapes.data(), + value_level_start_index.data(), + sampling_locations.data(), + attention_weights.data(), batch_size, + value_length, num_heads, channels, num_levels, + query_length, num_points, output.data()); + return {output}; +} + +// backward bilinear +template +__device__ void deformable_attn_bilinear_backward( + const data_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const data_t &h, const data_t &w, + const int &m, const int &c, const data_t &top_grad, + const data_t &attn_weight, data_t *&grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const data_t lh = h - h_low; + const data_t lw = w - w_low; + const data_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const data_t top_grad_value = top_grad * attn_weight; + data_t grad_h_weight = 0, grad_w_weight = 0; + + data_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value + ptr1, w1 * top_grad_value); + } + data_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value + ptr2, w2 * top_grad_value); + } + data_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value + ptr3, w3 * top_grad_value); + } + data_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value + ptr4, w4 * top_grad_value); + } + + const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + +template +__device__ void deformable_attn_bilinear_backward_gm( + const data_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const data_t &h, const data_t &w, + const int &m, const int &c, const data_t &top_grad, + const data_t &attn_weight, data_t *&grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const data_t lh = h - h_low; + const data_t lw = w - w_low; + const data_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const data_t top_grad_value = top_grad * attn_weight; + data_t grad_h_weight = 0, grad_w_weight = 0; + + data_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value + ptr1, w1 * top_grad_value); + } + data_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value + ptr2, w2 * top_grad_value); + } + data_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value + ptr3, w3 * top_grad_value); + } + data_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value + ptr4, w4 * top_grad_value); + } + + const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + +// backward kernels +// channels > 1024 +template +__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks( + const int n, const data_t *grad_col, const data_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const data_t *data_sampling_loc, const data_t *data_attn_weight, + const int batch_size, const int value_length, const int num_heads, + const int channels, const int num_levels, const int query_length, + const int num_points, data_t *grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + data_t *cache_grad_sampling_loc = (data_t *)_s; + data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % query_length; + _temp /= query_length; + const int b_col = _temp; + + const data_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_points; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * value_length * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const data_t *data_value_ptr = data_value + value_ptr_offset; + data_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_points; ++p_col) { + const data_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const data_t weight = data_attn_weight[data_weight_ptr]; + + const data_t h_im = loc_h * spatial_h - 0.5; + const data_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + deformable_attn_bilinear_backward( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_attn_weight[tid] += + cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += + cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void deformable_attn_cuda_kernel_backward_gm( + const int n, const data_t *grad_col, const data_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const data_t *data_sampling_loc, const data_t *data_attn_weight, + const int batch_size, const int value_length, const int num_heads, + const int channels, const int num_levels, const int query_length, + const int num_points, data_t *grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % query_length; + _temp /= query_length; + const int b_col = _temp; + + const data_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_points; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * value_length * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const data_t *data_value_ptr = data_value + value_ptr_offset; + data_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_points; ++p_col) { + const data_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const data_t weight = data_attn_weight[data_weight_ptr]; + + const data_t h_im = loc_h * spatial_h - 0.5; + const data_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + deformable_attn_bilinear_backward_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +// channels <= 1024 +template +__global__ void +deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1( + const int n, const data_t *grad_col, const data_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const data_t *data_sampling_loc, const data_t *data_attn_weight, + const int batch_size, const int value_length, const int num_heads, + const int channels, const int num_levels, const int query_length, + const int num_points, data_t *grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + __shared__ data_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ data_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % query_length; + _temp /= query_length; + const int b_col = _temp; + + const data_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_points; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * value_length * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const data_t *data_value_ptr = data_value + value_ptr_offset; + data_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_points; ++p_col) { + const data_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const data_t weight = data_attn_weight[data_weight_ptr]; + + const data_t h_im = loc_h * spatial_h - 0.5; + const data_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + deformable_attn_bilinear_backward( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + data_t _grad_w = cache_grad_sampling_loc[0], + _grad_h = cache_grad_sampling_loc[1], + _grad_a = cache_grad_attn_weight[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockSize; ++tid) { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void +deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2( + const int n, const data_t *grad_col, const data_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const data_t *data_sampling_loc, const data_t *data_attn_weight, + const int batch_size, const int value_length, const int num_heads, + const int channels, const int num_levels, const int query_length, + const int num_points, data_t *grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + __shared__ data_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ data_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % query_length; + _temp /= query_length; + const int b_col = _temp; + + const data_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_points; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * value_length * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const data_t *data_value_ptr = data_value + value_ptr_offset; + data_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_points; ++p_col) { + const data_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const data_t weight = data_attn_weight[data_weight_ptr]; + + const data_t h_im = loc_h * spatial_h - 0.5; + const data_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + deformable_attn_bilinear_backward( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v1( + const int n, const data_t *grad_col, const data_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const data_t *data_sampling_loc, const data_t *data_attn_weight, + const int batch_size, const int value_length, const int num_heads, + const int channels, const int num_levels, const int query_length, + const int num_points, data_t *grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + data_t *cache_grad_sampling_loc = (data_t *)_s; + data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % query_length; + _temp /= query_length; + const int b_col = _temp; + + const data_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_points; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * value_length * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const data_t *data_value_ptr = data_value + value_ptr_offset; + data_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_points; ++p_col) { + const data_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const data_t weight = data_attn_weight[data_weight_ptr]; + + const data_t h_im = loc_h * spatial_h - 0.5; + const data_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + deformable_attn_bilinear_backward( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + data_t _grad_w = cache_grad_sampling_loc[0], + _grad_h = cache_grad_sampling_loc[1], + _grad_a = cache_grad_attn_weight[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v2( + const int n, const data_t *grad_col, const data_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const data_t *data_sampling_loc, const data_t *data_attn_weight, + const int batch_size, const int value_length, const int num_heads, + const int channels, const int num_levels, const int query_length, + const int num_points, data_t *grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + CUDA_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + data_t *cache_grad_sampling_loc = (data_t *)_s; + data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % query_length; + _temp /= query_length; + const int b_col = _temp; + + const data_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_points; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * value_length * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const data_t *data_value_ptr = data_value + value_ptr_offset; + data_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_points; ++p_col) { + const data_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const data_t weight = data_attn_weight[data_weight_ptr]; + + const data_t h_im = loc_h * spatial_h - 0.5; + const data_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + deformable_attn_bilinear_backward( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_attn_weight[tid] += + cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += + cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +// backward branch +template +void deformable_attn_cuda_backward( + cudaStream_t stream, const data_t *grad_out, const data_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const data_t *data_sampling_loc, const data_t *data_attn_weight, + const int batch_size, const int value_length, const int num_heads, + const int channels, const int num_levels, const int query_length, + const int num_points, data_t *grad_value, data_t *grad_sampling_loc, + data_t *grad_attn_weight) { + const int num_threads = + (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels; + const int num_kernels = batch_size * query_length * num_heads * channels; + const int num_actual_kernels = + batch_size * query_length * num_heads * channels; + if (channels > 1024) { + if ((channels & 1023) == 0) { + deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, value_length, num_heads, channels, num_levels, + query_length, num_points, grad_value, grad_sampling_loc, + grad_attn_weight); + } else { + deformable_attn_cuda_kernel_backward_gm + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + } + } else { + switch (channels) { + case 1: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 2: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 4: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 8: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 16: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 32: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 64: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 128: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 256: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 512: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + case 1024: + deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, + grad_value, grad_sampling_loc, grad_attn_weight); + break; + default: + if (channels < 64) { + deformable_attn_cuda_kernel_backward_shm_reduce_v1 + <<>>( + num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, value_length, num_heads, channels, num_levels, + query_length, num_points, grad_value, grad_sampling_loc, + grad_attn_weight); + } else { + deformable_attn_cuda_kernel_backward_shm_reduce_v2 + <<>>( + num_kernels, grad_out, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, value_length, num_heads, channels, num_levels, + query_length, num_points, grad_value, grad_sampling_loc, + grad_attn_weight); + } + } + } +} + +// backward +std::vector MSDeformableAttnCUDABackward( + const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes, + const paddle::Tensor &value_level_start_index, + const paddle::Tensor &sampling_locations, + const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out) { + + CHECK_INPUT_GPU(value); + CHECK_INPUT_GPU(value_spatial_shapes); + CHECK_INPUT_GPU(value_level_start_index); + CHECK_INPUT_GPU(sampling_locations); + CHECK_INPUT_GPU(attention_weights); + CHECK_INPUT_GPU(grad_out); + + const int batch_size = value.shape()[0]; + const int value_length = value.shape()[1]; + const int num_heads = value.shape()[2]; + const int channels = value.shape()[3]; + + const int num_levels = value_spatial_shapes.shape()[0]; + const int query_length = sampling_locations.shape()[1]; + const int num_points = sampling_locations.shape()[4]; + + auto grad_value = + paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace()); + auto grad_spatial_shapes = + paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace()); + auto grad_level_start_index = + paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace()); + auto grad_sampling_locations = + paddle::full(sampling_locations.shape(), 0, sampling_locations.dtype(), + paddle::GPUPlace()); + auto grad_attention_weights = + paddle::full(attention_weights.shape(), 0, attention_weights.dtype(), + paddle::GPUPlace()); + + deformable_attn_cuda_backward( + value.stream(), grad_out.data(), value.data(), + value_spatial_shapes.data(), + value_level_start_index.data(), sampling_locations.data(), + attention_weights.data(), batch_size, value_length, num_heads, + channels, num_levels, query_length, num_points, grad_value.data(), + grad_sampling_locations.data(), + grad_attention_weights.data()); + + return {grad_value, grad_spatial_shapes, grad_level_start_index, + grad_sampling_locations, grad_attention_weights}; +} diff --git a/ppdet/modeling/transformers/ext_op/setup_ms_deformable_attn_op.py b/ppdet/modeling/transformers/ext_op/setup_ms_deformable_attn_op.py new file mode 100644 index 000000000..7c3c38667 --- /dev/null +++ b/ppdet/modeling/transformers/ext_op/setup_ms_deformable_attn_op.py @@ -0,0 +1,7 @@ +from paddle.utils.cpp_extension import CUDAExtension, setup + +if __name__ == "__main__": + setup( + name='deformable_detr_ops', + ext_modules=CUDAExtension( + sources=['ms_deformable_attn_op.cc', 'ms_deformable_attn_op.cu'])) diff --git a/ppdet/modeling/transformers/ext_op/test_ms_deformable_attn_op.py b/ppdet/modeling/transformers/ext_op/test_ms_deformable_attn_op.py new file mode 100644 index 000000000..67476fac4 --- /dev/null +++ b/ppdet/modeling/transformers/ext_op/test_ms_deformable_attn_op.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os +import sys +import random +import numpy as np +import paddle +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 5))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.modeling.transformers.utils import deformable_attention_core_func +ms_deform_attn_core_paddle = deformable_attention_core_func + +try: + gpu_index = int(sys.argv[1]) +except: + gpu_index = 0 +print(f'Use gpu {gpu_index} to test...') +paddle.set_device(f'gpu:{gpu_index}') + +try: + from deformable_detr_ops import ms_deformable_attn +except Exception as e: + print('import deformable_detr_ops error', e) + sys.exit(-1) + +paddle.seed(1) +random.seed(1) +np.random.seed(1) + +bs, n_heads, c = 2, 8, 8 +query_length, n_levels, n_points = 2, 2, 2 +spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64) +level_start_index = paddle.concat((paddle.to_tensor( + [0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1])) +value_length = sum([(H * W).item() for H, W in spatial_shapes]) + + +def get_test_tensors(channels): + value = paddle.rand( + [bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01 + sampling_locations = paddle.rand( + [bs, query_length, n_heads, n_levels, n_points, 2], + dtype=paddle.float32) + attention_weights = paddle.rand( + [bs, query_length, n_heads, n_levels, n_points], + dtype=paddle.float32) + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum( + -2, keepdim=True) + + return [value, sampling_locations, attention_weights] + + +@paddle.no_grad() +def check_forward_equal_with_paddle_float(): + value, sampling_locations, attention_weights = get_test_tensors(c) + + output_paddle = ms_deform_attn_core_paddle( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights).detach().cpu() + output_cuda = ms_deformable_attn(value, spatial_shapes, level_start_index, + sampling_locations, + attention_weights).detach().cpu() + fwdok = paddle.allclose( + output_cuda, output_paddle, rtol=1e-2, atol=1e-3).item() + max_abs_err = (output_cuda - output_paddle).abs().max().item() + max_rel_err = ( + (output_cuda - output_paddle).abs() / output_paddle.abs()).max().item() + + print( + f'*{fwdok} check_forward_equal_with_paddle_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}' + ) + + +def check_gradient_numerical(channels=4): + value_paddle, sampling_locations_paddle, attention_weights_paddle = get_test_tensors( + channels) + value_paddle.stop_gradient = False + sampling_locations_paddle.stop_gradient = False + attention_weights_paddle.stop_gradient = False + + value_cuda = value_paddle.detach().clone() + sampling_locations_cuda = sampling_locations_paddle.detach().clone() + attention_weights_cuda = attention_weights_paddle.detach().clone() + value_cuda.stop_gradient = False + sampling_locations_cuda.stop_gradient = False + attention_weights_cuda.stop_gradient = False + + output_paddle = ms_deform_attn_core_paddle( + value_paddle, spatial_shapes, level_start_index, + sampling_locations_paddle, attention_weights_paddle) + output_paddle.sum().backward() + + output_cuda = ms_deformable_attn(value_cuda, spatial_shapes, + level_start_index, sampling_locations_cuda, + attention_weights_cuda) + output_cuda.sum().backward() + + res = paddle.allclose( + value_paddle.grad, value_cuda.grad, rtol=1e-2, atol=1e-3).item() + print(f'*tensor1 {res} check_gradient_numerical(D={channels})') + + res = paddle.allclose( + sampling_locations_paddle.grad, + sampling_locations_cuda.grad, + rtol=1e-2, + atol=1e-3).item() + print(f'*tensor2 {res} check_gradient_numerical(D={channels})') + + res = paddle.allclose( + attention_weights_paddle.grad, + attention_weights_cuda.grad, + rtol=1e-2, + atol=1e-3).item() + print(f'*tensor3 {res} check_gradient_numerical(D={channels})') + + +if __name__ == '__main__': + check_forward_equal_with_paddle_float() + + for channels in [30, 32, 64, 71, 128, 1024, 1025, 2048, 3096]: + check_gradient_numerical(channels) diff --git a/ppdet/modeling/transformers/position_encoding.py b/ppdet/modeling/transformers/position_encoding.py index dffd9ce9b..a2c326097 100644 --- a/ppdet/modeling/transformers/position_encoding.py +++ b/ppdet/modeling/transformers/position_encoding.py @@ -33,37 +33,34 @@ class PositionEmbedding(nn.Layer): num_pos_feats=128, temperature=10000, normalize=True, - scale=None, + scale=2 * math.pi, embed_type='sine', num_embeddings=50, - offset=0.): + offset=0., + eps=1e-6): super(PositionEmbedding, self).__init__() assert embed_type in ['sine', 'learned'] self.embed_type = embed_type self.offset = offset - self.eps = 1e-6 + self.eps = eps if self.embed_type == 'sine': self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi self.scale = scale elif self.embed_type == 'learned': self.row_embed = nn.Embedding(num_embeddings, num_pos_feats) self.col_embed = nn.Embedding(num_embeddings, num_pos_feats) else: - raise ValueError(f"not supported {self.embed_type}") + raise ValueError(f"{self.embed_type} is not supported.") def forward(self, mask): """ Args: mask (Tensor): [B, H, W] Returns: - pos (Tensor): [B, C, H, W] + pos (Tensor): [B, H, W, C] """ if self.embed_type == 'sine': y_embed = mask.cumsum(1) @@ -86,20 +83,18 @@ class PositionEmbedding(nn.Layer): pos_y = paddle.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4).flatten(3) - pos = paddle.concat((pos_y, pos_x), axis=3).transpose([0, 3, 1, 2]) - return pos + return paddle.concat((pos_y, pos_x), axis=3) elif self.embed_type == 'learned': h, w = mask.shape[-2:] i = paddle.arange(w) j = paddle.arange(h) x_emb = self.col_embed(i) y_emb = self.row_embed(j) - pos = paddle.concat( + return paddle.concat( [ x_emb.unsqueeze(0).tile([h, 1, 1]), y_emb.unsqueeze(1).tile([1, w, 1]), ], - axis=-1).transpose([2, 0, 1]).unsqueeze(0) - return pos + axis=-1).unsqueeze(0) else: raise ValueError(f"not supported {self.embed_type}") diff --git a/ppdet/modeling/transformers/utils.py b/ppdet/modeling/transformers/utils.py index 7f3afa87e..026cf2e55 100644 --- a/ppdet/modeling/transformers/utils.py +++ b/ppdet/modeling/transformers/utils.py @@ -38,15 +38,14 @@ def _get_clones(module, N): def bbox_cxcywh_to_xyxy(x): - x_c, y_c, w, h = x.split(4, axis=-1) - b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] - return paddle.concat(b, axis=-1) + cxcy, wh = paddle.split(x, 2, axis=-1) + return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1) def bbox_xyxy_to_cxcywh(x): - x0, y0, x1, y1 = x.split(4, axis=-1) - b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] - return paddle.concat(b, axis=-1) + x1, y1, x2, y2 = x.split(4, axis=-1) + return paddle.concat( + [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1) def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0): @@ -67,24 +66,27 @@ def inverse_sigmoid(x, eps=1e-6): def deformable_attention_core_func(value, value_spatial_shapes, - sampling_locations, attention_weights): + value_level_start_index, sampling_locations, + attention_weights): """ Args: value (Tensor): [bs, value_length, n_head, c] value_spatial_shapes (Tensor): [n_levels, 2] + value_level_start_index (Tensor): [n_levels] sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2] attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points] Returns: output (Tensor): [bs, Length_{query}, C] """ - bs, Len_v, n_head, c = value.shape - _, Len_q, n_head, n_levels, n_points, _ = sampling_locations.shape + bs, _, n_head, c = value.shape + _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape - value_list = value.split(value_spatial_shapes.prod(1).tolist(), axis=1) + value_list = value.split( + value_spatial_shapes.prod(1).split(n_levels), axis=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] - for level, (h, w) in enumerate(value_spatial_shapes.tolist()): + for level, (h, w) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[level].flatten(2).transpose( [0, 2, 1]).reshape([bs * n_head, c, h, w]) @@ -107,3 +109,11 @@ def deformable_attention_core_func(value, value_spatial_shapes, attention_weights).sum(-1).reshape([bs, n_head * c, Len_q]) return output.transpose([0, 2, 1]) + + +def get_valid_ratio(mask): + _, H, W = paddle.shape(mask) + valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H + valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W + # [b, 2] + return paddle.stack([valid_ratio_w, valid_ratio_h], -1) -- GitLab