未验证 提交 3f7e70d4 编写于 作者: W Wenyu 提交者: GitHub

add vit + mask rcnn (#7592)

上级 9daf164b
...@@ -13,11 +13,14 @@ non-trivial when new architectures, such as Vision Transformer (ViT) models, arr ...@@ -13,11 +13,14 @@ non-trivial when new architectures, such as Vision Transformer (ViT) models, arr
## Model Zoo ## Model Zoo
| Backbone | Pretrained | Model | Scheduler | Images/GPU | Box AP | Config | Download | | Model | Backbone | Pretrained | Scheduler | Images/GPU | Box AP | Mask AP | Config | Download |
|:------:|:--------:|:--------------:|:--------------:|:--------------:|:------:|:------:|:--------:| |:------:|:--------:|:--------------:|:--------------:|:--------------:|:--------------:|:------:|:------:|:--------:|
| ViT-base | CAE | Cascade RCNN | 1x | 1 | 52.7 | [config](./cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/cascade_rcnn_vit_base_hrfpn_cae_1x_coco.pdparams) | | Cascade RCNN | ViT-base | CAE | 1x | 1 | 52.7 | - | [config](./cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/cascade_rcnn_vit_base_hrfpn_cae_1x_coco.pdparams) |
| ViT-large | CAE | Cascade RCNN | 1x | 1 | 55.7 | [config](./cascade_rcnn_vit_large_hrfpn_cae_1x_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/cascade_rcnn_vit_large_hrfpn_cae_1x_coco.pdparams) | | Cascade RCNN | ViT-large | CAE | 1x | 1 | 55.7 | - | [config](./cascade_rcnn_vit_large_hrfpn_cae_1x_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/cascade_rcnn_vit_large_hrfpn_cae_1x_coco.pdparams) |
| ViT-base | CAE | PP-YOLOE | 36e | 2 | 52.2 | [config](./ppyoloe_vit_base_csppan_cae_36e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_vit_base_csppan_cae_36e_coco.pdparams) | | PP-YOLOE | ViT-base | CAE | 36e | 2 | 52.2 | - | [config](./ppyoloe_vit_base_csppan_cae_36e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_vit_base_csppan_cae_36e_coco.pdparams) |
| Mask RCNN | ViT-base | CAE | 1x | 1 | 50.6 | 44.9 | [config](./mask_rcnn_vit_base_hrfpn_cae_1x_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/mask_rcnn_vit_base_hrfpn_cae_1x_coco.pdparams) |
| Mask RCNN | ViT-large | CAE | 1x | 1 | 54.2 | 47.4 | [config](./mask_rcnn_vit_large_hrfpn_cae_1x_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/mask_rcnn_vit_large_hrfpn_cae_1x_coco.pdparams) |
**Notes:** **Notes:**
- Model is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95) - Model is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
_BASE_: [ _BASE_: [
'../datasets/coco_detection.yml', '../datasets/coco_detection.yml',
'../runtime.yml', '../runtime.yml',
'./_base_/reader.yml', './_base_/faster_rcnn_reader.yml',
'./_base_/optimizer_base_1x.yml' './_base_/optimizer_base_1x.yml'
] ]
...@@ -81,15 +81,30 @@ RPNHead: ...@@ -81,15 +81,30 @@ RPNHead:
nms_thresh: 0.7 nms_thresh: 0.7
pre_nms_top_n: 1000 pre_nms_top_n: 1000
post_nms_top_n: 1000 post_nms_top_n: 1000
loss_rpn_bbox: SmoothL1Loss
SmoothL1Loss:
beta: 0.1111111111111111
BBoxHead: BBoxHead:
head: TwoFCHead # head: TwoFCHead
head: XConvNormHead
roi_extractor: roi_extractor:
resolution: 7 resolution: 7
sampling_ratio: 0 sampling_ratio: 0
aligned: True aligned: True
bbox_assigner: BBoxAssigner bbox_assigner: BBoxAssigner
loss_normalize_pos: True
bbox_loss: GIoULoss
GIoULoss:
loss_weight: 10.
reduction: 'none'
eps: 0.000001 # 1e-6
BBoxAssigner: BBoxAssigner:
batch_size_per_im: 512 batch_size_per_im: 512
...@@ -98,8 +113,13 @@ BBoxAssigner: ...@@ -98,8 +113,13 @@ BBoxAssigner:
fg_fraction: 0.25 fg_fraction: 0.25
use_random: True use_random: True
TwoFCHead: # TwoFCHead:
out_channel: 1024 # out_channel: 1024
XConvNormHead:
num_convs: 4
norm_type: bn
BBoxPostProcess: BBoxPostProcess:
decode: RCNNBox decode: RCNNBox
......
_BASE_: [
'./mask_rcnn_vit_base_hrfpn_cae_1x_coco.yml'
]
weights: output/mask_rcnn_vit_large_hrfpn_cae_1x_coco/model_final
depth: &depth 24
dim: &dim 1024
use_fused_allreduce_gradients: &use_checkpoint True
VisionTransformer:
img_size: [800, 1344]
embed_dim: *dim
depth: *depth
num_heads: 16
drop_path_rate: 0.25
out_indices: [7, 11, 15, 23]
use_checkpoint: *use_checkpoint
pretrained: https://bj.bcebos.com/v1/paddledet/models/pretrained/vit_large_cae_pretrained.pdparams
HRFPN:
in_channels: [*dim, *dim, *dim, *dim]
OptimizerBuilder:
optimizer:
layer_decay: 0.9
weight_decay: 0.02
num_layers: *depth
...@@ -509,16 +509,24 @@ class VisionTransformer(nn.Layer): ...@@ -509,16 +509,24 @@ class VisionTransformer(nn.Layer):
dim = x.shape[-1] dim = x.shape[-1]
# we add a small number to avoid floating point error in the interpolation # we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8 # see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1 # w0, h0 = w0 + 0.1, h0 + 0.1
# patch_pos_embed = nn.functional.interpolate(
# patch_pos_embed.reshape([
# 1, self.patch_embed.num_patches_w,
# self.patch_embed.num_patches_h, dim
# ]).transpose((0, 3, 1, 2)),
# scale_factor=(w0 / self.patch_embed.num_patches_w,
# h0 / self.patch_embed.num_patches_h),
# mode='bicubic', )
patch_pos_embed = nn.functional.interpolate( patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape([ patch_pos_embed.reshape([
1, self.patch_embed.num_patches_w, 1, self.patch_embed.num_patches_w,
self.patch_embed.num_patches_h, dim self.patch_embed.num_patches_h, dim
]).transpose((0, 3, 1, 2)), ]).transpose((0, 3, 1, 2)),
scale_factor=(w0 / self.patch_embed.num_patches_w, (w0, h0),
h0 / self.patch_embed.num_patches_h),
mode='bicubic', ) mode='bicubic', )
assert int(w0) == patch_pos_embed.shape[-2] and int( assert int(w0) == patch_pos_embed.shape[-2] and int(
h0) == patch_pos_embed.shape[-1] h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.transpose( patch_pos_embed = patch_pos_embed.transpose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册