未验证 提交 d48a4bb9 编写于 作者: Z Zhao-Yian 提交者: GitHub

add group detr for dino (#7865)

上级 b352ef88
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184], max_size: 2000 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [400, 500, 600, 700, 800, 900] },
RandomSizeCrop: { min_size: 384, max_size: 900 },
RandomShortSideResize: { short_side_sizes: [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184], max_size: 2000 } ]
}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 2
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [1184, 2000], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [1184, 2000], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
RandomSizeCrop: { min_size: 384, max_size: 600 },
RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ]
}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 2
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
architecture: DETR
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
hidden_dim: 256
use_focal_loss: True
DETR:
backbone: ResNet
transformer: GroupDINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.0, 0.1, 0.1, 0.1]
num_stages: 4
GroupDINOTransformer:
num_queries: 900
position_embed_type: sine
num_levels: 4
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
dropout: 0.0
activation: relu
pe_temperature: 20
pe_offset: 0.0
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: True
dual_queries: True
dual_groups: 10
DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
num_top_queries: 300
dual_queries: True
dual_groups: 10
architecture: DETR
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/vit_huge_mae_patch14_dec512d8b_pretrained.pdparams
hidden_dim: 256
use_focal_loss: True
DETR:
backbone: VisionTransformer2D
neck: SimpleFeaturePyramid
transformer: GroupDINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
VisionTransformer2D:
patch_size: 16
embed_dim: 1280
depth: 32
num_heads: 16
mlp_ratio: 4
attn_bias: True
drop_rate: 0.0
drop_path_rate: 0.1
lr_decay_rate: 0.7
global_attn_indexes: [7, 15, 23, 31]
use_abs_pos: False
use_rel_pos: True
rel_pos_zero_init: True
window_size: 14
out_indices: [ 31, ]
SimpleFeaturePyramid:
out_channels: 256
num_levels: 4
GroupDINOTransformer:
num_queries: 900
position_embed_type: sine
pe_temperature: 20
pe_offset: 0.0
num_levels: 4
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
use_input_proj: False
dropout: 0.0
activation: relu
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: True
dual_queries: True
dual_groups: 10
DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
num_top_queries: 300
dual_queries: True
dual_groups: 10
epoch: 12
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [11]
use_warmup: false
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/group_dino_r50.yml',
'_base_/dino_reader.yml',
]
weights: output/group_dino_r50_4scale_1x_coco/model_final
find_unused_parameters: True
log_iter: 100
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/group_dino_vit_huge.yml',
'_base_/dino_2000_reader.yml',
]
weights: output/group_dino_vit_huge_4scale_1x_coco/model_final
find_unused_parameters: True
log_iter: 100
......@@ -34,10 +34,12 @@ class DETR(BaseArch):
backbone,
transformer='DETRTransformer',
detr_head='DETRHead',
neck=None,
post_process='DETRBBoxPostProcess',
exclude_post_process=False):
super(DETR, self).__init__()
self.backbone = backbone
self.neck = neck
self.transformer = transformer
self.detr_head = detr_head
self.post_process = post_process
......@@ -47,8 +49,12 @@ class DETR(BaseArch):
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
# transformer
# neck
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs) if cfg['neck'] else None
# transformer
if neck is not None:
kwargs = {'input_shape': neck.out_shape}
transformer = create(cfg['transformer'], **kwargs)
# head
kwargs = {
......@@ -62,12 +68,17 @@ class DETR(BaseArch):
'backbone': backbone,
'transformer': transformer,
"detr_head": detr_head,
"neck": neck
}
def _forward(self):
# Backbone
body_feats = self.backbone(self.inputs)
# Neck
if self.neck is not None:
body_feats = self.neck(body_feats)
# Transformer
pad_mask = self.inputs.get('pad_mask', None)
out_transformer = self.transformer(body_feats, pad_mask, self.inputs)
......
......@@ -36,6 +36,7 @@ from . import vision_transformer
from . import mobileone
from . import trans_encoder
from . import focalnet
from . import vit_mae
from .vgg import *
from .resnet import *
......@@ -61,3 +62,4 @@ from .vision_transformer import *
from .mobileone import *
from .trans_encoder import *
from .focalnet import *
from .vit_mae import *
......@@ -14,6 +14,7 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import TruncatedNormal, Constant, Assign
......@@ -72,3 +73,52 @@ def add_parameter(layer, datas, name=None):
if name:
layer.add_parameter(name, parameter)
return parameter
def window_partition(x, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = paddle.shape(x)
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
x = F.pad(x.transpose([0, 3, 1, 2]),
paddle.to_tensor(
[0, int(pad_w), 0, int(pad_h)],
dtype='int32')).transpose([0, 2, 3, 1])
Hp, Wp = H + pad_h, W + pad_w
num_h, num_w = Hp // window_size, Wp // window_size
x = x.reshape([B, num_h, window_size, num_w, window_size, C])
windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
[-1, window_size, window_size, C])
return windows, (Hp, Wp), (num_h, num_w)
def window_unpartition(x, pad_hw, num_hw, hw):
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
num_h, num_w = num_hw
H, W = hw
B, window_size, _, C = paddle.shape(x)
B = B // (num_h * num_w)
x = x.reshape([B, num_h, num_w, window_size, window_size, C])
x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, C])
return x[:, :H, :W, :]
此差异已折叠。
......@@ -380,10 +380,67 @@ class DINOHead(nn.Layer):
assert 'gt_bbox' in inputs and 'gt_class' in inputs
if dn_meta is not None:
dn_out_bboxes, dec_out_bboxes = paddle.split(
dec_out_bboxes, dn_meta['dn_num_split'], axis=2)
dn_out_logits, dec_out_logits = paddle.split(
dec_out_logits, dn_meta['dn_num_split'], axis=2)
if isinstance(dn_meta, list):
dual_groups = len(dn_meta) - 1
dec_out_bboxes = paddle.split(
dec_out_bboxes, dual_groups + 1, axis=2)
dec_out_logits = paddle.split(
dec_out_logits, dual_groups + 1, axis=2)
enc_topk_bboxes = paddle.split(
enc_topk_bboxes, dual_groups + 1, axis=1)
enc_topk_logits = paddle.split(
enc_topk_logits, dual_groups + 1, axis=1)
dec_out_bboxes_list = []
dec_out_logits_list = []
dn_out_bboxes_list = []
dn_out_logits_list = []
loss = {}
for g_id in range(dual_groups + 1):
if dn_meta[g_id] is not None:
dn_out_bboxes_gid, dec_out_bboxes_gid = paddle.split(
dec_out_bboxes[g_id],
dn_meta[g_id]['dn_num_split'],
axis=2)
dn_out_logits_gid, dec_out_logits_gid = paddle.split(
dec_out_logits[g_id],
dn_meta[g_id]['dn_num_split'],
axis=2)
else:
dn_out_bboxes_gid, dn_out_logits_gid = None, None
dec_out_bboxes_gid = dec_out_bboxes[g_id]
dec_out_logits_gid = dec_out_logits[g_id]
out_bboxes_gid = paddle.concat([
enc_topk_bboxes[g_id].unsqueeze(0),
dec_out_bboxes_gid
])
out_logits_gid = paddle.concat([
enc_topk_logits[g_id].unsqueeze(0),
dec_out_logits_gid
])
loss_gid = self.loss(
out_bboxes_gid,
out_logits_gid,
inputs['gt_bbox'],
inputs['gt_class'],
dn_out_bboxes=dn_out_bboxes_gid,
dn_out_logits=dn_out_logits_gid,
dn_meta=dn_meta[g_id])
# sum loss
for key, value in loss_gid.items():
loss.update({
key: loss.get(key, paddle.zeros([1])) + value
})
# average across (dual_groups + 1)
for key, value in loss.items():
loss.update({key: value / (dual_groups + 1)})
return loss
else:
dn_out_bboxes, dec_out_bboxes = paddle.split(
dec_out_bboxes, dn_meta['dn_num_split'], axis=2)
dn_out_logits, dec_out_logits = paddle.split(
dec_out_logits, dn_meta['dn_num_split'], axis=2)
else:
dn_out_bboxes, dn_out_logits = None, None
......
......@@ -273,7 +273,8 @@ def kaiming_normal_(tensor,
def linear_init_(module):
bound = 1 / math.sqrt(module.weight.shape[0])
uniform_(module.weight, -bound, bound)
uniform_(module.bias, -bound, bound)
if hasattr(module, "bias") and module.bias is not None:
uniform_(module.bias, -bound, bound)
def conv_init_(module):
......
......@@ -67,7 +67,8 @@ class BBoxPostProcess(object):
"""
if self.nms is not None:
bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
bbox_pred, bbox_num, before_nms_indexes = self.nms(bboxes, score, self.num_classes)
bbox_pred, bbox_num, before_nms_indexes = self.nms(bboxes, score,
self.num_classes)
else:
bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
......@@ -449,10 +450,14 @@ class DETRBBoxPostProcess(object):
def __init__(self,
num_classes=80,
num_top_queries=100,
dual_queries=False,
dual_groups=0,
use_focal_loss=False):
super(DETRBBoxPostProcess, self).__init__()
self.num_classes = num_classes
self.num_top_queries = num_top_queries
self.dual_queries = dual_queries
self.dual_groups = dual_groups
self.use_focal_loss = use_focal_loss
def __call__(self, head_out, im_shape, scale_factor):
......@@ -471,6 +476,10 @@ class DETRBBoxPostProcess(object):
shape [bs], and is N.
"""
bboxes, logits, masks = head_out
if self.dual_queries:
num_queries = logits.shape[1]
logits, bboxes = logits[:, :int(num_queries // (self.dual_groups + 1)), :], \
bboxes[:, :int(num_queries // (self.dual_groups + 1)), :]
bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
......
......@@ -18,6 +18,7 @@ from . import matchers
from . import position_encoding
from . import deformable_transformer
from . import dino_transformer
from . import group_detr_transformer
from .detr_transformer import *
from .utils import *
......@@ -26,3 +27,4 @@ from .position_encoding import *
from .deformable_transformer import *
from .dino_transformer import *
from .petr_transformer import *
from .group_detr_transformer import *
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册