From e8aeb802a901730ccc11ae1653a4a71249f9b46e Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Mon, 19 Jul 2021 18:56:24 +0800 Subject: [PATCH] [transformer] add Deformable DETR base code (#3718) --- ppdet/modeling/heads/detr_head.py | 79 ++- ppdet/modeling/post_process.py | 21 +- ppdet/modeling/transformers/__init__.py | 2 + .../transformers/deformable_transformer.py | 514 ++++++++++++++++++ .../transformers/position_encoding.py | 12 +- ppdet/modeling/transformers/utils.py | 51 +- 6 files changed, 667 insertions(+), 12 deletions(-) create mode 100644 ppdet/modeling/transformers/deformable_transformer.py diff --git a/ppdet/modeling/heads/detr_head.py b/ppdet/modeling/heads/detr_head.py index 5b55642e4..303e814d9 100644 --- a/ppdet/modeling/heads/detr_head.py +++ b/ppdet/modeling/heads/detr_head.py @@ -21,9 +21,10 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register import pycocotools.mask as mask_util -from ..initializer import linear_init_ +from ..initializer import linear_init_, constant_ +from ..transformers.utils import inverse_sigmoid -__all__ = ['DETRHead'] +__all__ = ['DETRHead', 'DeformableDETRHead'] class MLP(nn.Layer): @@ -275,3 +276,77 @@ class DETRHead(nn.Layer): gt_mask=gt_mask) else: return (outputs_bbox[-1], outputs_logit[-1], outputs_seg) + + +@register +class DeformableDETRHead(nn.Layer): + __shared__ = ['num_classes', 'hidden_dim'] + __inject__ = ['loss'] + + def __init__(self, + num_classes=80, + hidden_dim=512, + nhead=8, + num_mlp_layers=3, + loss='DETRLoss'): + super(DeformableDETRHead, self).__init__() + self.num_classes = num_classes + self.hidden_dim = hidden_dim + self.nhead = nhead + self.loss = loss + + self.score_head = nn.Linear(hidden_dim, self.num_classes) + self.bbox_head = MLP(hidden_dim, + hidden_dim, + output_dim=4, + num_layers=num_mlp_layers) + + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.score_head) + constant_(self.score_head.bias, -4.595) + constant_(self.bbox_head.layers[-1].weight) + bias = paddle.zeros_like(self.bbox_head.layers[-1].bias) + bias[2:] = -2.0 + self.bbox_head.layers[-1].bias.set_value(bias) + + @classmethod + def from_config(cls, cfg, hidden_dim, nhead, input_shape): + return {'hidden_dim': hidden_dim, 'nhead': nhead} + + def forward(self, out_transformer, body_feats, inputs=None): + r""" + Args: + out_transformer (Tuple): (feats: [num_levels, batch_size, + num_queries, hidden_dim], + memory: [batch_size, + \sum_{l=0}^{L-1} H_l \cdot W_l, hidden_dim], + reference_points: [batch_size, num_queries, 2]) + body_feats (List(Tensor)): list[[B, C, H, W]] + inputs (dict): dict(inputs) + """ + feats, memory, reference_points = out_transformer + reference_points = inverse_sigmoid(reference_points.unsqueeze(0)) + outputs_bbox = self.bbox_head(feats) + + # It's equivalent to "outputs_bbox[:, :, :, :2] += reference_points", + # but the gradient is wrong in paddle. + outputs_bbox = paddle.concat( + [ + outputs_bbox[:, :, :, :2] + reference_points, + outputs_bbox[:, :, :, 2:] + ], + axis=-1) + + outputs_bbox = F.sigmoid(outputs_bbox) + outputs_logit = self.score_head(feats) + + if self.training: + assert inputs is not None + assert 'gt_bbox' in inputs and 'gt_class' in inputs + + return self.loss(outputs_bbox, outputs_logit, inputs['gt_bbox'], + inputs['gt_class']) + else: + return (outputs_bbox[-1], outputs_logit[-1], None) diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 698f95f03..95f51a9a8 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -532,12 +532,23 @@ class DETRBBoxPostProcess(object): scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax( logits)[:, :, :-1] - scores, labels = scores.max(-1), scores.argmax(-1) - if scores.shape[1] > self.num_top_queries: - scores, index = paddle.topk(scores, self.num_top_queries, axis=-1) - labels = paddle.stack( - [paddle.gather(l, i) for l, i in zip(labels, index)]) + if not self.use_focal_loss: + scores, labels = scores.max(-1), scores.argmax(-1) + if scores.shape[1] > self.num_top_queries: + scores, index = paddle.topk( + scores, self.num_top_queries, axis=-1) + labels = paddle.stack( + [paddle.gather(l, i) for l, i in zip(labels, index)]) + bbox_pred = paddle.stack( + [paddle.gather(b, i) for b, i in zip(bbox_pred, index)]) + else: + scores, index = paddle.topk( + scores.reshape([logits.shape[0], -1]), + self.num_top_queries, + axis=-1) + labels = index % logits.shape[2] + index = index // logits.shape[2] bbox_pred = paddle.stack( [paddle.gather(b, i) for b, i in zip(bbox_pred, index)]) diff --git a/ppdet/modeling/transformers/__init__.py b/ppdet/modeling/transformers/__init__.py index 8bdcf4c26..4aed815d7 100644 --- a/ppdet/modeling/transformers/__init__.py +++ b/ppdet/modeling/transformers/__init__.py @@ -16,8 +16,10 @@ from . import detr_transformer from . import utils from . import matchers from . import position_encoding +from . import deformable_transformer from .detr_transformer import * from .utils import * from .matchers import * from .position_encoding import * +from .deformable_transformer import * diff --git a/ppdet/modeling/transformers/deformable_transformer.py b/ppdet/modeling/transformers/deformable_transformer.py new file mode 100644 index 000000000..2ed3ae5f2 --- /dev/null +++ b/ppdet/modeling/transformers/deformable_transformer.py @@ -0,0 +1,514 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr + +from ppdet.core.workspace import register +from ..layers import MultiHeadAttention +from .position_encoding import PositionEmbedding +from .utils import _get_clones, deformable_attention_core_func +from ..initializer import linear_init_, constant_, xavier_uniform_, normal_ + +__all__ = ['DeformableTransformer'] + + +class MSDeformableAttention(nn.Layer): + def __init__(self, + embed_dim=256, + num_heads=8, + num_levels=4, + num_points=4, + lr_mult=0.1): + """ + Multi-Scale Deformable Attention Module + """ + super(MSDeformableAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_levels = num_levels + self.num_points = num_points + self.total_points = num_heads * num_levels * num_points + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.sampling_offsets = nn.Linear( + embed_dim, + self.total_points * 2, + weight_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult)) + + self.attention_weights = nn.Linear(embed_dim, self.total_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + + self._reset_parameters() + + def _reset_parameters(self): + # sampling_offsets + constant_(self.sampling_offsets.weight) + thetas = paddle.arange( + self.num_heads, + dtype=paddle.float32) * (2.0 * math.pi / self.num_heads) + grid_init = paddle.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True) + grid_init = grid_init.reshape([self.num_heads, 1, 1, 2]).tile( + [1, self.num_levels, self.num_points, 1]) + scaling = paddle.arange( + 1, self.num_points + 1, + dtype=paddle.float32).reshape([1, 1, -1, 1]) + grid_init *= scaling + self.sampling_offsets.bias.set_value(grid_init.flatten()) + # attention_weights + constant_(self.attention_weights.weight) + constant_(self.attention_weights.bias) + # proj + xavier_uniform_(self.value_proj.weight) + constant_(self.value_proj.bias) + xavier_uniform_(self.output_proj.weight) + constant_(self.output_proj.bias) + + def forward(self, + query, + reference_points, + value, + value_spatial_shapes, + value_mask=None): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + Len_v = value.shape[1] + assert int(value_spatial_shapes.prod(1).sum()) == Len_v + + value = self.value_proj(value) + if value_mask is not None: + value_mask = value_mask.astype(value.dtype).unsqueeze(-1) + value *= value_mask + value = value.reshape([bs, Len_v, self.num_heads, self.head_dim]) + + sampling_offsets = self.sampling_offsets(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2]) + attention_weights = self.attention_weights(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels * self.num_points]) + attention_weights = F.softmax(attention_weights, -1).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points]) + + offset_normalizer = value_spatial_shapes.flip([1]).reshape( + [1, 1, 1, self.num_levels, 1, 2]) + sampling_locations = reference_points.reshape([ + bs, Len_q, 1, self.num_levels, 1, 2 + ]) + sampling_offsets / offset_normalizer + + output = deformable_attention_core_func( + value, value_spatial_shapes, sampling_locations, attention_weights) + output = self.output_proj(output) + + return output + + +class DeformableTransformerEncoderLayer(nn.Layer): + def __init__(self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_points=4, + weight_attr=None, + bias_attr=None): + super(DeformableTransformerEncoderLayer, self).__init__() + # self attention + self.self_attn = MSDeformableAttention(d_model, n_head, n_levels, + n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr, + bias_attr) + self.activation = getattr(F, activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr, + bias_attr) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, + src, + reference_points, + spatial_shapes, + src_mask=None, + pos_embed=None): + # self attention + src2 = self.self_attn( + self.with_pos_embed(src, pos_embed), reference_points, src, + spatial_shapes, src_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + # ffn + src = self.forward_ffn(src) + + return src + + +class DeformableTransformerEncoder(nn.Layer): + def __init__(self, encoder_layer, num_layers): + super(DeformableTransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios): + valid_ratios = valid_ratios.unsqueeze(1) + reference_points = [] + for i, (H, W) in enumerate(spatial_shapes.tolist()): + ref_y, ref_x = paddle.meshgrid( + paddle.linspace(0.5, H - 0.5, H), + paddle.linspace(0.5, W - 0.5, W)) + ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] * + H) + ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] * + W) + reference_points.append(paddle.stack((ref_x, ref_y), axis=-1)) + reference_points = paddle.concat(reference_points, 1).unsqueeze(2) + reference_points = reference_points * valid_ratios + return reference_points + + def forward(self, + src, + spatial_shapes, + src_mask=None, + pos_embed=None, + valid_ratios=None): + output = src + if valid_ratios is None: + valid_ratios = paddle.ones( + [src.shape[0], spatial_shapes.shape[0], 2]) + reference_points = self.get_reference_points(spatial_shapes, + valid_ratios) + for layer in self.layers: + output = layer(output, reference_points, spatial_shapes, src_mask, + pos_embed) + + return output + + +class DeformableTransformerDecoderLayer(nn.Layer): + def __init__(self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_points=4, + weight_attr=None, + bias_attr=None): + super(DeformableTransformerDecoderLayer, self).__init__() + + # self attention + self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # cross attention + self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, + n_points) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr, + bias_attr) + self.activation = getattr(F, activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr, + bias_attr) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward(self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_mask=None, + query_pos_embed=None): + # self attention + q = k = self.with_pos_embed(tgt, query_pos_embed) + tgt2 = self.self_attn(q, k, value=tgt) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos_embed), reference_points, memory, + memory_spatial_shapes, memory_mask) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Layer): + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super(DeformableTransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + def forward(self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_mask=None, + query_pos_embed=None): + output = tgt + intermediate = [] + for lid, layer in enumerate(self.layers): + output = layer(output, reference_points, memory, + memory_spatial_shapes, memory_mask, query_pos_embed) + + if self.return_intermediate: + intermediate.append(output) + + if self.return_intermediate: + return paddle.stack(intermediate) + + return output.unsqueeze(0) + + +@register +class DeformableTransformer(nn.Layer): + __shared__ = ['hidden_dim'] + + def __init__(self, + num_queries=300, + position_embed_type='sine', + return_intermediate_dec=True, + backbone_num_channels=[512, 1024, 2048], + num_feature_levels=4, + num_encoder_points=4, + num_decoder_points=4, + hidden_dim=256, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + lr_mult=0.1, + weight_attr=None, + bias_attr=None): + super(DeformableTransformer, self).__init__() + assert position_embed_type in ['sine', 'learned'], \ + f'ValueError: position_embed_type not supported {position_embed_type}!' + assert len(backbone_num_channels) <= num_feature_levels + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.num_feature_levels = num_feature_levels + + encoder_layer = DeformableTransformerEncoderLayer( + hidden_dim, nhead, dim_feedforward, dropout, activation, + num_feature_levels, num_encoder_points, weight_attr, bias_attr) + self.encoder = DeformableTransformerEncoder(encoder_layer, + num_encoder_layers) + + decoder_layer = DeformableTransformerDecoderLayer( + hidden_dim, nhead, dim_feedforward, dropout, activation, + num_feature_levels, num_decoder_points, weight_attr, bias_attr) + self.decoder = DeformableTransformerDecoder( + decoder_layer, num_decoder_layers, return_intermediate_dec) + + self.level_embed = nn.Embedding(num_feature_levels, hidden_dim) + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_embed = nn.Embedding(num_queries, hidden_dim) + + self.reference_points = nn.Linear( + hidden_dim, + 2, + weight_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult)) + + self.input_proj = nn.LayerList() + for in_channels in backbone_num_channels: + self.input_proj.append( + nn.Sequential( + nn.Conv2D( + in_channels, + hidden_dim, + kernel_size=1, + weight_attr=weight_attr, + bias_attr=bias_attr), + nn.GroupNorm(32, hidden_dim))) + in_channels = backbone_num_channels[-1] + for _ in range(num_feature_levels - len(backbone_num_channels)): + self.input_proj.append( + nn.Sequential( + nn.Conv2D( + in_channels, + hidden_dim, + kernel_size=3, + stride=2, + padding=1, + weight_attr=weight_attr, + bias_attr=bias_attr), + nn.GroupNorm(32, hidden_dim))) + in_channels = hidden_dim + + self.position_embedding = PositionEmbedding( + hidden_dim // 2, + normalize=True if position_embed_type == 'sine' else False, + embed_type=position_embed_type, + offset=-0.5) + + self._reset_parameters() + + def _reset_parameters(self): + normal_(self.level_embed.weight) + normal_(self.tgt_embed.weight) + normal_(self.query_pos_embed.weight) + xavier_uniform_(self.reference_points.weight) + constant_(self.reference_points.bias) + for l in self.input_proj: + xavier_uniform_(l[0].weight) + constant_(l[0].bias) + + @classmethod + def from_config(cls, cfg, input_shape): + return {'backbone_num_channels': [i.channels for i in input_shape], } + + def _get_valid_ratio(self, mask): + mask = mask.astype(paddle.float32) + _, H, W = mask.shape + valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H + valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W + valid_ratio = paddle.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, src_feats, src_mask=None): + srcs = [] + for i in range(len(src_feats)): + srcs.append(self.input_proj[i](src_feats[i])) + if self.num_feature_levels > len(srcs): + len_srcs = len(srcs) + for i in range(len_srcs, self.num_feature_levels): + if i == len_srcs: + srcs.append(self.input_proj[i](src_feats[-1])) + else: + srcs.append(self.input_proj[i](srcs[-1])) + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + valid_ratios = [] + for level, src in enumerate(srcs): + bs, c, h, w = src.shape + spatial_shapes.append([h, w]) + src = src.flatten(2).transpose([0, 2, 1]) + src_flatten.append(src) + if src_mask is not None: + mask = F.interpolate( + src_mask.unsqueeze(0).astype(src.dtype), + size=(h, w))[0].astype('bool') + else: + mask = paddle.ones([bs, h, w], dtype='bool') + valid_ratios.append(self._get_valid_ratio(mask)) + pos_embed = self.position_embedding(mask).flatten(2).transpose( + [0, 2, 1]) + lvl_pos_embed = pos_embed + self.level_embed.weight[level].reshape( + [1, 1, -1]) + lvl_pos_embed_flatten.append(lvl_pos_embed) + mask = mask.astype(src.dtype).flatten(1) + mask_flatten.append(mask) + src_flatten = paddle.concat(src_flatten, 1) + mask_flatten = paddle.concat(mask_flatten, 1) + lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1) + # [l, 2] + spatial_shapes = paddle.to_tensor(spatial_shapes, dtype='int64') + # [b, l, 2] + valid_ratios = paddle.stack(valid_ratios, 1) + + # encoder + memory = self.encoder(src_flatten, spatial_shapes, mask_flatten, + lvl_pos_embed_flatten, valid_ratios) + + # prepare input for decoder + bs, _, c = memory.shape + query_embed = self.query_pos_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + tgt = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + reference_points = F.sigmoid(self.reference_points(query_embed)) + reference_points_input = reference_points.unsqueeze( + 2) * valid_ratios.unsqueeze(1) + + # decoder + hs = self.decoder(tgt, reference_points_input, memory, spatial_shapes, + mask_flatten, query_embed) + + return (hs, memory, reference_points) diff --git a/ppdet/modeling/transformers/position_encoding.py b/ppdet/modeling/transformers/position_encoding.py index 2644d36e1..52067ffc8 100644 --- a/ppdet/modeling/transformers/position_encoding.py +++ b/ppdet/modeling/transformers/position_encoding.py @@ -32,11 +32,14 @@ class PositionEmbedding(nn.Layer): normalize=True, scale=None, embed_type='sine', - num_embeddings=50): + num_embeddings=50, + offset=0.): super(PositionEmbedding, self).__init__() assert embed_type in ['sine', 'learned'] self.embed_type = embed_type + self.offset = offset + self.eps = 1e-6 if self.embed_type == 'sine': self.num_pos_feats = num_pos_feats self.temperature = temperature @@ -65,9 +68,10 @@ class PositionEmbedding(nn.Layer): y_embed = mask.cumsum(1, dtype='float32') x_embed = mask.cumsum(2, dtype='float32') if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + y_embed = (y_embed + self.offset) / ( + y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / ( + x_embed[:, :, -1:] + self.eps) * self.scale dim_t = 2 * (paddle.arange(self.num_pos_feats) // 2).astype('float32') diff --git a/ppdet/modeling/transformers/utils.py b/ppdet/modeling/transformers/utils.py index 5756cfe85..d4fa3efa3 100644 --- a/ppdet/modeling/transformers/utils.py +++ b/ppdet/modeling/transformers/utils.py @@ -25,7 +25,8 @@ from ..bbox_utils import bbox_overlaps __all__ = [ '_get_clones', 'bbox_overlaps', 'bbox_cxcywh_to_xyxy', - 'bbox_xyxy_to_cxcywh', 'sigmoid_focal_loss' + 'bbox_xyxy_to_cxcywh', 'sigmoid_focal_loss', 'inverse_sigmoid', + 'deformable_attention_core_func' ] @@ -55,3 +56,51 @@ def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0): alpha_t = alpha * label + (1 - alpha) * (1 - label) loss = alpha_t * loss return loss.mean(1).sum() / normalizer + + +def inverse_sigmoid(x, eps=1e-6): + x = x.clip(min=0., max=1.) + return paddle.log(x / (1 - x + eps) + eps) + + +def deformable_attention_core_func(value, value_spatial_shapes, + sampling_locations, attention_weights): + """ + Args: + value (Tensor): [bs, value_length, n_head, c] + value_spatial_shapes (Tensor): [n_levels, 2] + sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2] + attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_v, n_head, c = value.shape + _, Len_q, n_head, n_levels, n_points, _ = sampling_locations.shape + + value_list = value.split(value_spatial_shapes.prod(1).tolist(), axis=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (h, w) in enumerate(value_spatial_shapes.tolist()): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[level].flatten(2).transpose( + [0, 2, 1]).reshape([bs * n_head, c, h, w]) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level].transpose( + [0, 2, 1, 3, 4]).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose([0, 2, 1, 3, 4]).reshape( + [bs * n_head, 1, Len_q, n_levels * n_points]) + output = (paddle.stack( + sampling_value_list, axis=-2).flatten(-2) * + attention_weights).sum(-1).reshape([bs, n_head * c, Len_q]) + + return output.transpose([0, 2, 1]) -- GitLab