未验证 提交 6c2ee01b 编写于 作者: N niefeng 提交者: GitHub

add swin focalnet backbone rtdetr models (#8309)

上级 579254f4
......@@ -2,8 +2,9 @@
## 最新动态
- 发布RT-DETR-Swin和RT-DETR-FocalNet模型
- 发布RT-DETR-R50和RT-DETR-R101的代码和预训练模型
- 发布RT-DETR-L和RT-DETR-X的代码和预训练模型
- **发布RT-DETR-L和RT-DETR-X的代码和预训练模型**
- 发布RT-DETR-R50-m模型(scale模型的范例)
- 发布RT-DETR-R34模型
- 发布RT-DETR-R18模型
......@@ -17,7 +18,7 @@ RT-DETR是第一个实时端到端目标检测器。具体而言,我们设计
<img src="https://github.com/PaddlePaddle/PaddleDetection/assets/17582080/3184a08e-aa4d-49cf-9079-f3695c4cc1c3" width=500 />
</div>
## 模型
## 基础模型
| Model | Epoch | backbone | input shape | $AP^{val}$ | $AP^{val}_{50}$| Params(M) | FLOPs(G) | T4 TensorRT FP16(FPS) | Pretrained Model | config |
|:--------------:|:-----:|:----------:| :-------:|:--------------------------:|:---------------------------:|:---------:|:--------:| :---------------------: |:------------------------------------------------------------------------------------:|:-------------------------------------------:|
......@@ -29,10 +30,17 @@ RT-DETR是第一个实时端到端目标检测器。具体而言,我们设计
| RT-DETR-L | 6x | HGNetv2 | 640 | 53.0 | 71.6 | 32 | 110 | 114 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_l_6x_coco.pdparams) | [config](rtdetr_hgnetv2_l_6x_coco.yml)
| RT-DETR-X | 6x | HGNetv2 | 640 | 54.8 | 73.1 | 67 | 234 | 74 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_x_6x_coco.pdparams) | [config](rtdetr_hgnetv2_x_6x_coco.yml)
## 高精度模型
| Model | Epoch | backbone | input shape | $AP^{val}$ | $AP^{val}_{50}$ | Pretrained Model | config |
|:-----:|:-----:|:---------:| :---------:|:-----------:|:---------------:|:----------------:|:------:|
| RT-DETR-Swin | 3x | Swin_L_384 | 640 | 56.2 | 73.5 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_swin_L_384_3x_coco.pdparams) | [config](./rtdetr_swin_L_384_3x_coco.yml)
| RT-DETR-FocalNet | 3x | FocalNet_L_384 | 640 | 56.9 | 74.3 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_focalnet_L_384_3x_coco.pdparams) | [config](./rtdetr_focalnet_L_384_3x_coco.yml)
**注意事项:**
- RT-DETR 使用4个GPU训练。
- RT-DETR 基础模型均使用4个GPU训练。
- RT-DETR 在COCO train2017上训练,并在val2017上评估。
- 高精度模型RT-DETR-Swin和RT-DETR-FocalNet使用8个GPU训练,显存需求较高。
## 快速开始
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_6x.yml',
'_base_/rtdetr_r50vd.yml',
'_base_/rtdetr_reader.yml',
]
weights: output/rtdetr_focalnet_L_384_3x_coco/model_final
find_unused_parameters: True
log_iter: 100
snapshot_epoch: 2
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/focalnet_large_fl4_pretrained_on_o365.pdparams
DETR:
backbone: FocalNet
neck: HybridEncoder
transformer: RTDETRTransformer
detr_head: DINOHead
post_process: DETRPostProcess
FocalNet:
arch: 'focalnet_L_384_22k_fl4'
out_indices: [1, 2, 3]
HybridEncoder:
hidden_dim: 256
use_encoder_idx: [2]
num_encoder_layers: 6 #
encoder_layer:
name: TransformerLayer
d_model: 256
nhead: 8
dim_feedforward: 2048
dropout: 0.
activation: 'gelu'
expansion: 1.0
RTDETRTransformer:
num_queries: 300
position_embed_type: sine
feat_strides: [8, 16, 32]
num_levels: 3
nhead: 8
num_decoder_layers: 6
dim_feedforward: 2048 #
dropout: 0.0
activation: relu
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: False
query_pos_head_inv_sig: True #
DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
use_vfl: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRPostProcess:
num_top_queries: 300
epoch: 36
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [36]
use_warmup: false
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
param_groups:
- params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm']
weight_decay: 0.0
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_6x.yml',
'_base_/rtdetr_r50vd.yml',
'_base_/rtdetr_reader.yml',
]
weights: output/rtdetr_swin_L_384_3x_coco/model_final
find_unused_parameters: True
log_iter: 100
snapshot_epoch: 2
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/dino_swin_large_384_4scale_3x_coco.pdparams
DETR:
backbone: SwinTransformer
neck: HybridEncoder
transformer: RTDETRTransformer
detr_head: DINOHead
post_process: DETRPostProcess
SwinTransformer:
arch: 'swin_L_384' # ['swin_T_224', 'swin_S_224', 'swin_B_224', 'swin_L_224', 'swin_B_384', 'swin_L_384']
ape: false
drop_path_rate: 0.2
patch_norm: true
out_indices: [1, 2, 3]
HybridEncoder:
hidden_dim: 256
use_encoder_idx: [2]
num_encoder_layers: 6 #
encoder_layer:
name: TransformerLayer
d_model: 256
nhead: 8
dim_feedforward: 2048 #
dropout: 0.
activation: 'gelu'
expansion: 1.0
RTDETRTransformer:
num_queries: 300
position_embed_type: sine
feat_strides: [8, 16, 32]
num_levels: 3
nhead: 8
num_decoder_layers: 6
dim_feedforward: 2048 #
dropout: 0.0
activation: relu
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: False
DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
use_vfl: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRPostProcess:
num_top_queries: 300
epoch: 36
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [36]
use_warmup: false
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
param_groups:
- params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm']
weight_decay: 0.0
......@@ -218,14 +218,19 @@ class TransformerDecoder(nn.Layer):
score_head,
query_pos_head,
attn_mask=None,
memory_mask=None):
memory_mask=None,
query_pos_head_inv_sig=False):
output = tgt
dec_out_bboxes = []
dec_out_logits = []
ref_points_detach = F.sigmoid(ref_points_unact)
for i, layer in enumerate(self.layers):
ref_points_input = ref_points_detach.unsqueeze(2)
query_pos_embed = query_pos_head(ref_points_detach)
if not query_pos_head_inv_sig:
query_pos_embed = query_pos_head(ref_points_detach)
else:
query_pos_embed = query_pos_head(
inverse_sigmoid(ref_points_detach))
output = layer(output, ref_points_input, memory,
memory_spatial_shapes, memory_level_start_index,
......@@ -276,6 +281,7 @@ class RTDETRTransformer(nn.Layer):
label_noise_ratio=0.5,
box_noise_scale=1.0,
learnt_init_query=True,
query_pos_head_inv_sig=False,
eval_size=None,
eval_idx=-1,
eps=1e-2):
......@@ -321,6 +327,7 @@ class RTDETRTransformer(nn.Layer):
if learnt_init_query:
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
self.query_pos_head_inv_sig = query_pos_head_inv_sig
# encoder head
self.enc_output = nn.Sequential(
......@@ -464,7 +471,9 @@ class RTDETRTransformer(nn.Layer):
self.dec_bbox_head,
self.dec_score_head,
self.query_pos_head,
attn_mask=attn_mask)
attn_mask=attn_mask,
memory_mask=None,
query_pos_head_inv_sig=self.query_pos_head_inv_sig)
return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits,
dn_meta)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册