未验证 提交 918b7542 编写于 作者: S shangliang Xu 提交者: GitHub

add DINO (#7583)

* add DINO

add msdeformable attention cuda op

fix export bug

* fix decoder norm
上级 ca1f803a
# DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection
## Introduction
[DINO](https://arxiv.org/abs/2203.03605) is an object detection model based on DETR. We reproduced the model of the paper.
## Model Zoo
| Backbone | Model | Epochs | Box AP | Config | Download |
|:------:|:---------------:|:------:|:------:|:---------------------------------------:|:--------------------------------------------------------------------------------:|
| R-50 | dino_r50_4scale | 12 | 49.3 | [config](./dino_r50_4scale_1x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/dino_r50_4scale_1x_coco.pdparams) |
| R-50 | dino_r50_4scale | 24 | 50.8 | [config](./dino_r50_4scale_2x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/dino_r50_4scale_2x_coco.pdparams) |
**Notes:**
- DINO is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
- DINO uses 4GPU to train.
GPU multi-card training
```bash
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/dino/dino_r50_4scale_1x_coco.yml --fleet --eval
```
## Custom Operator
- Multi-scale deformable attention custom operator see [here](../../ppdet/modeling/transformers/ext_op).
## Citations
```
@misc{zhang2022dino,
title={DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection},
author={Hao Zhang and Feng Li and Shilong Liu and Lei Zhang and Hang Su and Jun Zhu and Lionel M. Ni and Heung-Yeung Shum},
year={2022},
eprint={2203.03605},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
architecture: DETR
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
hidden_dim: 256
use_focal_loss: True
DETR:
backbone: ResNet
transformer: DINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.0, 0.1, 0.1, 0.1]
num_stages: 4
DINOTransformer:
num_queries: 900
position_embed_type: sine
num_levels: 4
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
dropout: 0.0
activation: relu
pe_temperature: 20
pe_offset: 0.0
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: True
DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
num_top_queries: 300
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
RandomSizeCrop: { min_size: 384, max_size: 600 },
RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ]
}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
epoch: 12
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [11]
use_warmup: false
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
epoch: 24
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [20]
use_warmup: false
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/dino_r50.yml',
'_base_/dino_reader.yml',
]
weights: output/dino_r50_4scale_1x_coco/model_final
find_unused_parameters: True
log_iter: 100
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_2x.yml',
'_base_/dino_r50.yml',
'_base_/dino_reader.yml',
]
weights: output/dino_r50_4scale_2x_coco/model_final
find_unused_parameters: True
log_iter: 100
......@@ -69,7 +69,7 @@ class DETR(BaseArch):
# Transformer
pad_mask = self.inputs['pad_mask'] if self.training else None
out_transformer = self.transformer(body_feats, pad_mask)
out_transformer = self.transformer(body_feats, pad_mask, self.inputs)
# DETR Head
if self.training:
......
......@@ -24,7 +24,7 @@ import pycocotools.mask as mask_util
from ..initializer import linear_init_, constant_
from ..transformers.utils import inverse_sigmoid
__all__ = ['DETRHead', 'DeformableDETRHead']
__all__ = ['DETRHead', 'DeformableDETRHead', 'DINOHead']
class MLP(nn.Layer):
......@@ -362,3 +362,43 @@ class DeformableDETRHead(nn.Layer):
inputs['gt_class'])
else:
return (outputs_bbox[-1], outputs_logit[-1], None)
@register
class DINOHead(nn.Layer):
__inject__ = ['loss']
def __init__(self, loss='DINOLoss'):
super(DINOHead, self).__init__()
self.loss = loss
def forward(self, out_transformer, body_feats, inputs=None):
(dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits,
dn_meta) = out_transformer
if self.training:
assert inputs is not None
assert 'gt_bbox' in inputs and 'gt_class' in inputs
if dn_meta is not None:
dn_out_bboxes, dec_out_bboxes = paddle.split(
dec_out_bboxes, dn_meta['dn_num_split'], axis=2)
dn_out_logits, dec_out_logits = paddle.split(
dec_out_logits, dn_meta['dn_num_split'], axis=2)
else:
dn_out_bboxes, dn_out_logits = None, None
out_bboxes = paddle.concat(
[enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
out_logits = paddle.concat(
[enc_topk_logits.unsqueeze(0), dec_out_logits])
return self.loss(
out_bboxes,
out_logits,
inputs['gt_bbox'],
inputs['gt_class'],
dn_out_bboxes=dn_out_bboxes,
dn_out_logits=dn_out_logits,
dn_meta=dn_meta)
else:
return (dec_out_bboxes[-1], dec_out_logits[-1], None)
......@@ -23,7 +23,7 @@ from ppdet.core.workspace import register
from .iou_loss import GIoULoss
from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
__all__ = ['DETRLoss']
__all__ = ['DETRLoss', 'DINOLoss']
@register
......@@ -67,9 +67,17 @@ class DETRLoss(nn.Layer):
self.loss_coeff['class'][-1] = loss_coeff['no_object']
self.giou_loss = GIoULoss()
def _get_loss_class(self, logits, gt_class, match_indices, bg_index,
num_gts):
def _get_loss_class(self,
logits,
gt_class,
match_indices,
bg_index,
num_gts,
postfix=""):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = "loss_class" + postfix
if logits is None:
return {name_class: paddle.zeros([1])}
target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
bs, num_query_objects = target_label.shape
if sum(len(a) for a in gt_class) > 0:
......@@ -82,36 +90,46 @@ class DETRLoss(nn.Layer):
target_label = F.one_hot(target_label,
self.num_classes + 1)[..., :-1]
return {
'loss_class': self.loss_coeff['class'] * sigmoid_focal_loss(
name_class: self.loss_coeff['class'] * sigmoid_focal_loss(
logits, target_label, num_gts / num_query_objects)
if self.use_focal_loss else F.cross_entropy(
logits, target_label, weight=self.loss_coeff['class'])
}
def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts):
def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
postfix=""):
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = "loss_bbox" + postfix
name_giou = "loss_giou" + postfix
if boxes is None:
return {name_bbox: paddle.zeros([1]), name_giou: paddle.zeros([1])}
loss = dict()
if sum(len(a) for a in gt_bbox) == 0:
loss['loss_bbox'] = paddle.to_tensor([0.])
loss['loss_giou'] = paddle.to_tensor([0.])
loss[name_bbox] = paddle.to_tensor([0.])
loss[name_giou] = paddle.to_tensor([0.])
return loss
src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox,
match_indices)
loss['loss_bbox'] = self.loss_coeff['bbox'] * F.l1_loss(
loss[name_bbox] = self.loss_coeff['bbox'] * F.l1_loss(
src_bbox, target_bbox, reduction='sum') / num_gts
loss['loss_giou'] = self.giou_loss(
loss[name_giou] = self.giou_loss(
bbox_cxcywh_to_xyxy(src_bbox), bbox_cxcywh_to_xyxy(target_bbox))
loss['loss_giou'] = loss['loss_giou'].sum() / num_gts
loss['loss_giou'] = self.loss_coeff['giou'] * loss['loss_giou']
loss[name_giou] = loss[name_giou].sum() / num_gts
loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
return loss
def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts):
def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts,
postfix=""):
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
name_mask = "loss_mask" + postfix
name_dice = "loss_dice" + postfix
if masks is None:
return {name_mask: paddle.zeros([1]), name_dice: paddle.zeros([1])}
loss = dict()
if sum(len(a) for a in gt_mask) == 0:
loss['loss_mask'] = paddle.to_tensor([0.])
loss['loss_dice'] = paddle.to_tensor([0.])
loss[name_mask] = paddle.to_tensor([0.])
loss[name_dice] = paddle.to_tensor([0.])
return loss
src_masks, target_masks = self._get_src_target_assign(masks, gt_mask,
......@@ -120,12 +138,12 @@ class DETRLoss(nn.Layer):
src_masks.unsqueeze(0),
size=target_masks.shape[-2:],
mode="bilinear")[0]
loss['loss_mask'] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
loss[name_mask] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
src_masks,
target_masks,
paddle.to_tensor(
[num_gts], dtype='float32'))
loss['loss_dice'] = self.loss_coeff['dice'] * self._dice_loss(
loss[name_dice] = self.loss_coeff['dice'] * self._dice_loss(
src_masks, target_masks, num_gts)
return loss
......@@ -138,25 +156,40 @@ class DETRLoss(nn.Layer):
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_gts
def _get_loss_aux(self, boxes, logits, gt_bbox, gt_class, bg_index,
num_gts):
def _get_loss_aux(self,
boxes,
logits,
gt_bbox,
gt_class,
bg_index,
num_gts,
match_indices=None,
postfix=""):
if boxes is None and logits is None:
return {
"loss_class_aux" + postfix: paddle.paddle.zeros([1]),
"loss_bbox_aux" + postfix: paddle.paddle.zeros([1]),
"loss_giou_aux" + postfix: paddle.paddle.zeros([1])
}
loss_class = []
loss_bbox = []
loss_giou = []
for aux_boxes, aux_logits in zip(boxes, logits):
match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
gt_class)
if match_indices is None:
match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
gt_class)
loss_class.append(
self._get_loss_class(aux_logits, gt_class, match_indices,
bg_index, num_gts)['loss_class'])
bg_index, num_gts, postfix)['loss_class' +
postfix])
loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
num_gts)
loss_bbox.append(loss_['loss_bbox'])
loss_giou.append(loss_['loss_giou'])
num_gts, postfix)
loss_bbox.append(loss_['loss_bbox' + postfix])
loss_giou.append(loss_['loss_giou' + postfix])
loss = {
'loss_class_aux': paddle.add_n(loss_class),
'loss_bbox_aux': paddle.add_n(loss_bbox),
'loss_giou_aux': paddle.add_n(loss_giou)
"loss_class_aux" + postfix: paddle.add_n(loss_class),
"loss_bbox_aux" + postfix: paddle.add_n(loss_bbox),
"loss_giou_aux" + postfix: paddle.add_n(loss_giou)
}
return loss
......@@ -191,40 +224,105 @@ class DETRLoss(nn.Layer):
gt_bbox,
gt_class,
masks=None,
gt_mask=None):
gt_mask=None,
postfix="",
**kwargs):
r"""
Args:
boxes (Tensor): [l, b, query, 4]
logits (Tensor): [l, b, query, num_classes]
boxes (Tensor|None): [l, b, query, 4]
logits (Tensor|None): [l, b, query, num_classes]
gt_bbox (List(Tensor)): list[[n, 4]]
gt_class (List(Tensor)): list[[n, 1]]
masks (Tensor, optional): [b, query, h, w]
gt_mask (List(Tensor), optional): list[[n, H, W]]
postfix (str): postfix of loss name
"""
match_indices = self.matcher(boxes[-1].detach(), logits[-1].detach(),
gt_bbox, gt_class)
if "match_indices" in kwargs:
match_indices = kwargs["match_indices"]
else:
match_indices = self.matcher(boxes[-1].detach(),
logits[-1].detach(), gt_bbox, gt_class)
num_gts = sum(len(a) for a in gt_bbox)
try:
# TODO: Paddle does not have a "paddle.distributed.is_initialized()"
num_gts = paddle.to_tensor([num_gts], dtype=paddle.float32)
num_gts = paddle.to_tensor([num_gts], dtype="float32")
if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(num_gts)
num_gts = paddle.clip(
num_gts / paddle.distributed.get_world_size(), min=1).item()
except:
num_gts = max(num_gts.item(), 1)
num_gts /= paddle.distributed.get_world_size()
num_gts = paddle.clip(num_gts, min=1.) * kwargs.get("dn_num_group", 1.)
total_loss = dict()
total_loss.update(
self._get_loss_class(logits[-1], gt_class, match_indices,
self.num_classes, num_gts))
self._get_loss_class(logits[
-1] if logits is not None else None, gt_class, match_indices,
self.num_classes, num_gts, postfix))
total_loss.update(
self._get_loss_bbox(boxes[-1], gt_bbox, match_indices, num_gts))
self._get_loss_bbox(boxes[-1] if boxes is not None else None,
gt_bbox, match_indices, num_gts, postfix))
if masks is not None and gt_mask is not None:
total_loss.update(
self._get_loss_mask(masks, gt_mask, match_indices, num_gts))
self._get_loss_mask(masks if masks is not None else None,
gt_mask, match_indices, num_gts, postfix))
if self.aux_loss:
if "match_indices" not in kwargs:
match_indices = None
total_loss.update(
self._get_loss_aux(boxes[:-1], logits[:-1], gt_bbox, gt_class,
self.num_classes, num_gts))
self._get_loss_aux(
boxes[:-1] if boxes is not None else None, logits[:-1]
if logits is not None else None, gt_bbox, gt_class,
self.num_classes, num_gts, match_indices, postfix))
return total_loss
@register
class DINOLoss(DETRLoss):
def forward(self,
boxes,
logits,
gt_bbox,
gt_class,
masks=None,
gt_mask=None,
postfix="",
dn_out_bboxes=None,
dn_out_logits=None,
dn_meta=None,
**kwargs):
total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox,
gt_class)
# denoising training loss
if dn_meta is not None:
dn_positive_idx, dn_num_group = \
dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
assert len(gt_class) == len(dn_positive_idx)
# denoising match indices
dn_match_indices = []
for i in range(len(gt_class)):
num_gt = len(gt_class[i])
if num_gt > 0:
gt_idx = paddle.arange(end=num_gt, dtype="int64")
gt_idx = gt_idx.unsqueeze(0).tile(
[dn_num_group, 1]).flatten()
assert len(gt_idx) == len(dn_positive_idx[i])
dn_match_indices.append((dn_positive_idx[i], gt_idx))
else:
dn_match_indices.append((paddle.zeros(
[0], dtype="int64"), paddle.zeros(
[0], dtype="int64")))
else:
dn_match_indices, dn_num_group = None, 1.
dn_loss = super(DINOLoss, self).forward(
dn_out_bboxes,
dn_out_logits,
gt_bbox,
gt_class,
postfix="_dn",
match_indices=dn_match_indices,
dn_num_group=dn_num_group)
total_loss.update(dn_loss)
return total_loss
......@@ -17,9 +17,11 @@ from . import utils
from . import matchers
from . import position_encoding
from . import deformable_transformer
from . import dino_transformer
from .detr_transformer import *
from .utils import *
from .matchers import *
from .position_encoding import *
from .deformable_transformer import *
from .dino_transformer import *
......@@ -131,14 +131,24 @@ class MSDeformableAttention(nn.Layer):
[bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2])
attention_weights = self.attention_weights(query).reshape(
[bs, Len_q, self.num_heads, self.num_levels * self.num_points])
attention_weights = F.softmax(attention_weights, -1).reshape(
attention_weights = F.softmax(attention_weights).reshape(
[bs, Len_q, self.num_heads, self.num_levels, self.num_points])
offset_normalizer = value_spatial_shapes.flip([1]).reshape(
[1, 1, 1, self.num_levels, 1, 2])
sampling_locations = reference_points.reshape([
bs, Len_q, 1, self.num_levels, 1, 2
]) + sampling_offsets / offset_normalizer
if reference_points.shape[-1] == 2:
offset_normalizer = value_spatial_shapes.flip([1]).reshape(
[1, 1, 1, self.num_levels, 1, 2])
sampling_locations = reference_points.reshape([
bs, Len_q, 1, self.num_levels, 1, 2
]) + sampling_offsets / offset_normalizer
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2] + sampling_offsets /
self.num_points * reference_points[:, :, None, :, None, 2:] *
0.5)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".
format(reference_points.shape[-1]))
output = self.ms_deformable_attn_core(
value, value_spatial_shapes, value_level_start_index,
......
此差异已折叠。
......@@ -16,6 +16,7 @@ python setup_ms_deformable_attn_op.py install
```
# 引入自定义op
from deformable_detr_ops import ms_deformable_attn
# 构造fake input tensor
bs, n_heads, c = 2, 8, 8
query_length, n_levels, n_points = 2, 2, 2
......
......@@ -14,12 +14,15 @@
#
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from detrex (https://github.com/IDEA-Research/detrex)
# Copyright 2022 The IDEA Authors. All rights reserved.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......@@ -117,3 +120,146 @@ def get_valid_ratio(mask):
valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
# [b, 2]
return paddle.stack([valid_ratio_w, valid_ratio_h], -1)
def get_contrastive_denoising_training_group(targets,
num_classes,
num_queries,
class_embed,
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0):
if num_denoising <= 0:
return None, None, None, None
num_gts = [len(t) for t in targets["gt_class"]]
max_gt_num = max(num_gts)
if max_gt_num == 0:
return None, None, None, None
num_group = num_denoising // max_gt_num
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
bs = len(targets["gt_class"])
input_query_class = paddle.full(
[bs, max_gt_num], num_classes, dtype='int32')
input_query_bbox = paddle.zeros([bs, max_gt_num, 4])
pad_gt_mask = paddle.zeros([bs, max_gt_num])
for i in range(bs):
num_gt = num_gts[i]
if num_gt > 0:
input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1)
input_query_bbox[i, :num_gt] = targets["gt_bbox"][i]
pad_gt_mask[i, :num_gt] = 1
# each group has positive and negative queries.
input_query_class = input_query_class.tile([1, 2 * num_group])
input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
# positive and negative mask
negative_gt_mask = paddle.zeros([bs, max_gt_num * 2, 1])
negative_gt_mask[:, max_gt_num:] = 1
negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
positive_gt_mask = 1 - negative_gt_mask
# contrastive denoising training positive index
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
dn_positive_idx = paddle.nonzero(positive_gt_mask)[:, 1]
dn_positive_idx = paddle.split(dn_positive_idx,
[n * num_group for n in num_gts])
# total denoising queries
num_denoising = int(max_gt_num * 2 * num_group)
if label_noise_ratio > 0:
input_query_class = input_query_class.flatten()
pad_gt_mask = pad_gt_mask.flatten()
# half of bbox prob
mask = paddle.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
chosen_idx = paddle.nonzero(mask * pad_gt_mask).squeeze(-1)
# randomly put a new one here
new_label = paddle.randint_like(
chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
input_query_class.scatter_(chosen_idx, new_label)
input_query_class.reshape_([bs, num_denoising])
pad_gt_mask.reshape_([bs, num_denoising])
if box_noise_scale > 0:
known_bbox = bbox_cxcywh_to_xyxy(input_query_bbox)
diff = paddle.tile(input_query_bbox[..., 2:] * 0.5,
[1, 1, 2]) * box_noise_scale
rand_sign = paddle.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
rand_part = paddle.rand(input_query_bbox.shape)
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
1 - negative_gt_mask)
rand_part *= rand_sign
known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0)
input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
input_query_bbox.clip_(min=0.0, max=1.0)
class_embed = paddle.concat(
[class_embed, paddle.zeros([1, class_embed.shape[-1]])])
input_query_class = paddle.gather(
class_embed, input_query_class.flatten(),
axis=0).reshape([bs, num_denoising, -1])
tgt_size = num_denoising + num_queries
attn_mask = paddle.ones([tgt_size, tgt_size]) < 0
# match query cannot see the reconstruct
attn_mask[num_denoising:, :num_denoising] = True
# reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
2 * (i + 1):num_denoising] = True
if i == num_group - 1:
attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
i * 2] = True
else:
attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
2 * (i + 1):num_denoising] = True
attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
2 * i] = True
attn_mask = ~attn_mask
dn_meta = {
"dn_positive_idx": dn_positive_idx,
"dn_num_group": num_group,
"dn_num_split": [num_denoising, num_queries]
}
return input_query_class, input_query_bbox, attn_mask, dn_meta
def get_sine_pos_embed(pos_tensor,
num_pos_feats=128,
temperature=10000,
exchange_xy=True):
"""generate sine position embedding from a position tensor
Args:
pos_tensor (torch.Tensor): Shape as `(None, n)`.
num_pos_feats (int): projected shape for each float in the tensor. Default: 128
temperature (int): The temperature used for scaling
the position embedding. Default: 10000.
exchange_xy (bool, optional): exchange pos x and pos y. \
For example, input tensor is `[x, y]`, the results will # noqa
be `[pos(y), pos(x)]`. Defaults: True.
Returns:
torch.Tensor: Returned position embedding # noqa
with shape `(None, n * num_pos_feats)`.
"""
scale = 2. * math.pi
dim_t = 2. * paddle.floor_divide(
paddle.arange(num_pos_feats), paddle.to_tensor(2))
dim_t = scale / temperature**(dim_t / num_pos_feats)
def sine_func(x):
x *= dim_t
return paddle.stack(
(x[:, :, 0::2].sin(), x[:, :, 1::2].cos()), axis=3).flatten(2)
pos_res = [sine_func(x) for x in pos_tensor.split(pos_tensor.shape[-1], -1)]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = paddle.concat(pos_res, axis=2)
return pos_res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册