From c684201e30b6fb9948c620a2415fbd9444ccab23 Mon Sep 17 00:00:00 2001 From: Wenyu Date: Thu, 14 Jul 2022 15:14:08 +0800 Subject: [PATCH] [WIP] Add vitdet (#6397) * add cascade vitdet * cascade vit * fix for model export * add vit cascade cfgs --- configs/vitdet/README.md | 65 +++++++++ configs/vitdet/_base_/optimizer_base_1x.yml | 22 +++ configs/vitdet/_base_/reader.yml | 41 ++++++ ...ascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml | 129 ++++++++++++++++++ ...scade_rcnn_vit_large_hrfpn_cae_1x_coco.yml | 27 ++++ .../modeling/backbones/vision_transformer.py | 9 +- ppdet/modeling/heads/bbox_head.py | 32 ++++- ppdet/modeling/heads/cascade_head.py | 24 +++- 8 files changed, 333 insertions(+), 16 deletions(-) create mode 100644 configs/vitdet/README.md create mode 100644 configs/vitdet/_base_/optimizer_base_1x.yml create mode 100644 configs/vitdet/_base_/reader.yml create mode 100644 configs/vitdet/cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml create mode 100644 configs/vitdet/cascade_rcnn_vit_large_hrfpn_cae_1x_coco.yml diff --git a/configs/vitdet/README.md b/configs/vitdet/README.md new file mode 100644 index 000000000..64d8cbf9a --- /dev/null +++ b/configs/vitdet/README.md @@ -0,0 +1,65 @@ +# Vision transformer Detection + +## Introduction + +- [Context Autoencoder for Self-Supervised Representation Learning](https://arxiv.org/abs/2202.03026) +- [Benchmarking Detection Transfer Learning with Vision Transformers](https://arxiv.org/pdf/2111.11429.pdf) + +Object detection is a central downstream task used to +test if pre-trained network parameters confer benefits, such +as improved accuracy or training speed. The complexity +of object detection methods can make this benchmarking +non-trivial when new architectures, such as Vision Transformer (ViT) models, arrive. + +## Model Zoo + +| Backbone | Pretrained | Model | Scheduler | Images/GPU | Box AP | Config | Download | +|:------:|:--------:|:--------------:|:--------------:|:--------------:|:------:|:------:|:--------:| +| ViT-base | CAE | Cascade RCNN | 1x | 1 | -- | [config](./cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml) | [coming soon]() | +| ViT-large | CAE | Cascade RCNN | 1x | 1 | -- | [config](./cascade_rcnn_vit_large_hrfpn_cae_1x_coco.yml) | [coming soon]() | + +**Notes:** +- Model is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. +- Base model is trained on 8x32G V100 GPU, large model on 8x80G A100. + + +## Citations +``` +@article{chen2022context, + title={Context autoencoder for self-supervised representation learning}, + author={Chen, Xiaokang and Ding, Mingyu and Wang, Xiaodi and Xin, Ying and Mo, Shentong and Wang, Yunhao and Han, Shumin and Luo, Ping and Zeng, Gang and Wang, Jingdong}, + journal={arXiv preprint arXiv:2202.03026}, + year={2022} +} + +@article{DBLP:journals/corr/abs-2111-11429, + author = {Yanghao Li and + Saining Xie and + Xinlei Chen and + Piotr Doll{\'{a}}r and + Kaiming He and + Ross B. Girshick}, + title = {Benchmarking Detection Transfer Learning with Vision Transformers}, + journal = {CoRR}, + volume = {abs/2111.11429}, + year = {2021}, + url = {https://arxiv.org/abs/2111.11429}, + eprinttype = {arXiv}, + eprint = {2111.11429}, + timestamp = {Fri, 26 Nov 2021 13:48:43 +0100}, + biburl = {https://dblp.org/rec/journals/corr/abs-2111-11429.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + +@article{Cai_2019, + title={Cascade R-CNN: High Quality Object Detection and Instance Segmentation}, + ISSN={1939-3539}, + url={http://dx.doi.org/10.1109/tpami.2019.2956516}, + DOI={10.1109/tpami.2019.2956516}, + journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, + publisher={Institute of Electrical and Electronics Engineers (IEEE)}, + author={Cai, Zhaowei and Vasconcelos, Nuno}, + year={2019}, + pages={1–1} +} +``` diff --git a/configs/vitdet/_base_/optimizer_base_1x.yml b/configs/vitdet/_base_/optimizer_base_1x.yml new file mode 100644 index 000000000..b822b3bf9 --- /dev/null +++ b/configs/vitdet/_base_/optimizer_base_1x.yml @@ -0,0 +1,22 @@ +epoch: 12 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [9, 11] + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: AdamWDL + betas: [0.9, 0.999] + layer_decay: 0.75 + weight_decay: 0.02 + num_layers: 12 + filter_bias_and_bn: True + skip_decay_names: ['pos_embed', 'cls_token'] + set_param_lr_func: 'layerwise_lr_decay' diff --git a/configs/vitdet/_base_/reader.yml b/configs/vitdet/_base_/reader.yml new file mode 100644 index 000000000..1af6175a9 --- /dev/null +++ b/configs/vitdet/_base_/reader.yml @@ -0,0 +1,41 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomResizeCrop: {resizes: [400, 500, 600], cropsizes: [[384, 600], ], prob: 0.5} + - RandomResize: {target_size: [[480, 1333], [512, 1333], [544, 1333], [576, 1333], [608, 1333], [640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 2} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false + drop_empty: false + + +TestReader: + inputs_def: + image_shape: [-1, 3, 640, 640] + sample_transforms: + - Decode: {} + - LetterBoxResize: {target_size: 640} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/vitdet/cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml b/configs/vitdet/cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml new file mode 100644 index 000000000..3499ffd55 --- /dev/null +++ b/configs/vitdet/cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml @@ -0,0 +1,129 @@ + +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + './_base_/reader.yml', + './_base_/optimizer_base_1x.yml' +] + +weights: output/cascade_rcnn_vit_base_hrfpn_cae_1x_coco/model_final + + +# runtime +log_iter: 100 +snapshot_epoch: 1 +find_unused_parameters: True + +use_gpu: true +norm_type: sync_bn + + +# reader +worker_num: 2 +TrainReader: + batch_size: 1 + + +# model +architecture: CascadeRCNN + +CascadeRCNN: + backbone: VisionTransformer + neck: HRFPN + rpn_head: RPNHead + bbox_head: CascadeHead + # post process + bbox_post_process: BBoxPostProcess + + +VisionTransformer: + patch_size: 16 + embed_dim: 768 + depth: 12 + num_heads: 12 + mlp_ratio: 4 + qkv_bias: True + drop_rate: 0.0 + drop_path_rate: 0.2 + init_values: 0.1 + final_norm: False + use_rel_pos_bias: False + use_sincos_pos_emb: True + epsilon: 0.000001 # 1e-6 + out_indices: [3, 5, 7, 11] + with_fpn: True + pretrained: ~ + +HRFPN: + out_channel: 256 + use_bias: True + +RPNHead: + anchor_generator: + aspect_ratios: [0.5, 1.0, 2.0] + anchor_sizes: [[32], [64], [128], [256], [512]] + strides: [4, 8, 16, 32, 64] + rpn_target_assign: + batch_size_per_im: 256 + fg_fraction: 0.5 + negative_overlap: 0.3 + positive_overlap: 0.7 + use_random: True + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 2000 + topk_after_collect: True + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 1000 + post_nms_top_n: 1000 + loss_rpn_bbox: SmoothL1Loss + +SmoothL1Loss: + beta: 0.1111111111111111 + + +CascadeHead: + head: CascadeXConvNormHead + roi_extractor: + resolution: 7 + sampling_ratio: 0 + aligned: True + bbox_assigner: BBoxAssigner + bbox_loss: GIoULoss + num_cascade_stages: 3 + reg_class_agnostic: False + stage_loss_weights: [1, 0.5, 0.25] + loss_normalize_pos: True + +BBoxAssigner: + batch_size_per_im: 512 + bg_thresh: 0.5 + fg_thresh: 0.5 + fg_fraction: 0.25 + cascade_iou: [0.5, 0.6, 0.7] + use_random: True + + +CascadeXConvNormHead: + norm_type: bn + + +GIoULoss: + loss_weight: 10. + reduction: 'none' + eps: 0.000001 + + +BBoxPostProcess: + decode: + name: RCNNBox + prior_box_var: [30.0, 30.0, 15.0, 15.0] + nms: + name: MultiClassNMS + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 diff --git a/configs/vitdet/cascade_rcnn_vit_large_hrfpn_cae_1x_coco.yml b/configs/vitdet/cascade_rcnn_vit_large_hrfpn_cae_1x_coco.yml new file mode 100644 index 000000000..bdb20bdf1 --- /dev/null +++ b/configs/vitdet/cascade_rcnn_vit_large_hrfpn_cae_1x_coco.yml @@ -0,0 +1,27 @@ +_BASE_: [ + './cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml' +] + +weights: output/cascade_rcnn_vit_large_hrfpn_cae_1x_coco/model_final + + +depth: &depth 24 +dim: &dim 1024 + +VisionTransformer: + img_size: [800, 1344] + embed_dim: *dim + depth: *depth + num_heads: 16 + drop_path_rate: 0.25 + out_indices: [7, 11, 15, 23] + pretrained: ~ + +HRFPN: + in_channels: [*dim, *dim, *dim, *dim] + +OptimizerBuilder: + optimizer: + layer_decay: 0.9 + weight_decay: 0.02 + num_layers: *depth diff --git a/ppdet/modeling/backbones/vision_transformer.py b/ppdet/modeling/backbones/vision_transformer.py index 4f980d482..798ea3768 100644 --- a/ppdet/modeling/backbones/vision_transformer.py +++ b/ppdet/modeling/backbones/vision_transformer.py @@ -578,9 +578,10 @@ class VisionTransformer(nn.Layer): x = self.patch_embed(x) - x_shape = paddle.shape(x) # b * c * h * w + B, D, Hp, Wp = x.shape # b * c * h * w - cls_tokens = self.cls_token.expand((x_shape[0], -1, -1)) + cls_tokens = self.cls_token.expand( + (B, self.cls_token.shape[-2], self.cls_token.shape[-1])) x = x.flatten(2).transpose([0, 2, 1]) # b * hw * c x = paddle.concat([cls_tokens, x], axis=1) @@ -593,8 +594,6 @@ class VisionTransformer(nn.Layer): rel_pos_bias = self.rel_pos_bias( ) if self.rel_pos_bias is not None else None - B, _, Hp, Wp = x_shape - feats = [] for idx, blk in enumerate(self.blocks): if self.use_checkpoint: @@ -607,7 +606,7 @@ class VisionTransformer(nn.Layer): xp = paddle.reshape( paddle.transpose( self.norm(x[:, 1:, :]), perm=[0, 2, 1]), - shape=[B, -1, Hp, Wp]) + shape=[B, D, Hp, Wp]) feats.append(xp) if self.with_fpn: diff --git a/ppdet/modeling/heads/bbox_head.py b/ppdet/modeling/heads/bbox_head.py index 8874ba067..debd3074c 100644 --- a/ppdet/modeling/heads/bbox_head.py +++ b/ppdet/modeling/heads/bbox_head.py @@ -257,7 +257,13 @@ class BBoxHead(nn.Layer): pred = self.get_prediction(scores, deltas) return pred, self.head - def get_loss(self, scores, deltas, targets, rois, bbox_weight): + def get_loss(self, + scores, + deltas, + targets, + rois, + bbox_weight, + loss_normalize_pos=False): """ scores (Tensor): scores from bbox head outputs deltas (Tensor): deltas from bbox head outputs @@ -280,8 +286,15 @@ class BBoxHead(nn.Layer): else: tgt_labels = tgt_labels.cast('int64') tgt_labels.stop_gradient = True - loss_bbox_cls = F.cross_entropy( - input=scores, label=tgt_labels, reduction='mean') + + if not loss_normalize_pos: + loss_bbox_cls = F.cross_entropy( + input=scores, label=tgt_labels, reduction='mean') + else: + loss_bbox_cls = F.cross_entropy( + input=scores, label=tgt_labels, + reduction='none').sum() / (tgt_labels.shape[0] + 1e-7) + loss_bbox[cls_name] = loss_bbox_cls # bbox reg @@ -322,9 +335,16 @@ class BBoxHead(nn.Layer): if self.bbox_loss is not None: reg_delta = self.bbox_transform(reg_delta) reg_target = self.bbox_transform(reg_target) - loss_bbox_reg = self.bbox_loss( - reg_delta, reg_target).sum() / tgt_labels.shape[0] - loss_bbox_reg *= self.num_classes + + if not loss_normalize_pos: + loss_bbox_reg = self.bbox_loss( + reg_delta, reg_target).sum() / tgt_labels.shape[0] + loss_bbox_reg *= self.num_classes + + else: + loss_bbox_reg = self.bbox_loss( + reg_delta, reg_target).sum() / (tgt_labels.shape[0] + 1e-7) + else: loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum( ) / tgt_labels.shape[0] diff --git a/ppdet/modeling/heads/cascade_head.py b/ppdet/modeling/heads/cascade_head.py index c07c22734..9a964e4e0 100644 --- a/ppdet/modeling/heads/cascade_head.py +++ b/ppdet/modeling/heads/cascade_head.py @@ -162,7 +162,8 @@ class CascadeHead(BBoxHead): num_cascade_stages=3, bbox_loss=None, reg_class_agnostic=True, - stage_loss_weights=None): + stage_loss_weights=None, + loss_normalize_pos=False): nn.Layer.__init__(self, ) self.head = head @@ -184,6 +185,7 @@ class CascadeHead(BBoxHead): self.reg_class_agnostic = reg_class_agnostic num_bbox_delta = 4 if reg_class_agnostic else 4 * num_classes + self.loss_normalize_pos = loss_normalize_pos self.bbox_score_list = [] self.bbox_delta_list = [] @@ -242,9 +244,16 @@ class CascadeHead(BBoxHead): # TODO (lyuwenyu) Is it correct for only one class ? if not self.reg_class_agnostic and i < self.num_cascade_stages - 1: - deltas = deltas.reshape([-1, self.num_classes, 4]) + deltas = deltas.reshape([deltas.shape[0], self.num_classes, 4]) labels = scores[:, :-1].argmax(axis=-1) - deltas = deltas[paddle.arange(deltas.shape[0]), labels] + + if self.training: + deltas = deltas[paddle.arange(deltas.shape[0]), labels] + else: + deltas = deltas[(deltas * F.one_hot( + labels, num_classes=self.num_classes).unsqueeze(-1) != 0 + ).nonzero(as_tuple=True)].reshape( + [deltas.shape[0], 4]) head_out_list.append([scores, deltas, rois]) pred_bbox = self._get_pred_bbox(deltas, rois, self.bbox_weight[i]) @@ -253,8 +262,13 @@ class CascadeHead(BBoxHead): loss = {} for stage, value in enumerate(zip(head_out_list, targets_list)): (scores, deltas, rois), targets = value - loss_stage = self.get_loss(scores, deltas, targets, rois, - self.bbox_weight[stage]) + loss_stage = self.get_loss( + scores, + deltas, + targets, + rois, + self.bbox_weight[stage], + loss_normalize_pos=self.loss_normalize_pos) for k, v in loss_stage.items(): loss[k + "_stage{}".format( stage)] = v * self.stage_loss_weights[stage] -- GitLab