From 653604c05ba15ee204b3b8a5642108758df78327 Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Tue, 7 Mar 2023 11:45:00 +0800 Subject: [PATCH] Fix swin and add swin ppyoloe (#7857) * refine swin configs and codes * fix swin ppyoloe * fix swin for ema and distill training * fix configs for CI * fix docs, test=document_fix --- configs/faster_rcnn/README.md | 2 +- .../faster_rcnn/_base_/optimizer_swin_1x.yml | 6 +- .../faster_rcnn_swin_tiny_fpn_2x_coco.yml | 6 -- .../faster_rcnn_swin_tiny_fpn_3x_coco.yml | 22 ----- configs/swin/README.md | 26 ++++++ .../faster_rcnn_swin_tiny_fpn_3x_coco.yml | 82 +++++++++++++++++++ .../swin/ppyoloe_plus_swin_tiny_36e_coco.yml | 67 +++++++++++++++ ppdet/modeling/backbones/swin_transformer.py | 20 ++--- ppdet/modeling/transformers/utils.py | 4 +- 9 files changed, 186 insertions(+), 49 deletions(-) delete mode 100644 configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_3x_coco.yml create mode 100644 configs/swin/README.md create mode 100644 configs/swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml create mode 100644 configs/swin/ppyoloe_plus_swin_tiny_36e_coco.yml diff --git a/configs/faster_rcnn/README.md b/configs/faster_rcnn/README.md index da495599c..8ba30cbcb 100644 --- a/configs/faster_rcnn/README.md +++ b/configs/faster_rcnn/README.md @@ -23,7 +23,7 @@ | ResNet50-vd-SSLDv2-FPN | Faster | 1 | 2x | ---- | 42.3 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams) | [配置文件](./faster_rcnn_r50_vd_fpn_ssld_2x_coco.yml) | | Swin-Tiny-FPN | Faster | 2 | 1x | ---- | 42.6 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_1x_coco.pdparams) | [配置文件](./faster_rcnn_swin_tiny_fpn_1x_coco.yml) | | Swin-Tiny-FPN | Faster | 2 | 2x | ---- | 44.8 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_2x_coco.pdparams) | [配置文件](./faster_rcnn_swin_tiny_fpn_2x_coco.yml) | -| Swin-Tiny-FPN | Faster | 2 | 3x | ---- | 45.3 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_3x_coco.pdparams) | [配置文件](./faster_rcnn_swin_tiny_fpn_3x_coco.yml) | +| Swin-Tiny-FPN | Faster | 2 | 3x | ---- | 45.3 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_3x_coco.pdparams) | [配置文件](../swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml) | ## Citations ``` diff --git a/configs/faster_rcnn/_base_/optimizer_swin_1x.yml b/configs/faster_rcnn/_base_/optimizer_swin_1x.yml index 5c1c66799..66de8f0b5 100644 --- a/configs/faster_rcnn/_base_/optimizer_swin_1x.yml +++ b/configs/faster_rcnn/_base_/optimizer_swin_1x.yml @@ -15,8 +15,6 @@ OptimizerBuilder: optimizer: type: AdamW weight_decay: 0.05 - param_groups: - - - params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] - weight_decay: 0. + - params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + weight_decay: 0.0 diff --git a/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_2x_coco.yml b/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_2x_coco.yml index 5848c4943..902dcbe83 100644 --- a/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_2x_coco.yml +++ b/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_2x_coco.yml @@ -14,9 +14,3 @@ LearningRate: - !LinearWarmup start_factor: 0.1 steps: 1000 - -OptimizerBuilder: - clip_grad_by_norm: 1.0 - optimizer: - type: AdamW - weight_decay: 0.05 diff --git a/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_3x_coco.yml b/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_3x_coco.yml deleted file mode 100644 index a1b68cf47..000000000 --- a/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_3x_coco.yml +++ /dev/null @@ -1,22 +0,0 @@ -_BASE_: [ - 'faster_rcnn_swin_tiny_fpn_1x_coco.yml', -] -weights: output/faster_rcnn_swin_tiny_fpn_3x_coco/model_final - -epoch: 36 - -LearningRate: - base_lr: 0.0001 - schedulers: - - !PiecewiseDecay - gamma: 0.1 - milestones: [24, 33] - - !LinearWarmup - start_factor: 0.1 - steps: 1000 - -OptimizerBuilder: - clip_grad_by_norm: 1.0 - optimizer: - type: AdamW - weight_decay: 0.05 diff --git a/configs/swin/README.md b/configs/swin/README.md new file mode 100644 index 000000000..617ee67d3 --- /dev/null +++ b/configs/swin/README.md @@ -0,0 +1,26 @@ +# Swin Transformer + +## COCO Model Zoo + +| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | +| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| swin_T_224 | Faster R-CNN | 2 | 36e | ---- | 45.3 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_3x_coco.pdparams) | [配置文件](./faster_rcnn_swin_tiny_fpn_3x_coco.yml) | +| swin_T_224 | PP-YOLOE+ | 8 | 36e | ---- | 43.6 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_swin_tiny_36e_coco.pdparams) | [配置文件](./ppyoloe_plus_swin_tiny_36e_coco.yml) | + + +## Citations +``` +@article{liu2021Swin, + title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, + author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining}, + journal={arXiv preprint arXiv:2103.14030}, + year={2021} +} + +@inproceedings{liu2021swinv2, + title={Swin Transformer V2: Scaling Up Capacity and Resolution}, + author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo}, + booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2022} +} +``` diff --git a/configs/swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml b/configs/swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml new file mode 100644 index 000000000..3fb2da3dd --- /dev/null +++ b/configs/swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml @@ -0,0 +1,82 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '../faster_rcnn/_base_/faster_rcnn_r50_fpn.yml', + '../faster_rcnn/_base_/faster_fpn_reader.yml', +] +weights: output/faster_rcnn_swin_tiny_fpn_3x_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/swin_tiny_patch4_window7_224_22kto1k_pretrained.pdparams + + +FasterRCNN: + backbone: SwinTransformer + neck: FPN + rpn_head: RPNHead + bbox_head: BBoxHead + bbox_post_process: BBoxPostProcess + +SwinTransformer: + arch: 'swin_T_224' # ['swin_T_224', 'swin_S_224', 'swin_B_224', 'swin_L_224', 'swin_B_384', 'swin_L_384'] + ape: false + drop_path_rate: 0.1 + patch_norm: true + out_indices: [0, 1, 2, 3] + + +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomResizeCrop: {resizes: [400, 500, 600], cropsizes: [[384, 600], ], prob: 0.5} + - RandomResize: {target_size: [[480, 1333], [512, 1333], [544, 1333], [576, 1333], [608, 1333], [640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 2} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + +TestReader: + inputs_def: + image_shape: [-1, 3, 640, 640] # TODO deploy: set fixes shape currently + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: 640, keep_ratio: True} + - Pad: {size: 640} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_size: 1 + + +epoch: 36 +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [24, 33] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 1.0 + optimizer: + type: AdamW + weight_decay: 0.05 + param_groups: + - params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + weight_decay: 0.0 diff --git a/configs/swin/ppyoloe_plus_swin_tiny_36e_coco.yml b/configs/swin/ppyoloe_plus_swin_tiny_36e_coco.yml new file mode 100644 index 000000000..a5403d86e --- /dev/null +++ b/configs/swin/ppyoloe_plus_swin_tiny_36e_coco.yml @@ -0,0 +1,67 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '../ppyoloe/_base_/ppyoloe_plus_crn.yml', + '../ppyoloe/_base_/ppyoloe_plus_reader.yml', +] +depth_mult: 0.33 # s version +width_mult: 0.50 + +log_iter: 50 +snapshot_epoch: 4 +weights: output/ppyoloe_plus_swin_tiny_36e_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/swin_tiny_patch4_window7_224_22kto1k_pretrained.pdparams + + +architecture: PPYOLOE +norm_type: sync_bn +use_ema: true +ema_decay: 0.9998 +ema_black_list: ['proj_conv.weight'] +custom_black_list: ['reduce_mean'] + +PPYOLOE: + backbone: SwinTransformer + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + +SwinTransformer: + arch: 'swin_T_224' # ['swin_T_224', 'swin_S_224', 'swin_B_224', 'swin_L_224', 'swin_B_384', 'swin_L_384'] + ape: false + drop_path_rate: 0.1 + patch_norm: true + out_indices: [1, 2, 3] + +PPYOLOEHead: + static_assigner_epoch: 12 + nms: + nms_top_k: 10000 + keep_top_k: 300 + score_threshold: 0.01 + nms_threshold: 0.7 + + +TrainReader: + batch_size: 8 + + +epoch: 36 +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [24, 33] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 1.0 + optimizer: + type: AdamW + weight_decay: 0.05 + param_groups: + - params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + weight_decay: 0.0 diff --git a/ppdet/modeling/backbones/swin_transformer.py b/ppdet/modeling/backbones/swin_transformer.py index 8a581b763..64aabab47 100644 --- a/ppdet/modeling/backbones/swin_transformer.py +++ b/ppdet/modeling/backbones/swin_transformer.py @@ -191,8 +191,6 @@ class WindowAttention(nn.Layer): relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 self.relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", - self.relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -425,7 +423,6 @@ class BasicLayer(nn.Layer): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. @@ -500,10 +497,7 @@ class BasicLayer(nn.Layer): cnt = 0 for h in h_slices: for w in w_slices: - try: - img_mask[:, h, w, :] = cnt - except: - pass + img_mask[:, h, w, :] = cnt cnt += 1 @@ -572,15 +566,12 @@ class PatchEmbed(nn.Layer): @register @serializable class SwinTransformer(nn.Layer): - """ Swin Transformer - A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - + """ Swin Transformer backbone Args: - img_size (int | tuple(int)): Input image size. Default 224 + arch (str): Architecture of FocalNet + pretrain_img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. @@ -619,6 +610,7 @@ class SwinTransformer(nn.Layer): pretrained=None): super(SwinTransformer, self).__init__() assert arch in MODEL_cfg.keys(), "Unsupported arch: {}".format(arch) + pretrain_img_size = MODEL_cfg[arch]['pretrain_img_size'] embed_dim = MODEL_cfg[arch]['embed_dim'] depths = MODEL_cfg[arch]['depths'] @@ -748,7 +740,7 @@ class SwinTransformer(nn.Layer): (0, 3, 1, 2)) outs.append(out) - return tuple(outs) + return outs @property def out_shape(self): diff --git a/ppdet/modeling/transformers/utils.py b/ppdet/modeling/transformers/utils.py index d8f869fbc..c41b069cf 100644 --- a/ppdet/modeling/transformers/utils.py +++ b/ppdet/modeling/transformers/utils.py @@ -236,7 +236,7 @@ def get_sine_pos_embed(pos_tensor, """generate sine position embedding from a position tensor Args: - pos_tensor (torch.Tensor): Shape as `(None, n)`. + pos_tensor (Tensor): Shape as `(None, n)`. num_pos_feats (int): projected shape for each float in the tensor. Default: 128 temperature (int): The temperature used for scaling the position embedding. Default: 10000. @@ -245,7 +245,7 @@ def get_sine_pos_embed(pos_tensor, be `[pos(y), pos(x)]`. Defaults: True. Returns: - torch.Tensor: Returned position embedding # noqa + Tensor: Returned position embedding # noqa with shape `(None, n * num_pos_feats)`. """ scale = 2. * math.pi -- GitLab