未验证 提交 50410757 编写于 作者: G George Ni 提交者: GitHub

[MOT] add deepsort (#2976)

* add deepsort configs comments

* fix doc, add readme_cn

* fix doc
上级 76ab2780
English | [简体中文](README_cn.md)
# DeepSORT (Simple Online and Realtime Tracking with a Deep Association Metric)
## Table of Contents
- [Introduction](#Introduction)
- [Model Zoo](#Model_Zoo)
- [Getting Start](#Getting_Start)
## Introduction
[DeepSORT](https://arxiv.org/abs/1812.00442) is basicly the same with SORT but added a CNN model to extract features in image of human part bounded by a detector. We use JDE as detection model to generate boxes, and select `PCBPyramid` as the ReID model. We also support loading the boxes from saved detection result files.
## Model Zoo
### DeepSORT on MOT-16 training set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | Detector | ReID | config |
| :---------| :------- | :----: | :----: | :--: | :----: | :---: | :---: |:---: | :---: | :---: |
| DarkNet53 | 1088x608 | 72.2 | 60.3 | 998 | 8055 | 21631 | 3.28 |[JDE](https://paddledet.bj.bcebos.com/models/mot/jde_darknet53_30e_1088x608.pdparams)| [ReID](https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams)|[config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml) |
**Notes:**
DeepSORT does not need to train, only used for evaluation. Before DeepSORT evaluation, you should get detection results by a detection model first, here we use JDE, and then prepare them like this:
```
det_results_dir
|——————MOT16-02.txt
|——————MOT16-04.txt
|——————MOT16-05.txt
|——————MOT16-09.txt
|——————MOT16-10.txt
|——————MOT16-11.txt
|——————MOT16-13.txt
```
## Getting Start
### 1. Evaluate a detector to get detection results
```bash
# use weights released in PaddleDetection model zoo
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/jde/jde_darknet53_30e_1088x608_track.yml -o metric=MOT weights=https://paddledet.bj.bcebos.com/models/mot/jde_darknet53_30e_1088x608.pdparams --output ./det_results_dir
# use saved checkpoint after training
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/jde/jde_darknet53_30e_1088x608_track.yml -o metric=MOT weights=output/jde_darknet53_30e_1088x608/model_final --output ./det_results_dir
```
### 2. Tracking
```bash
# track the objects by loading detected result files
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml --det_results_dir ./det_results_dir/mot_results
```
## Citations
```
@inproceedings{Wojke2017simple,
title={Simple Online and Realtime Tracking with a Deep Association Metric},
author={Wojke, Nicolai and Bewley, Alex and Paulus, Dietrich},
booktitle={2017 IEEE International Conference on Image Processing (ICIP)},
year={2017},
pages={3645--3649},
organization={IEEE},
doi={10.1109/ICIP.2017.8296962}
}
@inproceedings{Wojke2018deep,
title={Deep Cosine Metric Learning for Person Re-identification},
author={Wojke, Nicolai and Bewley, Alex},
booktitle={2018 IEEE Winter Conference on Applications of Computer Vision (WACV)},
year={2018},
pages={748--756},
organization={IEEE},
doi={10.1109/WACV.2018.00087}
}
```
简体中文 | [English](README.md)
# DeepSORT (Simple Online and Realtime Tracking with a Deep Association Metric)
## 内容
- [简介](#简介)
- [模型库与基线](#模型库与基线)
- [快速开始](#快速开始)
## 简介
[DeepSORT](https://arxiv.org/abs/1812.00442) 与SORT基本类似,但增加了一个CNN模型用于在检测器限定的人体部分图像中提取特征。我们使用JDE作为检测模型来生成检测框,并选择`PCBPyramid`作为ReID模型。我们还支持加载保存的检测结果文件来进行预测跟踪。
## 模型库与基线
### DeepSORT on MOT-16 training set
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 检测模型 | ReID模型 | 配置文件 |
| :---------| :------- | :----: | :----: | :--: | :----: | :---: | :---: |:---: | :---: | :---: |
| DarkNet53 | 1088x608 | 72.2 | 60.3 | 998 | 8055 | 21631 | 3.28 |[JDE](https://paddledet.bj.bcebos.com/models/mot/jde_darknet53_30e_1088x608.pdparams)| [ReID](https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams)|[配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml) |
**Notes:**
DeepSORT此处不需要训练,只用于评估。在使用DeepSORT模型评估之前,应该首先通过一个检测模型得到检测结果,这里我们使用JDE,然后像这样准备好结果文件:
```
det_results_dir
|——————MOT16-02.txt
|——————MOT16-04.txt
|——————MOT16-05.txt
|——————MOT16-09.txt
|——————MOT16-10.txt
|——————MOT16-11.txt
|——————MOT16-13.txt
```
## 快速开始
### 1. 验证检测模型得到检测结果文件
```bash
# 使用PaddleDetection发布的权重
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/jde/jde_darknet53_30e_1088x608_track.yml -o metric=MOT weights=https://paddledet.bj.bcebos.com/models/mot/jde_darknet53_30e_1088x608.pdparams --output ./det_results_dir
# 使用训练保存的checkpoint
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/jde/jde_darknet53_30e_1088x608_track.yml -o metric=MOT weights=output/jde_darknet53_30e_1088x608/model_final --output ./det_results_dir
```
### 2. 跟踪预测
```bash
# 加载检测结果文件得到跟踪结果
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml --det_results_dir ./det_results_dir/mot_results
```
## 引用
```
@inproceedings{Wojke2017simple,
title={Simple Online and Realtime Tracking with a Deep Association Metric},
author={Wojke, Nicolai and Bewley, Alex and Paulus, Dietrich},
booktitle={2017 IEEE International Conference on Image Processing (ICIP)},
year={2017},
pages={3645--3649},
organization={IEEE},
doi={10.1109/ICIP.2017.8296962}
}
@inproceedings{Wojke2018deep,
title={Deep Cosine Metric Learning for Person Re-identification},
author={Wojke, Nicolai and Bewley, Alex},
booktitle={2018 IEEE Winter Conference on Applications of Computer Vision (WACV)},
year={2018},
pages={748--756},
organization={IEEE},
doi={10.1109/WACV.2018.00087}
}
```
TestReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
EvalMOTReader:
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
TestMOTReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
architecture: DeepSORT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/DarkNet53_pretrained.pdparams
DeepSORT:
detector: YOLOv3
reid: PCBPyramid
tracker: DeepSORTTracker
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
DarkNet:
depth: 53
return_idx: [2, 3, 4]
# use default config
# YOLOv3FPN:
YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.7
downsample: [32, 16, 8]
label_smooth: false
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.2
downsample_ratio: 32
clip_bbox: true
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 1000
PCBPyramid:
num_conv_out_channels: 128
num_classes: 751
DeepSORTTracker:
budget: 100
max_age: 70
n_init: 3
metric_type: 'cosine'
matching_threshold: 0.2
max_iou_distance: 0.9
motion: 'KalmanFilter'
_BASE_: [
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml',
'_base_/deepsort_reader_1088x608.yml',
]
metric: MOT
EvalMOTDataset:
!MOTImageFolder
task: MOT16_train
dataset_dir: dataset/mot
data_root: MOT16/images/train
keep_ori_im: True # set True if used in DeepSORT
det_weights: None
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
DeepSORT:
detector: None
reid: PCBPyramid
tracker: DeepSORTTracker
......@@ -38,7 +38,7 @@ class DeepSORT(BaseArch):
def __init__(self,
detector='YOLOv3',
reid='PCBPlusDropoutPyramid',
reid='PCBPyramid',
tracker='DeepSORTTracker'):
super(DeepSORT, self).__init__()
self.detector = detector
......
......@@ -24,21 +24,34 @@ from paddle import ParamAttr
from .resnet import *
from ppdet.core.workspace import register
__all__ = ['PCBPlusDropoutPyramid']
__all__ = ['PCBPyramid']
@register
class PCBPlusDropoutPyramid(nn.Layer):
def __init__(
self,
input_ch=2048,
num_stripes=6, # number of sub-parts
used_levels=(1, 1, 1, 1, 1, 1),
num_classes=751,
last_conv_stride=1,
last_conv_dilation=1,
num_conv_out_channels=128):
super(PCBPlusDropoutPyramid, self).__init__()
class PCBPyramid(nn.Layer):
"""
PCB (Part-based Convolutional Baseline), see https://arxiv.org/abs/1711.09349,
Pyramidal Person Re-IDentification, see https://arxiv.org/abs/1810.12193
Args:
input_ch (int): Number of channels of the input feature.
num_stripes (int): Number of sub-parts.
used_levels (tuple): Whether the level is used, 1 means used.
num_classes (int): Number of classes for identities.
last_conv_stride (int): Stride of the last conv.
last_conv_dilation (int): Dilation of the last conv.
num_conv_out_channels (int): Number of channels of conv feature.
"""
def __init__(self,
input_ch=2048,
num_stripes=6,
used_levels=(1, 1, 1, 1, 1, 1),
num_classes=751,
last_conv_stride=1,
last_conv_dilation=1,
num_conv_out_channels=128):
super(PCBPyramid, self).__init__()
self.num_stripes = num_stripes
self.used_levels = used_levels
self.num_classes = num_classes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册