From d48a4bb9cfb61d46ba57974164bdf2d6249f0208 Mon Sep 17 00:00:00 2001 From: Zhao-Yian <77494834+Zhao-Yian@users.noreply.github.com> Date: Mon, 6 Mar 2023 11:19:32 +0800 Subject: [PATCH] add group detr for dino (#7865) --- .../group_detr/_base_/dino_2000_reader.yml | 48 + configs/group_detr/_base_/dino_reader.yml | 48 + configs/group_detr/_base_/group_dino_r50.yml | 53 ++ .../group_detr/_base_/group_dino_vit_huge.yml | 68 ++ configs/group_detr/_base_/optimizer_1x.yml | 16 + .../group_dino_r50_4scale_1x_coco.yml | 11 + .../group_dino_vit_huge_4scale_1x_coco.yml | 11 + ppdet/modeling/architectures/detr.py | 13 +- ppdet/modeling/backbones/__init__.py | 2 + ppdet/modeling/backbones/transformer_utils.py | 50 + ppdet/modeling/backbones/vit_mae.py | 749 +++++++++++++++ ppdet/modeling/heads/detr_head.py | 65 +- ppdet/modeling/initializer.py | 3 +- ppdet/modeling/post_process.py | 11 +- ppdet/modeling/transformers/__init__.py | 2 + .../transformers/group_detr_transformer.py | 857 ++++++++++++++++++ 16 files changed, 2000 insertions(+), 7 deletions(-) create mode 100644 configs/group_detr/_base_/dino_2000_reader.yml create mode 100644 configs/group_detr/_base_/dino_reader.yml create mode 100644 configs/group_detr/_base_/group_dino_r50.yml create mode 100644 configs/group_detr/_base_/group_dino_vit_huge.yml create mode 100644 configs/group_detr/_base_/optimizer_1x.yml create mode 100644 configs/group_detr/group_dino_r50_4scale_1x_coco.yml create mode 100644 configs/group_detr/group_dino_vit_huge_4scale_1x_coco.yml create mode 100644 ppdet/modeling/backbones/vit_mae.py create mode 100644 ppdet/modeling/transformers/group_detr_transformer.py diff --git a/configs/group_detr/_base_/dino_2000_reader.yml b/configs/group_detr/_base_/dino_2000_reader.yml new file mode 100644 index 000000000..ef7620eb8 --- /dev/null +++ b/configs/group_detr/_base_/dino_2000_reader.yml @@ -0,0 +1,48 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomFlip: {prob: 0.5} + - RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184], max_size: 2000 } ], + transforms2: [ + RandomShortSideResize: { short_side_sizes: [400, 500, 600, 700, 800, 900] }, + RandomSizeCrop: { min_size: 384, max_size: 900 }, + RandomShortSideResize: { short_side_sizes: [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184], max_size: 2000 } ] + } + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + batch_transforms: + - PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true} + batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [1184, 2000], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true} + batch_size: 1 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [1184, 2000], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/group_detr/_base_/dino_reader.yml b/configs/group_detr/_base_/dino_reader.yml new file mode 100644 index 000000000..c15a0f3b6 --- /dev/null +++ b/configs/group_detr/_base_/dino_reader.yml @@ -0,0 +1,48 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomFlip: {prob: 0.5} + - RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ], + transforms2: [ + RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] }, + RandomSizeCrop: { min_size: 384, max_size: 600 }, + RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ] + } + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + batch_transforms: + - PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true} + batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true} + batch_size: 1 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/group_detr/_base_/group_dino_r50.yml b/configs/group_detr/_base_/group_dino_r50.yml new file mode 100644 index 000000000..587f7f519 --- /dev/null +++ b/configs/group_detr/_base_/group_dino_r50.yml @@ -0,0 +1,53 @@ +architecture: DETR +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams +hidden_dim: 256 +use_focal_loss: True + + +DETR: + backbone: ResNet + transformer: GroupDINOTransformer + detr_head: DINOHead + post_process: DETRBBoxPostProcess + +ResNet: + # index 0 stands for res2 + depth: 50 + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + lr_mult_list: [0.0, 0.1, 0.1, 0.1] + num_stages: 4 + +GroupDINOTransformer: + num_queries: 900 + position_embed_type: sine + num_levels: 4 + nhead: 8 + num_encoder_layers: 6 + num_decoder_layers: 6 + dim_feedforward: 2048 + dropout: 0.0 + activation: relu + pe_temperature: 20 + pe_offset: 0.0 + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: True + dual_queries: True + dual_groups: 10 + +DINOHead: + loss: + name: DINOLoss + loss_coeff: {class: 1, bbox: 5, giou: 2} + aux_loss: True + matcher: + name: HungarianMatcher + matcher_coeff: {class: 2, bbox: 5, giou: 2} + +DETRBBoxPostProcess: + num_top_queries: 300 + dual_queries: True + dual_groups: 10 diff --git a/configs/group_detr/_base_/group_dino_vit_huge.yml b/configs/group_detr/_base_/group_dino_vit_huge.yml new file mode 100644 index 000000000..8849f8a2d --- /dev/null +++ b/configs/group_detr/_base_/group_dino_vit_huge.yml @@ -0,0 +1,68 @@ +architecture: DETR +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/vit_huge_mae_patch14_dec512d8b_pretrained.pdparams +hidden_dim: 256 +use_focal_loss: True + +DETR: + backbone: VisionTransformer2D + neck: SimpleFeaturePyramid + transformer: GroupDINOTransformer + detr_head: DINOHead + post_process: DETRBBoxPostProcess + +VisionTransformer2D: + patch_size: 16 + embed_dim: 1280 + depth: 32 + num_heads: 16 + mlp_ratio: 4 + attn_bias: True + drop_rate: 0.0 + drop_path_rate: 0.1 + lr_decay_rate: 0.7 + global_attn_indexes: [7, 15, 23, 31] + use_abs_pos: False + use_rel_pos: True + rel_pos_zero_init: True + window_size: 14 + out_indices: [ 31, ] + +SimpleFeaturePyramid: + out_channels: 256 + num_levels: 4 + +GroupDINOTransformer: + num_queries: 900 + position_embed_type: sine + pe_temperature: 20 + pe_offset: 0.0 + num_levels: 4 + nhead: 8 + num_encoder_layers: 6 + num_decoder_layers: 6 + dim_feedforward: 2048 + use_input_proj: False + dropout: 0.0 + activation: relu + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: True + dual_queries: True + dual_groups: 10 + + +DINOHead: + loss: + name: DINOLoss + loss_coeff: {class: 1, bbox: 5, giou: 2} + aux_loss: True + matcher: + name: HungarianMatcher + matcher_coeff: {class: 2, bbox: 5, giou: 2} + + +DETRBBoxPostProcess: + num_top_queries: 300 + dual_queries: True + dual_groups: 10 diff --git a/configs/group_detr/_base_/optimizer_1x.yml b/configs/group_detr/_base_/optimizer_1x.yml new file mode 100644 index 000000000..63b3a9ed2 --- /dev/null +++ b/configs/group_detr/_base_/optimizer_1x.yml @@ -0,0 +1,16 @@ +epoch: 12 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [11] + use_warmup: false + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/group_detr/group_dino_r50_4scale_1x_coco.yml b/configs/group_detr/group_dino_r50_4scale_1x_coco.yml new file mode 100644 index 000000000..1f38c690d --- /dev/null +++ b/configs/group_detr/group_dino_r50_4scale_1x_coco.yml @@ -0,0 +1,11 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_1x.yml', + '_base_/group_dino_r50.yml', + '_base_/dino_reader.yml', +] + +weights: output/group_dino_r50_4scale_1x_coco/model_final +find_unused_parameters: True +log_iter: 100 diff --git a/configs/group_detr/group_dino_vit_huge_4scale_1x_coco.yml b/configs/group_detr/group_dino_vit_huge_4scale_1x_coco.yml new file mode 100644 index 000000000..90d0c483e --- /dev/null +++ b/configs/group_detr/group_dino_vit_huge_4scale_1x_coco.yml @@ -0,0 +1,11 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_1x.yml', + '_base_/group_dino_vit_huge.yml', + '_base_/dino_2000_reader.yml', +] + +weights: output/group_dino_vit_huge_4scale_1x_coco/model_final +find_unused_parameters: True +log_iter: 100 diff --git a/ppdet/modeling/architectures/detr.py b/ppdet/modeling/architectures/detr.py index 419a44377..223eeda89 100644 --- a/ppdet/modeling/architectures/detr.py +++ b/ppdet/modeling/architectures/detr.py @@ -34,10 +34,12 @@ class DETR(BaseArch): backbone, transformer='DETRTransformer', detr_head='DETRHead', + neck=None, post_process='DETRBBoxPostProcess', exclude_post_process=False): super(DETR, self).__init__() self.backbone = backbone + self.neck = neck self.transformer = transformer self.detr_head = detr_head self.post_process = post_process @@ -47,8 +49,12 @@ class DETR(BaseArch): def from_config(cls, cfg, *args, **kwargs): # backbone backbone = create(cfg['backbone']) - # transformer + # neck kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) if cfg['neck'] else None + # transformer + if neck is not None: + kwargs = {'input_shape': neck.out_shape} transformer = create(cfg['transformer'], **kwargs) # head kwargs = { @@ -62,12 +68,17 @@ class DETR(BaseArch): 'backbone': backbone, 'transformer': transformer, "detr_head": detr_head, + "neck": neck } def _forward(self): # Backbone body_feats = self.backbone(self.inputs) + # Neck + if self.neck is not None: + body_feats = self.neck(body_feats) + # Transformer pad_mask = self.inputs.get('pad_mask', None) out_transformer = self.transformer(body_feats, pad_mask, self.inputs) diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index 388ba0458..f8b183e27 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -36,6 +36,7 @@ from . import vision_transformer from . import mobileone from . import trans_encoder from . import focalnet +from . import vit_mae from .vgg import * from .resnet import * @@ -61,3 +62,4 @@ from .vision_transformer import * from .mobileone import * from .trans_encoder import * from .focalnet import * +from .vit_mae import * diff --git a/ppdet/modeling/backbones/transformer_utils.py b/ppdet/modeling/backbones/transformer_utils.py index 46d7b9f28..a0783e1e9 100644 --- a/ppdet/modeling/backbones/transformer_utils.py +++ b/ppdet/modeling/backbones/transformer_utils.py @@ -14,6 +14,7 @@ import paddle import paddle.nn as nn +import paddle.nn.functional as F from paddle.nn.initializer import TruncatedNormal, Constant, Assign @@ -72,3 +73,52 @@ def add_parameter(layer, datas, name=None): if name: layer.add_parameter(name, parameter) return parameter + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = paddle.shape(x) + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + x = F.pad(x.transpose([0, 3, 1, 2]), + paddle.to_tensor( + [0, int(pad_w), 0, int(pad_h)], + dtype='int32')).transpose([0, 2, 3, 1]) + Hp, Wp = H + pad_h, W + pad_w + + num_h, num_w = Hp // window_size, Wp // window_size + + x = x.reshape([B, num_h, window_size, num_w, window_size, C]) + windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape( + [-1, window_size, window_size, C]) + return windows, (Hp, Wp), (num_h, num_w) + + +def window_unpartition(x, pad_hw, num_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + num_h, num_w = num_hw + H, W = hw + B, window_size, _, C = paddle.shape(x) + B = B // (num_h * num_w) + x = x.reshape([B, num_h, num_w, window_size, window_size, C]) + x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, C]) + + return x[:, :H, :W, :] diff --git a/ppdet/modeling/backbones/vit_mae.py b/ppdet/modeling/backbones/vit_mae.py new file mode 100644 index 000000000..8d00da72b --- /dev/null +++ b/ppdet/modeling/backbones/vit_mae.py @@ -0,0 +1,749 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np +import math +from paddle import ParamAttr +from paddle.regularizer import L2Decay +from paddle.nn.initializer import Constant, TruncatedNormal + +from ppdet.modeling.shape_spec import ShapeSpec +from ppdet.core.workspace import register, serializable + +from .transformer_utils import (zeros_, DropPath, Identity, window_partition, + window_unpartition) +from ..initializer import linear_init_ + +__all__ = ['VisionTransformer2D', 'SimpleFeaturePyramid'] + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer='nn.GELU', + drop=0., + lr_factor=1.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear( + in_features, + hidden_features, + weight_attr=ParamAttr(learning_rate=lr_factor), + bias_attr=ParamAttr(learning_rate=lr_factor)) + self.act = eval(act_layer)() + self.fc2 = nn.Linear( + hidden_features, + out_features, + weight_attr=ParamAttr(learning_rate=lr_factor), + bias_attr=ParamAttr(learning_rate=lr_factor)) + self.drop = nn.Dropout(drop) + + self._init_weights() + + def _init_weights(self): + linear_init_(self.fc1) + linear_init_(self.fc2) + + def forward(self, x): + x = self.drop(self.act(self.fc1(x))) + x = self.drop(self.fc2(x)) + return x + + +class Attention(nn.Layer): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + attn_bias=False, + attn_drop=0., + proj_drop=0., + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=None, + input_size=None, + qk_scale=None, + lr_factor=1.0): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + self.use_rel_pos = use_rel_pos + self.input_size = input_size + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.lr_factor = lr_factor + + self.qkv = nn.Linear( + dim, + dim * 3, + weight_attr=ParamAttr(learning_rate=lr_factor), + bias_attr=ParamAttr(learning_rate=lr_factor) + if attn_bias else False) + if qkv_bias: + self.q_bias = self.create_parameter( + shape=([dim]), default_initializer=zeros_) + self.v_bias = self.create_parameter( + shape=([dim]), default_initializer=zeros_) + else: + self.q_bias = None + self.v_bias = None + self.proj = nn.Linear( + dim, + dim, + weight_attr=ParamAttr(learning_rate=lr_factor), + bias_attr=ParamAttr(learning_rate=lr_factor)) + self.attn_drop = nn.Dropout(attn_drop) + if window_size is None: + self.window_size = self.input_size[0] + + self._init_weights() + + def _init_weights(self): + linear_init_(self.qkv) + linear_init_(self.proj) + + if self.use_rel_pos: + self.rel_pos_h = self.create_parameter( + [2 * self.window_size - 1, self.head_dim], + attr=ParamAttr(learning_rate=self.lr_factor), + default_initializer=Constant(value=0.)) + self.rel_pos_w = self.create_parameter( + [2 * self.window_size - 1, self.head_dim], + attr=ParamAttr(learning_rate=self.lr_factor), + default_initializer=Constant(value=0.)) + + if not self.rel_pos_zero_init: + TruncatedNormal(self.rel_pos_h, std=0.02) + TruncatedNormal(self.rel_pos_w, std=0.02) + + def get_rel_pos(self, seq_size, rel_pos): + max_rel_dist = int(2 * seq_size - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos = rel_pos.reshape([1, rel_pos.shape[0], -1]) + rel_pos = rel_pos.transpose([0, 2, 1]) + rel_pos_resized = F.interpolate( + rel_pos, + size=(max_rel_dist, ), + mode="linear", + data_format='NCW') + rel_pos_resized = rel_pos_resized.reshape([-1, max_rel_dist]) + rel_pos_resized = rel_pos_resized.transpose([1, 0]) + else: + rel_pos_resized = rel_pos + + coords = paddle.arange(seq_size, dtype='float32') + relative_coords = coords.unsqueeze(-1) - coords.unsqueeze(0) + relative_coords += (seq_size - 1) + relative_coords = relative_coords.astype('int64').flatten() + + return paddle.index_select(rel_pos_resized, relative_coords).reshape( + [seq_size, seq_size, self.head_dim]) + + def add_decomposed_rel_pos(self, attn, q, h, w): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + Rh = self.get_rel_pos(h, self.rel_pos_h) + Rw = self.get_rel_pos(w, self.rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape([B, h, w, dim]) + # bhwc, hch->bhwh1 + # bwhc, wcw->bhw1w + rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh).unsqueeze(-1) + rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw).unsqueeze(-2) + + attn = attn.reshape([B, h, w, h, w]) + rel_h + rel_w + return attn.reshape([B, h * w, h * w]) + + def forward(self, x): + B, H, W, C = paddle.shape(x) + + if self.q_bias is not None: + qkv_bias = paddle.concat( + (self.q_bias, paddle.zeros_like(self.v_bias), self.v_bias)) + qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x).reshape( + [B, H * W, 3, self.num_heads, self.head_dim]).transpose( + [2, 0, 3, 1, 4]).reshape( + [3, B * self.num_heads, H * W, self.head_dim]) + + q, k, v = qkv[0], qkv[1], qkv[2] + attn = q.matmul(k.transpose([0, 2, 1])) * self.scale + + if self.use_rel_pos: + attn = self.add_decomposed_rel_pos(attn, q, H, W) + + attn = F.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + x = attn.matmul(v).reshape( + [B, self.num_heads, H * W, self.head_dim]).transpose( + [0, 2, 1, 3]).reshape([B, H, W, C]) + x = self.proj(x) + return x + + +class Block(nn.Layer): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + attn_bias=False, + qk_scale=None, + init_values=None, + drop=0., + attn_drop=0., + drop_path=0., + use_rel_pos=True, + rel_pos_zero_init=True, + window_size=None, + input_size=None, + act_layer='nn.GELU', + norm_layer='nn.LayerNorm', + lr_factor=1.0, + epsilon=1e-5): + super().__init__() + self.window_size = window_size + + self.norm1 = eval(norm_layer)(dim, + weight_attr=ParamAttr( + learning_rate=lr_factor, + regularizer=L2Decay(0.0)), + bias_attr=ParamAttr( + learning_rate=lr_factor, + regularizer=L2Decay(0.0)), + epsilon=epsilon) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_bias=attn_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size, + input_size=input_size, + lr_factor=lr_factor) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + self.norm2 = eval(norm_layer)(dim, + weight_attr=ParamAttr( + learning_rate=lr_factor, + regularizer=L2Decay(0.0)), + bias_attr=ParamAttr( + learning_rate=lr_factor, + regularizer=L2Decay(0.0)), + epsilon=epsilon) + self.mlp = Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + lr_factor=lr_factor) + if init_values is not None: + self.gamma_1 = self.create_parameter( + shape=([dim]), default_initializer=Constant(value=init_values)) + self.gamma_2 = self.create_parameter( + shape=([dim]), default_initializer=Constant(value=init_values)) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x): + y = self.norm1(x) + if self.window_size is not None: + y, pad_hw, num_hw = window_partition(y, self.window_size) + y = self.attn(y) + if self.gamma_1 is not None: + y = self.gamma_1 * y + + if self.window_size is not None: + y = window_unpartition(y, pad_hw, num_hw, (x.shape[1], x.shape[2])) + x = x + self.drop_path(y) + if self.gamma_2 is None: + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + + return x + + +class PatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, + img_size=(224, 224), + patch_size=16, + in_chans=3, + embed_dim=768, + lr_factor=0.01): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.proj = nn.Conv2D( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + weight_attr=ParamAttr(learning_rate=lr_factor), + bias_attr=ParamAttr(learning_rate=lr_factor)) + + @property + def num_patches_in_h(self): + return self.img_size[1] // self.patch_size + + @property + def num_patches_in_w(self): + return self.img_size[0] // self.patch_size + + def forward(self, x): + out = self.proj(x) + return out + + +@register +@serializable +class VisionTransformer2D(nn.Layer): + """ Vision Transformer with support for patch input + """ + + def __init__(self, + img_size=(1024, 1024), + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=False, + attn_bias=False, + qk_scale=None, + init_values=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_layer='nn.GELU', + norm_layer='nn.LayerNorm', + lr_decay_rate=1.0, + global_attn_indexes=(2, 5, 8, 11), + use_abs_pos=False, + use_rel_pos=False, + use_abs_pos_emb=False, + use_sincos_pos_emb=False, + rel_pos_zero_init=True, + epsilon=1e-5, + final_norm=False, + pretrained=None, + window_size=None, + out_indices=(11, ), + with_fpn=False, + use_checkpoint=False, + *args, + **kwargs): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.num_heads = num_heads + self.depth = depth + self.global_attn_indexes = global_attn_indexes + self.epsilon = epsilon + self.with_fpn = with_fpn + self.use_checkpoint = use_checkpoint + + self.patch_h = img_size[0] // patch_size + self.patch_w = img_size[1] // patch_size + self.num_patches = self.patch_h * self.patch_w + self.use_abs_pos = use_abs_pos + self.use_abs_pos_emb = use_abs_pos_emb + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + + dpr = np.linspace(0, drop_path_rate, depth) + if use_checkpoint: + paddle.seed(0) + + if use_abs_pos_emb: + self.pos_w = self.patch_embed.num_patches_in_w + self.pos_h = self.patch_embed.num_patches_in_h + self.pos_embed = self.create_parameter( + shape=(1, self.pos_w * self.pos_h + 1, embed_dim), + default_initializer=paddle.nn.initializer.TruncatedNormal( + std=.02)) + elif use_sincos_pos_emb: + pos_embed = self.get_2d_sincos_position_embedding(self.patch_h, + self.patch_w) + + self.pos_embed = pos_embed + self.pos_embed = self.create_parameter(shape=pos_embed.shape) + self.pos_embed.set_value(pos_embed.numpy()) + self.pos_embed.stop_gradient = True + else: + self.pos_embed = None + + self.blocks = nn.LayerList([ + Block( + embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_bias=attn_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=None + if i in self.global_attn_indexes else window_size, + input_size=[self.patch_h, self.patch_w], + act_layer=act_layer, + lr_factor=self.get_vit_lr_decay_rate(i, lr_decay_rate), + norm_layer=norm_layer, + init_values=init_values, + epsilon=epsilon) for i in range(depth) + ]) + + assert len(out_indices) <= 4, 'out_indices out of bound' + self.out_indices = out_indices + self.pretrained = pretrained + self.init_weight() + + self.out_channels = [embed_dim for _ in range(len(out_indices))] + self.out_strides = [4, 8, 16, 32][-len(out_indices):] if with_fpn else [ + patch_size for _ in range(len(out_indices)) + ] + self.norm = Identity() + if self.with_fpn: + self.init_fpn( + embed_dim=embed_dim, + patch_size=patch_size, + out_with_norm=final_norm) + + def get_vit_lr_decay_rate(self, layer_id, lr_decay_rate): + return lr_decay_rate**(self.depth - layer_id) + + def init_weight(self): + pretrained = self.pretrained + if pretrained: + if 'http' in pretrained: + path = paddle.utils.download.get_weights_path_from_url( + pretrained) + else: + path = pretrained + + load_state_dict = paddle.load(path) + model_state_dict = self.state_dict() + pos_embed_name = "pos_embed" + + if pos_embed_name in load_state_dict.keys( + ) and self.use_abs_pos_emb: + load_pos_embed = paddle.to_tensor( + load_state_dict[pos_embed_name], dtype="float32") + if self.pos_embed.shape != load_pos_embed.shape: + pos_size = int(math.sqrt(load_pos_embed.shape[1] - 1)) + model_state_dict[pos_embed_name] = self.resize_pos_embed( + load_pos_embed, (pos_size, pos_size), + (self.pos_h, self.pos_w)) + + # self.set_state_dict(model_state_dict) + load_state_dict[pos_embed_name] = model_state_dict[ + pos_embed_name] + + print("Load pos_embed and resize it from {} to {} .".format( + load_pos_embed.shape, self.pos_embed.shape)) + + self.set_state_dict(load_state_dict) + print("Load load_state_dict....") + + def init_fpn(self, embed_dim=768, patch_size=16, out_with_norm=False): + if patch_size == 16: + self.fpn1 = nn.Sequential( + nn.Conv2DTranspose( + embed_dim, embed_dim, kernel_size=2, stride=2), + nn.BatchNorm2D(embed_dim), + nn.GELU(), + nn.Conv2DTranspose( + embed_dim, embed_dim, kernel_size=2, stride=2), ) + + self.fpn2 = nn.Sequential( + nn.Conv2DTranspose( + embed_dim, embed_dim, kernel_size=2, stride=2), ) + + self.fpn3 = Identity() + + self.fpn4 = nn.MaxPool2D(kernel_size=2, stride=2) + elif patch_size == 8: + self.fpn1 = nn.Sequential( + nn.Conv2DTranspose( + embed_dim, embed_dim, kernel_size=2, stride=2), ) + + self.fpn2 = Identity() + + self.fpn3 = nn.Sequential(nn.MaxPool2D(kernel_size=2, stride=2), ) + + self.fpn4 = nn.Sequential(nn.MaxPool2D(kernel_size=4, stride=4), ) + + if not out_with_norm: + self.norm = Identity() + else: + self.norm = nn.LayerNorm(embed_dim, epsilon=self.epsilon) + + def resize_pos_embed(self, pos_embed, old_hw, new_hw): + """ + Resize pos_embed weight. + Args: + pos_embed (Tensor): the pos_embed weight + old_hw (list[int]): the height and width of old pos_embed + new_hw (list[int]): the height and width of new pos_embed + Returns: + Tensor: the resized pos_embed weight + """ + cls_pos_embed = pos_embed[:, :1, :] + pos_embed = pos_embed[:, 1:, :] + + pos_embed = pos_embed.transpose([0, 2, 1]) + pos_embed = pos_embed.reshape([1, -1, old_hw[0], old_hw[1]]) + pos_embed = F.interpolate( + pos_embed, new_hw, mode='bicubic', align_corners=False) + pos_embed = pos_embed.flatten(2).transpose([0, 2, 1]) + pos_embed = paddle.concat([cls_pos_embed, pos_embed], axis=1) + + return pos_embed + + def get_2d_sincos_position_embedding(self, h, w, temperature=10000.): + grid_y, grid_x = paddle.meshgrid( + paddle.arange( + h, dtype=paddle.float32), + paddle.arange( + w, dtype=paddle.float32)) + assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + pos_dim = self.embed_dim // 4 + omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim + omega = (1. / (temperature**omega)).unsqueeze(0) + + out_x = grid_x.reshape([-1, 1]).matmul(omega) + out_y = grid_y.reshape([-1, 1]).matmul(omega) + + pos_emb = paddle.concat( + [ + paddle.sin(out_y), paddle.cos(out_y), paddle.sin(out_x), + paddle.cos(out_x) + ], + axis=1) + + return pos_emb.reshape([1, h, w, self.embed_dim]) + + def forward(self, inputs): + x = self.patch_embed(inputs['image']).transpose([0, 2, 3, 1]) + B, Hp, Wp, _ = paddle.shape(x) + + if self.use_abs_pos: + x = x + self.get_2d_sincos_position_embedding(Hp, Wp) + + if self.use_abs_pos_emb: + x = x + self.resize_pos_embed(self.pos_embed, + (self.pos_h, self.pos_w), (Hp, Wp)) + + feats = [] + for idx, blk in enumerate(self.blocks): + if self.use_checkpoint and self.training: + x = paddle.distributed.fleet.utils.recompute( + blk, x, **{"preserve_rng_state": True}) + else: + x = blk(x) + if idx in self.out_indices: + feats.append(self.norm(x.transpose([0, 3, 1, 2]))) + + if self.with_fpn: + fpns = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(feats)): + feats[i] = fpns[i](feats[i]) + return feats + + @property + def num_layers(self): + return len(self.blocks) + + @property + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=c, stride=s) + for c, s in zip(self.out_channels, self.out_strides) + ] + + +class LayerNorm(nn.Layer): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and + variance normalization over the channel dimension for inputs that have shape + (batch_size, channels, height, width). + Note that, the modified LayerNorm on used in ResBlock and SimpleFeaturePyramid. + + In ViT, we use the nn.LayerNorm + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = self.create_parameter([normalized_shape]) + self.bias = self.create_parameter([normalized_shape]) + self.eps = eps + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / paddle.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +@register +@serializable +class SimpleFeaturePyramid(nn.Layer): + def __init__(self, + in_channels, + out_channels, + spatial_scales, + num_levels=4, + use_bias=False): + """ + Args: + in_channels (list[int]): input channels of each level which can be + derived from the output shape of backbone by from_config + out_channel (int): output channel of each level. + spatial_scales (list[float]): list of scaling factors to upsample or downsample + the input features for creating pyramid features which can be derived from + the output shape of backbone by from_config + num_levels (int): number of levels of output features. + use_bias (bool): whether use bias or not. + """ + super(SimpleFeaturePyramid, self).__init__() + + self.in_channels = in_channels[0] + self.out_channels = out_channels + self.num_levels = num_levels + + self.stages = [] + dim = self.in_channels + if num_levels == 4: + scale_factors = [2.0, 1.0, 0.5] + elif num_levels == 5: + scale_factors = [4.0, 2.0, 1.0, 0.5] + else: + raise NotImplementedError( + f"num_levels={num_levels} is not supported yet.") + + dim = in_channels[0] + for idx, scale in enumerate(scale_factors): + out_dim = dim + if scale == 4.0: + layers = [ + nn.Conv2DTranspose( + dim, dim // 2, kernel_size=2, stride=2), + nn.LayerNorm(dim // 2), + nn.GELU(), + nn.Conv2DTranspose( + dim // 2, dim // 4, kernel_size=2, stride=2), + ] + out_dim = dim // 4 + elif scale == 2.0: + layers = [ + nn.Conv2DTranspose( + dim, dim // 2, kernel_size=2, stride=2) + ] + out_dim = dim // 2 + elif scale == 1.0: + layers = [] + elif scale == 0.5: + layers = [nn.MaxPool2D(kernel_size=2, stride=2)] + + layers.extend([ + nn.Conv2D( + out_dim, + out_channels, + kernel_size=1, + bias_attr=use_bias, ), LayerNorm(out_channels), nn.Conv2D( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias_attr=use_bias, ), LayerNorm(out_channels) + ]) + layers = nn.Sequential(*layers) + + stage = -int(math.log2(spatial_scales[0] * scale_factors[idx])) + self.add_sublayer(f"simfp_{stage}", layers) + self.stages.append(layers) + + # top block output feature maps. + self.top_block = nn.Sequential( + nn.MaxPool2D( + kernel_size=1, stride=2, padding=0)) + + @classmethod + def from_config(cls, cfg, input_shape): + return { + 'in_channels': [i.channels for i in input_shape], + 'spatial_scales': [1.0 / i.stride for i in input_shape], + } + + @property + def out_shape(self): + return [ + ShapeSpec(channels=self.out_channels) + for _ in range(self.num_levels) + ] + + def forward(self, feats): + """ + Args: + x: Tensor of shape (N,C,H,W). + """ + features = feats[0] + results = [] + + for stage in self.stages: + results.append(stage(features)) + + top_block_in_feature = results[-1] + results.append(self.top_block(top_block_in_feature)) + assert self.num_levels == len(results) + + return results diff --git a/ppdet/modeling/heads/detr_head.py b/ppdet/modeling/heads/detr_head.py index 6b9d8d8db..61448e4e0 100644 --- a/ppdet/modeling/heads/detr_head.py +++ b/ppdet/modeling/heads/detr_head.py @@ -380,10 +380,67 @@ class DINOHead(nn.Layer): assert 'gt_bbox' in inputs and 'gt_class' in inputs if dn_meta is not None: - dn_out_bboxes, dec_out_bboxes = paddle.split( - dec_out_bboxes, dn_meta['dn_num_split'], axis=2) - dn_out_logits, dec_out_logits = paddle.split( - dec_out_logits, dn_meta['dn_num_split'], axis=2) + if isinstance(dn_meta, list): + dual_groups = len(dn_meta) - 1 + dec_out_bboxes = paddle.split( + dec_out_bboxes, dual_groups + 1, axis=2) + dec_out_logits = paddle.split( + dec_out_logits, dual_groups + 1, axis=2) + enc_topk_bboxes = paddle.split( + enc_topk_bboxes, dual_groups + 1, axis=1) + enc_topk_logits = paddle.split( + enc_topk_logits, dual_groups + 1, axis=1) + + dec_out_bboxes_list = [] + dec_out_logits_list = [] + dn_out_bboxes_list = [] + dn_out_logits_list = [] + loss = {} + for g_id in range(dual_groups + 1): + if dn_meta[g_id] is not None: + dn_out_bboxes_gid, dec_out_bboxes_gid = paddle.split( + dec_out_bboxes[g_id], + dn_meta[g_id]['dn_num_split'], + axis=2) + dn_out_logits_gid, dec_out_logits_gid = paddle.split( + dec_out_logits[g_id], + dn_meta[g_id]['dn_num_split'], + axis=2) + else: + dn_out_bboxes_gid, dn_out_logits_gid = None, None + dec_out_bboxes_gid = dec_out_bboxes[g_id] + dec_out_logits_gid = dec_out_logits[g_id] + out_bboxes_gid = paddle.concat([ + enc_topk_bboxes[g_id].unsqueeze(0), + dec_out_bboxes_gid + ]) + out_logits_gid = paddle.concat([ + enc_topk_logits[g_id].unsqueeze(0), + dec_out_logits_gid + ]) + loss_gid = self.loss( + out_bboxes_gid, + out_logits_gid, + inputs['gt_bbox'], + inputs['gt_class'], + dn_out_bboxes=dn_out_bboxes_gid, + dn_out_logits=dn_out_logits_gid, + dn_meta=dn_meta[g_id]) + # sum loss + for key, value in loss_gid.items(): + loss.update({ + key: loss.get(key, paddle.zeros([1])) + value + }) + + # average across (dual_groups + 1) + for key, value in loss.items(): + loss.update({key: value / (dual_groups + 1)}) + return loss + else: + dn_out_bboxes, dec_out_bboxes = paddle.split( + dec_out_bboxes, dn_meta['dn_num_split'], axis=2) + dn_out_logits, dec_out_logits = paddle.split( + dec_out_logits, dn_meta['dn_num_split'], axis=2) else: dn_out_bboxes, dn_out_logits = None, None diff --git a/ppdet/modeling/initializer.py b/ppdet/modeling/initializer.py index 758eed240..308c51baf 100644 --- a/ppdet/modeling/initializer.py +++ b/ppdet/modeling/initializer.py @@ -273,7 +273,8 @@ def kaiming_normal_(tensor, def linear_init_(module): bound = 1 / math.sqrt(module.weight.shape[0]) uniform_(module.weight, -bound, bound) - uniform_(module.bias, -bound, bound) + if hasattr(module, "bias") and module.bias is not None: + uniform_(module.bias, -bound, bound) def conv_init_(module): diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 933d012de..b48cc98a7 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -67,7 +67,8 @@ class BBoxPostProcess(object): """ if self.nms is not None: bboxes, score = self.decode(head_out, rois, im_shape, scale_factor) - bbox_pred, bbox_num, before_nms_indexes = self.nms(bboxes, score, self.num_classes) + bbox_pred, bbox_num, before_nms_indexes = self.nms(bboxes, score, + self.num_classes) else: bbox_pred, bbox_num = self.decode(head_out, rois, im_shape, @@ -449,10 +450,14 @@ class DETRBBoxPostProcess(object): def __init__(self, num_classes=80, num_top_queries=100, + dual_queries=False, + dual_groups=0, use_focal_loss=False): super(DETRBBoxPostProcess, self).__init__() self.num_classes = num_classes self.num_top_queries = num_top_queries + self.dual_queries = dual_queries + self.dual_groups = dual_groups self.use_focal_loss = use_focal_loss def __call__(self, head_out, im_shape, scale_factor): @@ -471,6 +476,10 @@ class DETRBBoxPostProcess(object): shape [bs], and is N. """ bboxes, logits, masks = head_out + if self.dual_queries: + num_queries = logits.shape[1] + logits, bboxes = logits[:, :int(num_queries // (self.dual_groups + 1)), :], \ + bboxes[:, :int(num_queries // (self.dual_groups + 1)), :] bbox_pred = bbox_cxcywh_to_xyxy(bboxes) origin_shape = paddle.floor(im_shape / scale_factor + 0.5) diff --git a/ppdet/modeling/transformers/__init__.py b/ppdet/modeling/transformers/__init__.py index e55cb0c1d..0457e0414 100644 --- a/ppdet/modeling/transformers/__init__.py +++ b/ppdet/modeling/transformers/__init__.py @@ -18,6 +18,7 @@ from . import matchers from . import position_encoding from . import deformable_transformer from . import dino_transformer +from . import group_detr_transformer from .detr_transformer import * from .utils import * @@ -26,3 +27,4 @@ from .position_encoding import * from .deformable_transformer import * from .dino_transformer import * from .petr_transformer import * +from .group_detr_transformer import * diff --git a/ppdet/modeling/transformers/group_detr_transformer.py b/ppdet/modeling/transformers/group_detr_transformer.py new file mode 100644 index 000000000..31ec6172e --- /dev/null +++ b/ppdet/modeling/transformers/group_detr_transformer.py @@ -0,0 +1,857 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Modified from detrex (https://github.com/IDEA-Research/detrex) +# Copyright 2022 The IDEA Authors. All rights reserved. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +from ppdet.core.workspace import register +from ..layers import MultiHeadAttention +from .position_encoding import PositionEmbedding +from ..heads.detr_head import MLP +from .deformable_transformer import MSDeformableAttention +from ..initializer import (linear_init_, constant_, xavier_uniform_, normal_, + bias_init_with_prob) +from .utils import (_get_clones, get_valid_ratio, + get_contrastive_denoising_training_group, + get_sine_pos_embed, inverse_sigmoid) + +__all__ = ['GroupDINOTransformer'] + + +class DINOTransformerEncoderLayer(nn.Layer): + def __init__(self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0., + activation="relu", + n_levels=4, + n_points=4, + weight_attr=None, + bias_attr=None): + super(DINOTransformerEncoderLayer, self).__init__() + # self attention + self.self_attn = MSDeformableAttention(d_model, n_head, n_levels, + n_points, 1.0) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr, + bias_attr) + self.activation = getattr(F, activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr, + bias_attr) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, + src, + reference_points, + spatial_shapes, + level_start_index, + src_mask=None, + query_pos_embed=None): + # self attention + src2 = self.self_attn( + self.with_pos_embed(src, query_pos_embed), reference_points, src, + spatial_shapes, level_start_index, src_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + # ffn + src = self.forward_ffn(src) + + return src + + +class DINOTransformerEncoder(nn.Layer): + def __init__(self, encoder_layer, num_layers): + super(DINOTransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, offset=0.5): + valid_ratios = valid_ratios.unsqueeze(1) + reference_points = [] + for i, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = paddle.meshgrid( + paddle.arange(end=H) + offset, paddle.arange(end=W) + offset) + ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] * + H) + ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] * + W) + reference_points.append(paddle.stack((ref_x, ref_y), axis=-1)) + reference_points = paddle.concat(reference_points, 1).unsqueeze(2) + reference_points = reference_points * valid_ratios + return reference_points + + def forward(self, + feat, + spatial_shapes, + level_start_index, + feat_mask=None, + query_pos_embed=None, + valid_ratios=None): + if valid_ratios is None: + valid_ratios = paddle.ones( + [feat.shape[0], spatial_shapes.shape[0], 2]) + reference_points = self.get_reference_points(spatial_shapes, + valid_ratios) + for layer in self.layers: + feat = layer(feat, reference_points, spatial_shapes, + level_start_index, feat_mask, query_pos_embed) + + return feat + + +class DINOTransformerDecoderLayer(nn.Layer): + def __init__(self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0., + activation="relu", + n_levels=4, + n_points=4, + dual_queries=False, + dual_groups=0, + weight_attr=None, + bias_attr=None): + super(DINOTransformerDecoderLayer, self).__init__() + + # self attention + self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # cross attention + self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, + n_points, 1.0) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr, + bias_attr) + self.activation = getattr(F, activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr, + bias_attr) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm( + d_model, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # for dual groups + self.dual_queries = dual_queries + self.dual_groups = dual_groups + self.n_head = n_head + + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + xavier_uniform_(self.linear1.weight) + xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + + def forward(self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + attn_mask=None, + memory_mask=None, + query_pos_embed=None): + # self attention + q = k = self.with_pos_embed(tgt, query_pos_embed) + if self.dual_queries: + dual_groups = self.dual_groups + bs, num_queries, n_model = paddle.shape(q) + q = paddle.concat(q.split(dual_groups + 1, axis=1), axis=0) + k = paddle.concat(k.split(dual_groups + 1, axis=1), axis=0) + tgt = paddle.concat(tgt.split(dual_groups + 1, axis=1), axis=0) + + g_num_queries = num_queries // (dual_groups + 1) + if attn_mask is None or attn_mask[0] is None: + attn_mask = None + else: + # [(dual_groups + 1), g_num_queries, g_num_queries] + attn_mask = paddle.concat( + [sa_mask.unsqueeze(0) for sa_mask in attn_mask], axis=0) + # [1, (dual_groups + 1), 1, g_num_queries, g_num_queries] + # --> [bs, (dual_groups + 1), nhead, g_num_queries, g_num_queries] + # --> [bs * (dual_groups + 1), nhead, g_num_queries, g_num_queries] + attn_mask = attn_mask.unsqueeze(0).unsqueeze(2).tile( + [bs, 1, self.n_head, 1, 1]) + attn_mask = attn_mask.reshape([ + bs * (dual_groups + 1), self.n_head, g_num_queries, + g_num_queries + ]) + + if attn_mask is not None: + attn_mask = attn_mask.astype('bool') + + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm2(tgt) + + # trace back + if self.dual_queries: + tgt = paddle.concat(tgt.split(dual_groups + 1, axis=0), axis=1) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos_embed), reference_points, memory, + memory_spatial_shapes, memory_level_start_index, memory_mask) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt2 = self.forward_ffn(tgt) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class DINOTransformerDecoder(nn.Layer): + def __init__(self, + hidden_dim, + decoder_layer, + num_layers, + return_intermediate=True): + super(DINOTransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + self.norm = nn.LayerNorm( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + def forward(self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + bbox_head, + query_pos_head, + valid_ratios=None, + attn_mask=None, + memory_mask=None): + if valid_ratios is None: + valid_ratios = paddle.ones( + [memory.shape[0], memory_spatial_shapes.shape[0], 2]) + + output = tgt + intermediate = [] + inter_ref_bboxes = [] + for i, layer in enumerate(self.layers): + reference_points_input = reference_points.unsqueeze( + 2) * valid_ratios.tile([1, 1, 2]).unsqueeze(1) + query_pos_embed = get_sine_pos_embed( + reference_points_input[..., 0, :], self.hidden_dim // 2) + query_pos_embed = query_pos_head(query_pos_embed) + + output = layer(output, reference_points_input, memory, + memory_spatial_shapes, memory_level_start_index, + attn_mask, memory_mask, query_pos_embed) + inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid( + reference_points)) + + if self.return_intermediate: + intermediate.append(self.norm(output)) + inter_ref_bboxes.append(inter_ref_bbox) + + reference_points = inter_ref_bbox.detach() + + if self.return_intermediate: + return paddle.stack(intermediate), paddle.stack(inter_ref_bboxes) + + return output, reference_points + + +@register +class GroupDINOTransformer(nn.Layer): + __shared__ = ['num_classes', 'hidden_dim'] + + def __init__(self, + num_classes=80, + hidden_dim=256, + num_queries=900, + position_embed_type='sine', + return_intermediate_dec=True, + backbone_feat_channels=[512, 1024, 2048], + num_levels=4, + num_encoder_points=4, + num_decoder_points=4, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0., + activation="relu", + pe_temperature=10000, + pe_offset=-0.5, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=True, + use_input_proj=True, + dual_queries=False, + dual_groups=0, + eps=1e-2): + super(GroupDINOTransformer, self).__init__() + assert position_embed_type in ['sine', 'learned'], \ + f'ValueError: position_embed_type not supported {position_embed_type}!' + assert len(backbone_feat_channels) <= num_levels + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = num_queries + self.eps = eps + self.num_decoder_layers = num_decoder_layers + self.use_input_proj = use_input_proj + + if use_input_proj: + # backbone feature projection + self._build_input_proj_layer(backbone_feat_channels) + + # Transformer module + encoder_layer = DINOTransformerEncoderLayer( + hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, + num_encoder_points) + self.encoder = DINOTransformerEncoder(encoder_layer, num_encoder_layers) + decoder_layer = DINOTransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + num_levels, + num_decoder_points, + dual_queries=dual_queries, + dual_groups=dual_groups) + self.decoder = DINOTransformerDecoder(hidden_dim, decoder_layer, + num_decoder_layers, + return_intermediate_dec) + + # denoising part + self.denoising_class_embed = nn.Embedding( + num_classes, + hidden_dim, + weight_attr=ParamAttr(initializer=nn.initializer.Normal())) + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + + # for dual group + self.dual_queries = dual_queries + self.dual_groups = dual_groups + if self.dual_queries: + self.denoising_class_embed_groups = nn.LayerList([ + nn.Embedding( + num_classes, + hidden_dim, + weight_attr=ParamAttr(initializer=nn.initializer.Normal())) + for _ in range(self.dual_groups) + ]) + + # position embedding + self.position_embedding = PositionEmbedding( + hidden_dim // 2, + temperature=pe_temperature, + normalize=True if position_embed_type == 'sine' else False, + embed_type=position_embed_type, + offset=pe_offset) + self.level_embed = nn.Embedding(num_levels, hidden_dim) + # decoder embedding + self.learnt_init_query = learnt_init_query + if learnt_init_query: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + normal_(self.tgt_embed.weight) + if self.dual_queries: + self.tgt_embed_dual = nn.LayerList([ + nn.Embedding(num_queries, hidden_dim) + for _ in range(self.dual_groups) + ]) + for dual_tgt_module in self.tgt_embed_dual: + normal_(dual_tgt_module.weight) + self.query_pos_head = MLP(2 * hidden_dim, + hidden_dim, + hidden_dim, + num_layers=2) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))) + if self.dual_queries: + self.enc_output = _get_clones(self.enc_output, self.dual_groups + 1) + else: + self.enc_output = _get_clones(self.enc_output, 1) + + self.enc_score_head = nn.Linear(hidden_dim, num_classes) + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) + + if self.dual_queries: + self.enc_bbox_head_dq = nn.LayerList([ + MLP(hidden_dim, hidden_dim, 4, num_layers=3) + for i in range(self.dual_groups) + ]) + self.enc_score_head_dq = nn.LayerList([ + nn.Linear(hidden_dim, num_classes) + for i in range(self.dual_groups) + ]) + + # decoder head + self.dec_score_head = nn.LayerList([ + nn.Linear(hidden_dim, num_classes) + for _ in range(num_decoder_layers) + ]) + self.dec_bbox_head = nn.LayerList([ + MLP(hidden_dim, hidden_dim, 4, num_layers=3) + for _ in range(num_decoder_layers) + ]) + + self._reset_parameters() + + def _reset_parameters(self): + # class and bbox head init + bias_cls = bias_init_with_prob(0.01) + linear_init_(self.enc_score_head) + constant_(self.enc_score_head.bias, bias_cls) + constant_(self.enc_bbox_head.layers[-1].weight) + constant_(self.enc_bbox_head.layers[-1].bias) + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + linear_init_(cls_) + constant_(cls_.bias, bias_cls) + constant_(reg_.layers[-1].weight) + constant_(reg_.layers[-1].bias) + + for enc_output in self.enc_output: + linear_init_(enc_output[0]) + xavier_uniform_(enc_output[0].weight) + normal_(self.level_embed.weight) + if self.learnt_init_query: + xavier_uniform_(self.tgt_embed.weight) + xavier_uniform_(self.query_pos_head.layers[0].weight) + xavier_uniform_(self.query_pos_head.layers[1].weight) + normal_(self.denoising_class_embed.weight) + if self.use_input_proj: + for l in self.input_proj: + xavier_uniform_(l[0].weight) + constant_(l[0].bias) + + @classmethod + def from_config(cls, cfg, input_shape): + return {'backbone_feat_channels': [i.channels for i in input_shape], } + + def _build_input_proj_layer(self, backbone_feat_channels): + self.input_proj = nn.LayerList() + for in_channels in backbone_feat_channels: + self.input_proj.append( + nn.Sequential( + ('conv', nn.Conv2D( + in_channels, self.hidden_dim, kernel_size=1)), + ('norm', nn.GroupNorm( + 32, + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))) + in_channels = backbone_feat_channels[-1] + for _ in range(self.num_levels - len(backbone_feat_channels)): + self.input_proj.append( + nn.Sequential( + ('conv', nn.Conv2D( + in_channels, + self.hidden_dim, + kernel_size=3, + stride=2, + padding=1)), ('norm', nn.GroupNorm( + 32, + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats, pad_mask=None): + if self.use_input_proj: + # get projection features + proj_feats = [ + self.input_proj[i](feat) for i, feat in enumerate(feats) + ] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + else: + proj_feats = feats + # get encoder inputs + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + valid_ratios = [] + for i, feat in enumerate(proj_feats): + bs, _, h, w = paddle.shape(feat) + spatial_shapes.append(paddle.concat([h, w])) + # [b,c,h,w] -> [b,h*w,c] + feat_flatten.append(feat.flatten(2).transpose([0, 2, 1])) + if pad_mask is not None: + mask = F.interpolate(pad_mask.unsqueeze(0), size=(h, w))[0] + else: + mask = paddle.ones([bs, h, w]) + valid_ratios.append(get_valid_ratio(mask)) + # [b, h*w, c] + pos_embed = self.position_embedding(mask).flatten(1, 2) + lvl_pos_embed = pos_embed + self.level_embed.weight[i].reshape( + [1, 1, -1]) + lvl_pos_embed_flatten.append(lvl_pos_embed) + if pad_mask is not None: + # [b, h*w] + mask_flatten.append(mask.flatten(1)) + + # [b, l, c] + feat_flatten = paddle.concat(feat_flatten, 1) + # [b, l] + mask_flatten = None if pad_mask is None else paddle.concat(mask_flatten, + 1) + # [b, l, c] + lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1) + # [num_levels, 2] + spatial_shapes = paddle.to_tensor( + paddle.stack(spatial_shapes).astype('int64')) + # [l] start index of each level + level_start_index = paddle.concat([ + paddle.zeros( + [1], dtype='int64'), spatial_shapes.prod(1).cumsum(0)[:-1] + ]) + # [b, num_levels, 2] + valid_ratios = paddle.stack(valid_ratios, 1) + return (feat_flatten, spatial_shapes, level_start_index, mask_flatten, + lvl_pos_embed_flatten, valid_ratios) + + def forward(self, feats, pad_mask=None, gt_meta=None): + # input projection and embedding + (feat_flatten, spatial_shapes, level_start_index, mask_flatten, + lvl_pos_embed_flatten, + valid_ratios) = self._get_encoder_input(feats, pad_mask) + + # encoder + memory = self.encoder(feat_flatten, spatial_shapes, level_start_index, + mask_flatten, lvl_pos_embed_flatten, valid_ratios) + + # prepare denoising training + if self.training: + denoising_class, denoising_bbox, attn_mask, dn_meta = \ + get_contrastive_denoising_training_group(gt_meta, + self.num_classes, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale) + if self.dual_queries: + denoising_class_groups = [] + denoising_bbox_groups = [] + attn_mask_groups = [] + dn_meta_groups = [] + for g_id in range(self.dual_groups): + denoising_class_gid, denoising_bbox_gid, attn_mask_gid, dn_meta_gid = \ + get_contrastive_denoising_training_group(gt_meta, + self.num_classes, + self.num_queries, + self.denoising_class_embed_groups[g_id].weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale) + denoising_class_groups.append(denoising_class_gid) + denoising_bbox_groups.append(denoising_bbox_gid) + attn_mask_groups.append(attn_mask_gid) + dn_meta_groups.append(dn_meta_gid) + + # combine + denoising_class = [denoising_class] + denoising_class_groups + denoising_bbox = [denoising_bbox] + denoising_bbox_groups + attn_mask = [attn_mask] + attn_mask_groups + dn_meta = [dn_meta] + dn_meta_groups + else: + denoising_class, denoising_bbox, attn_mask, dn_meta = None, None, None, None + + target, init_ref_points, enc_topk_bboxes, enc_topk_logits = \ + self._get_decoder_input( + memory, spatial_shapes, mask_flatten, denoising_class, + denoising_bbox) + + # decoder + inter_feats, inter_ref_bboxes = self.decoder( + target, init_ref_points, memory, spatial_shapes, level_start_index, + self.dec_bbox_head, self.query_pos_head, valid_ratios, attn_mask, + mask_flatten) + # solve hang during distributed training + inter_feats[0] += self.denoising_class_embed.weight[0, 0] * 0. + if self.dual_queries: + for g_id in range(self.dual_groups): + inter_feats[0] += self.denoising_class_embed_groups[ + g_id].weight[0, 0] * 0.0 + + out_bboxes = [] + out_logits = [] + for i in range(self.num_decoder_layers): + out_logits.append(self.dec_score_head[i](inter_feats[i])) + if i == 0: + out_bboxes.append( + F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) + + inverse_sigmoid(init_ref_points))) + else: + out_bboxes.append( + F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) + + inverse_sigmoid(inter_ref_bboxes[i - 1]))) + + out_bboxes = paddle.stack(out_bboxes) + out_logits = paddle.stack(out_logits) + return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, + dn_meta) + + def _get_encoder_output_anchors(self, + memory, + spatial_shapes, + memory_mask=None, + grid_size=0.05): + output_anchors = [] + idx = 0 + for lvl, (h, w) in enumerate(spatial_shapes): + if memory_mask is not None: + mask_ = memory_mask[:, idx:idx + h * w].reshape([-1, h, w]) + valid_H = paddle.sum(mask_[:, :, 0], 1) + valid_W = paddle.sum(mask_[:, 0, :], 1) + else: + valid_H, valid_W = h, w + + grid_y, grid_x = paddle.meshgrid( + paddle.arange( + end=h, dtype=memory.dtype), + paddle.arange( + end=w, dtype=memory.dtype)) + grid_xy = paddle.stack([grid_x, grid_y], -1) + + valid_WH = paddle.stack([valid_W, valid_H], -1).reshape( + [-1, 1, 1, 2]).astype(grid_xy.dtype) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH + wh = paddle.ones_like(grid_xy) * grid_size * (2.0**lvl) + output_anchors.append( + paddle.concat([grid_xy, wh], -1).reshape([-1, h * w, 4])) + idx += h * w + + output_anchors = paddle.concat(output_anchors, 1) + valid_mask = ((output_anchors > self.eps) * + (output_anchors < 1 - self.eps)).all(-1, keepdim=True) + output_anchors = paddle.log(output_anchors / (1 - output_anchors)) + if memory_mask is not None: + valid_mask = (valid_mask * (memory_mask.unsqueeze(-1) > 0)) > 0 + output_anchors = paddle.where(valid_mask, output_anchors, + paddle.to_tensor(float("inf"))) + + memory = paddle.where(valid_mask, memory, paddle.to_tensor(0.)) + if self.dual_queries: + output_memory = [ + self.enc_output[g_id](memory) + for g_id in range(self.dual_groups + 1) + ] + else: + output_memory = self.enc_output[0](memory) + return output_memory, output_anchors + + def _get_decoder_input(self, + memory, + spatial_shapes, + memory_mask=None, + denoising_class=None, + denoising_bbox=None): + bs, _, _ = memory.shape + # prepare input for decoder + output_memory, output_anchors = self._get_encoder_output_anchors( + memory, spatial_shapes, memory_mask) + if self.dual_queries: + enc_outputs_class = self.enc_score_head(output_memory[0]) + enc_outputs_coord_unact = self.enc_bbox_head(output_memory[ + 0]) + output_anchors + else: + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_unact = self.enc_bbox_head( + output_memory) + output_anchors + + _, topk_ind = paddle.topk( + enc_outputs_class.max(-1), self.num_queries, axis=1) + # extract region proposal boxes + batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype) + batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries]) + topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1) + topk_coords_unact = paddle.gather_nd(enc_outputs_coord_unact, + topk_ind) # unsigmoided. + enc_topk_bboxes = F.sigmoid(topk_coords_unact) + reference_points = enc_topk_bboxes.detach() + enc_topk_logits = paddle.gather_nd(enc_outputs_class, topk_ind) + + if self.dual_queries: + enc_topk_logits_groups = [] + enc_topk_bboxes_groups = [] + reference_points_groups = [] + topk_ind_groups = [] + for g_id in range(self.dual_groups): + enc_outputs_class_gid = self.enc_score_head_dq[g_id]( + output_memory[g_id + 1]) + enc_outputs_coord_unact_gid = self.enc_bbox_head_dq[g_id]( + output_memory[g_id + 1]) + output_anchors + _, topk_ind_gid = paddle.topk( + enc_outputs_class_gid.max(-1), self.num_queries, axis=1) + # extract region proposal boxes + batch_ind = paddle.arange(end=bs, dtype=topk_ind_gid.dtype) + batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries]) + topk_ind_gid = paddle.stack([batch_ind, topk_ind_gid], axis=-1) + topk_coords_unact_gid = paddle.gather_nd( + enc_outputs_coord_unact_gid, topk_ind_gid) # unsigmoided. + enc_topk_bboxes_gid = F.sigmoid(topk_coords_unact_gid) + reference_points_gid = enc_topk_bboxes_gid.detach() + enc_topk_logits_gid = paddle.gather_nd(enc_outputs_class_gid, + topk_ind_gid) + + # append and combine + topk_ind_groups.append(topk_ind_gid) + enc_topk_logits_groups.append(enc_topk_logits_gid) + enc_topk_bboxes_groups.append(enc_topk_bboxes_gid) + reference_points_groups.append(reference_points_gid) + + enc_topk_bboxes = paddle.concat( + [enc_topk_bboxes] + enc_topk_bboxes_groups, 1) + enc_topk_logits = paddle.concat( + [enc_topk_logits] + enc_topk_logits_groups, 1) + reference_points = paddle.concat( + [reference_points] + reference_points_groups, 1) + topk_ind = paddle.concat([topk_ind] + topk_ind_groups, 1) + + # extract region features + if self.learnt_init_query: + target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + if self.dual_queries: + target = paddle.concat([target] + [ + self.tgt_embed_dual[g_id].weight.unsqueeze(0).tile( + [bs, 1, 1]) for g_id in range(self.dual_groups) + ], 1) + else: + if self.dual_queries: + target = paddle.gather_nd(output_memory[0], topk_ind) + target_groups = [] + for g_id in range(self.dual_groups): + target_gid = paddle.gather_nd(output_memory[g_id + 1], + topk_ind_groups[g_id]) + target_groups.append(target_gid) + target = paddle.concat([target] + target_groups, 1).detach() + else: + target = paddle.gather_nd(output_memory, topk_ind).detach() + + if denoising_bbox is not None: + if isinstance(denoising_bbox, list) and isinstance( + denoising_class, list) and self.dual_queries: + if denoising_bbox[0] is not None: + reference_points_list = paddle.split( + reference_points, self.dual_groups + 1, axis=1) + reference_points = paddle.concat( + [ + paddle.concat( + [ref, ref_], axis=1) + for ref, ref_ in zip(denoising_bbox, + reference_points_list) + ], + axis=1) + + target_list = paddle.split( + target, self.dual_groups + 1, axis=1) + target = paddle.concat( + [ + paddle.concat( + [tgt, tgt_], axis=1) + for tgt, tgt_ in zip(denoising_class, target_list) + ], + axis=1) + else: + reference_points, target = reference_points, target + else: + reference_points = paddle.concat( + [denoising_bbox, reference_points], 1) + target = paddle.concat([denoising_class, target], 1) + + return target, reference_points, enc_topk_bboxes, enc_topk_logits -- GitLab