diff --git a/configs/rtdetr/README.md b/configs/rtdetr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3a11e87bca3e8533c744fc108ffad9987b314e8a --- /dev/null +++ b/configs/rtdetr/README.md @@ -0,0 +1,41 @@ +# DETRs Beat YOLOs on Real-time Object Detection + +## Introduction +We propose a **R**eal-**T**ime **DE**tection **TR**ansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS. For more details, please refer to our [paper](https://arxiv.org/abs/2304.08069). + +
+ +
+ + +## Model Zoo + +### Model Zoo on COCO + +| Model | Epoch | backbone | input shape | $AP^{val}$ | $AP^{val}_{50}$| Params(M) | FLOPs(G) | T4 TensorRT FP16(FPS) | Pretrained Model | config | +|:--------------:|:-----:|:----------:| :-------:|:--------------------------:|:---------------------------:|:---------:|:--------:| :---------------------: |:------------------------------------------------------------------------------------:|:-------------------------------------------:| +| RT-DETR-R50 | 80 | ResNet-50 | 640 | 53.1 | 71.3 | 42 | 136 | 108 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_r50vd_6x_coco.pdparams) | [config](./rtdetr_r50vd_6x_coco.yml) +| RT-DETR-R101 | 80 | ResNet-101 | 640 | 54.3 | 72.7 | 76 | 259 | 74 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_r101vd_6x_coco.pdparams) | [config](./rtdetr_r101vd_6x_coco.yml) +| RT-DETR-L | 80 | HGNetv2 | 640 | 53.0 | 71.6 | 32 | 110 | 114 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_l_6x_coco.pdparams) | [comming soon](rtdetr_hgnetv2_l_6x_coco.yml) +| RT-DETR-X | 80 | HGNetv2 | 640 | 54.8 | 73.1 | 67 | 234 | 74 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_x_6x_coco.pdparams) | [comming soon](rtdetr_hgnetv2_x_6x_coco.yml) + +**Notes:** +- RT-DETR uses 4GPU to train. +- RT-DETR is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. + +GPU multi-card training +```bash +python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/rtdetr/rtdetr_r50vd_6x_coco.yml --fleet --eval +``` + +## Citations +``` +@misc{lv2023detrs, + title={DETRs Beat YOLOs on Real-time Object Detection}, + author={Wenyu Lv and Shangliang Xu and Yian Zhao and Guanzhong Wang and Jinman Wei and Cheng Cui and Yuning Du and Qingqing Dang and Yi Liu}, + year={2023}, + eprint={2304.08069}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` diff --git a/configs/rtdetr/_base_/optimizer_6x.yml b/configs/rtdetr/_base_/optimizer_6x.yml new file mode 100644 index 0000000000000000000000000000000000000000..5abe2f75a2c0796abca8e4fd2ea847343aea5a71 --- /dev/null +++ b/configs/rtdetr/_base_/optimizer_6x.yml @@ -0,0 +1,19 @@ +epoch: 72 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 1.0 + milestones: [100] + use_warmup: true + - !LinearWarmup + start_factor: 0.001 + steps: 2000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/rtdetr/_base_/rtdetr_r50vd.yml b/configs/rtdetr/_base_/rtdetr_r50vd.yml new file mode 100644 index 0000000000000000000000000000000000000000..fc5fb3ada4f311dad7a3c3c611866e47442e8c45 --- /dev/null +++ b/configs/rtdetr/_base_/rtdetr_r50vd.yml @@ -0,0 +1,71 @@ +architecture: DETR +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams +norm_type: sync_bn +use_ema: True +ema_decay: 0.9999 +ema_decay_type: "exponential" +ema_filter_no_grad: True +hidden_dim: 256 +use_focal_loss: True +eval_size: [640, 640] + + +DETR: + backbone: ResNet + neck: HybridEncoder + transformer: RTDETRTransformer + detr_head: DINOHead + post_process: DETRPostProcess + +ResNet: + # index 0 stands for res2 + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + lr_mult_list: [0.1, 0.1, 0.1, 0.1] + num_stages: 4 + freeze_stem_only: True + +HybridEncoder: + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 1.0 + + +RTDETRTransformer: + num_queries: 300 + position_embed_type: sine + feat_strides: [8, 16, 32] + num_levels: 3 + nhead: 8 + num_decoder_layers: 6 + dim_feedforward: 1024 + dropout: 0.0 + activation: relu + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: False + +DINOHead: + loss: + name: DINOLoss + loss_coeff: {class: 1, bbox: 5, giou: 2} + aux_loss: True + use_vfl: True + matcher: + name: HungarianMatcher + matcher_coeff: {class: 2, bbox: 5, giou: 2} + +DETRPostProcess: + num_top_queries: 300 diff --git a/configs/rtdetr/_base_/rtdetr_reader.yml b/configs/rtdetr/_base_/rtdetr_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..1b6f86c8c50206e71205a8c360b9a248c07d207a --- /dev/null +++ b/configs/rtdetr/_base_/rtdetr_reader.yml @@ -0,0 +1,43 @@ +worker_num: 4 +TrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + batch_transforms: + - BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + batch_size: 4 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 4 + shuffle: false + drop_last: false + + +TestReader: + inputs_def: + image_shape: [3, 640, 640] + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/rtdetr/rtdetr_r101vd_6x_coco.yml b/configs/rtdetr/rtdetr_r101vd_6x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..fd2f55ae1aecf5c63262fa3afaa956b885c1b438 --- /dev/null +++ b/configs/rtdetr/rtdetr_r101vd_6x_coco.yml @@ -0,0 +1,37 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_6x.yml', + '_base_/rtdetr_r50vd.yml', + '_base_/rtdetr_reader.yml', +] + +weights: output/rtdetr_r101vd_6x_coco/model_final +find_unused_parameters: True +log_iter: 200 + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_ssld_pretrained.pdparams + +ResNet: + # index 0 stands for res2 + depth: 101 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + lr_mult_list: [0.01, 0.01, 0.01, 0.01] + num_stages: 4 + freeze_stem_only: True + +HybridEncoder: + hidden_dim: 384 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 384 + nhead: 8 + dim_feedforward: 2048 + dropout: 0. + activation: 'gelu' + expansion: 1.0 diff --git a/configs/rtdetr/rtdetr_r50vd_6x_coco.yml b/configs/rtdetr/rtdetr_r50vd_6x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..51bf4437fb9a476e41e2fdbd5fdd829a2e42893b --- /dev/null +++ b/configs/rtdetr/rtdetr_r50vd_6x_coco.yml @@ -0,0 +1,11 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_6x.yml', + '_base_/rtdetr_r50vd.yml', + '_base_/rtdetr_reader.yml', +] + +weights: output/rtdetr_r50vd_6x_coco/model_final +find_unused_parameters: True +log_iter: 200 diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 2637db43d217e5b9bcbc7900f396f03bf4f5319e..5b8bbcd3b631f410e12f20368b541e988de74585 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -950,7 +950,7 @@ class Gt2SparseTarget(BaseOperator): @register_op class PadMaskBatch(BaseOperator): """ - Pad a batch of samples so they can be divisible by a stride. + Pad a batch of samples so that they can be divisible by a stride. The layout of each image should be 'CHW'. Args: pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure @@ -959,7 +959,7 @@ class PadMaskBatch(BaseOperator): `pad_mask` for transformer. """ - def __init__(self, pad_to_stride=0, return_pad_mask=False): + def __init__(self, pad_to_stride=0, return_pad_mask=True): super(PadMaskBatch, self).__init__() self.pad_to_stride = pad_to_stride self.return_pad_mask = return_pad_mask @@ -984,7 +984,7 @@ class PadMaskBatch(BaseOperator): im_c, im_h, im_w = im.shape[:] padding_im = np.zeros( (im_c, max_shape[1], max_shape[2]), dtype=np.float32) - padding_im[:, :im_h, :im_w] = im + padding_im[:, :im_h, :im_w] = im.astype(np.float32) data['image'] = padding_im if 'semantic' in data and data['semantic'] is not None: semantic = data['semantic'] @@ -1108,12 +1108,13 @@ class PadGT(BaseOperator): self.pad_img = pad_img self.minimum_gtnum = minimum_gtnum - def _impad(self, img: np.ndarray, - *, - shape = None, - padding = None, - pad_val = 0, - padding_mode = 'constant') -> np.ndarray: + def _impad(self, + img: np.ndarray, + *, + shape=None, + padding=None, + pad_val=0, + padding_mode='constant') -> np.ndarray: """Pad the given image to a certain shape or pad on all sides with specified padding mode and padding value. @@ -1169,7 +1170,7 @@ class PadGT(BaseOperator): padding = (padding, padding, padding, padding) else: raise ValueError('Padding must be a int or a 2, or 4 element tuple.' - f'But received {padding}') + f'But received {padding}') # check padding mode assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] @@ -1194,10 +1195,10 @@ class PadGT(BaseOperator): def checkmaxshape(self, samples): maxh, maxw = 0, 0 for sample in samples: - h,w = sample['im_shape'] - if h>maxh: + h, w = sample['im_shape'] + if h > maxh: maxh = h - if w>maxw: + if w > maxw: maxw = w return (maxh, maxw) @@ -1246,7 +1247,8 @@ class PadGT(BaseOperator): sample['difficult'] = pad_diff if 'gt_joints' in sample: num_joints = sample['gt_joints'].shape[1] - pad_gt_joints = np.zeros((num_max_boxes, num_joints, 3), dtype=np.float32) + pad_gt_joints = np.zeros( + (num_max_boxes, num_joints, 3), dtype=np.float32) if num_gt > 0: pad_gt_joints[:num_gt] = sample['gt_joints'] sample['gt_joints'] = pad_gt_joints diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 25f3452993ee7c8fa17e606c30e62238d5098248..206d9a48d0febbba671738c83660753962904bbd 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -501,7 +501,8 @@ class RandomDistort(BaseOperator): brightness=[0.5, 1.5, 0.5], random_apply=True, count=4, - random_channel=False): + random_channel=False, + prob=1.0): super(RandomDistort, self).__init__() self.hue = hue self.saturation = saturation @@ -510,6 +511,7 @@ class RandomDistort(BaseOperator): self.random_apply = random_apply self.count = count self.random_channel = random_channel + self.prob = prob def apply_hue(self, img): low, high, prob = self.hue @@ -563,6 +565,8 @@ class RandomDistort(BaseOperator): return img def apply(self, sample, context=None): + if random.random() > self.prob: + return sample img = sample['image'] if self.random_apply: functions = [ @@ -1488,7 +1492,8 @@ class RandomCrop(BaseOperator): allow_no_crop=True, cover_all_box=False, is_mask_crop=False, - ioumode="iou"): + ioumode="iou", + prob=1.0): super(RandomCrop, self).__init__() self.aspect_ratio = aspect_ratio self.thresholds = thresholds @@ -1498,6 +1503,7 @@ class RandomCrop(BaseOperator): self.cover_all_box = cover_all_box self.is_mask_crop = is_mask_crop self.ioumode = ioumode + self.prob = prob def crop_segms(self, segms, valid_ids, crop, height, width): def _crop_poly(segm, crop): @@ -1588,6 +1594,9 @@ class RandomCrop(BaseOperator): return sample def apply(self, sample, context=None): + if random.random() > self.prob: + return sample + if 'gt_bbox' not in sample: # only used in semi-det as unsup data sample = self.set_fake_bboxes(sample) @@ -2829,22 +2838,23 @@ class RandomShortSideResize(BaseOperator): def get_size_with_aspect_ratio(self, image_shape, size, max_size=None): h, w = image_shape + max_clip = False if max_size is not None: min_original_size = float(min((w, h))) max_original_size = float(max((w, h))) if max_original_size / min_original_size * size > max_size: - size = int( - round(max_size * min_original_size / max_original_size)) + size = int(max_size * min_original_size / max_original_size) + max_clip = True if (w <= h and w == size) or (h <= w and h == size): return (w, h) if w < h: ow = size - oh = int(round(size * h / w)) + oh = int(round(size * h / w)) if not max_clip else max_size else: oh = size - ow = int(round(size * w / h)) + ow = int(round(size * w / h)) if not max_clip else max_size return (ow, oh) diff --git a/ppdet/modeling/architectures/detr.py b/ppdet/modeling/architectures/detr.py index 2d59925859211dddcfe7c1778b4bccbfd0d87d78..7839a1263ffc02a97edf231ff44395c8960a2ec9 100644 --- a/ppdet/modeling/architectures/detr.py +++ b/ppdet/modeling/architectures/detr.py @@ -40,9 +40,9 @@ class DETR(BaseArch): exclude_post_process=False): super(DETR, self).__init__() self.backbone = backbone - self.neck = neck self.transformer = transformer self.detr_head = detr_head + self.neck = neck self.post_process = post_process self.with_mask = with_mask self.exclude_post_process = exclude_post_process @@ -54,6 +54,7 @@ class DETR(BaseArch): # neck kwargs = {'input_shape': backbone.out_shape} neck = create(cfg['neck'], **kwargs) if cfg['neck'] else None + # transformer if neck is not None: kwargs = {'input_shape': neck.out_shape} diff --git a/ppdet/modeling/losses/detr_loss.py b/ppdet/modeling/losses/detr_loss.py index 45a2d5e14d72850363e8fc30b44fbe8b811d0756..24f14c3d4893826f3d660a2765e5e4a5236e44a5 100644 --- a/ppdet/modeling/losses/detr_loss.py +++ b/ppdet/modeling/losses/detr_loss.py @@ -21,7 +21,8 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register from .iou_loss import GIoULoss -from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss +from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss, varifocal_loss_with_logits +from ..bbox_utils import bbox_iou __all__ = ['DETRLoss', 'DINOLoss'] @@ -43,7 +44,10 @@ class DETRLoss(nn.Layer): 'dice': 1 }, aux_loss=True, - use_focal_loss=False): + use_focal_loss=False, + use_vfl=False, + use_uni_match=False, + uni_match_ind=0): r""" Args: num_classes (int): The number of classes. @@ -60,6 +64,9 @@ class DETRLoss(nn.Layer): self.loss_coeff = loss_coeff self.aux_loss = aux_loss self.use_focal_loss = use_focal_loss + self.use_vfl = use_vfl + self.use_uni_match = use_uni_match + self.uni_match_ind = uni_match_ind if not self.use_focal_loss: self.loss_coeff['class'] = paddle.full([num_classes + 1], @@ -73,13 +80,15 @@ class DETRLoss(nn.Layer): match_indices, bg_index, num_gts, - postfix=""): + postfix="", + iou_score=None): # logits: [b, query, num_classes], gt_class: list[[n, 1]] name_class = "loss_class" + postfix target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64') bs, num_query_objects = target_label.shape - if sum(len(a) for a in gt_class) > 0: + num_gt = sum(len(a) for a in gt_class) + if num_gt > 0: index, updates = self._get_index_updates(num_query_objects, gt_class, match_indices) target_label = paddle.scatter( @@ -88,12 +97,23 @@ class DETRLoss(nn.Layer): if self.use_focal_loss: target_label = F.one_hot(target_label, self.num_classes + 1)[..., :-1] - return { - name_class: self.loss_coeff['class'] * sigmoid_focal_loss( - logits, target_label, num_gts / num_query_objects) - if self.use_focal_loss else F.cross_entropy( + if iou_score is not None and self.use_vfl: + target_score = paddle.zeros([bs, num_query_objects]) + if num_gt > 0: + target_score = paddle.scatter( + target_score.reshape([-1, 1]), index, iou_score) + target_score = target_score.reshape( + [bs, num_query_objects, 1]) * target_label + loss_ = self.loss_coeff['class'] * varifocal_loss_with_logits( + logits, target_score, target_label, + num_gts / num_query_objects) + else: + loss_ = self.loss_coeff['class'] * sigmoid_focal_loss( + logits, target_label, num_gts / num_query_objects) + else: + loss_ = F.cross_entropy( logits, target_label, weight=self.loss_coeff['class']) - } + return {name_class: loss_} def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts, postfix=""): @@ -167,9 +187,19 @@ class DETRLoss(nn.Layer): loss_class = [] loss_bbox, loss_giou = [], [] loss_mask, loss_dice = [], [] + if dn_match_indices is not None: + match_indices = dn_match_indices + elif self.use_uni_match: + match_indices = self.matcher( + boxes[self.uni_match_ind], + logits[self.uni_match_ind], + gt_bbox, + gt_class, + masks=masks[self.uni_match_ind] if masks is not None else None, + gt_mask=gt_mask) for i, (aux_boxes, aux_logits) in enumerate(zip(boxes, logits)): aux_masks = masks[i] if masks is not None else None - if dn_match_indices is None: + if not self.use_uni_match and dn_match_indices is None: match_indices = self.matcher( aux_boxes, aux_logits, @@ -177,12 +207,21 @@ class DETRLoss(nn.Layer): gt_class, masks=aux_masks, gt_mask=gt_mask) + if self.use_vfl: + if sum(len(a) for a in gt_bbox) > 0: + src_bbox, target_bbox = self._get_src_target_assign( + aux_boxes.detach(), gt_bbox, match_indices) + iou_score = bbox_iou( + bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), + bbox_cxcywh_to_xyxy(target_bbox).split(4, -1)) + else: + iou_score = None else: - match_indices = dn_match_indices + iou_score = None loss_class.append( self._get_loss_class(aux_logits, gt_class, match_indices, - bg_index, num_gts, postfix)['loss_class' + - postfix]) + bg_index, num_gts, postfix, iou_score)[ + 'loss_class' + postfix]) loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices, num_gts, postfix) loss_bbox.append(loss_['loss_bbox' + postfix]) @@ -252,10 +291,22 @@ class DETRLoss(nn.Layer): else: match_indices = dn_match_indices + if self.use_vfl: + if sum(len(a) for a in gt_bbox) > 0: + src_bbox, target_bbox = self._get_src_target_assign( + boxes.detach(), gt_bbox, match_indices) + iou_score = bbox_iou( + bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), + bbox_cxcywh_to_xyxy(target_bbox).split(4, -1)) + else: + iou_score = None + else: + iou_score = None + loss = dict() loss.update( self._get_loss_class(logits, gt_class, match_indices, - self.num_classes, num_gts, postfix)) + self.num_classes, num_gts, postfix, iou_score)) loss.update( self._get_loss_bbox(boxes, gt_bbox, match_indices, num_gts, postfix)) diff --git a/ppdet/modeling/transformers/__init__.py b/ppdet/modeling/transformers/__init__.py index e20bd6203cedd9eed4905fd3446a167588d33bfc..33a124026563c64dd75e345a809206812ddce749 100644 --- a/ppdet/modeling/transformers/__init__.py +++ b/ppdet/modeling/transformers/__init__.py @@ -20,6 +20,8 @@ from . import deformable_transformer from . import dino_transformer from . import group_detr_transformer from . import mask_dino_transformer +from . import rtdetr_transformer +from . import hybrid_encoder from .detr_transformer import * from .utils import * @@ -30,3 +32,5 @@ from .dino_transformer import * from .petr_transformer import * from .group_detr_transformer import * from .mask_dino_transformer import * +from .rtdetr_transformer import * +from .hybrid_encoder import * diff --git a/ppdet/modeling/transformers/hybrid_encoder.py b/ppdet/modeling/transformers/hybrid_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b64c4ee3ba64996c5bf83c40762d6742500a0a50 --- /dev/null +++ b/ppdet/modeling/transformers/hybrid_encoder.py @@ -0,0 +1,301 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register, serializable +from ppdet.modeling.ops import get_act_fn +from ..shape_spec import ShapeSpec +from ..backbones.csp_darknet import BaseConv +from ..backbones.cspresnet import RepVggBlock +from ppdet.modeling.transformers.detr_transformer import TransformerEncoder +from ..initializer import xavier_uniform_, linear_init_ +from ..layers import MultiHeadAttention +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +__all__ = ['HybridEncoder'] + + +class CSPRepLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + num_blocks=3, + expansion=1.0, + bias=False, + act="silu"): + super(CSPRepLayer, self).__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) + self.conv2 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) + self.bottlenecks = nn.Sequential(*[ + RepVggBlock( + hidden_channels, hidden_channels, act=act) + for _ in range(num_blocks) + ]) + if hidden_channels != out_channels: + self.conv3 = BaseConv( + hidden_channels, + out_channels, + ksize=1, + stride=1, + bias=bias, + act=act) + else: + self.conv3 = nn.Identity() + + def forward(self, x): + x_1 = self.conv1(x) + x_1 = self.bottlenecks(x_1) + x_2 = self.conv2(x) + return self.conv3(x_1 + x_2) + + +@register +class TransformerLayer(nn.Layer): + def __init__(self, + d_model, + nhead, + dim_feedforward=1024, + dropout=0., + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False): + super(TransformerLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train") + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward(self, src, src_mask=None, pos_embed=None): + residual = src + if self.normalize_before: + src = self.norm1(src) + q = k = self.with_pos_embed(src, pos_embed) + src = self.self_attn(q, k, value=src, attn_mask=src_mask) + + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + return src + + +@register +@serializable +class HybridEncoder(nn.Layer): + __shared__ = ['depth_mult', 'act', 'trt', 'eval_size'] + __inject__ = ['encoder_layer'] + + def __init__(self, + in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + hidden_dim=256, + use_encoder_idx=[2], + num_encoder_layers=1, + encoder_layer='TransformerLayer', + pe_temperature=10000, + expansion=1.0, + depth_mult=1.0, + act='silu', + trt=False, + eval_size=None): + super(HybridEncoder, self).__init__() + self.in_channels = in_channels + self.feat_strides = feat_strides + self.hidden_dim = hidden_dim + self.use_encoder_idx = use_encoder_idx + self.num_encoder_layers = num_encoder_layers + self.pe_temperature = pe_temperature + self.eval_size = eval_size + + # channel projection + self.input_proj = nn.LayerList() + for in_channel in in_channels: + self.input_proj.append( + nn.Sequential( + nn.Conv2D( + in_channel, hidden_dim, kernel_size=1, bias_attr=False), + nn.BatchNorm2D( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))))) + # encoder transformer + self.encoder = nn.LayerList([ + TransformerEncoder(encoder_layer, num_encoder_layers) + for _ in range(len(use_encoder_idx)) + ]) + + act = get_act_fn( + act, trt=trt) if act is None or isinstance(act, + (str, dict)) else act + # top-down fpn + self.lateral_convs = nn.LayerList() + self.fpn_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1, 0, -1): + self.lateral_convs.append( + BaseConv( + hidden_dim, hidden_dim, 1, 1, act=act)) + self.fpn_blocks.append( + CSPRepLayer( + hidden_dim * 2, + hidden_dim, + round(3 * depth_mult), + act=act, + expansion=expansion)) + + # bottom-up pan + self.downsample_convs = nn.LayerList() + self.pan_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1): + self.downsample_convs.append( + BaseConv( + hidden_dim, hidden_dim, 3, stride=2, act=act)) + self.pan_blocks.append( + CSPRepLayer( + hidden_dim * 2, + hidden_dim, + round(3 * depth_mult), + act=act, + expansion=expansion)) + + self._reset_parameters() + + def _reset_parameters(self): + if self.eval_size: + for idx in self.use_encoder_idx: + stride = self.feat_strides[idx] + pos_embed = self.build_2d_sincos_position_embedding( + self.eval_size[1] // stride, self.eval_size[0] // stride, + self.hidden_dim, self.pe_temperature) + setattr(self, f'pos_embed{idx}', pos_embed) + + @staticmethod + def build_2d_sincos_position_embedding(w, + h, + embed_dim=256, + temperature=10000.): + grid_w = paddle.arange(int(w), dtype=paddle.float32) + grid_h = paddle.arange(int(h), dtype=paddle.float32) + grid_w, grid_h = paddle.meshgrid(grid_w, grid_h) + assert embed_dim % 4 == 0, \ + 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + pos_dim = embed_dim // 4 + omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim + omega = 1. / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @omega[None] + out_h = grid_h.flatten()[..., None] @omega[None] + + return paddle.concat( + [ + paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h), + paddle.cos(out_h) + ], + axis=1)[None, :, :] + + def forward(self, feats, for_mot=False): + assert len(feats) == len(self.in_channels) + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + # encoder + if self.num_encoder_layers > 0: + for i, enc_ind in enumerate(self.use_encoder_idx): + h, w = proj_feats[enc_ind].shape[2:] + # flatten [B, C, H, W] to [B, HxW, C] + src_flatten = proj_feats[enc_ind].flatten(2).transpose( + [0, 2, 1]) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + w, h, self.hidden_dim, self.pe_temperature) + else: + pos_embed = getattr(self, f'pos_embed{enc_ind}', None) + memory = self.encoder[i](src_flatten, pos_embed=pos_embed) + proj_feats[enc_ind] = memory.transpose([0, 2, 1]).reshape( + [-1, self.hidden_dim, h, w]) + + # top-down fpn + inner_outs = [proj_feats[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = proj_feats[idx - 1] + feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx]( + feat_heigh) + inner_outs[0] = feat_heigh + + upsample_feat = F.interpolate( + feat_heigh, scale_factor=2., mode="nearest") + inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx]( + paddle.concat( + [upsample_feat, feat_low], axis=1)) + inner_outs.insert(0, inner_out) + + # bottom-up pan + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + out = self.pan_blocks[idx](paddle.concat( + [downsample_feat, feat_height], axis=1)) + outs.append(out) + + return outs + + @classmethod + def from_config(cls, cfg, input_shape): + return { + 'in_channels': [i.channels for i in input_shape], + 'feat_strides': [i.stride for i in input_shape] + } + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self.hidden_dim, stride=self.feat_strides[idx]) + for idx in range(len(self.in_channels)) + ] diff --git a/ppdet/modeling/transformers/matchers.py b/ppdet/modeling/transformers/matchers.py index f163a6eeae9b4769ef15deb9521cc0bca7a976d5..72459a3f909806f212a4b204a50a494875589e51 100644 --- a/ppdet/modeling/transformers/matchers.py +++ b/ppdet/modeling/transformers/matchers.py @@ -107,16 +107,15 @@ class HungarianMatcher(nn.Layer): tgt_bbox = paddle.concat(gt_bbox) # Compute the classification cost + out_prob = paddle.gather(out_prob, tgt_ids, axis=1) if self.use_focal_loss: neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-( 1 - out_prob + 1e-8).log()) pos_cost_class = self.alpha * ( (1 - out_prob)**self.gamma) * (-(out_prob + 1e-8).log()) - cost_class = paddle.gather( - pos_cost_class, tgt_ids, axis=1) - paddle.gather( - neg_cost_class, tgt_ids, axis=1) + cost_class = pos_cost_class - neg_cost_class else: - cost_class = -paddle.gather(out_prob, tgt_ids, axis=1) + cost_class = -out_prob # Compute the L1 cost between boxes cost_bbox = ( diff --git a/ppdet/modeling/transformers/rtdetr_transformer.py b/ppdet/modeling/transformers/rtdetr_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..672590edfdee3349570418e9487ff99fe6fafc23 --- /dev/null +++ b/ppdet/modeling/transformers/rtdetr_transformer.py @@ -0,0 +1,546 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Modified from detrex (https://github.com/IDEA-Research/detrex) +# Copyright 2022 The IDEA Authors. All rights reserved. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +from ppdet.core.workspace import register +from ..layers import MultiHeadAttention +from ..heads.detr_head import MLP +from .deformable_transformer import MSDeformableAttention +from ..initializer import (linear_init_, constant_, xavier_uniform_, normal_, + bias_init_with_prob) +from .utils import (_get_clones, get_sine_pos_embed, + get_contrastive_denoising_training_group, inverse_sigmoid) + +__all__ = ['RTDETRTransformer'] + + +class PPMSDeformableAttention(MSDeformableAttention): + def forward(self, + query, + reference_points, + value, + value_spatial_shapes, + value_level_start_index, + value_mask=None): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_level_start_index (List): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...] + value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + Len_v = value.shape[1] + + value = self.value_proj(value) + if value_mask is not None: + value_mask = value_mask.astype(value.dtype).unsqueeze(-1) + value *= value_mask + value = value.reshape([bs, Len_v, self.num_heads, self.head_dim]) + + sampling_offsets = self.sampling_offsets(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2]) + attention_weights = self.attention_weights(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels * self.num_points]) + attention_weights = F.softmax(attention_weights).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points]) + + if reference_points.shape[-1] == 2: + offset_normalizer = paddle.to_tensor(value_spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape( + [1, 1, 1, self.num_levels, 1, 2]) + sampling_locations = reference_points.reshape([ + bs, Len_q, 1, self.num_levels, 1, 2 + ]) + sampling_offsets / offset_normalizer + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + sampling_offsets / + self.num_points * reference_points[:, :, None, :, None, 2:] * + 0.5) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.". + format(reference_points.shape[-1])) + + if not isinstance(query, paddle.Tensor): + from ppdet.modeling.transformers.utils import deformable_attention_core_func + output = deformable_attention_core_func( + value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights) + else: + value_spatial_shapes = paddle.to_tensor(value_spatial_shapes) + value_level_start_index = paddle.to_tensor(value_level_start_index) + output = self.ms_deformable_attn_core( + value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights) + output = self.output_proj(output) + + return output + + +class TransformerDecoderLayer(nn.Layer): + def __init__(self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0., + activation="relu", + n_levels=4, + n_points=4, + weight_attr=None, + bias_attr=None): + super(TransformerDecoderLayer, self).__init__() + + # self attention + self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # cross attention + self.cross_attn = PPMSDeformableAttention(d_model, n_head, n_levels, + n_points, 1.0) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr, + bias_attr) + self.activation = getattr(F, activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr, + bias_attr) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + + def forward(self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + attn_mask=None, + memory_mask=None, + query_pos_embed=None): + # self attention + q = k = self.with_pos_embed(tgt, query_pos_embed) + if attn_mask is not None: + attn_mask = paddle.where( + attn_mask.astype('bool'), + paddle.zeros(attn_mask.shape, tgt.dtype), + paddle.full(attn_mask.shape, float("-inf"), tgt.dtype)) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos_embed), reference_points, memory, + memory_spatial_shapes, memory_level_start_index, memory_mask) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # ffn + tgt2 = self.forward_ffn(tgt) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class TransformerDecoder(nn.Layer): + def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + + def forward(self, + tgt, + ref_points_unact, + memory, + memory_spatial_shapes, + memory_level_start_index, + bbox_head, + score_head, + query_pos_head, + attn_mask=None, + memory_mask=None): + output = tgt + dec_out_bboxes = [] + dec_out_logits = [] + ref_points_detach = F.sigmoid(ref_points_unact) + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = query_pos_head(ref_points_detach) + + output = layer(output, ref_points_input, memory, + memory_spatial_shapes, memory_level_start_index, + attn_mask, memory_mask, query_pos_embed) + + inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid( + ref_points_detach)) + + if self.training: + dec_out_logits.append(score_head[i](output)) + if i == 0: + dec_out_bboxes.append(inter_ref_bbox) + else: + dec_out_bboxes.append( + F.sigmoid(bbox_head[i](output) + inverse_sigmoid( + ref_points))) + elif i == self.eval_idx: + dec_out_logits.append(score_head[i](output)) + dec_out_bboxes.append(inter_ref_bbox) + + ref_points = inter_ref_bbox + ref_points_detach = inter_ref_bbox.detach( + ) if self.training else inter_ref_bbox + + return paddle.stack(dec_out_bboxes), paddle.stack(dec_out_logits) + + +@register +class RTDETRTransformer(nn.Layer): + __shared__ = ['num_classes', 'hidden_dim', 'eval_size'] + + def __init__(self, + num_classes=80, + hidden_dim=256, + num_queries=300, + position_embed_type='sine', + backbone_feat_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + num_levels=3, + num_decoder_points=4, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0., + activation="relu", + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=True, + eval_size=None, + eval_idx=-1, + eps=1e-2): + super(RTDETRTransformer, self).__init__() + assert position_embed_type in ['sine', 'learned'], \ + f'ValueError: position_embed_type not supported {position_embed_type}!' + assert len(backbone_feat_channels) <= num_levels + assert len(feat_strides) == len(backbone_feat_channels) + for _ in range(num_levels - len(feat_strides)): + feat_strides.append(feat_strides[-1] * 2) + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.feat_strides = feat_strides + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = num_queries + self.eps = eps + self.num_decoder_layers = num_decoder_layers + self.eval_size = eval_size + + # backbone feature projection + self._build_input_proj_layer(backbone_feat_channels) + + # Transformer module + decoder_layer = TransformerDecoderLayer( + hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, + num_decoder_points) + self.decoder = TransformerDecoder(hidden_dim, decoder_layer, + num_decoder_layers, eval_idx) + + # denoising part + self.denoising_class_embed = nn.Embedding( + num_classes, + hidden_dim, + weight_attr=ParamAttr(initializer=nn.initializer.Normal())) + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + + # decoder embedding + self.learnt_init_query = learnt_init_query + if learnt_init_query: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))) + self.enc_score_head = nn.Linear(hidden_dim, num_classes) + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) + + # decoder head + self.dec_score_head = nn.LayerList([ + nn.Linear(hidden_dim, num_classes) + for _ in range(num_decoder_layers) + ]) + self.dec_bbox_head = nn.LayerList([ + MLP(hidden_dim, hidden_dim, 4, num_layers=3) + for _ in range(num_decoder_layers) + ]) + + self._reset_parameters() + + def _reset_parameters(self): + # class and bbox head init + bias_cls = bias_init_with_prob(0.01) + linear_init_(self.enc_score_head) + constant_(self.enc_score_head.bias, bias_cls) + constant_(self.enc_bbox_head.layers[-1].weight) + constant_(self.enc_bbox_head.layers[-1].bias) + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + linear_init_(cls_) + constant_(cls_.bias, bias_cls) + constant_(reg_.layers[-1].weight) + constant_(reg_.layers[-1].bias) + + linear_init_(self.enc_output[0]) + xavier_uniform_(self.enc_output[0].weight) + if self.learnt_init_query: + xavier_uniform_(self.tgt_embed.weight) + xavier_uniform_(self.query_pos_head.layers[0].weight) + xavier_uniform_(self.query_pos_head.layers[1].weight) + for l in self.input_proj: + xavier_uniform_(l[0].weight) + + # init encoder output anchors and valid_mask + if self.eval_size: + self.anchors, self.valid_mask = self._generate_anchors() + + @classmethod + def from_config(cls, cfg, input_shape): + return {'backbone_feat_channels': [i.channels for i in input_shape]} + + def _build_input_proj_layer(self, backbone_feat_channels): + self.input_proj = nn.LayerList() + for in_channels in backbone_feat_channels: + self.input_proj.append( + nn.Sequential( + ('conv', nn.Conv2D( + in_channels, + self.hidden_dim, + kernel_size=1, + bias_attr=False)), ('norm', nn.BatchNorm2D( + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))) + in_channels = backbone_feat_channels[-1] + for _ in range(self.num_levels - len(backbone_feat_channels)): + self.input_proj.append( + nn.Sequential( + ('conv', nn.Conv2D( + in_channels, + self.hidden_dim, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False)), ('norm', nn.BatchNorm2D( + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats): + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + + # get encoder inputs + feat_flatten = [] + spatial_shapes = [] + level_start_index = [0, ] + for i, feat in enumerate(proj_feats): + _, _, h, w = feat.shape + # [b, c, h, w] -> [b, h*w, c] + feat_flatten.append(feat.flatten(2).transpose([0, 2, 1])) + # [num_levels, 2] + spatial_shapes.append([h, w]) + # [l], start index of each level + level_start_index.append(h * w + level_start_index[-1]) + + # [b, l, c] + feat_flatten = paddle.concat(feat_flatten, 1) + level_start_index.pop() + return (feat_flatten, spatial_shapes, level_start_index) + + def forward(self, feats, pad_mask=None, gt_meta=None): + # input projection and embedding + (memory, spatial_shapes, + level_start_index) = self._get_encoder_input(feats) + + # prepare denoising training + if self.training: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \ + get_contrastive_denoising_training_group(gt_meta, + self.num_classes, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale) + else: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None + + target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \ + self._get_decoder_input( + memory, spatial_shapes, denoising_class, denoising_bbox_unact) + + # decoder + out_bboxes, out_logits = self.decoder( + target, + init_ref_points_unact, + memory, + spatial_shapes, + level_start_index, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_mask) + return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, + dn_meta) + + def _generate_anchors(self, + spatial_shapes=None, + grid_size=0.05, + dtype="float32"): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.eval_size[0] / s), int(self.eval_size[1] / s)] + for s in self.feat_strides + ] + anchors = [] + for lvl, (h, w) in enumerate(spatial_shapes): + grid_y, grid_x = paddle.meshgrid( + paddle.arange( + end=h, dtype=dtype), + paddle.arange( + end=w, dtype=dtype)) + grid_xy = paddle.stack([grid_x, grid_y], -1) + + valid_WH = paddle.to_tensor([h, w]).astype(dtype) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH + wh = paddle.ones_like(grid_xy) * grid_size * (2.0**lvl) + anchors.append( + paddle.concat([grid_xy, wh], -1).reshape([-1, h * w, 4])) + + anchors = paddle.concat(anchors, 1) + valid_mask = ((anchors > self.eps) * + (anchors < 1 - self.eps)).all(-1, keepdim=True) + anchors = paddle.log(anchors / (1 - anchors)) + anchors = paddle.where(valid_mask, anchors, + paddle.to_tensor(float("inf"))) + return anchors, valid_mask + + def _get_decoder_input(self, + memory, + spatial_shapes, + denoising_class=None, + denoising_bbox_unact=None): + bs, _, _ = memory.shape + # prepare input for decoder + if self.training or self.eval_size is None: + anchors, valid_mask = self._generate_anchors(spatial_shapes) + else: + anchors, valid_mask = self.anchors, self.valid_mask + memory = paddle.where(valid_mask, memory, paddle.to_tensor(0.)) + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = paddle.topk( + enc_outputs_class.max(-1), self.num_queries, axis=1) + # extract region proposal boxes + batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype) + batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries]) + topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1) + + reference_points_unact = paddle.gather_nd(enc_outputs_coord_unact, + topk_ind) # unsigmoided. + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = paddle.concat( + [denoising_bbox_unact, reference_points_unact], 1) + if self.training: + reference_points_unact = reference_points_unact.detach() + enc_topk_logits = paddle.gather_nd(enc_outputs_class, topk_ind) + + # extract region features + if self.learnt_init_query: + target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + else: + target = paddle.gather_nd(output_memory, topk_ind) + if self.training: + target = target.detach() + if denoising_class is not None: + target = paddle.concat([denoising_class, target], 1) + + return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits diff --git a/ppdet/modeling/transformers/utils.py b/ppdet/modeling/transformers/utils.py index b19233fdeecfbb9084a65c66795eabd81440d928..a6f211a78f21de6bc15b9332dc0f823dedbe6efa 100644 --- a/ppdet/modeling/transformers/utils.py +++ b/ppdet/modeling/transformers/utils.py @@ -32,7 +32,7 @@ from ..bbox_utils import bbox_overlaps __all__ = [ '_get_clones', 'bbox_overlaps', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh', 'sigmoid_focal_loss', 'inverse_sigmoid', - 'deformable_attention_core_func' + 'deformable_attention_core_func', 'varifocal_loss_with_logits' ] @@ -395,3 +395,16 @@ def mask_to_box_coordinate(mask, out_bbox /= paddle.to_tensor([w, h, w, h]).astype(dtype) return out_bbox if format == "xyxy" else bbox_xyxy_to_cxcywh(out_bbox) + + +def varifocal_loss_with_logits(pred_logits, + gt_score, + label, + normalizer=1.0, + alpha=0.75, + gamma=2.0): + pred_score = F.sigmoid(pred_logits) + weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label + loss = F.binary_cross_entropy_with_logits( + pred_logits, gt_score, weight=weight, reduction='none') + return loss.mean(1).sum() / normalizer