未验证 提交 918b7542 编写于 作者: S shangliang Xu 提交者: GitHub

add DINO (#7583)

* add DINO

add msdeformable attention cuda op

fix export bug

* fix decoder norm
上级 ca1f803a
# DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection
## Introduction
[DINO](https://arxiv.org/abs/2203.03605) is an object detection model based on DETR. We reproduced the model of the paper.
## Model Zoo
| Backbone | Model | Epochs | Box AP | Config | Download |
|:------:|:---------------:|:------:|:------:|:---------------------------------------:|:--------------------------------------------------------------------------------:|
| R-50 | dino_r50_4scale | 12 | 49.3 | [config](./dino_r50_4scale_1x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/dino_r50_4scale_1x_coco.pdparams) |
| R-50 | dino_r50_4scale | 24 | 50.8 | [config](./dino_r50_4scale_2x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/dino_r50_4scale_2x_coco.pdparams) |
**Notes:**
- DINO is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
- DINO uses 4GPU to train.
GPU multi-card training
```bash
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/dino/dino_r50_4scale_1x_coco.yml --fleet --eval
```
## Custom Operator
- Multi-scale deformable attention custom operator see [here](../../ppdet/modeling/transformers/ext_op).
## Citations
```
@misc{zhang2022dino,
title={DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection},
author={Hao Zhang and Feng Li and Shilong Liu and Lei Zhang and Hang Su and Jun Zhu and Lionel M. Ni and Heung-Yeung Shum},
year={2022},
eprint={2203.03605},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
architecture: DETR
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
hidden_dim: 256
use_focal_loss: True
DETR:
backbone: ResNet
transformer: DINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.0, 0.1, 0.1, 0.1]
num_stages: 4
DINOTransformer:
num_queries: 900
position_embed_type: sine
num_levels: 4
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
dropout: 0.0
activation: relu
pe_temperature: 20
pe_offset: 0.0
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: True
DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
num_top_queries: 300
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
RandomSizeCrop: { min_size: 384, max_size: 600 },
RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ]
}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
epoch: 12
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [11]
use_warmup: false
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
epoch: 24
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [20]
use_warmup: false
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/dino_r50.yml',
'_base_/dino_reader.yml',
]
weights: output/dino_r50_4scale_1x_coco/model_final
find_unused_parameters: True
log_iter: 100
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_2x.yml',
'_base_/dino_r50.yml',
'_base_/dino_reader.yml',
]
weights: output/dino_r50_4scale_2x_coco/model_final
find_unused_parameters: True
log_iter: 100
......@@ -69,7 +69,7 @@ class DETR(BaseArch):
# Transformer
pad_mask = self.inputs['pad_mask'] if self.training else None
out_transformer = self.transformer(body_feats, pad_mask)
out_transformer = self.transformer(body_feats, pad_mask, self.inputs)
# DETR Head
if self.training:
......
......@@ -24,7 +24,7 @@ import pycocotools.mask as mask_util
from ..initializer import linear_init_, constant_
from ..transformers.utils import inverse_sigmoid
__all__ = ['DETRHead', 'DeformableDETRHead']
__all__ = ['DETRHead', 'DeformableDETRHead', 'DINOHead']
class MLP(nn.Layer):
......@@ -362,3 +362,43 @@ class DeformableDETRHead(nn.Layer):
inputs['gt_class'])
else:
return (outputs_bbox[-1], outputs_logit[-1], None)
@register
class DINOHead(nn.Layer):
__inject__ = ['loss']
def __init__(self, loss='DINOLoss'):
super(DINOHead, self).__init__()
self.loss = loss
def forward(self, out_transformer, body_feats, inputs=None):
(dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits,
dn_meta) = out_transformer
if self.training:
assert inputs is not None
assert 'gt_bbox' in inputs and 'gt_class' in inputs
if dn_meta is not None:
dn_out_bboxes, dec_out_bboxes = paddle.split(
dec_out_bboxes, dn_meta['dn_num_split'], axis=2)
dn_out_logits, dec_out_logits = paddle.split(
dec_out_logits, dn_meta['dn_num_split'], axis=2)
else:
dn_out_bboxes, dn_out_logits = None, None
out_bboxes = paddle.concat(
[enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
out_logits = paddle.concat(
[enc_topk_logits.unsqueeze(0), dec_out_logits])
return self.loss(
out_bboxes,
out_logits,
inputs['gt_bbox'],
inputs['gt_class'],
dn_out_bboxes=dn_out_bboxes,
dn_out_logits=dn_out_logits,
dn_meta=dn_meta)
else:
return (dec_out_bboxes[-1], dec_out_logits[-1], None)
......@@ -23,7 +23,7 @@ from ppdet.core.workspace import register
from .iou_loss import GIoULoss
from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
__all__ = ['DETRLoss']
__all__ = ['DETRLoss', 'DINOLoss']
@register
......@@ -67,9 +67,17 @@ class DETRLoss(nn.Layer):
self.loss_coeff['class'][-1] = loss_coeff['no_object']
self.giou_loss = GIoULoss()
def _get_loss_class(self, logits, gt_class, match_indices, bg_index,
num_gts):
def _get_loss_class(self,
logits,
gt_class,
match_indices,
bg_index,
num_gts,
postfix=""):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = "loss_class" + postfix
if logits is None:
return {name_class: paddle.zeros([1])}
target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
bs, num_query_objects = target_label.shape
if sum(len(a) for a in gt_class) > 0:
......@@ -82,36 +90,46 @@ class DETRLoss(nn.Layer):
target_label = F.one_hot(target_label,
self.num_classes + 1)[..., :-1]
return {
'loss_class': self.loss_coeff['class'] * sigmoid_focal_loss(
name_class: self.loss_coeff['class'] * sigmoid_focal_loss(
logits, target_label, num_gts / num_query_objects)
if self.use_focal_loss else F.cross_entropy(
logits, target_label, weight=self.loss_coeff['class'])
}
def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts):
def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
postfix=""):
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = "loss_bbox" + postfix
name_giou = "loss_giou" + postfix
if boxes is None:
return {name_bbox: paddle.zeros([1]), name_giou: paddle.zeros([1])}
loss = dict()
if sum(len(a) for a in gt_bbox) == 0:
loss['loss_bbox'] = paddle.to_tensor([0.])
loss['loss_giou'] = paddle.to_tensor([0.])
loss[name_bbox] = paddle.to_tensor([0.])
loss[name_giou] = paddle.to_tensor([0.])
return loss
src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox,
match_indices)
loss['loss_bbox'] = self.loss_coeff['bbox'] * F.l1_loss(
loss[name_bbox] = self.loss_coeff['bbox'] * F.l1_loss(
src_bbox, target_bbox, reduction='sum') / num_gts
loss['loss_giou'] = self.giou_loss(
loss[name_giou] = self.giou_loss(
bbox_cxcywh_to_xyxy(src_bbox), bbox_cxcywh_to_xyxy(target_bbox))
loss['loss_giou'] = loss['loss_giou'].sum() / num_gts
loss['loss_giou'] = self.loss_coeff['giou'] * loss['loss_giou']
loss[name_giou] = loss[name_giou].sum() / num_gts
loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
return loss
def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts):
def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts,
postfix=""):
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
name_mask = "loss_mask" + postfix
name_dice = "loss_dice" + postfix
if masks is None:
return {name_mask: paddle.zeros([1]), name_dice: paddle.zeros([1])}
loss = dict()
if sum(len(a) for a in gt_mask) == 0:
loss['loss_mask'] = paddle.to_tensor([0.])
loss['loss_dice'] = paddle.to_tensor([0.])
loss[name_mask] = paddle.to_tensor([0.])
loss[name_dice] = paddle.to_tensor([0.])
return loss
src_masks, target_masks = self._get_src_target_assign(masks, gt_mask,
......@@ -120,12 +138,12 @@ class DETRLoss(nn.Layer):
src_masks.unsqueeze(0),
size=target_masks.shape[-2:],
mode="bilinear")[0]
loss['loss_mask'] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
loss[name_mask] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
src_masks,
target_masks,
paddle.to_tensor(
[num_gts], dtype='float32'))
loss['loss_dice'] = self.loss_coeff['dice'] * self._dice_loss(
loss[name_dice] = self.loss_coeff['dice'] * self._dice_loss(
src_masks, target_masks, num_gts)
return loss
......@@ -138,25 +156,40 @@ class DETRLoss(nn.Layer):
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_gts
def _get_loss_aux(self, boxes, logits, gt_bbox, gt_class, bg_index,
num_gts):
def _get_loss_aux(self,
boxes,
logits,
gt_bbox,
gt_class,
bg_index,
num_gts,
match_indices=None,
postfix=""):
if boxes is None and logits is None:
return {
"loss_class_aux" + postfix: paddle.paddle.zeros([1]),
"loss_bbox_aux" + postfix: paddle.paddle.zeros([1]),
"loss_giou_aux" + postfix: paddle.paddle.zeros([1])
}
loss_class = []
loss_bbox = []
loss_giou = []
for aux_boxes, aux_logits in zip(boxes, logits):
if match_indices is None:
match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
gt_class)
loss_class.append(
self._get_loss_class(aux_logits, gt_class, match_indices,
bg_index, num_gts)['loss_class'])
bg_index, num_gts, postfix)['loss_class' +
postfix])
loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
num_gts)
loss_bbox.append(loss_['loss_bbox'])
loss_giou.append(loss_['loss_giou'])
num_gts, postfix)
loss_bbox.append(loss_['loss_bbox' + postfix])
loss_giou.append(loss_['loss_giou' + postfix])
loss = {
'loss_class_aux': paddle.add_n(loss_class),
'loss_bbox_aux': paddle.add_n(loss_bbox),
'loss_giou_aux': paddle.add_n(loss_giou)
"loss_class_aux" + postfix: paddle.add_n(loss_class),
"loss_bbox_aux" + postfix: paddle.add_n(loss_bbox),
"loss_giou_aux" + postfix: paddle.add_n(loss_giou)
}
return loss
......@@ -191,40 +224,105 @@ class DETRLoss(nn.Layer):
gt_bbox,
gt_class,
masks=None,
gt_mask=None):
gt_mask=None,
postfix="",
**kwargs):
r"""
Args:
boxes (Tensor): [l, b, query, 4]
logits (Tensor): [l, b, query, num_classes]
boxes (Tensor|None): [l, b, query, 4]
logits (Tensor|None): [l, b, query, num_classes]
gt_bbox (List(Tensor)): list[[n, 4]]
gt_class (List(Tensor)): list[[n, 1]]
masks (Tensor, optional): [b, query, h, w]
gt_mask (List(Tensor), optional): list[[n, H, W]]
postfix (str): postfix of loss name
"""
match_indices = self.matcher(boxes[-1].detach(), logits[-1].detach(),
gt_bbox, gt_class)
if "match_indices" in kwargs:
match_indices = kwargs["match_indices"]
else:
match_indices = self.matcher(boxes[-1].detach(),
logits[-1].detach(), gt_bbox, gt_class)
num_gts = sum(len(a) for a in gt_bbox)
try:
# TODO: Paddle does not have a "paddle.distributed.is_initialized()"
num_gts = paddle.to_tensor([num_gts], dtype=paddle.float32)
num_gts = paddle.to_tensor([num_gts], dtype="float32")
if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(num_gts)
num_gts = paddle.clip(
num_gts / paddle.distributed.get_world_size(), min=1).item()
except:
num_gts = max(num_gts.item(), 1)
num_gts /= paddle.distributed.get_world_size()
num_gts = paddle.clip(num_gts, min=1.) * kwargs.get("dn_num_group", 1.)
total_loss = dict()
total_loss.update(
self._get_loss_class(logits[-1], gt_class, match_indices,
self.num_classes, num_gts))
self._get_loss_class(logits[
-1] if logits is not None else None, gt_class, match_indices,
self.num_classes, num_gts, postfix))
total_loss.update(
self._get_loss_bbox(boxes[-1], gt_bbox, match_indices, num_gts))
self._get_loss_bbox(boxes[-1] if boxes is not None else None,
gt_bbox, match_indices, num_gts, postfix))
if masks is not None and gt_mask is not None:
total_loss.update(
self._get_loss_mask(masks, gt_mask, match_indices, num_gts))
self._get_loss_mask(masks if masks is not None else None,
gt_mask, match_indices, num_gts, postfix))
if self.aux_loss:
if "match_indices" not in kwargs:
match_indices = None
total_loss.update(
self._get_loss_aux(boxes[:-1], logits[:-1], gt_bbox, gt_class,
self.num_classes, num_gts))
self._get_loss_aux(
boxes[:-1] if boxes is not None else None, logits[:-1]
if logits is not None else None, gt_bbox, gt_class,
self.num_classes, num_gts, match_indices, postfix))
return total_loss
@register
class DINOLoss(DETRLoss):
def forward(self,
boxes,
logits,
gt_bbox,
gt_class,
masks=None,
gt_mask=None,
postfix="",
dn_out_bboxes=None,
dn_out_logits=None,
dn_meta=None,
**kwargs):
total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox,
gt_class)
# denoising training loss
if dn_meta is not None:
dn_positive_idx, dn_num_group = \
dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
assert len(gt_class) == len(dn_positive_idx)
# denoising match indices
dn_match_indices = []
for i in range(len(gt_class)):
num_gt = len(gt_class[i])
if num_gt > 0:
gt_idx = paddle.arange(end=num_gt, dtype="int64")
gt_idx = gt_idx.unsqueeze(0).tile(
[dn_num_group, 1]).flatten()
assert len(gt_idx) == len(dn_positive_idx[i])
dn_match_indices.append((dn_positive_idx[i], gt_idx))
else:
dn_match_indices.append((paddle.zeros(
[0], dtype="int64"), paddle.zeros(
[0], dtype="int64")))
else:
dn_match_indices, dn_num_group = None, 1.
dn_loss = super(DINOLoss, self).forward(
dn_out_bboxes,
dn_out_logits,
gt_bbox,
gt_class,
postfix="_dn",
match_indices=dn_match_indices,
dn_num_group=dn_num_group)
total_loss.update(dn_loss)
return total_loss
......@@ -17,9 +17,11 @@ from . import utils
from . import matchers
from . import position_encoding
from . import deformable_transformer
from . import dino_transformer
from .detr_transformer import *
from .utils import *
from .matchers import *
from .position_encoding import *
from .deformable_transformer import *
from .dino_transformer import *
......@@ -131,14 +131,24 @@ class MSDeformableAttention(nn.Layer):
[bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2])
attention_weights = self.attention_weights(query).reshape(
[bs, Len_q, self.num_heads, self.num_levels * self.num_points])
attention_weights = F.softmax(attention_weights, -1).reshape(
attention_weights = F.softmax(attention_weights).reshape(
[bs, Len_q, self.num_heads, self.num_levels, self.num_points])
if reference_points.shape[-1] == 2:
offset_normalizer = value_spatial_shapes.flip([1]).reshape(
[1, 1, 1, self.num_levels, 1, 2])
sampling_locations = reference_points.reshape([
bs, Len_q, 1, self.num_levels, 1, 2
]) + sampling_offsets / offset_normalizer
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2] + sampling_offsets /
self.num_points * reference_points[:, :, None, :, None, 2:] *
0.5)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".
format(reference_points.shape[-1]))
output = self.ms_deformable_attn_core(
value, value_spatial_shapes, value_level_start_index,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Modified from detrex (https://github.com/IDEA-Research/detrex)
# Copyright 2022 The IDEA Authors. All rights reserved.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
from ..layers import MultiHeadAttention
from .position_encoding import PositionEmbedding
from ..heads.detr_head import MLP
from .deformable_transformer import MSDeformableAttention
from ..initializer import (linear_init_, constant_, xavier_uniform_, normal_,
bias_init_with_prob)
from .utils import (_get_clones, get_valid_ratio,
get_contrastive_denoising_training_group,
get_sine_pos_embed, inverse_sigmoid)
__all__ = ['DINOTransformer']
class DINOTransformerEncoderLayer(nn.Layer):
def __init__(self,
d_model=256,
n_head=8,
dim_feedforward=1024,
dropout=0.,
activation="relu",
n_levels=4,
n_points=4,
weight_attr=None,
bias_attr=None):
super(DINOTransformerEncoderLayer, self).__init__()
# self attention
self.self_attn = MSDeformableAttention(d_model, n_head, n_levels,
n_points, 1.0)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr,
bias_attr)
self.activation = getattr(F, activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr,
bias_attr)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self._reset_parameters()
def _reset_parameters(self):
linear_init_(self.linear1)
linear_init_(self.linear2)
xavier_uniform_(self.linear1.weight)
xavier_uniform_(self.linear2.weight)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self,
src,
reference_points,
spatial_shapes,
level_start_index,
src_mask=None,
query_pos_embed=None):
# self attention
src2 = self.self_attn(
self.with_pos_embed(src, query_pos_embed), reference_points, src,
spatial_shapes, level_start_index, src_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DINOTransformerEncoder(nn.Layer):
def __init__(self, encoder_layer, num_layers):
super(DINOTransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, offset=0.5):
valid_ratios = valid_ratios.unsqueeze(1)
reference_points = []
for i, (H, W) in enumerate(spatial_shapes):
ref_y, ref_x = paddle.meshgrid(
paddle.arange(end=H) + offset, paddle.arange(end=W) + offset)
ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] *
H)
ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] *
W)
reference_points.append(paddle.stack((ref_x, ref_y), axis=-1))
reference_points = paddle.concat(reference_points, 1).unsqueeze(2)
reference_points = reference_points * valid_ratios
return reference_points
def forward(self,
feat,
spatial_shapes,
level_start_index,
feat_mask=None,
query_pos_embed=None,
valid_ratios=None):
if valid_ratios is None:
valid_ratios = paddle.ones(
[feat.shape[0], spatial_shapes.shape[0], 2])
reference_points = self.get_reference_points(spatial_shapes,
valid_ratios)
for layer in self.layers:
feat = layer(feat, reference_points, spatial_shapes,
level_start_index, feat_mask, query_pos_embed)
return feat
class DINOTransformerDecoderLayer(nn.Layer):
def __init__(self,
d_model=256,
n_head=8,
dim_feedforward=1024,
dropout=0.,
activation="relu",
n_levels=4,
n_points=4,
weight_attr=None,
bias_attr=None):
super(DINOTransformerDecoderLayer, self).__init__()
# self attention
self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
# cross attention
self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels,
n_points, 1.0)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr,
bias_attr)
self.activation = getattr(F, activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr,
bias_attr)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self._reset_parameters()
def _reset_parameters(self):
linear_init_(self.linear1)
linear_init_(self.linear2)
xavier_uniform_(self.linear1.weight)
xavier_uniform_(self.linear2.weight)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
return self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
def forward(self,
tgt,
reference_points,
memory,
memory_spatial_shapes,
memory_level_start_index,
attn_mask=None,
memory_mask=None,
query_pos_embed=None):
# self attention
q = k = self.with_pos_embed(tgt, query_pos_embed)
if attn_mask is not None:
attn_mask = attn_mask.astype('bool')
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# cross attention
tgt2 = self.cross_attn(
self.with_pos_embed(tgt, query_pos_embed), reference_points, memory,
memory_spatial_shapes, memory_level_start_index, memory_mask)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# ffn
tgt2 = self.forward_ffn(tgt)
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
class DINOTransformerDecoder(nn.Layer):
def __init__(self,
hidden_dim,
decoder_layer,
num_layers,
return_intermediate=True):
super(DINOTransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.return_intermediate = return_intermediate
self.norm = nn.LayerNorm(
hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
def forward(self,
tgt,
reference_points,
memory,
memory_spatial_shapes,
memory_level_start_index,
bbox_head,
query_pos_head,
valid_ratios=None,
attn_mask=None,
memory_mask=None):
if valid_ratios is None:
valid_ratios = paddle.ones(
[memory.shape[0], memory_spatial_shapes.shape[0], 2])
output = tgt
intermediate = []
inter_ref_bboxes = []
for i, layer in enumerate(self.layers):
reference_points_input = reference_points.unsqueeze(
2) * valid_ratios.tile([1, 1, 2]).unsqueeze(1)
query_pos_embed = get_sine_pos_embed(
reference_points_input[..., 0, :], self.hidden_dim // 2)
query_pos_embed = query_pos_head(query_pos_embed)
output = layer(output, reference_points_input, memory,
memory_spatial_shapes, memory_level_start_index,
attn_mask, memory_mask, query_pos_embed)
inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
reference_points))
if self.return_intermediate:
intermediate.append(self.norm(output))
inter_ref_bboxes.append(inter_ref_bbox)
reference_points = inter_ref_bbox.detach()
if self.return_intermediate:
return paddle.stack(intermediate), paddle.stack(inter_ref_bboxes)
return output, reference_points
@register
class DINOTransformer(nn.Layer):
__shared__ = ['num_classes', 'hidden_dim']
def __init__(self,
num_classes=80,
hidden_dim=256,
num_queries=900,
position_embed_type='sine',
return_intermediate_dec=True,
backbone_feat_channels=[512, 1024, 2048],
num_levels=4,
num_encoder_points=4,
num_decoder_points=4,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0.,
activation="relu",
pe_temperature=10000,
pe_offset=-0.5,
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0,
learnt_init_query=True,
eps=1e-2):
super(DINOTransformer, self).__init__()
assert position_embed_type in ['sine', 'learned'], \
f'ValueError: position_embed_type not supported {position_embed_type}!'
assert len(backbone_feat_channels) <= num_levels
self.hidden_dim = hidden_dim
self.nhead = nhead
self.num_levels = num_levels
self.num_classes = num_classes
self.num_queries = num_queries
self.eps = eps
self.num_decoder_layers = num_decoder_layers
# backbone feature projection
self._build_input_proj_layer(backbone_feat_channels)
# Transformer module
encoder_layer = DINOTransformerEncoderLayer(
hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels,
num_encoder_points)
self.encoder = DINOTransformerEncoder(encoder_layer, num_encoder_layers)
decoder_layer = DINOTransformerDecoderLayer(
hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels,
num_decoder_points)
self.decoder = DINOTransformerDecoder(hidden_dim, decoder_layer,
num_decoder_layers,
return_intermediate_dec)
# denoising part
self.denoising_class_embed = nn.Embedding(
num_classes,
hidden_dim,
weight_attr=ParamAttr(initializer=nn.initializer.Normal()))
self.num_denoising = num_denoising
self.label_noise_ratio = label_noise_ratio
self.box_noise_scale = box_noise_scale
# position embedding
self.position_embedding = PositionEmbedding(
hidden_dim // 2,
temperature=pe_temperature,
normalize=True if position_embed_type == 'sine' else False,
embed_type=position_embed_type,
offset=pe_offset)
self.level_embed = nn.Embedding(num_levels, hidden_dim)
# decoder embedding
self.learnt_init_query = learnt_init_query
if learnt_init_query:
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
self.query_pos_head = MLP(2 * hidden_dim,
hidden_dim,
hidden_dim,
num_layers=2)
# encoder head
self.enc_output = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(
hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))))
self.enc_score_head = nn.Linear(hidden_dim, num_classes)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
# decoder head
self.dec_score_head = nn.LayerList([
nn.Linear(hidden_dim, num_classes)
for _ in range(num_decoder_layers)
])
self.dec_bbox_head = nn.LayerList([
MLP(hidden_dim, hidden_dim, 4, num_layers=3)
for _ in range(num_decoder_layers)
])
self._reset_parameters()
def _reset_parameters(self):
# class and bbox head init
bias_cls = bias_init_with_prob(0.01)
linear_init_(self.enc_score_head)
constant_(self.enc_score_head.bias, bias_cls)
constant_(self.enc_bbox_head.layers[-1].weight)
constant_(self.enc_bbox_head.layers[-1].bias)
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
linear_init_(cls_)
constant_(cls_.bias, bias_cls)
constant_(reg_.layers[-1].weight)
constant_(reg_.layers[-1].bias)
linear_init_(self.enc_output[0])
xavier_uniform_(self.enc_output[0].weight)
normal_(self.level_embed.weight)
xavier_uniform_(self.tgt_embed.weight)
xavier_uniform_(self.query_pos_head.layers[0].weight)
xavier_uniform_(self.query_pos_head.layers[1].weight)
for l in self.input_proj:
xavier_uniform_(l[0].weight)
constant_(l[0].bias)
@classmethod
def from_config(cls, cfg, input_shape):
return {'backbone_feat_channels': [i.channels for i in input_shape], }
def _build_input_proj_layer(self, backbone_feat_channels):
self.input_proj = nn.LayerList()
for in_channels in backbone_feat_channels:
self.input_proj.append(
nn.Sequential(
('conv', nn.Conv2D(
in_channels, self.hidden_dim, kernel_size=1)),
('norm', nn.GroupNorm(
32,
self.hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))))))
in_channels = backbone_feat_channels[-1]
for _ in range(self.num_levels - len(backbone_feat_channels)):
self.input_proj.append(
nn.Sequential(
('conv', nn.Conv2D(
in_channels,
self.hidden_dim,
kernel_size=3,
stride=2,
padding=1)), ('norm', nn.GroupNorm(
32,
self.hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))))))
in_channels = self.hidden_dim
def _get_encoder_input(self, feats, pad_mask=None):
# get projection features
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
if self.num_levels > len(proj_feats):
len_srcs = len(proj_feats)
for i in range(len_srcs, self.num_levels):
if i == len_srcs:
proj_feats.append(self.input_proj[i](feats[-1]))
else:
proj_feats.append(self.input_proj[i](proj_feats[-1]))
# get encoder inputs
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
valid_ratios = []
for i, feat in enumerate(proj_feats):
bs, _, h, w = paddle.shape(feat)
spatial_shapes.append(paddle.concat([h, w]))
# [b,c,h,w] -> [b,h*w,c]
feat_flatten.append(feat.flatten(2).transpose([0, 2, 1]))
if pad_mask is not None:
mask = F.interpolate(pad_mask.unsqueeze(0), size=(h, w))[0]
else:
mask = paddle.ones([bs, h, w])
valid_ratios.append(get_valid_ratio(mask))
# [b, h*w, c]
pos_embed = self.position_embedding(mask).flatten(1, 2)
lvl_pos_embed = pos_embed + self.level_embed.weight[i]
lvl_pos_embed_flatten.append(lvl_pos_embed)
if pad_mask is not None:
# [b, h*w]
mask_flatten.append(mask.flatten(1))
# [b, l, c]
feat_flatten = paddle.concat(feat_flatten, 1)
# [b, l]
mask_flatten = None if pad_mask is None else paddle.concat(mask_flatten,
1)
# [b, l, c]
lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1)
# [num_levels, 2]
spatial_shapes = paddle.to_tensor(
paddle.stack(spatial_shapes).astype('int64'))
# [l], 每一个level的起始index
level_start_index = paddle.concat([
paddle.zeros(
[1], dtype='int64'), spatial_shapes.prod(1).cumsum(0)[:-1]
])
# [b, num_levels, 2]
valid_ratios = paddle.stack(valid_ratios, 1)
return (feat_flatten, spatial_shapes, level_start_index, mask_flatten,
lvl_pos_embed_flatten, valid_ratios)
def forward(self, feats, pad_mask=None, gt_meta=None):
# input projection and embedding
(feat_flatten, spatial_shapes, level_start_index, mask_flatten,
lvl_pos_embed_flatten,
valid_ratios) = self._get_encoder_input(feats, pad_mask)
# encoder
memory = self.encoder(feat_flatten, spatial_shapes, level_start_index,
mask_flatten, lvl_pos_embed_flatten, valid_ratios)
# prepare denoising training
if self.training:
denoising_class, denoising_bbox, attn_mask, dn_meta = \
get_contrastive_denoising_training_group(gt_meta,
self.num_classes,
self.num_queries,
self.denoising_class_embed.weight,
self.num_denoising,
self.label_noise_ratio,
self.box_noise_scale)
else:
denoising_class, denoising_bbox, attn_mask, dn_meta = None, None, None, None
target, init_ref_points, enc_topk_bboxes, enc_topk_logits = \
self._get_decoder_input(
memory, spatial_shapes, mask_flatten, denoising_class,
denoising_bbox)
# decoder
inter_feats, inter_ref_bboxes = self.decoder(
target, init_ref_points, memory, spatial_shapes, level_start_index,
self.dec_bbox_head, self.query_pos_head, valid_ratios, attn_mask,
mask_flatten)
out_bboxes = []
out_logits = []
for i in range(self.num_decoder_layers):
out_logits.append(self.dec_score_head[i](inter_feats[i]))
if i == 0:
out_bboxes.append(
F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) +
inverse_sigmoid(init_ref_points)))
else:
out_bboxes.append(
F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) +
inverse_sigmoid(inter_ref_bboxes[i - 1])))
out_bboxes = paddle.stack(out_bboxes)
out_logits = paddle.stack(out_logits)
return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits,
dn_meta)
def _get_encoder_output_anchors(self,
memory,
spatial_shapes,
memory_mask=None,
grid_size=0.05):
output_anchors = []
idx = 0
for lvl, (h, w) in enumerate(spatial_shapes):
if memory_mask is not None:
mask_ = memory_mask[:, idx:idx + h * w].reshape([-1, h, w])
valid_H = paddle.sum(mask_[:, :, 0], 1)
valid_W = paddle.sum(mask_[:, 0, :], 1)
else:
valid_H, valid_W = h, w
grid_y, grid_x = paddle.meshgrid(
paddle.arange(
end=h, dtype=memory.dtype),
paddle.arange(
end=w, dtype=memory.dtype))
grid_xy = paddle.stack([grid_x, grid_y], -1)
valid_WH = paddle.stack([valid_W, valid_H], -1).reshape(
[-1, 1, 1, 2]).astype(grid_xy.dtype)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
wh = paddle.ones_like(grid_xy) * grid_size * (2.0**lvl)
output_anchors.append(
paddle.concat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
idx += h * w
output_anchors = paddle.concat(output_anchors, 1)
valid_mask = ((output_anchors > self.eps) *
(output_anchors < 1 - self.eps)).all(-1, keepdim=True)
output_anchors = paddle.log(output_anchors / (1 - output_anchors))
if memory_mask is not None:
valid_mask = (valid_mask * (memory_mask.unsqueeze(-1) > 0)) > 0
output_anchors = paddle.where(valid_mask, output_anchors,
paddle.to_tensor(float("inf")))
memory = paddle.where(valid_mask, memory, paddle.to_tensor(0.))
output_memory = self.enc_output(memory)
return output_memory, output_anchors
def _get_decoder_input(self,
memory,
spatial_shapes,
memory_mask=None,
denoising_class=None,
denoising_bbox=None):
bs, _, _ = memory.shape
# prepare input for decoder
output_memory, output_anchors = self._get_encoder_output_anchors(
memory, spatial_shapes, memory_mask)
enc_outputs_class = self.enc_score_head(output_memory)
enc_outputs_coord_unact = self.enc_bbox_head(
output_memory) + output_anchors
_, topk_ind = paddle.topk(
enc_outputs_class.max(-1), self.num_queries, axis=1)
# extract region proposal boxes
batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype)
batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries])
topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1)
topk_coords_unact = paddle.gather_nd(enc_outputs_coord_unact,
topk_ind) # unsigmoided.
reference_points = enc_topk_bboxes = F.sigmoid(topk_coords_unact)
if denoising_bbox is not None:
reference_points = paddle.concat([denoising_bbox, enc_topk_bboxes],
1)
enc_topk_logits = paddle.gather_nd(enc_outputs_class, topk_ind)
# extract region features
if self.learnt_init_query:
target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
else:
target = paddle.gather_nd(output_memory, topk_ind).detach()
if denoising_class is not None:
target = paddle.concat([denoising_class, target], 1)
return target, reference_points.detach(
), enc_topk_bboxes, enc_topk_logits
......@@ -16,6 +16,7 @@ python setup_ms_deformable_attn_op.py install
```
# 引入自定义op
from deformable_detr_ops import ms_deformable_attn
# 构造fake input tensor
bs, n_heads, c = 2, 8, 8
query_length, n_levels, n_points = 2, 2, 2
......
......@@ -14,12 +14,15 @@
#
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from detrex (https://github.com/IDEA-Research/detrex)
# Copyright 2022 The IDEA Authors. All rights reserved.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......@@ -117,3 +120,146 @@ def get_valid_ratio(mask):
valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
# [b, 2]
return paddle.stack([valid_ratio_w, valid_ratio_h], -1)
def get_contrastive_denoising_training_group(targets,
num_classes,
num_queries,
class_embed,
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0):
if num_denoising <= 0:
return None, None, None, None
num_gts = [len(t) for t in targets["gt_class"]]
max_gt_num = max(num_gts)
if max_gt_num == 0:
return None, None, None, None
num_group = num_denoising // max_gt_num
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
bs = len(targets["gt_class"])
input_query_class = paddle.full(
[bs, max_gt_num], num_classes, dtype='int32')
input_query_bbox = paddle.zeros([bs, max_gt_num, 4])
pad_gt_mask = paddle.zeros([bs, max_gt_num])
for i in range(bs):
num_gt = num_gts[i]
if num_gt > 0:
input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1)
input_query_bbox[i, :num_gt] = targets["gt_bbox"][i]
pad_gt_mask[i, :num_gt] = 1
# each group has positive and negative queries.
input_query_class = input_query_class.tile([1, 2 * num_group])
input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
# positive and negative mask
negative_gt_mask = paddle.zeros([bs, max_gt_num * 2, 1])
negative_gt_mask[:, max_gt_num:] = 1
negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
positive_gt_mask = 1 - negative_gt_mask
# contrastive denoising training positive index
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
dn_positive_idx = paddle.nonzero(positive_gt_mask)[:, 1]
dn_positive_idx = paddle.split(dn_positive_idx,
[n * num_group for n in num_gts])
# total denoising queries
num_denoising = int(max_gt_num * 2 * num_group)
if label_noise_ratio > 0:
input_query_class = input_query_class.flatten()
pad_gt_mask = pad_gt_mask.flatten()
# half of bbox prob
mask = paddle.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
chosen_idx = paddle.nonzero(mask * pad_gt_mask).squeeze(-1)
# randomly put a new one here
new_label = paddle.randint_like(
chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
input_query_class.scatter_(chosen_idx, new_label)
input_query_class.reshape_([bs, num_denoising])
pad_gt_mask.reshape_([bs, num_denoising])
if box_noise_scale > 0:
known_bbox = bbox_cxcywh_to_xyxy(input_query_bbox)
diff = paddle.tile(input_query_bbox[..., 2:] * 0.5,
[1, 1, 2]) * box_noise_scale
rand_sign = paddle.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
rand_part = paddle.rand(input_query_bbox.shape)
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
1 - negative_gt_mask)
rand_part *= rand_sign
known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0)
input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
input_query_bbox.clip_(min=0.0, max=1.0)
class_embed = paddle.concat(
[class_embed, paddle.zeros([1, class_embed.shape[-1]])])
input_query_class = paddle.gather(
class_embed, input_query_class.flatten(),
axis=0).reshape([bs, num_denoising, -1])
tgt_size = num_denoising + num_queries
attn_mask = paddle.ones([tgt_size, tgt_size]) < 0
# match query cannot see the reconstruct
attn_mask[num_denoising:, :num_denoising] = True
# reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
2 * (i + 1):num_denoising] = True
if i == num_group - 1:
attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
i * 2] = True
else:
attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
2 * (i + 1):num_denoising] = True
attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
2 * i] = True
attn_mask = ~attn_mask
dn_meta = {
"dn_positive_idx": dn_positive_idx,
"dn_num_group": num_group,
"dn_num_split": [num_denoising, num_queries]
}
return input_query_class, input_query_bbox, attn_mask, dn_meta
def get_sine_pos_embed(pos_tensor,
num_pos_feats=128,
temperature=10000,
exchange_xy=True):
"""generate sine position embedding from a position tensor
Args:
pos_tensor (torch.Tensor): Shape as `(None, n)`.
num_pos_feats (int): projected shape for each float in the tensor. Default: 128
temperature (int): The temperature used for scaling
the position embedding. Default: 10000.
exchange_xy (bool, optional): exchange pos x and pos y. \
For example, input tensor is `[x, y]`, the results will # noqa
be `[pos(y), pos(x)]`. Defaults: True.
Returns:
torch.Tensor: Returned position embedding # noqa
with shape `(None, n * num_pos_feats)`.
"""
scale = 2. * math.pi
dim_t = 2. * paddle.floor_divide(
paddle.arange(num_pos_feats), paddle.to_tensor(2))
dim_t = scale / temperature**(dim_t / num_pos_feats)
def sine_func(x):
x *= dim_t
return paddle.stack(
(x[:, :, 0::2].sin(), x[:, :, 1::2].cos()), axis=3).flatten(2)
pos_res = [sine_func(x) for x in pos_tensor.split(pos_tensor.shape[-1], -1)]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = paddle.concat(pos_res, axis=2)
return pos_res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册