未验证 提交 5d1f8883 编写于 作者: W Wenyu 提交者: GitHub

add rtdetr final (#8094)

* [exp] add r50vd in dino

add yoloe reader

alter reference points to unsigmoid

fix amp training

alter usage in paddle-inference

update new base

alter ext_ops

add hybrid encoder

* add pp rt-detr

---------
Co-authored-by: Nghostxsl <451323469@qq.com>
上级 92752b02
# DETRs Beat YOLOs on Real-time Object Detection
## Introduction
We propose a **R**eal-**T**ime **DE**tection **TR**ansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS. For more details, please refer to our [paper](https://arxiv.org/abs/2304.08069).
<div align="center">
<img src="https://user-images.githubusercontent.com/17582080/232390925-54e58fe6-1c17-4610-90b9-7e5525577d80.png" width=500 />
</div>
## Model Zoo
### Model Zoo on COCO
| Model | Epoch | backbone | input shape | $AP^{val}$ | $AP^{val}_{50}$| Params(M) | FLOPs(G) | T4 TensorRT FP16(FPS) | Pretrained Model | config |
|:--------------:|:-----:|:----------:| :-------:|:--------------------------:|:---------------------------:|:---------:|:--------:| :---------------------: |:------------------------------------------------------------------------------------:|:-------------------------------------------:|
| RT-DETR-R50 | 80 | ResNet-50 | 640 | 53.1 | 71.3 | 42 | 136 | 108 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_r50vd_6x_coco.pdparams) | [config](./rtdetr_r50vd_6x_coco.yml)
| RT-DETR-R101 | 80 | ResNet-101 | 640 | 54.3 | 72.7 | 76 | 259 | 74 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_r101vd_6x_coco.pdparams) | [config](./rtdetr_r101vd_6x_coco.yml)
| RT-DETR-L | 80 | HGNetv2 | 640 | 53.0 | 71.6 | 32 | 110 | 114 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_l_6x_coco.pdparams) | [comming soon](rtdetr_hgnetv2_l_6x_coco.yml)
| RT-DETR-X | 80 | HGNetv2 | 640 | 54.8 | 73.1 | 67 | 234 | 74 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_x_6x_coco.pdparams) | [comming soon](rtdetr_hgnetv2_x_6x_coco.yml)
**Notes:**
- RT-DETR uses 4GPU to train.
- RT-DETR is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
GPU multi-card training
```bash
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/rtdetr/rtdetr_r50vd_6x_coco.yml --fleet --eval
```
## Citations
```
@misc{lv2023detrs,
title={DETRs Beat YOLOs on Real-time Object Detection},
author={Wenyu Lv and Shangliang Xu and Yian Zhao and Guanzhong Wang and Jinman Wei and Cheng Cui and Yuning Du and Qingqing Dang and Yi Liu},
year={2023},
eprint={2304.08069},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
epoch: 72
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 1.0
milestones: [100]
use_warmup: true
- !LinearWarmup
start_factor: 0.001
steps: 2000
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
architecture: DETR
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams
norm_type: sync_bn
use_ema: True
ema_decay: 0.9999
ema_decay_type: "exponential"
ema_filter_no_grad: True
hidden_dim: 256
use_focal_loss: True
eval_size: [640, 640]
DETR:
backbone: ResNet
neck: HybridEncoder
transformer: RTDETRTransformer
detr_head: DINOHead
post_process: DETRPostProcess
ResNet:
# index 0 stands for res2
depth: 50
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.1, 0.1, 0.1, 0.1]
num_stages: 4
freeze_stem_only: True
HybridEncoder:
hidden_dim: 256
use_encoder_idx: [2]
num_encoder_layers: 1
encoder_layer:
name: TransformerLayer
d_model: 256
nhead: 8
dim_feedforward: 1024
dropout: 0.
activation: 'gelu'
expansion: 1.0
RTDETRTransformer:
num_queries: 300
position_embed_type: sine
feat_strides: [8, 16, 32]
num_levels: 3
nhead: 8
num_decoder_layers: 6
dim_feedforward: 1024
dropout: 0.0
activation: relu
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: False
DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
use_vfl: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRPostProcess:
num_top_queries: 300
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {prob: 0.8}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {prob: 0.8}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 4
shuffle: false
drop_last: false
TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
shuffle: false
drop_last: false
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_6x.yml',
'_base_/rtdetr_r50vd.yml',
'_base_/rtdetr_reader.yml',
]
weights: output/rtdetr_r101vd_6x_coco/model_final
find_unused_parameters: True
log_iter: 200
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_ssld_pretrained.pdparams
ResNet:
# index 0 stands for res2
depth: 101
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.01, 0.01, 0.01, 0.01]
num_stages: 4
freeze_stem_only: True
HybridEncoder:
hidden_dim: 384
use_encoder_idx: [2]
num_encoder_layers: 1
encoder_layer:
name: TransformerLayer
d_model: 384
nhead: 8
dim_feedforward: 2048
dropout: 0.
activation: 'gelu'
expansion: 1.0
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_6x.yml',
'_base_/rtdetr_r50vd.yml',
'_base_/rtdetr_reader.yml',
]
weights: output/rtdetr_r50vd_6x_coco/model_final
find_unused_parameters: True
log_iter: 200
......@@ -950,7 +950,7 @@ class Gt2SparseTarget(BaseOperator):
@register_op
class PadMaskBatch(BaseOperator):
"""
Pad a batch of samples so they can be divisible by a stride.
Pad a batch of samples so that they can be divisible by a stride.
The layout of each image should be 'CHW'.
Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
......@@ -959,7 +959,7 @@ class PadMaskBatch(BaseOperator):
`pad_mask` for transformer.
"""
def __init__(self, pad_to_stride=0, return_pad_mask=False):
def __init__(self, pad_to_stride=0, return_pad_mask=True):
super(PadMaskBatch, self).__init__()
self.pad_to_stride = pad_to_stride
self.return_pad_mask = return_pad_mask
......@@ -984,7 +984,7 @@ class PadMaskBatch(BaseOperator):
im_c, im_h, im_w = im.shape[:]
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
padding_im[:, :im_h, :im_w] = im.astype(np.float32)
data['image'] = padding_im
if 'semantic' in data and data['semantic'] is not None:
semantic = data['semantic']
......@@ -1108,12 +1108,13 @@ class PadGT(BaseOperator):
self.pad_img = pad_img
self.minimum_gtnum = minimum_gtnum
def _impad(self, img: np.ndarray,
*,
shape = None,
padding = None,
pad_val = 0,
padding_mode = 'constant') -> np.ndarray:
def _impad(self,
img: np.ndarray,
*,
shape=None,
padding=None,
pad_val=0,
padding_mode='constant') -> np.ndarray:
"""Pad the given image to a certain shape or pad on all sides with
specified padding mode and padding value.
......@@ -1169,7 +1170,7 @@ class PadGT(BaseOperator):
padding = (padding, padding, padding, padding)
else:
raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
f'But received {padding}')
f'But received {padding}')
# check padding mode
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
......@@ -1194,10 +1195,10 @@ class PadGT(BaseOperator):
def checkmaxshape(self, samples):
maxh, maxw = 0, 0
for sample in samples:
h,w = sample['im_shape']
if h>maxh:
h, w = sample['im_shape']
if h > maxh:
maxh = h
if w>maxw:
if w > maxw:
maxw = w
return (maxh, maxw)
......@@ -1246,7 +1247,8 @@ class PadGT(BaseOperator):
sample['difficult'] = pad_diff
if 'gt_joints' in sample:
num_joints = sample['gt_joints'].shape[1]
pad_gt_joints = np.zeros((num_max_boxes, num_joints, 3), dtype=np.float32)
pad_gt_joints = np.zeros(
(num_max_boxes, num_joints, 3), dtype=np.float32)
if num_gt > 0:
pad_gt_joints[:num_gt] = sample['gt_joints']
sample['gt_joints'] = pad_gt_joints
......
......@@ -501,7 +501,8 @@ class RandomDistort(BaseOperator):
brightness=[0.5, 1.5, 0.5],
random_apply=True,
count=4,
random_channel=False):
random_channel=False,
prob=1.0):
super(RandomDistort, self).__init__()
self.hue = hue
self.saturation = saturation
......@@ -510,6 +511,7 @@ class RandomDistort(BaseOperator):
self.random_apply = random_apply
self.count = count
self.random_channel = random_channel
self.prob = prob
def apply_hue(self, img):
low, high, prob = self.hue
......@@ -563,6 +565,8 @@ class RandomDistort(BaseOperator):
return img
def apply(self, sample, context=None):
if random.random() > self.prob:
return sample
img = sample['image']
if self.random_apply:
functions = [
......@@ -1488,7 +1492,8 @@ class RandomCrop(BaseOperator):
allow_no_crop=True,
cover_all_box=False,
is_mask_crop=False,
ioumode="iou"):
ioumode="iou",
prob=1.0):
super(RandomCrop, self).__init__()
self.aspect_ratio = aspect_ratio
self.thresholds = thresholds
......@@ -1498,6 +1503,7 @@ class RandomCrop(BaseOperator):
self.cover_all_box = cover_all_box
self.is_mask_crop = is_mask_crop
self.ioumode = ioumode
self.prob = prob
def crop_segms(self, segms, valid_ids, crop, height, width):
def _crop_poly(segm, crop):
......@@ -1588,6 +1594,9 @@ class RandomCrop(BaseOperator):
return sample
def apply(self, sample, context=None):
if random.random() > self.prob:
return sample
if 'gt_bbox' not in sample:
# only used in semi-det as unsup data
sample = self.set_fake_bboxes(sample)
......@@ -2829,22 +2838,23 @@ class RandomShortSideResize(BaseOperator):
def get_size_with_aspect_ratio(self, image_shape, size, max_size=None):
h, w = image_shape
max_clip = False
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(
round(max_size * min_original_size / max_original_size))
size = int(max_size * min_original_size / max_original_size)
max_clip = True
if (w <= h and w == size) or (h <= w and h == size):
return (w, h)
if w < h:
ow = size
oh = int(round(size * h / w))
oh = int(round(size * h / w)) if not max_clip else max_size
else:
oh = size
ow = int(round(size * w / h))
ow = int(round(size * w / h)) if not max_clip else max_size
return (ow, oh)
......
......@@ -40,9 +40,9 @@ class DETR(BaseArch):
exclude_post_process=False):
super(DETR, self).__init__()
self.backbone = backbone
self.neck = neck
self.transformer = transformer
self.detr_head = detr_head
self.neck = neck
self.post_process = post_process
self.with_mask = with_mask
self.exclude_post_process = exclude_post_process
......@@ -54,6 +54,7 @@ class DETR(BaseArch):
# neck
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs) if cfg['neck'] else None
# transformer
if neck is not None:
kwargs = {'input_shape': neck.out_shape}
......
......@@ -21,7 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from .iou_loss import GIoULoss
from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss, varifocal_loss_with_logits
from ..bbox_utils import bbox_iou
__all__ = ['DETRLoss', 'DINOLoss']
......@@ -43,7 +44,10 @@ class DETRLoss(nn.Layer):
'dice': 1
},
aux_loss=True,
use_focal_loss=False):
use_focal_loss=False,
use_vfl=False,
use_uni_match=False,
uni_match_ind=0):
r"""
Args:
num_classes (int): The number of classes.
......@@ -60,6 +64,9 @@ class DETRLoss(nn.Layer):
self.loss_coeff = loss_coeff
self.aux_loss = aux_loss
self.use_focal_loss = use_focal_loss
self.use_vfl = use_vfl
self.use_uni_match = use_uni_match
self.uni_match_ind = uni_match_ind
if not self.use_focal_loss:
self.loss_coeff['class'] = paddle.full([num_classes + 1],
......@@ -73,13 +80,15 @@ class DETRLoss(nn.Layer):
match_indices,
bg_index,
num_gts,
postfix=""):
postfix="",
iou_score=None):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = "loss_class" + postfix
target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
bs, num_query_objects = target_label.shape
if sum(len(a) for a in gt_class) > 0:
num_gt = sum(len(a) for a in gt_class)
if num_gt > 0:
index, updates = self._get_index_updates(num_query_objects,
gt_class, match_indices)
target_label = paddle.scatter(
......@@ -88,12 +97,23 @@ class DETRLoss(nn.Layer):
if self.use_focal_loss:
target_label = F.one_hot(target_label,
self.num_classes + 1)[..., :-1]
return {
name_class: self.loss_coeff['class'] * sigmoid_focal_loss(
logits, target_label, num_gts / num_query_objects)
if self.use_focal_loss else F.cross_entropy(
if iou_score is not None and self.use_vfl:
target_score = paddle.zeros([bs, num_query_objects])
if num_gt > 0:
target_score = paddle.scatter(
target_score.reshape([-1, 1]), index, iou_score)
target_score = target_score.reshape(
[bs, num_query_objects, 1]) * target_label
loss_ = self.loss_coeff['class'] * varifocal_loss_with_logits(
logits, target_score, target_label,
num_gts / num_query_objects)
else:
loss_ = self.loss_coeff['class'] * sigmoid_focal_loss(
logits, target_label, num_gts / num_query_objects)
else:
loss_ = F.cross_entropy(
logits, target_label, weight=self.loss_coeff['class'])
}
return {name_class: loss_}
def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
postfix=""):
......@@ -167,9 +187,19 @@ class DETRLoss(nn.Layer):
loss_class = []
loss_bbox, loss_giou = [], []
loss_mask, loss_dice = [], []
if dn_match_indices is not None:
match_indices = dn_match_indices
elif self.use_uni_match:
match_indices = self.matcher(
boxes[self.uni_match_ind],
logits[self.uni_match_ind],
gt_bbox,
gt_class,
masks=masks[self.uni_match_ind] if masks is not None else None,
gt_mask=gt_mask)
for i, (aux_boxes, aux_logits) in enumerate(zip(boxes, logits)):
aux_masks = masks[i] if masks is not None else None
if dn_match_indices is None:
if not self.use_uni_match and dn_match_indices is None:
match_indices = self.matcher(
aux_boxes,
aux_logits,
......@@ -177,12 +207,21 @@ class DETRLoss(nn.Layer):
gt_class,
masks=aux_masks,
gt_mask=gt_mask)
if self.use_vfl:
if sum(len(a) for a in gt_bbox) > 0:
src_bbox, target_bbox = self._get_src_target_assign(
aux_boxes.detach(), gt_bbox, match_indices)
iou_score = bbox_iou(
bbox_cxcywh_to_xyxy(src_bbox).split(4, -1),
bbox_cxcywh_to_xyxy(target_bbox).split(4, -1))
else:
iou_score = None
else:
match_indices = dn_match_indices
iou_score = None
loss_class.append(
self._get_loss_class(aux_logits, gt_class, match_indices,
bg_index, num_gts, postfix)['loss_class' +
postfix])
bg_index, num_gts, postfix, iou_score)[
'loss_class' + postfix])
loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
num_gts, postfix)
loss_bbox.append(loss_['loss_bbox' + postfix])
......@@ -252,10 +291,22 @@ class DETRLoss(nn.Layer):
else:
match_indices = dn_match_indices
if self.use_vfl:
if sum(len(a) for a in gt_bbox) > 0:
src_bbox, target_bbox = self._get_src_target_assign(
boxes.detach(), gt_bbox, match_indices)
iou_score = bbox_iou(
bbox_cxcywh_to_xyxy(src_bbox).split(4, -1),
bbox_cxcywh_to_xyxy(target_bbox).split(4, -1))
else:
iou_score = None
else:
iou_score = None
loss = dict()
loss.update(
self._get_loss_class(logits, gt_class, match_indices,
self.num_classes, num_gts, postfix))
self.num_classes, num_gts, postfix, iou_score))
loss.update(
self._get_loss_bbox(boxes, gt_bbox, match_indices, num_gts,
postfix))
......
......@@ -20,6 +20,8 @@ from . import deformable_transformer
from . import dino_transformer
from . import group_detr_transformer
from . import mask_dino_transformer
from . import rtdetr_transformer
from . import hybrid_encoder
from .detr_transformer import *
from .utils import *
......@@ -30,3 +32,5 @@ from .dino_transformer import *
from .petr_transformer import *
from .group_detr_transformer import *
from .mask_dino_transformer import *
from .rtdetr_transformer import *
from .hybrid_encoder import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ppdet.modeling.ops import get_act_fn
from ..shape_spec import ShapeSpec
from ..backbones.csp_darknet import BaseConv
from ..backbones.cspresnet import RepVggBlock
from ppdet.modeling.transformers.detr_transformer import TransformerEncoder
from ..initializer import xavier_uniform_, linear_init_
from ..layers import MultiHeadAttention
from paddle import ParamAttr
from paddle.regularizer import L2Decay
__all__ = ['HybridEncoder']
class CSPRepLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
num_blocks=3,
expansion=1.0,
bias=False,
act="silu"):
super(CSPRepLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.conv2 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.bottlenecks = nn.Sequential(*[
RepVggBlock(
hidden_channels, hidden_channels, act=act)
for _ in range(num_blocks)
])
if hidden_channels != out_channels:
self.conv3 = BaseConv(
hidden_channels,
out_channels,
ksize=1,
stride=1,
bias=bias,
act=act)
else:
self.conv3 = nn.Identity()
def forward(self, x):
x_1 = self.conv1(x)
x_1 = self.bottlenecks(x_1)
x_2 = self.conv2(x)
return self.conv3(x_1 + x_2)
@register
class TransformerLayer(nn.Layer):
def __init__(self,
d_model,
nhead,
dim_feedforward=1024,
dropout=0.,
activation="relu",
attn_dropout=None,
act_dropout=None,
normalize_before=False):
super(TransformerLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train")
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
self._reset_parameters()
def _reset_parameters(self):
linear_init_(self.linear1)
linear_init_(self.linear2)
@staticmethod
def with_pos_embed(tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed
def forward(self, src, src_mask=None, pos_embed=None):
residual = src
if self.normalize_before:
src = self.norm1(src)
q = k = self.with_pos_embed(src, pos_embed)
src = self.self_attn(q, k, value=src, attn_mask=src_mask)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src
@register
@serializable
class HybridEncoder(nn.Layer):
__shared__ = ['depth_mult', 'act', 'trt', 'eval_size']
__inject__ = ['encoder_layer']
def __init__(self,
in_channels=[512, 1024, 2048],
feat_strides=[8, 16, 32],
hidden_dim=256,
use_encoder_idx=[2],
num_encoder_layers=1,
encoder_layer='TransformerLayer',
pe_temperature=10000,
expansion=1.0,
depth_mult=1.0,
act='silu',
trt=False,
eval_size=None):
super(HybridEncoder, self).__init__()
self.in_channels = in_channels
self.feat_strides = feat_strides
self.hidden_dim = hidden_dim
self.use_encoder_idx = use_encoder_idx
self.num_encoder_layers = num_encoder_layers
self.pe_temperature = pe_temperature
self.eval_size = eval_size
# channel projection
self.input_proj = nn.LayerList()
for in_channel in in_channels:
self.input_proj.append(
nn.Sequential(
nn.Conv2D(
in_channel, hidden_dim, kernel_size=1, bias_attr=False),
nn.BatchNorm2D(
hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))
# encoder transformer
self.encoder = nn.LayerList([
TransformerEncoder(encoder_layer, num_encoder_layers)
for _ in range(len(use_encoder_idx))
])
act = get_act_fn(
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
# top-down fpn
self.lateral_convs = nn.LayerList()
self.fpn_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1, 0, -1):
self.lateral_convs.append(
BaseConv(
hidden_dim, hidden_dim, 1, 1, act=act))
self.fpn_blocks.append(
CSPRepLayer(
hidden_dim * 2,
hidden_dim,
round(3 * depth_mult),
act=act,
expansion=expansion))
# bottom-up pan
self.downsample_convs = nn.LayerList()
self.pan_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1):
self.downsample_convs.append(
BaseConv(
hidden_dim, hidden_dim, 3, stride=2, act=act))
self.pan_blocks.append(
CSPRepLayer(
hidden_dim * 2,
hidden_dim,
round(3 * depth_mult),
act=act,
expansion=expansion))
self._reset_parameters()
def _reset_parameters(self):
if self.eval_size:
for idx in self.use_encoder_idx:
stride = self.feat_strides[idx]
pos_embed = self.build_2d_sincos_position_embedding(
self.eval_size[1] // stride, self.eval_size[0] // stride,
self.hidden_dim, self.pe_temperature)
setattr(self, f'pos_embed{idx}', pos_embed)
@staticmethod
def build_2d_sincos_position_embedding(w,
h,
embed_dim=256,
temperature=10000.):
grid_w = paddle.arange(int(w), dtype=paddle.float32)
grid_h = paddle.arange(int(h), dtype=paddle.float32)
grid_w, grid_h = paddle.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, \
'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = grid_w.flatten()[..., None] @omega[None]
out_h = grid_h.flatten()[..., None] @omega[None]
return paddle.concat(
[
paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h),
paddle.cos(out_h)
],
axis=1)[None, :, :]
def forward(self, feats, for_mot=False):
assert len(feats) == len(self.in_channels)
# get projection features
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
# encoder
if self.num_encoder_layers > 0:
for i, enc_ind in enumerate(self.use_encoder_idx):
h, w = proj_feats[enc_ind].shape[2:]
# flatten [B, C, H, W] to [B, HxW, C]
src_flatten = proj_feats[enc_ind].flatten(2).transpose(
[0, 2, 1])
if self.training or self.eval_size is None:
pos_embed = self.build_2d_sincos_position_embedding(
w, h, self.hidden_dim, self.pe_temperature)
else:
pos_embed = getattr(self, f'pos_embed{enc_ind}', None)
memory = self.encoder[i](src_flatten, pos_embed=pos_embed)
proj_feats[enc_ind] = memory.transpose([0, 2, 1]).reshape(
[-1, self.hidden_dim, h, w])
# top-down fpn
inner_outs = [proj_feats[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = proj_feats[idx - 1]
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](
feat_heigh)
inner_outs[0] = feat_heigh
upsample_feat = F.interpolate(
feat_heigh, scale_factor=2., mode="nearest")
inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
paddle.concat(
[upsample_feat, feat_low], axis=1))
inner_outs.insert(0, inner_out)
# bottom-up pan
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsample_convs[idx](feat_low)
out = self.pan_blocks[idx](paddle.concat(
[downsample_feat, feat_height], axis=1))
outs.append(out)
return outs
@classmethod
def from_config(cls, cfg, input_shape):
return {
'in_channels': [i.channels for i in input_shape],
'feat_strides': [i.stride for i in input_shape]
}
@property
def out_shape(self):
return [
ShapeSpec(
channels=self.hidden_dim, stride=self.feat_strides[idx])
for idx in range(len(self.in_channels))
]
......@@ -107,16 +107,15 @@ class HungarianMatcher(nn.Layer):
tgt_bbox = paddle.concat(gt_bbox)
# Compute the classification cost
out_prob = paddle.gather(out_prob, tgt_ids, axis=1)
if self.use_focal_loss:
neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(
1 - out_prob + 1e-8).log())
pos_cost_class = self.alpha * (
(1 - out_prob)**self.gamma) * (-(out_prob + 1e-8).log())
cost_class = paddle.gather(
pos_cost_class, tgt_ids, axis=1) - paddle.gather(
neg_cost_class, tgt_ids, axis=1)
cost_class = pos_cost_class - neg_cost_class
else:
cost_class = -paddle.gather(out_prob, tgt_ids, axis=1)
cost_class = -out_prob
# Compute the L1 cost between boxes
cost_bbox = (
......
此差异已折叠。
......@@ -32,7 +32,7 @@ from ..bbox_utils import bbox_overlaps
__all__ = [
'_get_clones', 'bbox_overlaps', 'bbox_cxcywh_to_xyxy',
'bbox_xyxy_to_cxcywh', 'sigmoid_focal_loss', 'inverse_sigmoid',
'deformable_attention_core_func'
'deformable_attention_core_func', 'varifocal_loss_with_logits'
]
......@@ -395,3 +395,16 @@ def mask_to_box_coordinate(mask,
out_bbox /= paddle.to_tensor([w, h, w, h]).astype(dtype)
return out_bbox if format == "xyxy" else bbox_xyxy_to_cxcywh(out_bbox)
def varifocal_loss_with_logits(pred_logits,
gt_score,
label,
normalizer=1.0,
alpha=0.75,
gamma=2.0):
pred_score = F.sigmoid(pred_logits)
weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
loss = F.binary_cross_entropy_with_logits(
pred_logits, gt_score, weight=weight, reduction='none')
return loss.mean(1).sum() / normalizer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册