diff --git a/configs/faster_rcnn/README.md b/configs/faster_rcnn/README.md index da495599ce180b80ce019ff1828ae63c1140a7ff..8ba30cbcbd7b7b64549522a6abd64c6e0495c6b4 100644 --- a/configs/faster_rcnn/README.md +++ b/configs/faster_rcnn/README.md @@ -23,7 +23,7 @@ | ResNet50-vd-SSLDv2-FPN | Faster | 1 | 2x | ---- | 42.3 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams) | [配置文件](./faster_rcnn_r50_vd_fpn_ssld_2x_coco.yml) | | Swin-Tiny-FPN | Faster | 2 | 1x | ---- | 42.6 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_1x_coco.pdparams) | [配置文件](./faster_rcnn_swin_tiny_fpn_1x_coco.yml) | | Swin-Tiny-FPN | Faster | 2 | 2x | ---- | 44.8 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_2x_coco.pdparams) | [配置文件](./faster_rcnn_swin_tiny_fpn_2x_coco.yml) | -| Swin-Tiny-FPN | Faster | 2 | 3x | ---- | 45.3 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_3x_coco.pdparams) | [配置文件](./faster_rcnn_swin_tiny_fpn_3x_coco.yml) | +| Swin-Tiny-FPN | Faster | 2 | 3x | ---- | 45.3 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_3x_coco.pdparams) | [配置文件](../swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml) | ## Citations ``` diff --git a/configs/faster_rcnn/_base_/optimizer_swin_1x.yml b/configs/faster_rcnn/_base_/optimizer_swin_1x.yml index 5c1c6679940834f8ff3bb985bb44f6dc2f281428..66de8f0b5d93f7cb95b25acc052d6fdf4af2eed8 100644 --- a/configs/faster_rcnn/_base_/optimizer_swin_1x.yml +++ b/configs/faster_rcnn/_base_/optimizer_swin_1x.yml @@ -15,8 +15,6 @@ OptimizerBuilder: optimizer: type: AdamW weight_decay: 0.05 - param_groups: - - - params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] - weight_decay: 0. + - params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + weight_decay: 0.0 diff --git a/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_2x_coco.yml b/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_2x_coco.yml index 5848c4943b4a40a5b306fb87d9aae7508f56a8c7..902dcbe831a6f1e585f8a1c1ca96378826086be8 100644 --- a/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_2x_coco.yml +++ b/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_2x_coco.yml @@ -14,9 +14,3 @@ LearningRate: - !LinearWarmup start_factor: 0.1 steps: 1000 - -OptimizerBuilder: - clip_grad_by_norm: 1.0 - optimizer: - type: AdamW - weight_decay: 0.05 diff --git a/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_3x_coco.yml b/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_3x_coco.yml deleted file mode 100644 index a1b68cf4703886be497d8efa6aea4b9c5d256797..0000000000000000000000000000000000000000 --- a/configs/faster_rcnn/faster_rcnn_swin_tiny_fpn_3x_coco.yml +++ /dev/null @@ -1,22 +0,0 @@ -_BASE_: [ - 'faster_rcnn_swin_tiny_fpn_1x_coco.yml', -] -weights: output/faster_rcnn_swin_tiny_fpn_3x_coco/model_final - -epoch: 36 - -LearningRate: - base_lr: 0.0001 - schedulers: - - !PiecewiseDecay - gamma: 0.1 - milestones: [24, 33] - - !LinearWarmup - start_factor: 0.1 - steps: 1000 - -OptimizerBuilder: - clip_grad_by_norm: 1.0 - optimizer: - type: AdamW - weight_decay: 0.05 diff --git a/configs/swin/README.md b/configs/swin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..617ee67d3ff840687d0547d5a462cd0a74f07c1a --- /dev/null +++ b/configs/swin/README.md @@ -0,0 +1,26 @@ +# Swin Transformer + +## COCO Model Zoo + +| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | +| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| swin_T_224 | Faster R-CNN | 2 | 36e | ---- | 45.3 | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_swin_tiny_fpn_3x_coco.pdparams) | [配置文件](./faster_rcnn_swin_tiny_fpn_3x_coco.yml) | +| swin_T_224 | PP-YOLOE+ | 8 | 36e | ---- | 43.6 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_swin_tiny_36e_coco.pdparams) | [配置文件](./ppyoloe_plus_swin_tiny_36e_coco.yml) | + + +## Citations +``` +@article{liu2021Swin, + title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, + author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining}, + journal={arXiv preprint arXiv:2103.14030}, + year={2021} +} + +@inproceedings{liu2021swinv2, + title={Swin Transformer V2: Scaling Up Capacity and Resolution}, + author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo}, + booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2022} +} +``` diff --git a/configs/swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml b/configs/swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..3fb2da3dde15322eb106d4bc1b069fc664c68ea8 --- /dev/null +++ b/configs/swin/faster_rcnn_swin_tiny_fpn_3x_coco.yml @@ -0,0 +1,82 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '../faster_rcnn/_base_/faster_rcnn_r50_fpn.yml', + '../faster_rcnn/_base_/faster_fpn_reader.yml', +] +weights: output/faster_rcnn_swin_tiny_fpn_3x_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/swin_tiny_patch4_window7_224_22kto1k_pretrained.pdparams + + +FasterRCNN: + backbone: SwinTransformer + neck: FPN + rpn_head: RPNHead + bbox_head: BBoxHead + bbox_post_process: BBoxPostProcess + +SwinTransformer: + arch: 'swin_T_224' # ['swin_T_224', 'swin_S_224', 'swin_B_224', 'swin_L_224', 'swin_B_384', 'swin_L_384'] + ape: false + drop_path_rate: 0.1 + patch_norm: true + out_indices: [0, 1, 2, 3] + + +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomResizeCrop: {resizes: [400, 500, 600], cropsizes: [[384, 600], ], prob: 0.5} + - RandomResize: {target_size: [[480, 1333], [512, 1333], [544, 1333], [576, 1333], [608, 1333], [640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 2} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + +TestReader: + inputs_def: + image_shape: [-1, 3, 640, 640] # TODO deploy: set fixes shape currently + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: 640, keep_ratio: True} + - Pad: {size: 640} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_size: 1 + + +epoch: 36 +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [24, 33] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 1.0 + optimizer: + type: AdamW + weight_decay: 0.05 + param_groups: + - params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + weight_decay: 0.0 diff --git a/configs/swin/ppyoloe_plus_swin_tiny_36e_coco.yml b/configs/swin/ppyoloe_plus_swin_tiny_36e_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..a5403d86e840db4595d419727eeda54f161825be --- /dev/null +++ b/configs/swin/ppyoloe_plus_swin_tiny_36e_coco.yml @@ -0,0 +1,67 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '../ppyoloe/_base_/ppyoloe_plus_crn.yml', + '../ppyoloe/_base_/ppyoloe_plus_reader.yml', +] +depth_mult: 0.33 # s version +width_mult: 0.50 + +log_iter: 50 +snapshot_epoch: 4 +weights: output/ppyoloe_plus_swin_tiny_36e_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/swin_tiny_patch4_window7_224_22kto1k_pretrained.pdparams + + +architecture: PPYOLOE +norm_type: sync_bn +use_ema: true +ema_decay: 0.9998 +ema_black_list: ['proj_conv.weight'] +custom_black_list: ['reduce_mean'] + +PPYOLOE: + backbone: SwinTransformer + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + +SwinTransformer: + arch: 'swin_T_224' # ['swin_T_224', 'swin_S_224', 'swin_B_224', 'swin_L_224', 'swin_B_384', 'swin_L_384'] + ape: false + drop_path_rate: 0.1 + patch_norm: true + out_indices: [1, 2, 3] + +PPYOLOEHead: + static_assigner_epoch: 12 + nms: + nms_top_k: 10000 + keep_top_k: 300 + score_threshold: 0.01 + nms_threshold: 0.7 + + +TrainReader: + batch_size: 8 + + +epoch: 36 +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [24, 33] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 1.0 + optimizer: + type: AdamW + weight_decay: 0.05 + param_groups: + - params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + weight_decay: 0.0 diff --git a/ppdet/modeling/backbones/swin_transformer.py b/ppdet/modeling/backbones/swin_transformer.py index 8a581b763d6e6ffb3498e337f4294ed82fbedce0..64aabab47811500e2534716d28c0233d82f1973c 100644 --- a/ppdet/modeling/backbones/swin_transformer.py +++ b/ppdet/modeling/backbones/swin_transformer.py @@ -191,8 +191,6 @@ class WindowAttention(nn.Layer): relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 self.relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", - self.relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -425,7 +423,6 @@ class BasicLayer(nn.Layer): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. @@ -500,10 +497,7 @@ class BasicLayer(nn.Layer): cnt = 0 for h in h_slices: for w in w_slices: - try: - img_mask[:, h, w, :] = cnt - except: - pass + img_mask[:, h, w, :] = cnt cnt += 1 @@ -572,15 +566,12 @@ class PatchEmbed(nn.Layer): @register @serializable class SwinTransformer(nn.Layer): - """ Swin Transformer - A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - + """ Swin Transformer backbone Args: - img_size (int | tuple(int)): Input image size. Default 224 + arch (str): Architecture of FocalNet + pretrain_img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. @@ -619,6 +610,7 @@ class SwinTransformer(nn.Layer): pretrained=None): super(SwinTransformer, self).__init__() assert arch in MODEL_cfg.keys(), "Unsupported arch: {}".format(arch) + pretrain_img_size = MODEL_cfg[arch]['pretrain_img_size'] embed_dim = MODEL_cfg[arch]['embed_dim'] depths = MODEL_cfg[arch]['depths'] @@ -748,7 +740,7 @@ class SwinTransformer(nn.Layer): (0, 3, 1, 2)) outs.append(out) - return tuple(outs) + return outs @property def out_shape(self): diff --git a/ppdet/modeling/transformers/utils.py b/ppdet/modeling/transformers/utils.py index d8f869fbc27a807af88f2d4de262774a9f8638ce..c41b069cffb9fe4913899361a8e6ac1d708e0b17 100644 --- a/ppdet/modeling/transformers/utils.py +++ b/ppdet/modeling/transformers/utils.py @@ -236,7 +236,7 @@ def get_sine_pos_embed(pos_tensor, """generate sine position embedding from a position tensor Args: - pos_tensor (torch.Tensor): Shape as `(None, n)`. + pos_tensor (Tensor): Shape as `(None, n)`. num_pos_feats (int): projected shape for each float in the tensor. Default: 128 temperature (int): The temperature used for scaling the position embedding. Default: 10000. @@ -245,7 +245,7 @@ def get_sine_pos_embed(pos_tensor, be `[pos(y), pos(x)]`. Defaults: True. Returns: - torch.Tensor: Returned position embedding # noqa + Tensor: Returned position embedding # noqa with shape `(None, n * num_pos_feats)`. """ scale = 2. * math.pi