未验证 提交 8ffa45af 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] Fix deepsort detectors config (#5344)

* fix deepsort detectors cfgs, add ppyoloe

* modify yolov3 40e

* add ppyoloe mot17half modelzoo

* add resnet embedding reid

* remove fasterfpn picodet, fix nms for higher MOTA

* add mot17 yml

* fix _base_ cfgs

* fix doc
上级 f08c2ca7
......@@ -39,11 +39,14 @@
| :-------- | :----- | :----: |:------: | :----: |:-----: |:----:|:----: |
| MIX | JDE YOLOv3 | PCB Pyramid | - | 66.9 | 62.7 | - |[配置文件](./deepsort_jde_yolov3_pcb_pyramid.yml) |
| MIX | JDE YOLOv3 | PPLCNet | - | 66.3 | 62.1 | - |[配置文件](./deepsort_jde_yolov3_pplcnet.yml) |
| pedestrian(未开放) | YOLOv3 | PPLCNet | 45.4 | 45.8 | 54.3 | - |[配置文件](./deepsort_yolov3_pplcnet.yml) |
| MOT-17 half train | PPYOLOv2 | PPLCNet | 46.8 | 48.7 | 54.5 | - |[配置文件](./deepsort_ppyolov2_pplcnet.yml) |
| MOT-17 half train | YOLOv3 | PPLCNet | 42.7 | 50.2 | 52.4 | - |[配置文件](./deepsort_yolov3_pplcnet.yml) |
| MOT-17 half train | PPYOLOv2 | PPLCNet | 46.8 | 51.8 | 55.8 | - |[配置文件](./deepsort_ppyolov2_pplcnet.yml) |
| MOT-17 half train | PPYOLOe | PPLCNet | 52.9 | 56.7 | 60.5 | - |[配置文件](./deepsort_ppyoloe_pplcnet.yml) |
| MOT-17 half train | PPYOLOe | ResNet-50 | 52.9 | 56.7 | 64.6 | - |[配置文件](./deepsort_ppyoloe_resnet.yml) |
**注意:**
DeepSORT不需要训练MOT数据集,只用于评估,现在支持两种评估的方式。
模型权重下载链接在配置文件中的```det_weights``````reid_weights```,运行验证的命令即可自动下载。
DeepSORT是分离检测器和ReID模型的,其中检测器单独训练MOT数据集,而组装成DeepSORT后只用于评估,现在支持两种评估的方式。
- **方式1**:加载检测结果文件和ReID模型,在使用DeepSORT模型评估之前,应该首先通过一个检测模型得到检测结果,然后像这样准备好结果文件:
```
det_results_dir
......@@ -75,6 +78,15 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip
### 1. 评估
#### 1.1 评估检测效果
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/mot/deepsort/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml
```
**注意:**
- 评估检测使用的是```tools/eval.py```, 评估跟踪使用的是```tools/eval_mot.py```
#### 1.2 评估跟踪效果
**方式1**:加载检测结果文件和ReID模型,得到跟踪结果
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/reid/deepsort_pcb_pyramid_r101.yml --det_results_dir {your detection results}
......@@ -89,6 +101,8 @@ CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pplcnet.yml
# 或者
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_ppyolov2_pplcnet.yml --scaled=True
# 或者
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_ppyoloe_resnet.yml --scaled=True
```
**注意:**
- JDE YOLOv3行人检测模型是和JDE和FairMOT使用同样的MOT数据集训练的,因此MOTA较高。而其他通用检测模型如PPYOLOv2只使用了MOT17 half数据集训练。
......
metric: COCO
num_classes: 1
# Detection Dataset for training
TrainDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17
anno_path: annotations/train_half.json
image_dir: images/train
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17
anno_path: annotations/val_half.json
image_dir: images/train
TestDataset:
!ImageFolder
anno_path: annotations/val_half.json
# MOTDataset for MOT evaluation and inference
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: MOT17/images/half
keep_ori_im: True # set as True in DeepSORT and ByteTrack
TestMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video
_BASE_: [
'detector/jde_yolov3_darknet53_30e_1088x608_mix.yml',
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/mot17.yml',
'_base_/deepsort_reader_1088x608.yml',
]
metric: MOT
......@@ -30,6 +29,7 @@ DeepSORT:
# reid and tracker configuration
# see 'configs/mot/deepsort/reid/deepsort_pcb_pyramid_r101.yml'
PCBPyramid:
model_name: "ResNet101"
num_conv_out_channels: 128
num_classes: 751
......
_BASE_: [
'detector/jde_yolov3_darknet53_30e_1088x608_mix.yml',
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/mot17.yml',
'_base_/deepsort_reader_1088x608.yml',
]
metric: MOT
......
_BASE_: [
'detector/picodet_l_esnet_300e_896x896_mot17half.yml',
'../../datasets/mot.yml',
'../../runtime.yml',
'detector/ppyoloe_crn_l_36e_640x640_mot17half.yml',
'_base_/mot17.yml',
'_base_/deepsort_reader_1088x608.yml',
]
metric: MOT
......@@ -13,16 +12,35 @@ EvalMOTDataset:
data_root: MOT17/images/half
keep_ori_im: True # set as True in DeepSORT
det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/picodet_l_esnet_300e_896x896_mot17half.pdparams
det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet.pdparams
# reader
EvalMOTReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
TestMOTReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
# DeepSORT configuration
architecture: DeepSORT
pretrain_weights: None
DeepSORT:
detector: PicoDet
detector: YOLOv3 # PPYOLOe version
reid: PPLCNetEmbedding
tracker: DeepSORTTracker
......@@ -46,46 +64,47 @@ DeepSORTTracker:
motion: KalmanFilter
# detector configuration
# see 'configs/mot/deepsort/detector/picodet_l_esnet_300e_640x640_mot17half.yml'
PicoDet:
backbone: ESNet
neck: CSPPAN
head: PicoHead
# detector configuration: PPYOLOe version
# see 'configs/mot/deepsort/detector/ppyoloe_crn_l_300e_640x640_mot17half.yml'
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
CSPResNet:
layers: [3, 6, 6, 3]
channels: [64, 128, 256, 512, 1024]
return_idx: [1, 2, 3]
use_large_stem: True
PicoHead:
conv_feat:
name: PicoFeat
feat_in: 128
feat_out: 128
num_convs: 4
num_fpn_stride: 4
norm_type: bn
share_cls_reg: False
fpn_stride: [8, 16, 32, 64]
feat_in_chan: 128
prior_prob: 0.01
reg_max: 7
cell_offset: 0.5
loss_class:
name: VarifocalLoss
use_sigmoid: True
iou_weighted: True
loss_weight: 1.0
loss_dfl:
name: DistributionFocalLoss
loss_weight: 0.25
loss_bbox:
name: GIoULoss
loss_weight: 2.0
CustomCSPPAN:
out_channels: [768, 384, 192]
stage_num: 1
block_num: 3
act: 'swish'
spp: true
# Tracking requires higher quality boxes, so NMS score_threshold will be higher
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 # 100
use_varifocal_loss: True
eval_input_size: [640, 640]
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: SimOTAAssigner
candidate_topk: 10
iou_weight: 6
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.25 # 0.025 in original detector
score_threshold: 0.4 # 0.01 in original detector
nms_threshold: 0.6
_BASE_: [
'detector/faster_rcnn_r50_fpn_2x_1333x800_mot17half.yml',
'../../datasets/mot.yml',
'../../runtime.yml',
'detector/ppyoloe_crn_l_36e_640x640_mot17half.yml',
'_base_/mot17.yml',
'_base_/deepsort_reader_1088x608.yml',
]
metric: MOT
......@@ -13,8 +12,27 @@ EvalMOTDataset:
data_root: MOT17/images/half
keep_ori_im: True # set as True in DeepSORT
det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/faster_rcnn_r50_fpn_2x_1333x800_mot17half.pdparams
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet.pdparams
det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_resnet.pdparams
# reader
EvalMOTReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
TestMOTReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
# DeepSORT configuration
......@@ -22,16 +40,15 @@ architecture: DeepSORT
pretrain_weights: None
DeepSORT:
detector: FasterRCNN
reid: PPLCNetEmbedding
detector: YOLOv3 # PPYOLOe version
reid: ResNetEmbedding
tracker: DeepSORTTracker
# reid and tracker configuration
# see 'configs/mot/deepsort/reid/deepsort_pplcnet.yml'
PPLCNetEmbedding:
input_ch: 1280
output_ch: 512
# see 'configs/mot/deepsort/reid/deepsort_resnet.yml'
ResNetEmbedding:
model_name: "ResNet50"
DeepSORTTracker:
input_size: [64, 192]
......@@ -46,20 +63,47 @@ DeepSORTTracker:
motion: KalmanFilter
# detector configuration
# see 'configs/mot/deepsort/detector/faster_rcnn_r50_fpn_2x_1333x800_mot17half.yml'
FasterRCNN:
backbone: ResNet
neck: FPN
rpn_head: RPNHead
bbox_head: BBoxHead
bbox_post_process: BBoxPostProcess
# detector configuration: PPYOLOe version
# see 'configs/mot/deepsort/detector/ppyoloe_crn_l_300e_640x640_mot17half.yml'
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
CSPResNet:
layers: [3, 6, 6, 3]
channels: [64, 128, 256, 512, 1024]
return_idx: [1, 2, 3]
use_large_stem: True
CustomCSPPAN:
out_channels: [768, 384, 192]
stage_num: 1
block_num: 3
act: 'swish'
spp: true
# Tracking requires higher quality boxes, so nms.score_threshold will be higher
BBoxPostProcess:
decode: RCNNBox
# Tracking requires higher quality boxes, so NMS score_threshold will be higher
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 # 100
use_varifocal_loss: True
eval_input_size: [640, 640]
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.2 # 0.05 in original detector
nms_threshold: 0.5
score_threshold: 0.4 # 0.01 in original detector
nms_threshold: 0.6
_BASE_: [
'detector/ppyolov2_r50vd_dcn_365e_640x640_mot17half.yml',
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/mot17.yml',
'_base_/deepsort_reader_1088x608.yml',
]
metric: MOT
......@@ -16,6 +15,25 @@ EvalMOTDataset:
det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyolov2_r50vd_dcn_365e_640x640_mot17half.pdparams
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet.pdparams
# reader
EvalMOTReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
TestMOTReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
# DeepSORT configuration
architecture: DeepSORT
......@@ -63,7 +81,7 @@ ResNet:
freeze_norm: false
norm_decay: 0.
# Tracking requires higher quality boxes, so decode.conf_thresh will be higher
# Tracking requires higher quality boxes, so NMS score_threshold will be higher
BBoxPostProcess:
decode:
name: YOLOBox
......@@ -74,7 +92,7 @@ BBoxPostProcess:
nms:
name: MatrixNMS
keep_top_k: 100
score_threshold: 0.25 # 0.01 in original detector
post_threshold: 0.25 # 0.01 in original detector
score_threshold: 0.4 # 0.01 in original detector
post_threshold: 0.4 # 0.01 in original detector
nms_top_k: -1
background_label: -1
_BASE_: [
'detector/yolov3_darknet53_270e_608x608_pedestrian.yml',
'../../datasets/mot.yml',
'../../runtime.yml',
'detector/yolov3_darknet53_40e_608x608_mot17half.yml',
'_base_/mot17.yml',
'_base_/deepsort_reader_1088x608.yml',
]
metric: MOT
......@@ -13,9 +12,28 @@ EvalMOTDataset:
data_root: MOT17/images/half
keep_ori_im: True # set as True in DeepSORT
det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/yolov3_darknet53_270e_608x608_pedestrian.pdparams
det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/yolov3_darknet53_40e_608x608_mot17half.pdparams
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet.pdparams
# reader
EvalMOTReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [608, 608], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
TestMOTReader:
inputs_def:
image_shape: [3, 608, 608]
sample_transforms:
- Decode: {}
- Resize: {target_size: [608, 608], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
# DeepSORT configuration
architecture: DeepSORT
......@@ -47,23 +65,23 @@ DeepSORTTracker:
# detector configuration: General YOLOv3 version
# see 'configs/mot/deepsort/detector/yolov3_darknet53_270e_608x608_pedestrian.yml'
# see 'configs/mot/deepsort/detector/yolov3_darknet53_40e_608x608_mot17half.yml'
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
# Tracking requires higher quality boxes, so decode.conf_thresh will be higher
# Tracking requires higher quality boxes, so NMS score_threshold will be higher
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.1 # 0.005 in original detector
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.01
score_threshold: 0.3 # 0.01 in original detector
nms_threshold: 0.45
nms_top_k: 1000
......@@ -9,13 +9,12 @@ English | [简体中文](README_cn.md)
### Results on MOT17-half dataset
| Backbone | Model | input size | lr schedule | FPS | Box AP | download | config |
| :-------------- | :------------- | :--------: | :---------: | :-----------: | :-----: | :----------: | :-----: |
| DarkNet-53 | YOLOv3 | 608X608 | 40e | ---- | 42.7 | [download](https://paddledet.bj.bcebos.com/models/mot/deepsort/yolov3_darknet53_40e_608x608_mot17half.pdparams) | [config](./yolov3_darknet53_40e_608x608_mot17half.yml) |
| ResNet50-vd | PPYOLOv2 | 640x640 | 365e | ---- | 46.8 | [download](https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyolov2_r50vd_dcn_365e_640x640_mot17half.pdparams) | [config](./ppyolov2_r50vd_dcn_365e_640x640_mot17half.yml) |
| ResNet50-FPN | Faster R-CNN | 1333x800 | 1x | ---- | 44.2 | [download](https://paddledet.bj.bcebos.com/models/mot/deepsort/faster_rcnn_r50_fpn_2x_1333x800_mot17half.pdparams) | [config](./faster_rcnn_r50_fpn_2x_1333x800_mot17half.yml) |
| DarkNet-53 | YOLOv3 | 608X608 | 270e | ---- | 45.4 | [download](https://paddledet.bj.bcebos.com/models/mot/deepsort/yolov3_darknet53_270e_608x608_pedestrian.pdparams) | [config](./yolov3_darknet53_270e_608x608_pedestrian.yml) |
| ESNet | PicoDet | 896x896 | 300e | ---- | 40.9 | [download](https://paddledet.bj.bcebos.com/models/mot/deepsort/picodet_l_esnet_300e_896x896_mot17half.pdparams) | [config](./picodet_l_esnet_300e_896x896_mot17half.yml) |
| CSPResNet | PPYOLOe | 640x640 | 36e | ---- | 52.9 | [download](https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyoloe_crn_l_36e_640x640_mot17half.pdparams) | [config](./ppyoloe_crn_l_36e_640x640_mot17half.yml) |
**Notes:**
- The above model except for YOLOv3 is trained with **MOT17-half train** set.
- The above models are trained with **MOT17-half train** set, it can be downloaded from this [link](https://dataset.bj.bcebos.com/mot/MOT17.zip).
- **MOT17-half train** set is a dataset composed of pictures and labels of the first half frame of each video in MOT17 Train dataset (7 sequences in total). **MOT17-half val set** is used for evaluation, which is composed of the second half frame of each video. They can be downloaded from this [link](https://paddledet.bj.bcebos.com/data/mot/mot17half/annotations.zip). Download and unzip it in the `dataset/mot/MOT17/images/`folder.
- YOLOv3 is trained with the same pedestrian dataset as `configs/pedestrian/pedestrian_yolov3_darknet.yml`, which is not open yet.
- For pedestrian tracking, please use pedestrian detector combined with pedestrian ReID model. For vehicle tracking, please use vehicle detector combined with vehicle ReID model.
......@@ -25,11 +24,11 @@ English | [简体中文](README_cn.md)
Start the training and evaluation with the following command
```bash
job_name=ppyolov2_r50vd_dcn_365e_640x640_mot17half
job_name=ppyoloe_crn_l_36e_640x640_mot17half
config=configs/mot/deepsort/detector/${job_name}.yml
log_dir=log_dir/${job_name}
# 1. training
python -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config}
python -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --eval --amp --fleet
# 2. evaluation
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c ${config} -o weights=https://paddledet.bj.bcebos.com/models/mot/deepsort/${job_name}.pdparams
```
......@@ -10,13 +10,12 @@
### 在MOT17-half val数据集上的检测结果
| 骨架网络 | 网络类型 | 输入尺度 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :--------: | :---------: | :-----------: | :-----: | :------: | :-----: |
| DarkNet-53 | YOLOv3 | 608X608 | 40e | ---- | 42.7 | [下载链接](https://paddledet.bj.bcebos.com/models/mot/deepsort/yolov3_darknet53_40e_608x608_mot17half.pdparams) | [配置文件](./yolov3_darknet53_40e_608x608_mot17half.yml) |
| ResNet50-vd | PPYOLOv2 | 640x640 | 365e | ---- | 46.8 | [下载链接](https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyolov2_r50vd_dcn_365e_640x640_mot17half.pdparams) | [配置文件](./ppyolov2_r50vd_dcn_365e_640x640_mot17half.yml) |
| ResNet50-FPN | Faster R-CNN | 1333x800 | 1x | ---- | 44.2 | [下载链接](https://paddledet.bj.bcebos.com/models/mot/deepsort/faster_rcnn_r50_fpn_2x_1333x800_mot17half.pdparams) | [配置文件](./faster_rcnn_r50_fpn_2x_1333x800_mot17half.yml) |
| DarkNet-53 | YOLOv3 | 608X608 | 270e | ---- | 45.4 | [下载链接](https://paddledet.bj.bcebos.com/models/mot/deepsort/yolov3_darknet53_270e_608x608_pedestrian.pdparams) | [配置文件](./yolov3_darknet53_270e_608x608_pedestrian.yml) |
| ESNet | PicoDet | 896x896 | 300e | ---- | 40.9 | [下载链接](https://paddledet.bj.bcebos.com/models/mot/deepsort/picodet_l_esnet_300e_896x896_mot17half.pdparams) | [配置文件](./picodet_l_esnet_300e_896x896_mot17half.yml) |
| CSPResNet | PPYOLOe | 640x640 | 36e | ---- | 52.9 | [下载链接](https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyoloe_crn_l_36e_640x640_mot17half.pdparams) | [配置文件](./ppyoloe_crn_l_36e_640x640_mot17half.yml) |
**注意:**
- 以上模型除了YOLOv3以外均采用**MOT17-half train**数据集训练
- 以上模型均可采用**MOT17-half train**数据集训练,数据集可以从[此链接](https://dataset.bj.bcebos.com/mot/MOT17.zip)下载
- **MOT17-half train**是MOT17的train序列(共7个)每个视频的前一半帧的图片和标注组成的数据集,而为了验证精度可以都用**MOT17-half val**数据集去评估,它是每个视频的后一半帧组成的,数据集可以从[此链接](https://paddledet.bj.bcebos.com/data/mot/mot17half/annotations.zip)下载,并解压放在`dataset/mot/MOT17/images/`文件夹下。
- YOLOv3和`configs/pedestrian/pedestrian_yolov3_darknet.yml`是相同的pedestrian数据集训练的,此数据集暂未开放。
- 行人跟踪请使用行人检测器结合行人ReID模型。车辆跟踪请使用车辆检测器结合车辆ReID模型。
......@@ -31,7 +30,7 @@ job_name=ppyolov2_r50vd_dcn_365e_640x640_mot17half
config=configs/mot/deepsort/detector/${job_name}.yml
log_dir=log_dir/${job_name}
# 1. training
python -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config}
python -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --eval --amp --fleet
# 2. evaluation
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c ${config} -o weights=https://paddledet.bj.bcebos.com/models/mot/deepsort/${job_name}.pdparams
```
_BASE_: [
'../../../faster_rcnn/faster_rcnn_r50_fpn_2x_coco.yml',
]
weights: output/faster_rcnn_r50_fpn_2x_1333x800_mot17half/model_final
num_classes: 1
TrainDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17/images
anno_path: annotations/train_half.json
image_dir: train
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17/images
anno_path: annotations/val_half.json
image_dir: train
# detector configuration
architecture: FasterRCNN
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
FasterRCNN:
backbone: ResNet
neck: FPN
rpn_head: RPNHead
bbox_head: BBoxHead
bbox_post_process: BBoxPostProcess
ResNet:
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [0,1,2,3]
num_stages: 4
FPN:
out_channel: 256
RPNHead:
anchor_generator:
aspect_ratios: [0.5, 1.0, 2.0]
anchor_sizes: [[32], [64], [128], [256], [512]]
strides: [4, 8, 16, 32, 64]
rpn_target_assign:
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
use_random: True
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 1000
topk_after_collect: True
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
BBoxHead:
head: TwoFCHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
BBoxAssigner:
batch_size_per_im: 512
bg_thresh: 0.5
fg_thresh: 0.5
fg_fraction: 0.25
use_random: True
TwoFCHead:
out_channel: 1024
BBoxPostProcess:
decode: RCNNBox
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
_BASE_: [
'../../../picodet/picodet_l_640_coco.yml',
]
weights: output/picodet_l_esnet_300e_896x896_mot17half/model_final
num_classes: 1
TrainDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17/images
anno_path: annotations/train_half.json
image_dir: train
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17/images
anno_path: annotations/val_half.json
image_dir: train
worker_num: 6
TrainReader:
sample_transforms:
- Decode: {}
- RandomCrop: {}
- RandomFlip: {prob: 0.5}
- RandomDistort: {}
batch_transforms:
- BatchRandomResize: {target_size: [832, 864, 896, 928, 960], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_size: 32
shuffle: true
drop_last: true
collate_batch: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [896, 896], keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 8
shuffle: false
# detector configuration
architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_25_pretrained.pdparams
find_unused_parameters: True
use_ema: true
cycle_epoch: 40
snapshot_epoch: 10
epoch: 250
PicoDet:
backbone: ESNet
neck: CSPPAN
head: PicoHead
ESNet:
scale: 1.25
feature_maps: [4, 11, 14]
act: hard_swish
channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75]
CSPPAN:
out_channels: 128
use_depthwise: True
num_csp_blocks: 1
num_features: 4
PicoHead:
conv_feat:
name: PicoFeat
feat_in: 128
feat_out: 128
num_convs: 4
num_fpn_stride: 4
norm_type: bn
share_cls_reg: False
fpn_stride: [8, 16, 32, 64]
feat_in_chan: 128
prior_prob: 0.01
reg_max: 7
cell_offset: 0.5
loss_class:
name: VarifocalLoss
use_sigmoid: True
iou_weighted: True
loss_weight: 1.0
loss_dfl:
name: DistributionFocalLoss
loss_weight: 0.25
loss_bbox:
name: GIoULoss
loss_weight: 2.0
assigner:
name: SimOTAAssigner
candidate_topk: 10
iou_weight: 6
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.025
nms_threshold: 0.6
_BASE_: [
'../../../ppyoloe/ppyoloe_crn_l_300e_coco.yml',
'../_base_/mot17.yml',
]
weights: output/ppyoloe_crn_l_36e_640x640_mot17half/model_final
log_iter: 20
snapshot_epoch: 2
# schedule configuration for fine-tuning
epoch: 36
LearningRate:
base_lr: 0.001
schedulers:
- !CosineDecay
max_epochs: 43
- !LinearWarmup
start_factor: 0.001
steps: 100
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
TrainReader:
batch_size: 8
# detector configuration
architecture: YOLOv3
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_300e_coco.pdparams
depth_mult: 1.0
width_mult: 1.0
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
CSPResNet:
layers: [3, 6, 6, 3]
channels: [64, 128, 256, 512, 1024]
return_idx: [1, 2, 3]
use_large_stem: True
CustomCSPPAN:
out_channels: [768, 384, 192]
stage_num: 1
block_num: 3
act: 'swish'
spp: true
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 # 100
use_varifocal_loss: True
eval_input_size: [640, 640]
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.6
_BASE_: [
'../../../ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml',
'../_base_/mot17.yml',
]
weights: output/ppyolov2_r50vd_dcn_365e_640x640_mot17half/model_final
num_classes: 1
TrainDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17/images
anno_path: annotations/train_half.json
image_dir: train
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17/images
anno_path: annotations/val_half.json
image_dir: train
log_iter: 20
snapshot_epoch: 2
# detector configuration
......
# This config is the same as '../../../pedestrian/pedestrian_yolov3_darknet.yml'.
_BASE_: [
'../../../yolov3/yolov3_darknet53_270e_coco.yml',
'../_base_/mot17.yml',
]
weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/pedestrian_yolov3_darknet.pdparams
num_classes: 1
weights: output/yolov3_darknet53_40e_608x608_mot17half/model_final
log_iter: 20
snapshot_epoch: 2
# This pedestrian training dataset used is not open temporarily.
# Only the trained yolov3 model is provided, but you can eval on MOT17 half val dataset.
TrainDataset:
!COCODataSet
dataset_dir: dataset/pedestrian
anno_path: annotations/instances_train2017.json
image_dir: train2017
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
# schedule configuration for fine-tuning
epoch: 40
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 32
- 36
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 100
EvalDataset:
!COCODataSet
dataset_dir: dataset/mot/MOT17/images
anno_path: annotations/val_half.json
image_dir: train
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
TrainReader:
batch_size: 8
mixup_epoch: 35
# detector configuration
architecture: YOLOv3
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/DarkNet53_pretrained.pdparams
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/yolov3_darknet53_270e_coco.pdparams
norm_type: sync_bn
YOLOv3:
......
......@@ -9,7 +9,7 @@ English | [简体中文](README_cn.md)
### Results on Market1501 pedestrian ReID dataset
| Backbone | Model | Params | FPS | mAP | Top1 | Top5 | download | config |
| Backbone | Model | Params | FPS | mAP | Top1 | Top5 | download | config |
| :-------------: | :-----------------: | :-------: | :------: | :-------: | :-------: | :-------: | :-------: | :-------: |
| ResNet-101 | PCB Pyramid Embedding | 289M | --- | 86.31 | 94.95 | 98.28 | [download](https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pcb_pyramid_r101.pdparams) | [config](./deepsort_pcb_pyramid_r101.yml) |
| PPLCNet-2.5x | PPLCNet Embedding | 36M | --- | 71.59 | 87.38 | 95.49 | [download](https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet.pdparams) | [config](./deepsort_pplcnet.yml) |
......
......@@ -15,7 +15,7 @@ EvalMOTDataset:
keep_ori_im: True # set as True in DeepSORT
det_weights: None
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pcb_pyramid_r101.pdparams
# A ReID only configuration of DeepSORT, detector should be None.
......@@ -28,6 +28,7 @@ DeepSORT:
tracker: DeepSORTTracker
PCBPyramid:
model_name: "ResNet101"
num_conv_out_channels: 128
num_classes: 751 # default 751 classes in Market-1501 dataset.
......
# This config represents a ReID only configuration of DeepSORT, it has two uses.
# One is used for loading the detection results and ReID model to get tracking results;
# Another is used for exporting the ReID model to deploy infer.
_BASE_: [
'../../../datasets/mot.yml',
'../../../runtime.yml',
'../_base_/deepsort_reader_1088x608.yml',
]
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: MOT16/images/train
keep_ori_im: True # set as True in DeepSORT
det_weights: None
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_resnet.pdparams
# A ReID only configuration of DeepSORT, detector should be None.
architecture: DeepSORT
pretrain_weights: None
DeepSORT:
detector: None
reid: ResNetEmbedding
tracker: DeepSORTTracker
ResNetEmbedding:
model_name: "ResNet50"
DeepSORTTracker:
input_size: [64, 192]
min_box_area: 0 # filter out too small boxes
vertical_ratio: -1 # filter out bboxes, usuallly set 1.6 for pedestrian
budget: 100
max_age: 70
n_init: 3
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
motion: KalmanFilter
......@@ -241,7 +241,7 @@ class Tracker(object):
outs['bbox'] = outs['bbox'].numpy()
outs['bbox_num'] = outs['bbox_num'].numpy()
if outs['bbox_num'] > 0 and empty_detections == False:
if len(outs['bbox']) > 0 and empty_detections == False:
# detector outputs: pred_cls_ids, pred_scores, pred_bboxes
pred_cls_ids = outs['bbox'][:, 0:1]
pred_scores = outs['bbox'][:, 1:2]
......
......@@ -17,9 +17,11 @@ from . import fairmot_embedding_head
from . import resnet
from . import pyramidal_embedding
from . import pplcnet_embedding
from . import resnet_embedding
from .fairmot_embedding_head import *
from .jde_embedding_head import *
from .resnet import *
from .pyramidal_embedding import *
from .pplcnet_embedding import *
from .resnet_embedding import *
......@@ -21,7 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from paddle import ParamAttr
from .resnet import *
from .resnet import ResNet50, ResNet101
from ppdet.core.workspace import register
__all__ = ['PCBPyramid']
......@@ -46,6 +46,7 @@ class PCBPyramid(nn.Layer):
def __init__(self,
input_ch=2048,
model_name='ResNet101',
num_stripes=6,
used_levels=(1, 1, 1, 1, 1, 1),
num_classes=751,
......@@ -60,7 +61,8 @@ class PCBPyramid(nn.Layer):
self.num_in_each_level = [i for i in range(self.num_stripes, 0, -1)]
self.num_branches = sum(self.num_in_each_level)
self.base = ResNet101(
assert model_name in ['ResNet50', 'ResNet101'], "Unsupported ReID arch: {}".format(model_name)
self.base = eval(model_name)(
lr_mult=0.1,
last_conv_stride=last_conv_stride,
last_conv_dilation=last_conv_dilation)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn.functional as F
from paddle import nn
from .resnet import ResNet50, ResNet101
from ppdet.core.workspace import register
__all__ = ['ResNetEmbedding']
@register
class ResNetEmbedding(nn.Layer):
in_planes = 2048
def __init__(self, model_name='ResNet50', last_stride=1):
super(ResNetEmbedding, self).__init__()
assert model_name in ['ResNet50', 'ResNet101'], "Unsupported ReID arch: {}".format(model_name)
self.base = eval(model_name)(last_conv_stride=last_stride)
self.gap = nn.AdaptiveAvgPool2D(output_size=1)
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.bn = nn.BatchNorm1D(self.in_planes, bias_attr=False)
def forward(self, x):
base_out = self.base(x)
global_feat = self.gap(base_out)
global_feat = self.flatten(global_feat)
global_feat = self.bn(global_feat)
return global_feat
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册