未验证 提交 3d6a027c 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] add ms_deformable_attn cuda op (#7521)

上级 e1a8f660
......@@ -84,7 +84,7 @@ class DETR(BaseArch):
preds, self.inputs['im_shape'], self.inputs['scale_factor'])
return bbox, bbox_num
def get_loss(self, ):
def get_loss(self):
losses = self._forward()
losses.update({
'loss':
......
......@@ -492,19 +492,21 @@ class DETRBBoxPostProcess(object):
if scores.shape[1] > self.num_top_queries:
scores, index = paddle.topk(
scores, self.num_top_queries, axis=-1)
labels = paddle.stack(
[paddle.gather(l, i) for l, i in zip(labels, index)])
bbox_pred = paddle.stack(
[paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
batch_ind = paddle.arange(
end=scores.shape[0]).unsqueeze(-1).tile(
[1, self.num_top_queries])
index = paddle.stack([batch_ind, index], axis=-1)
labels = paddle.gather_nd(labels, index)
bbox_pred = paddle.gather_nd(bbox_pred, index)
else:
scores, index = paddle.topk(
scores.reshape([logits.shape[0], -1]),
self.num_top_queries,
axis=-1)
labels = index % logits.shape[2]
index = index // logits.shape[2]
bbox_pred = paddle.stack(
[paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
scores.flatten(1), self.num_top_queries, axis=-1)
labels = index % self.num_classes
index = index // self.num_classes
batch_ind = paddle.arange(end=scores.shape[0]).unsqueeze(-1).tile(
[1, self.num_top_queries])
index = paddle.stack([batch_ind, index], axis=-1)
bbox_pred = paddle.gather_nd(bbox_pred, index)
bbox_pred = paddle.concat(
[
......
......@@ -28,7 +28,7 @@ from paddle import ParamAttr
from ppdet.core.workspace import register
from ..layers import MultiHeadAttention
from .position_encoding import PositionEmbedding
from .utils import _get_clones, deformable_attention_core_func
from .utils import _get_clones, get_valid_ratio
from ..initializer import linear_init_, constant_, xavier_uniform_, normal_
__all__ = ['DeformableTransformer']
......@@ -63,6 +63,13 @@ class MSDeformableAttention(nn.Layer):
self.attention_weights = nn.Linear(embed_dim, self.total_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
try:
# use cuda op
from deformable_detr_ops import ms_deformable_attn
except:
# use paddle func
from .utils import deformable_attention_core_func as ms_deformable_attn
self.ms_deformable_attn_core = ms_deformable_attn
self._reset_parameters()
......@@ -95,6 +102,7 @@ class MSDeformableAttention(nn.Layer):
reference_points,
value,
value_spatial_shapes,
value_level_start_index,
value_mask=None):
"""
Args:
......@@ -103,6 +111,7 @@ class MSDeformableAttention(nn.Layer):
bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C]
value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns:
......@@ -131,8 +140,9 @@ class MSDeformableAttention(nn.Layer):
bs, Len_q, 1, self.num_levels, 1, 2
]) + sampling_offsets / offset_normalizer
output = deformable_attention_core_func(
value, value_spatial_shapes, sampling_locations, attention_weights)
output = self.ms_deformable_attn_core(
value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights)
output = self.output_proj(output)
return output
......@@ -185,12 +195,13 @@ class DeformableTransformerEncoderLayer(nn.Layer):
src,
reference_points,
spatial_shapes,
level_start_index,
src_mask=None,
pos_embed=None):
# self attention
src2 = self.self_attn(
self.with_pos_embed(src, pos_embed), reference_points, src,
spatial_shapes, src_mask)
spatial_shapes, level_start_index, src_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
......@@ -206,13 +217,12 @@ class DeformableTransformerEncoder(nn.Layer):
self.num_layers = num_layers
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios):
def get_reference_points(spatial_shapes, valid_ratios, offset=0.5):
valid_ratios = valid_ratios.unsqueeze(1)
reference_points = []
for i, (H, W) in enumerate(spatial_shapes.tolist()):
for i, (H, W) in enumerate(spatial_shapes):
ref_y, ref_x = paddle.meshgrid(
paddle.linspace(0.5, H - 0.5, H),
paddle.linspace(0.5, W - 0.5, W))
paddle.arange(end=H) + offset, paddle.arange(end=W) + offset)
ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] *
H)
ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] *
......@@ -225,6 +235,7 @@ class DeformableTransformerEncoder(nn.Layer):
def forward(self,
src,
spatial_shapes,
level_start_index,
src_mask=None,
pos_embed=None,
valid_ratios=None):
......@@ -235,8 +246,8 @@ class DeformableTransformerEncoder(nn.Layer):
reference_points = self.get_reference_points(spatial_shapes,
valid_ratios)
for layer in self.layers:
output = layer(output, reference_points, spatial_shapes, src_mask,
pos_embed)
output = layer(output, reference_points, spatial_shapes,
level_start_index, src_mask, pos_embed)
return output
......@@ -296,6 +307,7 @@ class DeformableTransformerDecoderLayer(nn.Layer):
reference_points,
memory,
memory_spatial_shapes,
memory_level_start_index,
memory_mask=None,
query_pos_embed=None):
# self attention
......@@ -307,7 +319,7 @@ class DeformableTransformerDecoderLayer(nn.Layer):
# cross attention
tgt2 = self.cross_attn(
self.with_pos_embed(tgt, query_pos_embed), reference_points, memory,
memory_spatial_shapes, memory_mask)
memory_spatial_shapes, memory_level_start_index, memory_mask)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
......@@ -329,13 +341,15 @@ class DeformableTransformerDecoder(nn.Layer):
reference_points,
memory,
memory_spatial_shapes,
memory_level_start_index,
memory_mask=None,
query_pos_embed=None):
output = tgt
intermediate = []
for lid, layer in enumerate(self.layers):
output = layer(output, reference_points, memory,
memory_spatial_shapes, memory_mask, query_pos_embed)
memory_spatial_shapes, memory_level_start_index,
memory_mask, query_pos_embed)
if self.return_intermediate:
intermediate.append(output)
......@@ -447,14 +461,7 @@ class DeformableTransformer(nn.Layer):
def from_config(cls, cfg, input_shape):
return {'backbone_num_channels': [i.channels for i in input_shape], }
def _get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H
valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
valid_ratio = paddle.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(self, src_feats, src_mask=None):
def forward(self, src_feats, src_mask=None, *args, **kwargs):
srcs = []
for i in range(len(src_feats)):
srcs.append(self.input_proj[i](src_feats[i]))
......@@ -471,33 +478,38 @@ class DeformableTransformer(nn.Layer):
spatial_shapes = []
valid_ratios = []
for level, src in enumerate(srcs):
bs, c, h, w = src.shape
spatial_shapes.append([h, w])
bs, _, h, w = paddle.shape(src)
spatial_shapes.append(paddle.concat([h, w]))
src = src.flatten(2).transpose([0, 2, 1])
src_flatten.append(src)
if src_mask is not None:
mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0]
else:
mask = paddle.ones([bs, h, w])
valid_ratios.append(self._get_valid_ratio(mask))
pos_embed = self.position_embedding(mask).flatten(2).transpose(
[0, 2, 1])
lvl_pos_embed = pos_embed + self.level_embed.weight[level].reshape(
[1, 1, -1])
valid_ratios.append(get_valid_ratio(mask))
pos_embed = self.position_embedding(mask).flatten(1, 2)
lvl_pos_embed = pos_embed + self.level_embed.weight[level]
lvl_pos_embed_flatten.append(lvl_pos_embed)
mask = mask.flatten(1)
mask_flatten.append(mask)
src_flatten = paddle.concat(src_flatten, 1)
mask_flatten = paddle.concat(mask_flatten, 1)
mask_flatten = None if src_mask is None else paddle.concat(mask_flatten,
1)
lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1)
# [l, 2]
spatial_shapes = paddle.to_tensor(spatial_shapes, dtype='int64')
spatial_shapes = paddle.to_tensor(
paddle.stack(spatial_shapes).astype('int64'))
# [l], 每一个level的起始index
level_start_index = paddle.concat([
paddle.zeros(
[1], dtype='int64'), spatial_shapes.prod(1).cumsum(0)[:-1]
])
# [b, l, 2]
valid_ratios = paddle.stack(valid_ratios, 1)
# encoder
memory = self.encoder(src_flatten, spatial_shapes, mask_flatten,
lvl_pos_embed_flatten, valid_ratios)
memory = self.encoder(src_flatten, spatial_shapes, level_start_index,
mask_flatten, lvl_pos_embed_flatten, valid_ratios)
# prepare input for decoder
bs, _, c = memory.shape
......@@ -509,6 +521,6 @@ class DeformableTransformer(nn.Layer):
# decoder
hs = self.decoder(tgt, reference_points_input, memory, spatial_shapes,
mask_flatten, query_embed)
level_start_index, mask_flatten, query_embed)
return (hs, memory, reference_points)
......@@ -295,7 +295,7 @@ class DETRTransformer(nn.Layer):
def _convert_attention_mask(self, mask):
return (mask - 1.0) * 1e9
def forward(self, src, src_mask=None):
def forward(self, src, src_mask=None, *args, **kwargs):
r"""
Applies a Transformer model on the inputs.
......@@ -325,8 +325,7 @@ class DETRTransformer(nn.Layer):
src_mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0]
else:
src_mask = paddle.ones([bs, h, w])
pos_embed = self.position_embedding(src_mask).flatten(2).transpose(
[0, 2, 1])
pos_embed = self.position_embedding(src_mask).flatten(1, 2)
if self.training:
src_mask = self._convert_attention_mask(src_mask)
......
# Multi-scale deformable attention自定义OP编译
该自定义OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html)
## 1. 环境依赖
- Paddle >= 2.3.2
- gcc 8.2
## 2. 安装
请在当前路径下进行编译安装
```
cd PaddleDetection/ppdet/modeling/transformers/ext_op/
python setup_ms_deformable_attn_op.py install
```
编译完成后即可使用,以下为`ms_deformable_attn`的使用示例
```
# 引入自定义op
from deformable_detr_ops import ms_deformable_attn
# 构造fake input tensor
bs, n_heads, c = 2, 8, 8
query_length, n_levels, n_points = 2, 2, 2
spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
level_start_index = paddle.concat((paddle.to_tensor(
[0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
value_length = sum([(H * W).item() for H, W in spatial_shapes])
def get_test_tensors(channels):
value = paddle.rand(
[bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
sampling_locations = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points, 2],
dtype=paddle.float32)
attention_weights = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points],
dtype=paddle.float32) + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
-2, keepdim=True)
return [value, sampling_locations, attention_weights]
value, sampling_locations, attention_weights = get_test_tensors(c)
output = ms_deformable_attn(value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights)
```
## 3. 单元测试
可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示:
```
python test_ms_deformable_attn_op.py
```
运行成功后,打印如下:
```
*True check_forward_equal_with_paddle_float: max_abs_err 6.98e-10 max_rel_err 2.03e-07
*tensor1 True check_gradient_numerical(D=30)
*tensor2 True check_gradient_numerical(D=30)
*tensor3 True check_gradient_numerical(D=30)
*tensor1 True check_gradient_numerical(D=32)
*tensor2 True check_gradient_numerical(D=32)
*tensor3 True check_gradient_numerical(D=32)
*tensor1 True check_gradient_numerical(D=64)
*tensor2 True check_gradient_numerical(D=64)
*tensor3 True check_gradient_numerical(D=64)
*tensor1 True check_gradient_numerical(D=71)
*tensor2 True check_gradient_numerical(D=71)
*tensor3 True check_gradient_numerical(D=71)
*tensor1 True check_gradient_numerical(D=128)
*tensor2 True check_gradient_numerical(D=128)
*tensor3 True check_gradient_numerical(D=128)
*tensor1 True check_gradient_numerical(D=1024)
*tensor2 True check_gradient_numerical(D=1024)
*tensor3 True check_gradient_numerical(D=1024)
*tensor1 True check_gradient_numerical(D=1025)
*tensor2 True check_gradient_numerical(D=1025)
*tensor3 True check_gradient_numerical(D=1025)
*tensor1 True check_gradient_numerical(D=2048)
*tensor2 True check_gradient_numerical(D=2048)
*tensor3 True check_gradient_numerical(D=2048)
*tensor1 True check_gradient_numerical(D=3096)
*tensor2 True check_gradient_numerical(D=3096)
*tensor3 True check_gradient_numerical(D=3096)
```
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/extension.h"
#include <vector>
// declare GPU implementation
std::vector<paddle::Tensor>
MSDeformableAttnCUDAForward(const paddle::Tensor &value,
const paddle::Tensor &value_spatial_shapes,
const paddle::Tensor &value_level_start_index,
const paddle::Tensor &sampling_locations,
const paddle::Tensor &attention_weights);
std::vector<paddle::Tensor> MSDeformableAttnCUDABackward(
const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes,
const paddle::Tensor &value_level_start_index,
const paddle::Tensor &sampling_locations,
const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out);
//// CPU not implemented
std::vector<std::vector<int64_t>>
MSDeformableAttnInferShape(std::vector<int64_t> value_shape,
std::vector<int64_t> value_spatial_shapes_shape,
std::vector<int64_t> value_level_start_index_shape,
std::vector<int64_t> sampling_locations_shape,
std::vector<int64_t> attention_weights_shape) {
return {{value_shape[0], sampling_locations_shape[1],
value_shape[2] * value_shape[3]}};
}
std::vector<paddle::DataType>
MSDeformableAttnInferDtype(paddle::DataType value_dtype,
paddle::DataType value_spatial_shapes_dtype,
paddle::DataType value_level_start_index_dtype,
paddle::DataType sampling_locations_dtype,
paddle::DataType attention_weights_dtype) {
return {value_dtype};
}
PD_BUILD_OP(ms_deformable_attn)
.Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations",
"AttentionWeights"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(MSDeformableAttnCUDAForward))
.SetInferShapeFn(PD_INFER_SHAPE(MSDeformableAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MSDeformableAttnInferDtype));
PD_BUILD_GRAD_OP(ms_deformable_attn)
.Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations",
"AttentionWeights", paddle::Grad("Out")})
.Outputs({paddle::Grad("Value"), paddle::Grad("SpatialShapes"),
paddle::Grad("LevelIndex"), paddle::Grad("SamplingLocations"),
paddle::Grad("AttentionWeights")})
.SetKernelFn(PD_KERNEL(MSDeformableAttnCUDABackward));
from paddle.utils.cpp_extension import CUDAExtension, setup
if __name__ == "__main__":
setup(
name='deformable_detr_ops',
ext_modules=CUDAExtension(
sources=['ms_deformable_attn_op.cc', 'ms_deformable_attn_op.cu']))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import sys
import random
import numpy as np
import paddle
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 5)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from ppdet.modeling.transformers.utils import deformable_attention_core_func
ms_deform_attn_core_paddle = deformable_attention_core_func
try:
gpu_index = int(sys.argv[1])
except:
gpu_index = 0
print(f'Use gpu {gpu_index} to test...')
paddle.set_device(f'gpu:{gpu_index}')
try:
from deformable_detr_ops import ms_deformable_attn
except Exception as e:
print('import deformable_detr_ops error', e)
sys.exit(-1)
paddle.seed(1)
random.seed(1)
np.random.seed(1)
bs, n_heads, c = 2, 8, 8
query_length, n_levels, n_points = 2, 2, 2
spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
level_start_index = paddle.concat((paddle.to_tensor(
[0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
value_length = sum([(H * W).item() for H, W in spatial_shapes])
def get_test_tensors(channels):
value = paddle.rand(
[bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
sampling_locations = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points, 2],
dtype=paddle.float32)
attention_weights = paddle.rand(
[bs, query_length, n_heads, n_levels, n_points],
dtype=paddle.float32) + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
-2, keepdim=True)
return [value, sampling_locations, attention_weights]
@paddle.no_grad()
def check_forward_equal_with_paddle_float():
value, sampling_locations, attention_weights = get_test_tensors(c)
output_paddle = ms_deform_attn_core_paddle(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights).detach().cpu()
output_cuda = ms_deformable_attn(value, spatial_shapes, level_start_index,
sampling_locations,
attention_weights).detach().cpu()
fwdok = paddle.allclose(
output_cuda, output_paddle, rtol=1e-2, atol=1e-3).item()
max_abs_err = (output_cuda - output_paddle).abs().max().item()
max_rel_err = (
(output_cuda - output_paddle).abs() / output_paddle.abs()).max().item()
print(
f'*{fwdok} check_forward_equal_with_paddle_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}'
)
def check_gradient_numerical(channels=4):
value_paddle, sampling_locations_paddle, attention_weights_paddle = get_test_tensors(
channels)
value_paddle.stop_gradient = False
sampling_locations_paddle.stop_gradient = False
attention_weights_paddle.stop_gradient = False
value_cuda = value_paddle.detach().clone()
sampling_locations_cuda = sampling_locations_paddle.detach().clone()
attention_weights_cuda = attention_weights_paddle.detach().clone()
value_cuda.stop_gradient = False
sampling_locations_cuda.stop_gradient = False
attention_weights_cuda.stop_gradient = False
output_paddle = ms_deform_attn_core_paddle(
value_paddle, spatial_shapes, level_start_index,
sampling_locations_paddle, attention_weights_paddle)
output_paddle.sum().backward()
output_cuda = ms_deformable_attn(value_cuda, spatial_shapes,
level_start_index, sampling_locations_cuda,
attention_weights_cuda)
output_cuda.sum().backward()
res = paddle.allclose(
value_paddle.grad, value_cuda.grad, rtol=1e-2, atol=1e-3).item()
print(f'*tensor1 {res} check_gradient_numerical(D={channels})')
res = paddle.allclose(
sampling_locations_paddle.grad,
sampling_locations_cuda.grad,
rtol=1e-2,
atol=1e-3).item()
print(f'*tensor2 {res} check_gradient_numerical(D={channels})')
res = paddle.allclose(
attention_weights_paddle.grad,
attention_weights_cuda.grad,
rtol=1e-2,
atol=1e-3).item()
print(f'*tensor3 {res} check_gradient_numerical(D={channels})')
if __name__ == '__main__':
check_forward_equal_with_paddle_float()
for channels in [30, 32, 64, 71, 128, 1024, 1025, 2048, 3096]:
check_gradient_numerical(channels)
......@@ -33,37 +33,34 @@ class PositionEmbedding(nn.Layer):
num_pos_feats=128,
temperature=10000,
normalize=True,
scale=None,
scale=2 * math.pi,
embed_type='sine',
num_embeddings=50,
offset=0.):
offset=0.,
eps=1e-6):
super(PositionEmbedding, self).__init__()
assert embed_type in ['sine', 'learned']
self.embed_type = embed_type
self.offset = offset
self.eps = 1e-6
self.eps = eps
if self.embed_type == 'sine':
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
elif self.embed_type == 'learned':
self.row_embed = nn.Embedding(num_embeddings, num_pos_feats)
self.col_embed = nn.Embedding(num_embeddings, num_pos_feats)
else:
raise ValueError(f"not supported {self.embed_type}")
raise ValueError(f"{self.embed_type} is not supported.")
def forward(self, mask):
"""
Args:
mask (Tensor): [B, H, W]
Returns:
pos (Tensor): [B, C, H, W]
pos (Tensor): [B, H, W, C]
"""
if self.embed_type == 'sine':
y_embed = mask.cumsum(1)
......@@ -86,20 +83,18 @@ class PositionEmbedding(nn.Layer):
pos_y = paddle.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
axis=4).flatten(3)
pos = paddle.concat((pos_y, pos_x), axis=3).transpose([0, 3, 1, 2])
return pos
return paddle.concat((pos_y, pos_x), axis=3)
elif self.embed_type == 'learned':
h, w = mask.shape[-2:]
i = paddle.arange(w)
j = paddle.arange(h)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = paddle.concat(
return paddle.concat(
[
x_emb.unsqueeze(0).tile([h, 1, 1]),
y_emb.unsqueeze(1).tile([1, w, 1]),
],
axis=-1).transpose([2, 0, 1]).unsqueeze(0)
return pos
axis=-1).unsqueeze(0)
else:
raise ValueError(f"not supported {self.embed_type}")
......@@ -38,15 +38,14 @@ def _get_clones(module, N):
def bbox_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.split(4, axis=-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return paddle.concat(b, axis=-1)
cxcy, wh = paddle.split(x, 2, axis=-1)
return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1)
def bbox_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.split(4, axis=-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return paddle.concat(b, axis=-1)
x1, y1, x2, y2 = x.split(4, axis=-1)
return paddle.concat(
[(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1)
def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
......@@ -67,24 +66,27 @@ def inverse_sigmoid(x, eps=1e-6):
def deformable_attention_core_func(value, value_spatial_shapes,
sampling_locations, attention_weights):
value_level_start_index, sampling_locations,
attention_weights):
"""
Args:
value (Tensor): [bs, value_length, n_head, c]
value_spatial_shapes (Tensor): [n_levels, 2]
value_level_start_index (Tensor): [n_levels]
sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs, Len_v, n_head, c = value.shape
_, Len_q, n_head, n_levels, n_points, _ = sampling_locations.shape
bs, _, n_head, c = value.shape
_, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
value_list = value.split(value_spatial_shapes.prod(1).tolist(), axis=1)
value_list = value.split(
value_spatial_shapes.prod(1).split(n_levels), axis=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (h, w) in enumerate(value_spatial_shapes.tolist()):
for level, (h, w) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[level].flatten(2).transpose(
[0, 2, 1]).reshape([bs * n_head, c, h, w])
......@@ -107,3 +109,11 @@ def deformable_attention_core_func(value, value_spatial_shapes,
attention_weights).sum(-1).reshape([bs, n_head * c, Len_q])
return output.transpose([0, 2, 1])
def get_valid_ratio(mask):
_, H, W = paddle.shape(mask)
valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H
valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
# [b, 2]
return paddle.stack([valid_ratio_w, valid_ratio_h], -1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册