未验证 提交 c684201e 编写于 作者: W Wenyu 提交者: GitHub

[WIP] Add vitdet (#6397)

* add cascade vitdet

* cascade vit

* fix for model export

* add vit cascade cfgs
上级 99f891be
# Vision transformer Detection
## Introduction
- [Context Autoencoder for Self-Supervised Representation Learning](https://arxiv.org/abs/2202.03026)
- [Benchmarking Detection Transfer Learning with Vision Transformers](https://arxiv.org/pdf/2111.11429.pdf)
Object detection is a central downstream task used to
test if pre-trained network parameters confer benefits, such
as improved accuracy or training speed. The complexity
of object detection methods can make this benchmarking
non-trivial when new architectures, such as Vision Transformer (ViT) models, arrive.
## Model Zoo
| Backbone | Pretrained | Model | Scheduler | Images/GPU | Box AP | Config | Download |
|:------:|:--------:|:--------------:|:--------------:|:--------------:|:------:|:------:|:--------:|
| ViT-base | CAE | Cascade RCNN | 1x | 1 | -- | [config](./cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml) | [coming soon]() |
| ViT-large | CAE | Cascade RCNN | 1x | 1 | -- | [config](./cascade_rcnn_vit_large_hrfpn_cae_1x_coco.yml) | [coming soon]() |
**Notes:**
- Model is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
- Base model is trained on 8x32G V100 GPU, large model on 8x80G A100.
## Citations
```
@article{chen2022context,
title={Context autoencoder for self-supervised representation learning},
author={Chen, Xiaokang and Ding, Mingyu and Wang, Xiaodi and Xin, Ying and Mo, Shentong and Wang, Yunhao and Han, Shumin and Luo, Ping and Zeng, Gang and Wang, Jingdong},
journal={arXiv preprint arXiv:2202.03026},
year={2022}
}
@article{DBLP:journals/corr/abs-2111-11429,
author = {Yanghao Li and
Saining Xie and
Xinlei Chen and
Piotr Doll{\'{a}}r and
Kaiming He and
Ross B. Girshick},
title = {Benchmarking Detection Transfer Learning with Vision Transformers},
journal = {CoRR},
volume = {abs/2111.11429},
year = {2021},
url = {https://arxiv.org/abs/2111.11429},
eprinttype = {arXiv},
eprint = {2111.11429},
timestamp = {Fri, 26 Nov 2021 13:48:43 +0100},
biburl = {https://dblp.org/rec/journals/corr/abs-2111-11429.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{Cai_2019,
title={Cascade R-CNN: High Quality Object Detection and Instance Segmentation},
ISSN={1939-3539},
url={http://dx.doi.org/10.1109/tpami.2019.2956516},
DOI={10.1109/tpami.2019.2956516},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
publisher={Institute of Electrical and Electronics Engineers (IEEE)},
author={Cai, Zhaowei and Vasconcelos, Nuno},
year={2019},
pages={1–1}
}
```
epoch: 12
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [9, 11]
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
optimizer:
type: AdamWDL
betas: [0.9, 0.999]
layer_decay: 0.75
weight_decay: 0.02
num_layers: 12
filter_bias_and_bn: True
skip_decay_names: ['pos_embed', 'cls_token']
set_param_lr_func: 'layerwise_lr_decay'
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomResizeCrop: {resizes: [400, 500, 600], cropsizes: [[384, 600], ], prob: 0.5}
- RandomResize: {target_size: [[480, 1333], [512, 1333], [544, 1333], [576, 1333], [608, 1333], [640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 2}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 2
shuffle: true
drop_last: true
collate_batch: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
inputs_def:
image_shape: [-1, 3, 640, 640]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: 640}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_size: 1
shuffle: false
drop_last: false
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/reader.yml',
'./_base_/optimizer_base_1x.yml'
]
weights: output/cascade_rcnn_vit_base_hrfpn_cae_1x_coco/model_final
# runtime
log_iter: 100
snapshot_epoch: 1
find_unused_parameters: True
use_gpu: true
norm_type: sync_bn
# reader
worker_num: 2
TrainReader:
batch_size: 1
# model
architecture: CascadeRCNN
CascadeRCNN:
backbone: VisionTransformer
neck: HRFPN
rpn_head: RPNHead
bbox_head: CascadeHead
# post process
bbox_post_process: BBoxPostProcess
VisionTransformer:
patch_size: 16
embed_dim: 768
depth: 12
num_heads: 12
mlp_ratio: 4
qkv_bias: True
drop_rate: 0.0
drop_path_rate: 0.2
init_values: 0.1
final_norm: False
use_rel_pos_bias: False
use_sincos_pos_emb: True
epsilon: 0.000001 # 1e-6
out_indices: [3, 5, 7, 11]
with_fpn: True
pretrained: ~
HRFPN:
out_channel: 256
use_bias: True
RPNHead:
anchor_generator:
aspect_ratios: [0.5, 1.0, 2.0]
anchor_sizes: [[32], [64], [128], [256], [512]]
strides: [4, 8, 16, 32, 64]
rpn_target_assign:
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
use_random: True
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
topk_after_collect: True
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
loss_rpn_bbox: SmoothL1Loss
SmoothL1Loss:
beta: 0.1111111111111111
CascadeHead:
head: CascadeXConvNormHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
bbox_loss: GIoULoss
num_cascade_stages: 3
reg_class_agnostic: False
stage_loss_weights: [1, 0.5, 0.25]
loss_normalize_pos: True
BBoxAssigner:
batch_size_per_im: 512
bg_thresh: 0.5
fg_thresh: 0.5
fg_fraction: 0.25
cascade_iou: [0.5, 0.6, 0.7]
use_random: True
CascadeXConvNormHead:
norm_type: bn
GIoULoss:
loss_weight: 10.
reduction: 'none'
eps: 0.000001
BBoxPostProcess:
decode:
name: RCNNBox
prior_box_var: [30.0, 30.0, 15.0, 15.0]
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
_BASE_: [
'./cascade_rcnn_vit_base_hrfpn_cae_1x_coco.yml'
]
weights: output/cascade_rcnn_vit_large_hrfpn_cae_1x_coco/model_final
depth: &depth 24
dim: &dim 1024
VisionTransformer:
img_size: [800, 1344]
embed_dim: *dim
depth: *depth
num_heads: 16
drop_path_rate: 0.25
out_indices: [7, 11, 15, 23]
pretrained: ~
HRFPN:
in_channels: [*dim, *dim, *dim, *dim]
OptimizerBuilder:
optimizer:
layer_decay: 0.9
weight_decay: 0.02
num_layers: *depth
...@@ -578,9 +578,10 @@ class VisionTransformer(nn.Layer): ...@@ -578,9 +578,10 @@ class VisionTransformer(nn.Layer):
x = self.patch_embed(x) x = self.patch_embed(x)
x_shape = paddle.shape(x) # b * c * h * w B, D, Hp, Wp = x.shape # b * c * h * w
cls_tokens = self.cls_token.expand((x_shape[0], -1, -1)) cls_tokens = self.cls_token.expand(
(B, self.cls_token.shape[-2], self.cls_token.shape[-1]))
x = x.flatten(2).transpose([0, 2, 1]) # b * hw * c x = x.flatten(2).transpose([0, 2, 1]) # b * hw * c
x = paddle.concat([cls_tokens, x], axis=1) x = paddle.concat([cls_tokens, x], axis=1)
...@@ -593,8 +594,6 @@ class VisionTransformer(nn.Layer): ...@@ -593,8 +594,6 @@ class VisionTransformer(nn.Layer):
rel_pos_bias = self.rel_pos_bias( rel_pos_bias = self.rel_pos_bias(
) if self.rel_pos_bias is not None else None ) if self.rel_pos_bias is not None else None
B, _, Hp, Wp = x_shape
feats = [] feats = []
for idx, blk in enumerate(self.blocks): for idx, blk in enumerate(self.blocks):
if self.use_checkpoint: if self.use_checkpoint:
...@@ -607,7 +606,7 @@ class VisionTransformer(nn.Layer): ...@@ -607,7 +606,7 @@ class VisionTransformer(nn.Layer):
xp = paddle.reshape( xp = paddle.reshape(
paddle.transpose( paddle.transpose(
self.norm(x[:, 1:, :]), perm=[0, 2, 1]), self.norm(x[:, 1:, :]), perm=[0, 2, 1]),
shape=[B, -1, Hp, Wp]) shape=[B, D, Hp, Wp])
feats.append(xp) feats.append(xp)
if self.with_fpn: if self.with_fpn:
......
...@@ -257,7 +257,13 @@ class BBoxHead(nn.Layer): ...@@ -257,7 +257,13 @@ class BBoxHead(nn.Layer):
pred = self.get_prediction(scores, deltas) pred = self.get_prediction(scores, deltas)
return pred, self.head return pred, self.head
def get_loss(self, scores, deltas, targets, rois, bbox_weight): def get_loss(self,
scores,
deltas,
targets,
rois,
bbox_weight,
loss_normalize_pos=False):
""" """
scores (Tensor): scores from bbox head outputs scores (Tensor): scores from bbox head outputs
deltas (Tensor): deltas from bbox head outputs deltas (Tensor): deltas from bbox head outputs
...@@ -280,8 +286,15 @@ class BBoxHead(nn.Layer): ...@@ -280,8 +286,15 @@ class BBoxHead(nn.Layer):
else: else:
tgt_labels = tgt_labels.cast('int64') tgt_labels = tgt_labels.cast('int64')
tgt_labels.stop_gradient = True tgt_labels.stop_gradient = True
loss_bbox_cls = F.cross_entropy(
input=scores, label=tgt_labels, reduction='mean') if not loss_normalize_pos:
loss_bbox_cls = F.cross_entropy(
input=scores, label=tgt_labels, reduction='mean')
else:
loss_bbox_cls = F.cross_entropy(
input=scores, label=tgt_labels,
reduction='none').sum() / (tgt_labels.shape[0] + 1e-7)
loss_bbox[cls_name] = loss_bbox_cls loss_bbox[cls_name] = loss_bbox_cls
# bbox reg # bbox reg
...@@ -322,9 +335,16 @@ class BBoxHead(nn.Layer): ...@@ -322,9 +335,16 @@ class BBoxHead(nn.Layer):
if self.bbox_loss is not None: if self.bbox_loss is not None:
reg_delta = self.bbox_transform(reg_delta) reg_delta = self.bbox_transform(reg_delta)
reg_target = self.bbox_transform(reg_target) reg_target = self.bbox_transform(reg_target)
loss_bbox_reg = self.bbox_loss(
reg_delta, reg_target).sum() / tgt_labels.shape[0] if not loss_normalize_pos:
loss_bbox_reg *= self.num_classes loss_bbox_reg = self.bbox_loss(
reg_delta, reg_target).sum() / tgt_labels.shape[0]
loss_bbox_reg *= self.num_classes
else:
loss_bbox_reg = self.bbox_loss(
reg_delta, reg_target).sum() / (tgt_labels.shape[0] + 1e-7)
else: else:
loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum( loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
) / tgt_labels.shape[0] ) / tgt_labels.shape[0]
......
...@@ -162,7 +162,8 @@ class CascadeHead(BBoxHead): ...@@ -162,7 +162,8 @@ class CascadeHead(BBoxHead):
num_cascade_stages=3, num_cascade_stages=3,
bbox_loss=None, bbox_loss=None,
reg_class_agnostic=True, reg_class_agnostic=True,
stage_loss_weights=None): stage_loss_weights=None,
loss_normalize_pos=False):
nn.Layer.__init__(self, ) nn.Layer.__init__(self, )
self.head = head self.head = head
...@@ -184,6 +185,7 @@ class CascadeHead(BBoxHead): ...@@ -184,6 +185,7 @@ class CascadeHead(BBoxHead):
self.reg_class_agnostic = reg_class_agnostic self.reg_class_agnostic = reg_class_agnostic
num_bbox_delta = 4 if reg_class_agnostic else 4 * num_classes num_bbox_delta = 4 if reg_class_agnostic else 4 * num_classes
self.loss_normalize_pos = loss_normalize_pos
self.bbox_score_list = [] self.bbox_score_list = []
self.bbox_delta_list = [] self.bbox_delta_list = []
...@@ -242,9 +244,16 @@ class CascadeHead(BBoxHead): ...@@ -242,9 +244,16 @@ class CascadeHead(BBoxHead):
# TODO (lyuwenyu) Is it correct for only one class ? # TODO (lyuwenyu) Is it correct for only one class ?
if not self.reg_class_agnostic and i < self.num_cascade_stages - 1: if not self.reg_class_agnostic and i < self.num_cascade_stages - 1:
deltas = deltas.reshape([-1, self.num_classes, 4]) deltas = deltas.reshape([deltas.shape[0], self.num_classes, 4])
labels = scores[:, :-1].argmax(axis=-1) labels = scores[:, :-1].argmax(axis=-1)
deltas = deltas[paddle.arange(deltas.shape[0]), labels]
if self.training:
deltas = deltas[paddle.arange(deltas.shape[0]), labels]
else:
deltas = deltas[(deltas * F.one_hot(
labels, num_classes=self.num_classes).unsqueeze(-1) != 0
).nonzero(as_tuple=True)].reshape(
[deltas.shape[0], 4])
head_out_list.append([scores, deltas, rois]) head_out_list.append([scores, deltas, rois])
pred_bbox = self._get_pred_bbox(deltas, rois, self.bbox_weight[i]) pred_bbox = self._get_pred_bbox(deltas, rois, self.bbox_weight[i])
...@@ -253,8 +262,13 @@ class CascadeHead(BBoxHead): ...@@ -253,8 +262,13 @@ class CascadeHead(BBoxHead):
loss = {} loss = {}
for stage, value in enumerate(zip(head_out_list, targets_list)): for stage, value in enumerate(zip(head_out_list, targets_list)):
(scores, deltas, rois), targets = value (scores, deltas, rois), targets = value
loss_stage = self.get_loss(scores, deltas, targets, rois, loss_stage = self.get_loss(
self.bbox_weight[stage]) scores,
deltas,
targets,
rois,
self.bbox_weight[stage],
loss_normalize_pos=self.loss_normalize_pos)
for k, v in loss_stage.items(): for k, v in loss_stage.items():
loss[k + "_stage{}".format( loss[k + "_stage{}".format(
stage)] = v * self.stage_loss_weights[stage] stage)] = v * self.stage_loss_weights[stage]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册