diff --git a/README_cn.md b/README_cn.md
index 1f5eda7ec60da758b252de0e190da3b4edb48101..a8a659b7505f5ccb76cd15688bc0306e81cd1f1d 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -253,10 +253,10 @@ PaddleDetection整理工业、农业、林业、交通、医疗、金融、能
Common
|
@@ -799,4 +800,4 @@ PP-Vehicle囊括四大交通场景核心功能:车牌识别、属性识别、
@misc{ppdet2019,
title={PaddleDetection, Object detection and instance segmentation toolkit based on PaddlePaddle.},
author={PaddlePaddle Authors},
-howpublished = {\url{https://github.com/PaddlePaddle/PaddleDetection}},
\ No newline at end of file
+howpublished = {\url{https://github.com/PaddlePaddle/PaddleDetection}},
diff --git a/README_en.md b/README_en.md
index db2eda9f3e8dd20c06c6a26525c84bf3955a2813..25d8a474512d46e7d07f2892f2a5f178770ce88b 100644
--- a/README_en.md
+++ b/README_en.md
@@ -156,6 +156,7 @@
DeepSORT
ByteTrack
OC-SORT
+ CenterTrack
KeyPoint-Detection
diff --git a/configs/mot/README.md b/configs/mot/README.md
index 8d194dce43d087aff3a054ea1fc0152ca4cf4364..533c3fcbae8796f85d24074b48a83bd83f1ef479 100644
--- a/configs/mot/README.md
+++ b/configs/mot/README.md
@@ -24,6 +24,7 @@ PaddleDetection中提供了SDE和JDE两个系列的多种算法实现:
- [ByteTrack](./bytetrack)
- [OC-SORT](./ocsort)
- [DeepSORT](./deepsort)
+ - [CenterTrack](./centertrack)
- JDE
- [JDE](./jde)
- [FairMOT](./fairmot)
@@ -31,7 +32,7 @@ PaddleDetection中提供了SDE和JDE两个系列的多种算法实现:
**注意:**
- 以上算法原论文均为单类别的多目标跟踪,PaddleDetection团队同时也支持了[ByteTrack](./bytetrack)和FairMOT([MCFairMOT](./mcfairmot))的多类别的多目标跟踪;
- - [DeepSORT](./deepsort)和[JDE](./jde)均只支持单类别的多目标跟踪;
+ - [DeepSORT](./deepsort)、[JDE](./jde)和[CenterTrack](./centertrack)均只支持单类别的多目标跟踪;
- [DeepSORT](./deepsort)需要额外添加ReID权重一起执行,[ByteTrack](./bytetrack)可加可不加ReID权重,默认不加;
@@ -96,6 +97,7 @@ pip install lap motmetrics sklearn filterpy
- [DeepSORT](deepsort/README_cn.md)
- [JDE](jde/README_cn.md)
- [FairMOT](fairmot/README_cn.md)
+ - [CenterTrack](centertrack/README_cn.md)
- 特色垂类模型
- [行人跟踪](pedestrian/README_cn.md)
- [人头跟踪](headtracking21/README_cn.md)
@@ -111,7 +113,7 @@ pip install lap motmetrics sklearn filterpy
| MOT方式 | 经典算法 | 算法流程 | 数据集要求 | 其他特点 |
| :--------------| :--------------| :------- | :----: | :----: |
-| SDE系列 | DeepSORT,ByteTrack,OC-SORT | 分离式,两个独立模型权重先检测后ReID,也可不加ReID | 检测和ReID数据相对独立,不加ReID时即纯检测数据集 |检测和ReID可分别调优,鲁棒性较高,AI竞赛常用|
+| SDE系列 | DeepSORT,ByteTrack,OC-SORT,CenterTrack | 分离式,两个独立模型权重先检测后ReID,也可不加ReID | 检测和ReID数据相对独立,不加ReID时即纯检测数据集 |检测和ReID可分别调优,鲁棒性较高,AI竞赛常用|
| JDE系列 | FairMOT,JDE | 联合式,一个模型权重端到端同时检测和ReID | 必须同时具有检测和ReID标注 | 检测和ReID联合训练,不易调优,泛化性不强|
**注意:**
@@ -266,4 +268,25 @@ MOT17
journal={arXiv preprint arXiv:2004.01888},
year={2020}
}
+
+@article{zhang2021bytetrack,
+ title={ByteTrack: Multi-Object Tracking by Associating Every Detection Box},
+ author={Zhang, Yifu and Sun, Peize and Jiang, Yi and Yu, Dongdong and Yuan, Zehuan and Luo, Ping and Liu, Wenyu and Wang, Xinggang},
+ journal={arXiv preprint arXiv:2110.06864},
+ year={2021}
+}
+
+@article{cao2022observation,
+ title={Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking},
+ author={Cao, Jinkun and Weng, Xinshuo and Khirodkar, Rawal and Pang, Jiangmiao and Kitani, Kris},
+ journal={arXiv preprint arXiv:2203.14360},
+ year={2022}
+}
+
+@article{zhou2020tracking,
+ title={Tracking Objects as Points},
+ author={Zhou, Xingyi and Koltun, Vladlen and Kr{\"a}henb{\"u}hl, Philipp},
+ journal={ECCV},
+ year={2020}
+}
```
diff --git a/configs/mot/README_en.md b/configs/mot/README_en.md
index e3817d9093a8d512566965e50115e967e8b89e63..ec78a85d0f330a38d6b41080f34ef1ac4612a718 100644
--- a/configs/mot/README_en.md
+++ b/configs/mot/README_en.md
@@ -64,6 +64,7 @@ pip install -r requirements.txt
- [DeepSORT](deepsort/README.md)
- [JDE](jde/README.md)
- [FairMOT](fairmot/README.md)
+ - [CenterTrack](centertrack/README.md)
- Feature models
- [Pedestrian](pedestrian/README.md)
- [Head](headtracking21/README.md)
@@ -184,4 +185,25 @@ In the annotation text, each line is describing a bounding box and has the follo
journal={arXiv preprint arXiv:2004.01888},
year={2020}
}
+
+@article{zhang2021bytetrack,
+ title={ByteTrack: Multi-Object Tracking by Associating Every Detection Box},
+ author={Zhang, Yifu and Sun, Peize and Jiang, Yi and Yu, Dongdong and Yuan, Zehuan and Luo, Ping and Liu, Wenyu and Wang, Xinggang},
+ journal={arXiv preprint arXiv:2110.06864},
+ year={2021}
+}
+
+@article{cao2022observation,
+ title={Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking},
+ author={Cao, Jinkun and Weng, Xinshuo and Khirodkar, Rawal and Pang, Jiangmiao and Kitani, Kris},
+ journal={arXiv preprint arXiv:2203.14360},
+ year={2022}
+}
+
+@article{zhou2020tracking,
+ title={Tracking Objects as Points},
+ author={Zhou, Xingyi and Koltun, Vladlen and Kr{\"a}henb{\"u}hl, Philipp},
+ journal={ECCV},
+ year={2020}
+}
```
diff --git a/configs/mot/centertrack/README.md b/configs/mot/centertrack/README.md
new file mode 120000
index 0000000000000000000000000000000000000000..4015683cfa5969297febc12e7ca1264afabbc0b5
--- /dev/null
+++ b/configs/mot/centertrack/README.md
@@ -0,0 +1 @@
+README_cn.md
\ No newline at end of file
diff --git a/configs/mot/centertrack/README_cn.md b/configs/mot/centertrack/README_cn.md
new file mode 100644
index 0000000000000000000000000000000000000000..a91a844402ac3ddbcad27b44938fb35438c44e49
--- /dev/null
+++ b/configs/mot/centertrack/README_cn.md
@@ -0,0 +1,156 @@
+简体中文 | [English](README.md)
+
+# CenterTrack (Tracking Objects as Points)
+
+## 内容
+- [模型库](#模型库)
+- [快速开始](#快速开始)
+- [引用](#引用)
+
+## 模型库
+
+### MOT17
+
+| 训练数据集 | 输入尺度 | 总batch_size | val MOTA | test MOTA | FPS | 配置文件 | 下载链接|
+| :---------------: | :-------: | :------------: | :----------------: | :---------: | :-------: | :----: | :-----: |
+| MOT17-half train | 544x960 | 32 | 69.2(MOT17-half) | - | - |[config](./centertrack_dla34_70e_mot17half.yml) | [download](https://paddledet.bj.bcebos.com/models/mot/centertrack_dla34_70e_mot17half.pdparams) |
+| MOT17 train | 544x960 | 32 | 87.9(MOT17-train) | 70.5(MOT17-test) | - |[config](./centertrack_dla34_70e_mot17.yml) | [download](https://paddledet.bj.bcebos.com/models/mot/centertrack_dla34_70e_mot17.pdparams) |
+| MOT17 train(paper) | 544x960| 32 | - | 67.8(MOT17-test) | - | - | - |
+
+
+**注意:**
+ - CenterTrack默认使用2 GPUs总batch_size为32进行训练,如改变GPU数或单卡batch_size,最好保持总batch_size为32去训练。
+ - **val MOTA**可能会有1.0 MOTA左右的波动,最好使用2 GPUs和总batch_size为32的默认配置去训练。
+ - **MOT17-half train**是MOT17的train序列(共7个)每个视频的**前一半帧**的图片和标注用作训练集,而用每个视频的后一半帧组成的**MOT17-half val**作为验证集去评估得到**val MOTA**,数据集可以从[此链接](https://bj.bcebos.com/v1/paddledet/data/mot/MOT17.zip)下载,并解压放在`dataset/mot/`文件夹下。
+ - **MOT17 train**是MOT17的train序列(共7个)每个视频的所有帧的图片和标注用作训练集,由于MOT17数据集有限也使用**MOT17 train**数据集去评估得到**val MOTA**,而**test MOTA**为交到[MOT Challenge官网](https://motchallenge.net)评测的结果。
+
+
+## 快速开始
+
+### 1.训练
+通过如下命令一键式启动训练和评估
+```bash
+# 单卡训练(不推荐)
+CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml --amp
+# 多卡训练
+python -m paddle.distributed.launch --log_dir=centertrack_dla34_70e_mot17half/ --gpus 0,1 tools/train.py -c configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml --amp
+```
+**注意:**
+ - `--eval`暂不支持边训练边验证跟踪的MOTA精度,如果需要开启`--eval`边训练边验证检测mAP,需设置**注释配置文件中的`mot_metric: True`和`metric: MOT`**;
+ - `--amp`表示混合精度训练避免显存溢出;
+ - CenterTrack默认使用2 GPUs总batch_size为32进行训练,如改变GPU数或单卡batch_size,最好保持总batch_size仍然为32;
+
+
+### 2.评估
+
+#### 2.1 评估检测效果
+
+注意首先需要**注释配置文件中的`mot_metric: True`和`metric: MOT`**:
+```python
+### for detection eval.py/infer.py
+mot_metric: False
+metric: COCO
+
+### for MOT eval_mot.py/infer_mot_mot.py
+#mot_metric: True # 默认是不注释的,评估跟踪需要为 True,会覆盖之前的 mot_metric: False
+#metric: MOT # 默认是不注释的,评估跟踪需要使用 MOT,会覆盖之前的 metric: COCO
+```
+
+然后执行以下语句:
+```bash
+CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml -o weights=output/centertrack_dla34_70e_mot17half/model_final.pdparams
+```
+
+**注意:**
+ - 评估检测使用的是```tools/eval.py```, 评估跟踪使用的是```tools/eval_mot.py```。
+
+#### 2.2 评估跟踪效果
+
+注意首先确保设置了**配置文件中的`mot_metric: True`和`metric: MOT`**;
+
+然后执行以下语句:
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml -o weights=output/centertrack_dla34_70e_mot17half/model_final.pdparams
+```
+**注意:**
+ - 评估检测使用的是```tools/eval.py```, 评估跟踪使用的是```tools/eval_mot.py```。
+ - 跟踪结果会存于`{output_dir}/mot_results/`中,里面每个视频序列对应一个txt,每个txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`, 此外`{output_dir}`可通过`--output_dir`设置,默认文件夹名为`output`。
+
+
+### 3.预测
+
+#### 3.1 预测检测效果
+注意首先需要**注释配置文件中的`mot_metric: True`和`metric: MOT`**:
+```python
+### for detection eval.py/infer.py
+mot_metric: False
+metric: COCO
+
+### for MOT eval_mot.py/infer_mot_mot.py
+#mot_metric: True # 默认是不注释的,评估跟踪需要为 True,会覆盖之前的 mot_metric: False
+#metric: MOT # 默认是不注释的,评估跟踪需要使用 MOT,会覆盖之前的 metric: COCO
+```
+
+然后执行以下语句:
+```bash
+CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml -o weights=output/centertrack_dla34_70e_mot17half/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg --draw_threshold=0.5
+```
+
+**注意:**
+ - 预测检测使用的是```tools/infer.py```, 预测跟踪使用的是```tools/infer_mot.py```。
+
+
+#### 3.2 预测跟踪效果
+
+注意首先确保设置了**配置文件中的`mot_metric: True`和`metric: MOT`**;
+
+然后执行以下语句:
+```bash
+# 下载demo视频
+wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4
+# 预测视频
+CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml --video_file=mot17_demo.mp4 --draw_threshold=0.5 --save_videos -o weights=output/centertrack_dla34_70e_mot17half/model_final.pdparams
+#或预测图片文件夹
+CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml --image_dir=mot17_demo/ --draw_threshold=0.5 --save_videos -o weights=output/centertrack_dla34_70e_mot17half/model_final.pdparams
+```
+
+**注意:**
+ - 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`。
+ - `--save_videos`表示保存可视化视频,同时会保存可视化的图片在`{output_dir}/mot_outputs/`中,`{output_dir}`可通过`--output_dir`设置,默认文件夹名为`output`。
+
+
+### 4. 导出预测模型
+
+注意首先确保设置了**配置文件中的`mot_metric: True`和`metric: MOT`**;
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/centertrack_dla34_70e_mot17half.pdparams
+```
+
+### 5. 用导出的模型基于Python去预测
+
+注意首先应在`deploy/python/tracker_config.yml`中设置`type: CenterTracker`。
+
+```bash
+# 预测某个视频
+# wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4
+python deploy/python/mot_centertrack_infer.py --model_dir=output_inference/centertrack_dla34_70e_mot17half/ --tracker_config=deploy/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --save_images=True --save_mot_txts
+# 预测图片文件夹
+python deploy/python/mot_centertrack_infer.py --model_dir=output_inference/centertrack_dla34_70e_mot17half/ --tracker_config=deploy/python/tracker_config.yml --image_dir=mot17_demo/ --device=GPU --save_images=True --save_mot_txts
+```
+
+**注意:**
+ - 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_mot_txt_per_img`(对每张图片保存一个txt)表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
+ - 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`。
+
+
+## 引用
+```
+@article{zhou2020tracking,
+ title={Tracking Objects as Points},
+ author={Zhou, Xingyi and Koltun, Vladlen and Kr{\"a}henb{\"u}hl, Philipp},
+ journal={ECCV},
+ year={2020}
+}
+```
diff --git a/configs/mot/centertrack/_base_/centertrack_dla34.yml b/configs/mot/centertrack/_base_/centertrack_dla34.yml
new file mode 100644
index 0000000000000000000000000000000000000000..159165bd159ff7f5ee310b546b5a137fbf470259
--- /dev/null
+++ b/configs/mot/centertrack/_base_/centertrack_dla34.yml
@@ -0,0 +1,57 @@
+pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/crowdhuman_centertrack.pdparams
+architecture: CenterTrack
+for_mot: True
+mot_metric: True
+
+### model
+CenterTrack:
+ detector: CenterNet
+ plugin_head: CenterTrackHead
+ tracker: CenterTracker
+
+
+### CenterTrack.detector
+CenterNet:
+ backbone: DLA
+ neck: CenterNetDLAFPN
+ head: CenterNetHead
+ post_process: CenterNetPostProcess
+ for_mot: True # Note
+
+DLA:
+ depth: 34
+ pre_img: True # Note
+ pre_hm: True # Note
+
+CenterNetDLAFPN:
+ down_ratio: 4
+ last_level: 5
+ out_channel: 0
+ dcn_v2: True
+
+CenterNetHead:
+ head_planes: 256
+ prior_bias: -4.6 # Note
+ regress_ltrb: False
+ size_loss: 'L1'
+ loss_weight: {'heatmap': 1.0, 'size': 0.1, 'offset': 1.0}
+
+CenterNetPostProcess:
+ max_per_img: 100 # top-K
+ regress_ltrb: False
+
+
+### CenterTrack.plugin_head
+CenterTrackHead:
+ head_planes: 256
+ task: tracking
+ loss_weight: {'tracking': 1.0, 'ltrb_amodal': 0.1}
+ add_ltrb_amodal: True
+
+
+### CenterTrack.tracker
+CenterTracker:
+ min_box_area: -1
+ vertical_ratio: -1
+ track_thresh: 0.4
+ pre_thresh: 0.5
diff --git a/configs/mot/centertrack/_base_/centertrack_reader.yml b/configs/mot/centertrack/_base_/centertrack_reader.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7a5bf6fda60242be1628635bd97eac4d0a85bb2b
--- /dev/null
+++ b/configs/mot/centertrack/_base_/centertrack_reader.yml
@@ -0,0 +1,75 @@
+input_h: &input_h 544
+input_w: &input_w 960
+input_size: &input_size [*input_h, *input_w]
+pre_img_epoch: &pre_img_epoch 70 # Add previous image as input
+
+worker_num: 4
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - FlipWarpAffine:
+ keep_res: False
+ input_h: *input_h
+ input_w: *input_w
+ not_rand_crop: False
+ flip: 0.5
+ is_scale: True
+ use_random: True
+ add_pre_img: True
+ - CenterRandColor: {saturation: 0.4, contrast: 0.4, brightness: 0.4}
+ - Lighting: {alphastd: 0.1, eigval: [0.2141788, 0.01817699, 0.00341571], eigvec: [[-0.58752847, -0.69563484, 0.41340352], [-0.5832747, 0.00994535, -0.81221408], [-0.56089297, 0.71832671, 0.41158938]]}
+ - NormalizeImage: {mean: [0.40789655, 0.44719303, 0.47026116], std: [0.2886383 , 0.27408165, 0.27809834], is_scale: False}
+ - Permute: {}
+ - Gt2CenterTrackTarget:
+ down_ratio: 4
+ max_objs: 256
+ hm_disturb: 0.05
+ lost_disturb: 0.4
+ fp_disturb: 0.1
+ pre_hm: True
+ add_tracking: True
+ add_ltrb_amodal: True
+ batch_size: 16 # total 32 for 2 GPUs
+ shuffle: True
+ drop_last: True
+ collate_batch: True
+ use_shared_memory: True
+ pre_img_epoch: *pre_img_epoch
+
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - WarpAffine: {keep_res: True, input_h: *input_h, input_w: *input_w}
+ - NormalizeImage: {mean: [0.40789655, 0.44719303, 0.47026116], std: [0.2886383 , 0.27408165, 0.27809834], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+
+TestReader:
+ sample_transforms:
+ - Decode: {}
+ - WarpAffine: {keep_res: True, input_h: *input_h, input_w: *input_w}
+ - NormalizeImage: {mean: [0.40789655, 0.44719303, 0.47026116], std: [0.2886383 , 0.27408165, 0.27809834], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+ fuse_normalize: True
+
+
+EvalMOTReader:
+ sample_transforms:
+ - Decode: {}
+ - WarpAffine: {keep_res: False, input_h: *input_h, input_w: *input_w}
+ - NormalizeImage: {mean: [0.40789655, 0.44719303, 0.47026116], std: [0.2886383 , 0.27408165, 0.27809834], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+
+TestMOTReader:
+ sample_transforms:
+ - Decode: {}
+ - WarpAffine: {keep_res: False, input_h: *input_h, input_w: *input_w}
+ - NormalizeImage: {mean: [0.40789655, 0.44719303, 0.47026116], std: [0.2886383 , 0.27408165, 0.27809834], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+ fuse_normalize: True
diff --git a/configs/mot/centertrack/_base_/optimizer_70e.yml b/configs/mot/centertrack/_base_/optimizer_70e.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a336290f2cecb9597b8c5fe351f132eef3235e4c
--- /dev/null
+++ b/configs/mot/centertrack/_base_/optimizer_70e.yml
@@ -0,0 +1,14 @@
+epoch: 70
+
+LearningRate:
+ base_lr: 0.000125
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [60]
+ use_warmup: False
+
+OptimizerBuilder:
+ optimizer:
+ type: Adam
+ regularizer: NULL
diff --git a/configs/mot/centertrack/centertrack_dla34_70e_mot17.yml b/configs/mot/centertrack/centertrack_dla34_70e_mot17.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2888a01747a078af34a92dfae014358f61bc668d
--- /dev/null
+++ b/configs/mot/centertrack/centertrack_dla34_70e_mot17.yml
@@ -0,0 +1,66 @@
+_BASE_: [
+ '_base_/optimizer_70e.yml',
+ '_base_/centertrack_dla34.yml',
+ '_base_/centertrack_reader.yml',
+ '../../runtime.yml',
+]
+log_iter: 20
+snapshot_epoch: 5
+weights: output/centertrack_dla34_70e_mot17/model_final
+pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/crowdhuman_centertrack.pdparams
+
+
+### for Detection eval.py/infer.py
+# mot_metric: False
+# metric: COCO
+
+### for MOT eval_mot.py/infer_mot_mot.py
+mot_metric: True
+metric: MOT
+
+
+worker_num: 4
+TrainReader:
+ batch_size: 16 # total 32 for 2 GPUs
+
+EvalReader:
+ batch_size: 1
+
+EvalMOTReader:
+ batch_size: 1
+
+
+# COCO style dataset for training
+num_classes: 1
+TrainDataset:
+ !COCODataSet
+ dataset_dir: dataset/mot/MOT17
+ anno_path: annotations/train.json
+ image_dir: images/train
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_track_id']
+ # add 'gt_track_id', the boxes annotations of json file should have 'gt_track_id'
+
+EvalDataset:
+ !COCODataSet
+ dataset_dir: dataset/mot/MOT17
+ anno_path: annotations/val_half.json
+ image_dir: images/train
+
+TestDataset:
+ !ImageFolder
+ dataset_dir: dataset/mot/MOT17
+ anno_path: annotations/val_half.json
+
+# for MOT evaluation
+# If you want to change the MOT evaluation dataset, please modify 'data_root'
+EvalMOTDataset:
+ !MOTImageFolder
+ dataset_dir: dataset/mot/MOT17
+ data_root: images/train # set 'images/test' for MOTChallenge test
+ keep_ori_im: True # set True if save visualization images or video, or used in SDE MOT
+
+# for MOT video inference
+TestMOTDataset:
+ !MOTImageFolder
+ dataset_dir: dataset/mot/MOT17
+ keep_ori_im: True # set True if save visualization images or video
diff --git a/configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml b/configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2aff70fe32aa6510d8441d2acf976457b759b9c0
--- /dev/null
+++ b/configs/mot/centertrack/centertrack_dla34_70e_mot17half.yml
@@ -0,0 +1,66 @@
+_BASE_: [
+ '_base_/optimizer_70e.yml',
+ '_base_/centertrack_dla34.yml',
+ '_base_/centertrack_reader.yml',
+ '../../runtime.yml',
+]
+log_iter: 20
+snapshot_epoch: 5
+weights: output/centertrack_dla34_70e_mot17half/model_final
+pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/crowdhuman_centertrack.pdparams
+
+
+### for Detection eval.py/infer.py
+# mot_metric: False
+# metric: COCO
+
+### for MOT eval_mot.py/infer_mot_mot.py
+mot_metric: True
+metric: MOT
+
+
+worker_num: 4
+TrainReader:
+ batch_size: 16 # total 32 for 2 GPUs
+
+EvalReader:
+ batch_size: 1
+
+EvalMOTReader:
+ batch_size: 1
+
+
+# COCO style dataset for training
+num_classes: 1
+TrainDataset:
+ !COCODataSet
+ dataset_dir: dataset/mot/MOT17
+ anno_path: annotations/train_half.json
+ image_dir: images/train
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_track_id']
+ # add 'gt_track_id', the boxes annotations of json file should have 'gt_track_id'
+
+EvalDataset:
+ !COCODataSet
+ dataset_dir: dataset/mot/MOT17
+ anno_path: annotations/val_half.json
+ image_dir: images/train
+
+TestDataset:
+ !ImageFolder
+ dataset_dir: dataset/mot/MOT17
+ anno_path: annotations/val_half.json
+
+# for MOT evaluation
+# If you want to change the MOT evaluation dataset, please modify 'data_root'
+EvalMOTDataset:
+ !MOTImageFolder
+ dataset_dir: dataset/mot/MOT17
+ data_root: images/half
+ keep_ori_im: True # set True if save visualization images or video, or used in SDE MOT
+
+# for MOT video inference
+TestMOTDataset:
+ !MOTImageFolder
+ dataset_dir: dataset/mot/MOT17
+ keep_ori_im: True # set True if save visualization images or video
diff --git a/configs/mot/fairmot/_base_/fairmot_dla34.yml b/configs/mot/fairmot/_base_/fairmot_dla34.yml
index b9f5c65a34d82899546a02c3c0ec246a375227bd..9388ab6692be242f5532c696393944b71b232821 100644
--- a/configs/mot/fairmot/_base_/fairmot_dla34.yml
+++ b/configs/mot/fairmot/_base_/fairmot_dla34.yml
@@ -23,12 +23,11 @@ CenterNetDLAFPN:
CenterNetHead:
head_planes: 256
- heatmap_weight: 1
+ prior_bias: -2.19
regress_ltrb: True
- size_weight: 0.1
size_loss: 'L1'
- offset_weight: 1
- iou_weight: 0
+ loss_weight: {'heatmap': 1.0, 'size': 0.1, 'offset': 1.0, 'iou': 0.0}
+ add_iou: False
FairMOTEmbeddingHead:
ch_head: 256
diff --git a/configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml b/configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml
index 61adb15659c8a512c42c3a660453bdda3b3e1bbf..287255fdaf032d2979083d460ef49335409e0b9f 100644
--- a/configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml
+++ b/configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml
@@ -15,7 +15,6 @@ CenterNetHead:
regress_ltrb: False
CenterNetPostProcess:
- for_mot: True
regress_ltrb: False
max_per_img: 200
diff --git a/configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml b/configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml
index 6118f053c43aa043a00d1310c00ac0a68018d6bb..99452f5dc55115a4267c9bb4ad4608009a54a16e 100644
--- a/configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml
+++ b/configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml
@@ -40,7 +40,6 @@ CenterNetHead:
regress_ltrb: False
CenterNetPostProcess:
- for_mot: True
regress_ltrb: False
max_per_img: 200
diff --git a/deploy/pptracking/python/mot/tracker/__init__.py b/deploy/pptracking/python/mot/tracker/__init__.py
index 03a5dd0a169203b86edbc6c81a44a095ebe9b3cc..76ee2a6c99c5b5fd3da0f6749a13b12f935cb588 100644
--- a/deploy/pptracking/python/mot/tracker/__init__.py
+++ b/deploy/pptracking/python/mot/tracker/__init__.py
@@ -14,12 +14,16 @@
from . import base_jde_tracker
from . import base_sde_tracker
+
+from .base_jde_tracker import *
+from .base_sde_tracker import *
+
from . import jde_tracker
from . import deepsort_tracker
from . import ocsort_tracker
+from . import center_tracker
-from .base_jde_tracker import *
-from .base_sde_tracker import *
from .jde_tracker import *
from .deepsort_tracker import *
from .ocsort_tracker import *
+from .center_tracker import *
diff --git a/deploy/pptracking/python/mot/tracker/center_tracker.py b/deploy/pptracking/python/mot/tracker/center_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b30ba9269711b21a60aa553e97d68f4950b7d1a
--- /dev/null
+++ b/deploy/pptracking/python/mot/tracker/center_tracker.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is based on https://github.com/xingyizhou/CenterTrack/blob/master/src/lib/utils/tracker.py
+"""
+
+import copy
+import numpy as np
+import sklearn
+
+__all__ = ['CenterTracker']
+
+
+class CenterTracker(object):
+ __shared__ = ['num_classes']
+
+ def __init__(self,
+ num_classes=1,
+ min_box_area=0,
+ vertical_ratio=-1,
+ track_thresh=0.4,
+ pre_thresh=0.5,
+ new_thresh=0.4,
+ out_thresh=0.4,
+ hungarian=False):
+ self.num_classes = num_classes
+ self.min_box_area = min_box_area
+ self.vertical_ratio = vertical_ratio
+
+ self.track_thresh = track_thresh
+ self.pre_thresh = max(track_thresh, pre_thresh)
+ self.new_thresh = max(track_thresh, new_thresh)
+ self.out_thresh = max(track_thresh, out_thresh)
+ self.hungarian = hungarian
+
+ self.reset()
+
+ def init_track(self, results):
+ print('Initialize tracking!')
+ for item in results:
+ if item['score'] > self.new_thresh:
+ self.id_count += 1
+ item['tracking_id'] = self.id_count
+ if not ('ct' in item):
+ bbox = item['bbox']
+ item['ct'] = [(bbox[0] + bbox[2]) / 2,
+ (bbox[1] + bbox[3]) / 2]
+ self.tracks.append(item)
+
+ def reset(self):
+ self.id_count = 0
+ self.tracks = []
+
+ def update(self, results, public_det=None):
+ N = len(results)
+ M = len(self.tracks)
+
+ dets = np.array([det['ct'] + det['tracking'] for det in results],
+ np.float32) # N x 2
+ track_size = np.array([((track['bbox'][2] - track['bbox'][0]) * \
+ (track['bbox'][3] - track['bbox'][1])) \
+ for track in self.tracks], np.float32) # M
+ track_cat = np.array([track['class'] for track in self.tracks],
+ np.int32) # M
+ item_size = np.array([((item['bbox'][2] - item['bbox'][0]) * \
+ (item['bbox'][3] - item['bbox'][1])) \
+ for item in results], np.float32) # N
+ item_cat = np.array([item['class'] for item in results], np.int32) # N
+ tracks = np.array([pre_det['ct'] for pre_det in self.tracks],
+ np.float32) # M x 2
+ dist = (((tracks.reshape(1, -1, 2) - \
+ dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
+
+ invalid = ((dist > track_size.reshape(1, M)) + \
+ (dist > item_size.reshape(N, 1)) + \
+ (item_cat.reshape(N, 1) != track_cat.reshape(1, M))) > 0
+ dist = dist + invalid * 1e18
+
+ if self.hungarian:
+ item_score = np.array([item['score'] for item in results],
+ np.float32)
+ dist[dist > 1e18] = 1e18
+ from sklearn.utils.linear_assignment_ import linear_assignment
+ matched_indices = linear_assignment(dist)
+ else:
+ matched_indices = greedy_assignment(copy.deepcopy(dist))
+
+ unmatched_dets = [d for d in range(dets.shape[0]) \
+ if not (d in matched_indices[:, 0])]
+ unmatched_tracks = [d for d in range(tracks.shape[0]) \
+ if not (d in matched_indices[:, 1])]
+
+ if self.hungarian:
+ matches = []
+ for m in matched_indices:
+ if dist[m[0], m[1]] > 1e16:
+ unmatched_dets.append(m[0])
+ unmatched_tracks.append(m[1])
+ else:
+ matches.append(m)
+ matches = np.array(matches).reshape(-1, 2)
+ else:
+ matches = matched_indices
+
+ ret = []
+ for m in matches:
+ track = results[m[0]]
+ track['tracking_id'] = self.tracks[m[1]]['tracking_id']
+ ret.append(track)
+
+ # Private detection: create tracks for all un-matched detections
+ for i in unmatched_dets:
+ track = results[i]
+ if track['score'] > self.new_thresh:
+ self.id_count += 1
+ track['tracking_id'] = self.id_count
+ ret.append(track)
+
+ self.tracks = ret
+ return ret
+
+
+def greedy_assignment(dist):
+ matched_indices = []
+ if dist.shape[1] == 0:
+ return np.array(matched_indices, np.int32).reshape(-1, 2)
+ for i in range(dist.shape[0]):
+ j = dist[i].argmin()
+ if dist[i][j] < 1e16:
+ dist[:, j] = 1e18
+ matched_indices.append([i, j])
+ return np.array(matched_indices, np.int32).reshape(-1, 2)
diff --git a/deploy/python/infer.py b/deploy/python/infer.py
index 344eca59f75740bd9cec53693ecdc57e970b02ee..6136d0db1564436b054f230250a23bd9e6f4fe49 100644
--- a/deploy/python/infer.py
+++ b/deploy/python/infer.py
@@ -42,7 +42,8 @@ from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco
SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
- 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet', 'PPLCNet', 'DETR'
+ 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet', 'PPLCNet', 'DETR',
+ 'CenterTrack'
}
TUNED_TRT_DYNAMIC_MODELS = {'DETR'}
@@ -197,8 +198,12 @@ class Detector(object):
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
- boxes_num = self.predictor.get_output_handle(output_names[1])
- np_boxes_num = boxes_num.copy_to_cpu()
+ if len(output_names) == 1:
+ # some exported model can not get tensor 'bbox_num'
+ np_boxes_num = np.array([len(np_boxes)])
+ else:
+ boxes_num = self.predictor.get_output_handle(output_names[1])
+ np_boxes_num = boxes_num.copy_to_cpu()
if self.pred_config.mask:
masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu()
diff --git a/deploy/python/mot_centertrack_infer.py b/deploy/python/mot_centertrack_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c04a96876ae3d3b785366029e119be3d943f92fa
--- /dev/null
+++ b/deploy/python/mot_centertrack_infer.py
@@ -0,0 +1,505 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+import math
+import time
+import yaml
+import cv2
+import numpy as np
+from collections import defaultdict
+import paddle
+
+from benchmark_utils import PaddleInferBenchmark
+from utils import gaussian_radius, gaussian2D, draw_umich_gaussian
+from preprocess import preprocess, decode_image, WarpAffine, NormalizeImage, Permute
+from utils import argsparser, Timer, get_current_memory_mb
+from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
+from keypoint_preprocess import get_affine_transform
+
+# add python path
+import sys
+parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
+sys.path.insert(0, parent_path)
+
+from pptracking.python.mot import CenterTracker
+from pptracking.python.mot.utils import MOTTimer, write_mot_results
+from pptracking.python.mot.visualize import plot_tracking
+
+
+def transform_preds_with_trans(coords, trans):
+ target_coords = np.ones((coords.shape[0], 3), np.float32)
+ target_coords[:, :2] = coords
+ target_coords = np.dot(trans, target_coords.transpose()).transpose()
+ return target_coords[:, :2]
+
+
+def affine_transform(pt, t):
+ new_pt = np.array([pt[0], pt[1], 1.]).T
+ new_pt = np.dot(t, new_pt)
+ return new_pt[:2]
+
+
+def affine_transform_bbox(bbox, trans, width, height):
+ bbox = np.array(copy.deepcopy(bbox), dtype=np.float32)
+ bbox[:2] = affine_transform(bbox[:2], trans)
+ bbox[2:] = affine_transform(bbox[2:], trans)
+ bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, width - 1)
+ bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, height - 1)
+ return bbox
+
+
+class CenterTrack(Detector):
+ """
+ Args:
+ model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
+ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
+ run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
+ batch_size (int): size of pre batch in inference
+ trt_min_shape (int): min shape for dynamic shape in trt
+ trt_max_shape (int): max shape for dynamic shape in trt
+ trt_opt_shape (int): opt shape for dynamic shape in trt
+ trt_calib_mode (bool): If the model is produced by TRT offline quantitative
+ calibration, trt_calib_mode need to set True
+ cpu_threads (int): cpu threads
+ enable_mkldnn (bool): whether to open MKLDNN
+ output_dir (string): The path of output, default as 'output'
+ threshold (float): Score threshold of the detected bbox, default as 0.5
+ save_images (bool): Whether to save visualization image results, default as False
+ save_mot_txts (bool): Whether to save tracking results (txt), default as False
+ """
+
+ def __init__(
+ self,
+ model_dir,
+ tracker_config=None,
+ device='CPU',
+ run_mode='paddle',
+ batch_size=1,
+ trt_min_shape=1,
+ trt_max_shape=960,
+ trt_opt_shape=544,
+ trt_calib_mode=False,
+ cpu_threads=1,
+ enable_mkldnn=False,
+ output_dir='output',
+ threshold=0.5,
+ save_images=False,
+ save_mot_txts=False, ):
+ super(CenterTrack, self).__init__(
+ model_dir=model_dir,
+ device=device,
+ run_mode=run_mode,
+ batch_size=batch_size,
+ trt_min_shape=trt_min_shape,
+ trt_max_shape=trt_max_shape,
+ trt_opt_shape=trt_opt_shape,
+ trt_calib_mode=trt_calib_mode,
+ cpu_threads=cpu_threads,
+ enable_mkldnn=enable_mkldnn,
+ output_dir=output_dir,
+ threshold=threshold, )
+ self.save_images = save_images
+ self.save_mot_txts = save_mot_txts
+ assert batch_size == 1, "MOT model only supports batch_size=1."
+ self.det_times = Timer(with_tracker=True)
+ self.num_classes = len(self.pred_config.labels)
+
+ # tracker config
+ cfg = self.pred_config.tracker
+ min_box_area = cfg.get('min_box_area', -1)
+ vertical_ratio = cfg.get('vertical_ratio', -1)
+ track_thresh = cfg.get('track_thresh', 0.4)
+ pre_thresh = cfg.get('pre_thresh', 0.5)
+
+ self.tracker = CenterTracker(
+ num_classes=self.num_classes,
+ min_box_area=min_box_area,
+ vertical_ratio=vertical_ratio,
+ track_thresh=track_thresh,
+ pre_thresh=pre_thresh)
+
+ self.pre_image = None
+
+ def get_additional_inputs(self, dets, meta, with_hm=True):
+ # Render input heatmap from previous trackings.
+ trans_input = meta['trans_input']
+ inp_width, inp_height = int(meta['inp_width']), int(meta['inp_height'])
+ input_hm = np.zeros((1, inp_height, inp_width), dtype=np.float32)
+
+ for det in dets:
+ if det['score'] < self.tracker.pre_thresh:
+ continue
+ bbox = affine_transform_bbox(det['bbox'], trans_input, inp_width,
+ inp_height)
+ h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
+ if (h > 0 and w > 0):
+ radius = gaussian_radius(
+ (math.ceil(h), math.ceil(w)), min_overlap=0.7)
+ radius = max(0, int(radius))
+ ct = np.array(
+ [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
+ dtype=np.float32)
+ ct_int = ct.astype(np.int32)
+ if with_hm:
+ input_hm[0] = draw_umich_gaussian(input_hm[0], ct_int,
+ radius)
+ if with_hm:
+ input_hm = input_hm[np.newaxis]
+ return input_hm
+
+ def preprocess(self, image_list):
+ preprocess_ops = []
+ for op_info in self.pred_config.preprocess_infos:
+ new_op_info = op_info.copy()
+ op_type = new_op_info.pop('type')
+ preprocess_ops.append(eval(op_type)(**new_op_info))
+
+ assert len(image_list) == 1, 'MOT only support bs=1'
+ im_path = image_list[0]
+ im, im_info = preprocess(im_path, preprocess_ops)
+ #inputs = create_inputs(im, im_info)
+ inputs = {}
+ inputs['image'] = np.array((im, )).astype('float32')
+ inputs['im_shape'] = np.array(
+ (im_info['im_shape'], )).astype('float32')
+ inputs['scale_factor'] = np.array(
+ (im_info['scale_factor'], )).astype('float32')
+
+ inputs['trans_input'] = im_info['trans_input']
+ inputs['inp_width'] = im_info['inp_width']
+ inputs['inp_height'] = im_info['inp_height']
+ inputs['center'] = im_info['center']
+ inputs['scale'] = im_info['scale']
+ inputs['out_height'] = im_info['out_height']
+ inputs['out_width'] = im_info['out_width']
+
+ if self.pre_image is None:
+ self.pre_image = inputs['image']
+ # initializing tracker for the first frame
+ self.tracker.init_track([])
+ inputs['pre_image'] = self.pre_image
+ self.pre_image = inputs['image'] # Note: update for next image
+
+ # render input heatmap from tracker status
+ pre_hm = self.get_additional_inputs(
+ self.tracker.tracks, inputs, with_hm=True)
+ inputs['pre_hm'] = pre_hm #.to_tensor(pre_hm)
+
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(input_names[i])
+ if input_names[i] == 'x':
+ input_tensor.copy_from_cpu(inputs['image'])
+ else:
+ input_tensor.copy_from_cpu(inputs[input_names[i]])
+
+ return inputs
+
+ def postprocess(self, inputs, result):
+ # postprocess output of predictor
+ np_bboxes = result['bboxes']
+ if np_bboxes.shape[0] <= 0:
+ print('[WARNNING] No object detected and tracked.')
+ result = {'bboxes': np.zeros([0, 6]), 'cts': None, 'tracking': None}
+ return result
+ result = {k: v for k, v in result.items() if v is not None}
+ return result
+
+ def centertrack_post_process(self, dets, meta, out_thresh):
+ if not ('bboxes' in dets):
+ return [{}]
+
+ preds = []
+ c, s = meta['center'], meta['scale']
+ h, w = meta['out_height'], meta['out_width']
+ trans = get_affine_transform(
+ center=c,
+ input_size=s,
+ rot=0,
+ output_size=[w, h],
+ shift=(0., 0.),
+ inv=True).astype(np.float32)
+ for i, dets_bbox in enumerate(dets['bboxes']):
+ if dets_bbox[1] < out_thresh:
+ break
+ item = {}
+ item['score'] = dets_bbox[1]
+ item['class'] = int(dets_bbox[0]) + 1
+ item['ct'] = transform_preds_with_trans(
+ dets['cts'][i].reshape([1, 2]), trans).reshape(2)
+
+ if 'tracking' in dets:
+ tracking = transform_preds_with_trans(
+ (dets['tracking'][i] + dets['cts'][i]).reshape([1, 2]),
+ trans).reshape(2)
+ item['tracking'] = tracking - item['ct']
+
+ if 'bboxes' in dets:
+ bbox = transform_preds_with_trans(
+ dets_bbox[2:6].reshape([2, 2]), trans).reshape(4)
+ item['bbox'] = bbox
+
+ preds.append(item)
+ return preds
+
+ def tracking(self, inputs, det_results):
+ result = self.centertrack_post_process(
+ det_results, inputs, self.tracker.out_thresh)
+ online_targets = self.tracker.update(result)
+
+ online_tlwhs, online_scores, online_ids = [], [], []
+ for t in online_targets:
+ bbox = t['bbox']
+ tlwh = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]]
+ tscore = float(t['score'])
+ tid = int(t['tracking_id'])
+ if tlwh[2] * tlwh[3] > 0:
+ online_tlwhs.append(tlwh)
+ online_ids.append(tid)
+ online_scores.append(tscore)
+ return online_tlwhs, online_scores, online_ids
+
+ def predict(self, repeats=1):
+ '''
+ Args:
+ repeats (int): repeats number for prediction
+ Returns:
+ result (dict): include 'bboxes', 'cts' and 'tracking':
+ np.ndarray: shape:[N,6],[N,2] and [N,2], N: number of box
+ '''
+ # model prediction
+ np_bboxes, np_cts, np_tracking = None, None, None
+ for i in range(repeats):
+ self.predictor.run()
+ output_names = self.predictor.get_output_names()
+ bboxes_tensor = self.predictor.get_output_handle(output_names[0])
+ np_bboxes = bboxes_tensor.copy_to_cpu()
+ cts_tensor = self.predictor.get_output_handle(output_names[1])
+ np_cts = cts_tensor.copy_to_cpu()
+ tracking_tensor = self.predictor.get_output_handle(output_names[2])
+ np_tracking = tracking_tensor.copy_to_cpu()
+
+ result = dict(
+ bboxes=np_bboxes,
+ cts=np_cts,
+ tracking=np_tracking)
+ return result
+
+ def predict_image(self,
+ image_list,
+ run_benchmark=False,
+ repeats=1,
+ visual=True,
+ seq_name=None):
+ mot_results = []
+ num_classes = self.num_classes
+ image_list.sort()
+ ids2names = self.pred_config.labels
+ data_type = 'mcmot' if num_classes > 1 else 'mot'
+ for frame_id, img_file in enumerate(image_list):
+ batch_image_list = [img_file] # bs=1 in MOT model
+ if run_benchmark:
+ # preprocess
+ inputs = self.preprocess(batch_image_list) # warmup
+ self.det_times.preprocess_time_s.start()
+ inputs = self.preprocess(batch_image_list)
+ self.det_times.preprocess_time_s.end()
+
+ # model prediction
+ result_warmup = self.predict(repeats=repeats) # warmup
+ self.det_times.inference_time_s.start()
+ result = self.predict(repeats=repeats)
+ self.det_times.inference_time_s.end(repeats=repeats)
+
+ # postprocess
+ result_warmup = self.postprocess(inputs, result) # warmup
+ self.det_times.postprocess_time_s.start()
+ det_result = self.postprocess(inputs, result)
+ self.det_times.postprocess_time_s.end()
+
+ # tracking
+ result_warmup = self.tracking(inputs, det_result)
+ self.det_times.tracking_time_s.start()
+ online_tlwhs, online_scores, online_ids = self.tracking(inputs,
+ det_result)
+ self.det_times.tracking_time_s.end()
+ self.det_times.img_num += 1
+
+ cm, gm, gu = get_current_memory_mb()
+ self.cpu_mem += cm
+ self.gpu_mem += gm
+ self.gpu_util += gu
+
+ else:
+ self.det_times.preprocess_time_s.start()
+ inputs = self.preprocess(batch_image_list)
+ self.det_times.preprocess_time_s.end()
+
+ self.det_times.inference_time_s.start()
+ result = self.predict()
+ self.det_times.inference_time_s.end()
+
+ self.det_times.postprocess_time_s.start()
+ det_result = self.postprocess(inputs, result)
+ self.det_times.postprocess_time_s.end()
+
+ # tracking process
+ self.det_times.tracking_time_s.start()
+ online_tlwhs, online_scores, online_ids = self.tracking(inputs,
+ det_result)
+ self.det_times.tracking_time_s.end()
+ self.det_times.img_num += 1
+
+ if visual:
+ if len(image_list) > 1 and frame_id % 10 == 0:
+ print('Tracking frame {}'.format(frame_id))
+ frame, _ = decode_image(img_file, {})
+
+ im = plot_tracking(
+ frame,
+ online_tlwhs,
+ online_ids,
+ online_scores,
+ frame_id=frame_id,
+ ids2names=ids2names)
+ if seq_name is None:
+ seq_name = image_list[0].split('/')[-2]
+ save_dir = os.path.join(self.output_dir, seq_name)
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ cv2.imwrite(
+ os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
+
+ mot_results.append([online_tlwhs, online_scores, online_ids])
+ return mot_results
+
+ def predict_video(self, video_file, camera_id):
+ video_out_name = 'mot_output.mp4'
+ if camera_id != -1:
+ capture = cv2.VideoCapture(camera_id)
+ else:
+ capture = cv2.VideoCapture(video_file)
+ video_out_name = os.path.split(video_file)[-1]
+ # Get Video info : resolution, fps, frame count
+ width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ fps = int(capture.get(cv2.CAP_PROP_FPS))
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
+ print("fps: %d, frame_count: %d" % (fps, frame_count))
+
+ if not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir)
+ out_path = os.path.join(self.output_dir, video_out_name)
+ video_format = 'mp4v'
+ fourcc = cv2.VideoWriter_fourcc(*video_format)
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
+
+ frame_id = 1
+ timer = MOTTimer()
+ results = defaultdict(list) # centertrack onpy support single class
+ num_classes = self.num_classes
+ data_type = 'mcmot' if num_classes > 1 else 'mot'
+ ids2names = self.pred_config.labels
+ while (1):
+ ret, frame = capture.read()
+ if not ret:
+ break
+ if frame_id % 10 == 0:
+ print('Tracking frame: %d' % (frame_id))
+ frame_id += 1
+
+ timer.tic()
+ seq_name = video_out_name.split('.')[0]
+ mot_results = self.predict_image(
+ [frame[:, :, ::-1]], visual=False, seq_name=seq_name)
+ timer.toc()
+
+ fps = 1. / timer.duration
+ online_tlwhs, online_scores, online_ids = mot_results[0]
+ results[0].append(
+ (frame_id + 1, online_tlwhs, online_scores, online_ids))
+ im = plot_tracking(
+ frame,
+ online_tlwhs,
+ online_ids,
+ online_scores,
+ frame_id=frame_id,
+ fps=fps,
+ ids2names=ids2names)
+
+ writer.write(im)
+ if camera_id != -1:
+ cv2.imshow('Mask Detection', im)
+ if cv2.waitKey(1) & 0xFF == ord('q'):
+ break
+
+ if self.save_mot_txts:
+ result_filename = os.path.join(
+ self.output_dir, video_out_name.split('.')[-2] + '.txt')
+
+ write_mot_results(result_filename, results, data_type, num_classes)
+
+ writer.release()
+
+
+def main():
+ detector = CenterTrack(
+ FLAGS.model_dir,
+ tracker_config=None,
+ device=FLAGS.device,
+ run_mode=FLAGS.run_mode,
+ batch_size=1,
+ trt_min_shape=FLAGS.trt_min_shape,
+ trt_max_shape=FLAGS.trt_max_shape,
+ trt_opt_shape=FLAGS.trt_opt_shape,
+ trt_calib_mode=FLAGS.trt_calib_mode,
+ cpu_threads=FLAGS.cpu_threads,
+ enable_mkldnn=FLAGS.enable_mkldnn,
+ output_dir=FLAGS.output_dir,
+ threshold=FLAGS.threshold,
+ save_images=FLAGS.save_images,
+ save_mot_txts=FLAGS.save_mot_txts)
+
+ # predict from video file or camera video stream
+ if FLAGS.video_file is not None or FLAGS.camera_id != -1:
+ detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
+ else:
+ # predict from image
+ img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
+ detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
+
+ if not FLAGS.run_benchmark:
+ detector.det_times.info(average=True)
+ else:
+ mode = FLAGS.run_mode
+ model_dir = FLAGS.model_dir
+ model_info = {
+ 'model_name': model_dir.strip('/').split('/')[-1],
+ 'precision': mode.split('_')[-1]
+ }
+ bench_log(detector, img_list, model_info, name='MOT')
+
+
+if __name__ == '__main__':
+ paddle.enable_static()
+ parser = argsparser()
+ FLAGS = parser.parse_args()
+ print_arguments(FLAGS)
+ FLAGS.device = FLAGS.device.upper()
+ assert FLAGS.device in ['CPU', 'GPU', 'XPU'
+ ], "device should be CPU, GPU or XPU"
+
+ main()
diff --git a/deploy/python/preprocess.py b/deploy/python/preprocess.py
index 51033aebc818a6e53c0528e202db272bfc5dd9bc..6f1a5a2a1a0e38e3edbd9685ad4013b6579ddb87 100644
--- a/deploy/python/preprocess.py
+++ b/deploy/python/preprocess.py
@@ -450,13 +450,15 @@ class WarpAffine(object):
input_h=512,
input_w=512,
scale=0.4,
- shift=0.1):
+ shift=0.1,
+ down_ratio=4):
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.scale = scale
self.shift = shift
+ self.down_ratio = down_ratio
def __call__(self, im, im_info):
"""
@@ -472,12 +474,14 @@ class WarpAffine(object):
h, w = img.shape[:2]
if self.keep_res:
+ # True in detection eval/infer
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
+ # False in centertrack eval_mot/eval_mot
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
@@ -486,6 +490,22 @@ class WarpAffine(object):
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
+
+ if not self.keep_res:
+ out_h = input_h // self.down_ratio
+ out_w = input_w // self.down_ratio
+ trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
+
+ im_info.update({
+ 'center': c,
+ 'scale': s,
+ 'out_height': out_h,
+ 'out_width': out_w,
+ 'inp_height': input_h,
+ 'inp_width': input_w,
+ 'trans_input': trans_input,
+ 'trans_output': trans_output,
+ })
return inp, im_info
diff --git a/deploy/python/tracker_config.yml b/deploy/python/tracker_config.yml
index ddd55e8653870ed9bdfe9734995e8af5b56f49e2..9531c549e3f6993da81147a41d55d47b35a12fef 100644
--- a/deploy/python/tracker_config.yml
+++ b/deploy/python/tracker_config.yml
@@ -2,7 +2,7 @@
# The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
-type: JDETracker # 'JDETracker' or 'DeepSORTTracker'
+type: JDETracker # 'JDETracker', 'DeepSORTTracker' or 'CenterTracker'
# BYTETracker
JDETracker:
@@ -24,3 +24,9 @@ DeepSORTTracker:
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
+
+CenterTracker:
+ min_box_area: -1
+ vertical_ratio: -1
+ track_thresh: 0.4
+ pre_thresh: 0.5
diff --git a/deploy/python/utils.py b/deploy/python/utils.py
index b7f514ebc999c361944fb0f16f73043fbc4e6460..d1f7d59f8571f2795af059300be69016f47fb4d7 100644
--- a/deploy/python/utils.py
+++ b/deploy/python/utils.py
@@ -476,3 +476,59 @@ coco_clsid2catid = {
78: 89,
79: 90
}
+
+
+def gaussian_radius(bbox_size, min_overlap):
+ height, width = bbox_size
+
+ a1 = 1
+ b1 = (height + width)
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
+ sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
+ radius1 = (b1 + sq1) / (2 * a1)
+
+ a2 = 4
+ b2 = 2 * (height + width)
+ c2 = (1 - min_overlap) * width * height
+ sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
+ radius2 = (b2 + sq2) / 2
+
+ a3 = 4 * min_overlap
+ b3 = -2 * min_overlap * (height + width)
+ c3 = (min_overlap - 1) * width * height
+ sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
+ radius3 = (b3 + sq3) / 2
+ return min(radius1, radius2, radius3)
+
+
+def gaussian2D(shape, sigma_x=1, sigma_y=1):
+ m, n = [(ss - 1.) / 2. for ss in shape]
+ y, x = np.ogrid[-m:m + 1, -n:n + 1]
+
+ h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
+ sigma_y)))
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
+ return h
+
+
+def draw_umich_gaussian(heatmap, center, radius, k=1):
+ """
+ draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126
+ """
+ diameter = 2 * radius + 1
+ gaussian = gaussian2D(
+ (diameter, diameter), sigma_x=diameter / 6, sigma_y=diameter / 6)
+
+ x, y = int(center[0]), int(center[1])
+
+ height, width = heatmap.shape[0:2]
+
+ left, right = min(x, radius), min(width - x, radius + 1)
+ top, bottom = min(y, radius), min(height - y, radius + 1)
+
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
+ masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
+ radius + right]
+ if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
+ np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
+ return heatmap
diff --git a/docs/MODEL_ZOO_cn.md b/docs/MODEL_ZOO_cn.md
index 49b4a54e6b75f23cfd497c267b12cec02d252d0d..2eb099eec8831e45bf34fa8d1ff6883b2751bbb0 100644
--- a/docs/MODEL_ZOO_cn.md
+++ b/docs/MODEL_ZOO_cn.md
@@ -247,6 +247,10 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
请参考[OC-SORT](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/ocsort)
+### CenterTrack
+
+请参考[CenterTrack](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/centertrack)
+
### FairMOT/MC-FairMOT
请参考[FairMOT](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/fairmot)
diff --git a/docs/MODEL_ZOO_en.md b/docs/MODEL_ZOO_en.md
index 87f7fb16680e4a16a1436a8f574f1f5614e4c82b..ac725bcf9a04831e72ffd3afcc4d66a954323950 100644
--- a/docs/MODEL_ZOO_en.md
+++ b/docs/MODEL_ZOO_en.md
@@ -246,6 +246,10 @@ Please refer to [ByteTrack](https://github.com/PaddlePaddle/PaddleDetection/tree
Please refer to [OC-SORT](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/ocsort)
+### CenterTrack
+
+Please refer to [CenterTrack](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/centertrack)
+
### FairMOT/MC-FairMOT
Please refer to [FairMOT](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mot/fairmot)
diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py
index e87e9bf703df485800d3d442086eaed1114ccc98..330dae6775115bb4401e5adcdc30471b7099f3e8 100644
--- a/ppdet/data/source/coco.py
+++ b/ppdet/data/source/coco.py
@@ -177,8 +177,10 @@ class COCODataSet(DetDataset):
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox
+ gt_track_id = -np.ones((num_bbox, 1), dtype=np.int32)
has_segmentation = False
+ has_track_id = False
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = self.catid2clsid[catid]
@@ -200,6 +202,10 @@ class COCODataSet(DetDataset):
gt_poly[i] = box['segmentation']
has_segmentation = True
+ if 'track_id' in box:
+ gt_track_id[i][0] = box['track_id']
+ has_track_id = True
+
if has_segmentation and not any(
gt_poly) and not self.allow_empty:
continue
@@ -210,6 +216,8 @@ class COCODataSet(DetDataset):
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
+ if has_track_id:
+ gt_rec.update({'gt_track_id': gt_track_id})
for k, v in gt_rec.items():
if k in self.data_fields:
diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py
index d378b4da5e010a4872f8bb0fb26a769a00b9a677..4f22b222aa1a99bf1239db5c379cc4bd1a6632e0 100644
--- a/ppdet/data/source/dataset.py
+++ b/ppdet/data/source/dataset.py
@@ -86,6 +86,12 @@ class DetDataset(Dataset):
copy.deepcopy(self.roidbs[np.random.randint(n)])
for _ in range(4)
]
+ elif self.pre_img_epoch == 0 or self._epoch < self.pre_img_epoch:
+ # Add previous image as input, only used in CenterTrack
+ idx_pre_img = idx - 1
+ if idx_pre_img < 0:
+ idx_pre_img = idx + 1
+ roidb = [roidb, ] + [copy.deepcopy(self.roidbs[idx_pre_img])]
if isinstance(roidb, Sequence):
for r in roidb:
r['curr_iter'] = self._curr_iter
@@ -103,6 +109,7 @@ class DetDataset(Dataset):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
+ self.pre_img_epoch = kwargs.get('pre_img_epoch', -1)
def set_transform(self, transform):
self.transform = transform
diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py
index aaf4d18e24b7190ee387751ba71926263ab77b45..bdb7989f0a5f1c90c39b086113c7a6f4166f53b9 100644
--- a/ppdet/data/transform/batch_operators.py
+++ b/ppdet/data/transform/batch_operators.py
@@ -24,6 +24,7 @@ except Exception:
from collections import Sequence
import cv2
+import copy
import math
import numpy as np
from .operators import register_op, BaseOperator, Resize
@@ -47,6 +48,7 @@ __all__ = [
'PadMaskBatch',
'Gt2GFLTarget',
'Gt2CenterNetTarget',
+ 'Gt2CenterTrackTarget',
'PadGT',
'PadRGT',
]
@@ -169,6 +171,7 @@ class BatchRandomResize(BaseOperator):
@register_op
class Gt2YoloTarget(BaseOperator):
+ __shared__ = ['num_classes']
"""
Generate YOLOv3 targets by groud truth data, this operator is only used in
fine grained YOLOv3 loss mode
@@ -492,6 +495,7 @@ class Gt2FCOSTarget(BaseOperator):
@register_op
class Gt2GFLTarget(BaseOperator):
+ __shared__ = ['num_classes']
"""
Generate GFocal loss targets by groud truth data
"""
@@ -1000,6 +1004,7 @@ class PadMaskBatch(BaseOperator):
@register_op
class Gt2CenterNetTarget(BaseOperator):
+ __shared__ = ['num_classes']
"""Gt2CenterNetTarget
Genterate CenterNet targets by ground-truth
Args:
@@ -1009,40 +1014,39 @@ class Gt2CenterNetTarget(BaseOperator):
max_objs (int): The maximum objects detected, 128 by default.
"""
- def __init__(self, down_ratio, num_classes=80, max_objs=128):
+ def __init__(self, num_classes=80, down_ratio=4, max_objs=128):
super(Gt2CenterNetTarget, self).__init__()
+ self.nc = num_classes
self.down_ratio = down_ratio
- self.num_classes = num_classes
self.max_objs = max_objs
def __call__(self, sample, context=None):
input_h, input_w = sample['image'].shape[1:]
output_h = input_h // self.down_ratio
output_w = input_w // self.down_ratio
- num_classes = self.num_classes
- c = sample['center']
- s = sample['scale']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
- hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
+ hm = np.zeros((self.nc, output_h, output_w), dtype=np.float32)
wh = np.zeros((self.max_objs, 2), dtype=np.float32)
- dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
reg = np.zeros((self.max_objs, 2), dtype=np.float32)
ind = np.zeros((self.max_objs), dtype=np.int64)
reg_mask = np.zeros((self.max_objs), dtype=np.int32)
- cat_spec_wh = np.zeros(
- (self.max_objs, num_classes * 2), dtype=np.float32)
- cat_spec_mask = np.zeros(
- (self.max_objs, num_classes * 2), dtype=np.int32)
+ cat_spec_wh = np.zeros((self.max_objs, self.nc * 2), dtype=np.float32)
+ cat_spec_mask = np.zeros((self.max_objs, self.nc * 2), dtype=np.int32)
- trans_output = get_affine_transform(c, [s, s], 0, [output_w, output_h])
+ trans_output = get_affine_transform(
+ center=sample['center'],
+ input_size=[sample['scale'], sample['scale']],
+ rot=0,
+ output_size=[output_w, output_h])
gt_det = []
for i, (bbox, cls) in enumerate(zip(gt_bbox, gt_class)):
cls = int(cls)
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
+ bbox_amodal = copy.deepcopy(bbox)
bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
@@ -1053,10 +1057,12 @@ class Gt2CenterNetTarget(BaseOperator):
[(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
dtype=np.float32)
ct_int = ct.astype(np.int32)
+
+ # get hm,wh,reg,ind,ind_mask
draw_umich_gaussian(hm[cls], ct_int, radius)
wh[i] = 1. * w, 1. * h
- ind[i] = ct_int[1] * output_w + ct_int[0]
reg[i] = ct - ct_int
+ ind[i] = ct_int[1] * output_w + ct_int[0]
reg_mask[i] = 1
cat_spec_wh[i, cls * 2:cls * 2 + 2] = wh[i]
cat_spec_mask[i, cls * 2:cls * 2 + 2] = 1
@@ -1071,9 +1077,10 @@ class Gt2CenterNetTarget(BaseOperator):
sample.pop('scale', None)
sample.pop('is_crowd', None)
sample.pop('difficult', None)
- sample['heatmap'] = hm
- sample['index_mask'] = reg_mask
+
sample['index'] = ind
+ sample['index_mask'] = reg_mask
+ sample['heatmap'] = hm
sample['size'] = wh
sample['offset'] = reg
return sample
@@ -1184,3 +1191,175 @@ class PadRGT(BaseOperator):
num_gt)
return samples
+
+
+@register_op
+class Gt2CenterTrackTarget(BaseOperator):
+ __shared__ = ['num_classes']
+ """Gt2CenterTrackTarget
+ Genterate CenterTrack targets by ground-truth
+ Args:
+ num_classes (int): The number of classes, 1 by default.
+ down_ratio (int): The down sample ratio between output feature and
+ input image.
+ max_objs (int): The maximum objects detected, 256 by default.
+ """
+
+ def __init__(self,
+ num_classes=1,
+ down_ratio=4,
+ max_objs=256,
+ hm_disturb=0.05,
+ lost_disturb=0.4,
+ fp_disturb=0.1,
+ pre_hm=True,
+ add_tracking=True,
+ add_ltrb_amodal=True):
+ super(Gt2CenterTrackTarget, self).__init__()
+ self.nc = num_classes
+ self.down_ratio = down_ratio
+ self.max_objs = max_objs
+
+ self.hm_disturb = hm_disturb
+ self.lost_disturb = lost_disturb
+ self.fp_disturb = fp_disturb
+ self.pre_hm = pre_hm
+ self.add_tracking = add_tracking
+ self.add_ltrb_amodal = add_ltrb_amodal
+
+ def _get_pre_dets(self, input_h, input_w, trans_input_pre, gt_bbox_pre,
+ gt_class_pre, gt_track_id_pre):
+ hm_h, hm_w = input_h, input_w
+ reutrn_hm = self.pre_hm
+ pre_hm = np.zeros(
+ (1, hm_h, hm_w), dtype=np.float32) if reutrn_hm else None
+ pre_cts, track_ids = [], []
+
+ for i, (
+ bbox, cls, track_id
+ ) in enumerate(zip(gt_bbox_pre, gt_class_pre, gt_track_id_pre)):
+ cls = int(cls)
+ bbox[:2] = affine_transform(bbox[:2], trans_input_pre)
+ bbox[2:] = affine_transform(bbox[2:], trans_input_pre)
+ bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, hm_w - 1)
+ bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, hm_h - 1)
+ h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
+ max_rad = 1
+ if (h > 0 and w > 0):
+ radius = gaussian_radius((math.ceil(h), math.ceil(w)), 0.7)
+ radius = max(0, int(radius))
+ max_rad = max(max_rad, radius)
+ ct = np.array(
+ [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
+ dtype=np.float32)
+ ct0 = ct.copy()
+ conf = 1
+
+ ct[0] = ct[0] + np.random.randn() * self.hm_disturb * w
+ ct[1] = ct[1] + np.random.randn() * self.hm_disturb * h
+ conf = 1 if np.random.rand() > self.lost_disturb else 0
+
+ ct_int = ct.astype(np.int32)
+ if conf == 0:
+ pre_cts.append(ct / self.down_ratio)
+ else:
+ pre_cts.append(ct0 / self.down_ratio)
+
+ track_ids.append(track_id)
+ if reutrn_hm:
+ draw_umich_gaussian(pre_hm[0], ct_int, radius, k=conf)
+
+ if np.random.rand() < self.fp_disturb and reutrn_hm:
+ ct2 = ct0.copy()
+ # Hard code heatmap disturb ratio, haven't tried other numbers.
+ ct2[0] = ct2[0] + np.random.randn() * 0.05 * w
+ ct2[1] = ct2[1] + np.random.randn() * 0.05 * h
+ ct2_int = ct2.astype(np.int32)
+ draw_umich_gaussian(pre_hm[0], ct2_int, radius, k=conf)
+ return pre_hm, pre_cts, track_ids
+
+ def __call__(self, sample, context=None):
+ input_h, input_w = sample['image'].shape[1:]
+ output_h = input_h // self.down_ratio
+ output_w = input_w // self.down_ratio
+ gt_bbox = sample['gt_bbox']
+ gt_class = sample['gt_class']
+
+ # init
+ hm = np.zeros((self.nc, output_h, output_w), dtype=np.float32)
+ wh = np.zeros((self.max_objs, 2), dtype=np.float32)
+ reg = np.zeros((self.max_objs, 2), dtype=np.float32)
+ ind = np.zeros((self.max_objs), dtype=np.int64)
+ reg_mask = np.zeros((self.max_objs), dtype=np.int32)
+ if self.add_tracking:
+ tr = np.zeros((self.max_objs, 2), dtype=np.float32)
+ if self.add_ltrb_amodal:
+ ltrb_amodal = np.zeros((self.max_objs, 4), dtype=np.float32)
+
+ trans_output = get_affine_transform(
+ center=sample['center'],
+ input_size=[sample['scale'], sample['scale']],
+ rot=0,
+ output_size=[output_w, output_h])
+
+ pre_hm, pre_cts, track_ids = self._get_pre_dets(
+ input_h, input_w, sample['trans_input'], sample['pre_gt_bbox'],
+ sample['pre_gt_class'], sample['pre_gt_track_id'])
+
+ for i, (bbox, cls) in enumerate(zip(gt_bbox, gt_class)):
+ cls = int(cls)
+ rect = np.array(
+ [[bbox[0], bbox[1]], [bbox[0], bbox[3]], [bbox[2], bbox[3]],
+ [bbox[2], bbox[1]]],
+ dtype=np.float32)
+ for t in range(4):
+ rect[t] = affine_transform(rect[t], trans_output)
+ bbox[:2] = rect[:, 0].min(), rect[:, 1].min()
+ bbox[2:] = rect[:, 0].max(), rect[:, 1].max()
+
+ bbox_amodal = copy.deepcopy(bbox)
+ bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
+ bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
+
+ h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
+ if h > 0 and w > 0:
+ radius = gaussian_radius((math.ceil(h), math.ceil(w)), 0.7)
+ radius = max(0, int(radius))
+ ct = np.array(
+ [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
+ dtype=np.float32)
+ ct_int = ct.astype(np.int32)
+
+ # get hm,wh,reg,ind,ind_mask
+ draw_umich_gaussian(hm[cls], ct_int, radius)
+ wh[i] = 1. * w, 1. * h
+ reg[i] = ct - ct_int
+ ind[i] = ct_int[1] * output_w + ct_int[0]
+ reg_mask[i] = 1
+ if self.add_tracking:
+ if sample['gt_track_id'][i] in track_ids:
+ pre_ct = pre_cts[track_ids.index(sample['gt_track_id'][
+ i])]
+ tr[i] = pre_ct - ct_int
+
+ if self.add_ltrb_amodal:
+ ltrb_amodal[i] = \
+ bbox_amodal[0] - ct_int[0], bbox_amodal[1] - ct_int[1], \
+ bbox_amodal[2] - ct_int[0], bbox_amodal[3] - ct_int[1]
+
+ new_sample = {'image': sample['image']}
+ new_sample['index'] = ind
+ new_sample['index_mask'] = reg_mask
+ new_sample['heatmap'] = hm
+ new_sample['size'] = wh
+ new_sample['offset'] = reg
+ if self.add_tracking:
+ new_sample['tracking'] = tr
+ if self.add_ltrb_amodal:
+ new_sample['ltrb_amodal'] = ltrb_amodal
+
+ new_sample['pre_image'] = sample['pre_image']
+ new_sample['pre_hm'] = pre_hm
+
+ del sample
+ return new_sample
diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py
index 8c6a65e7bd02dd169e392447529ee06c488852e4..a9dfb1647e6cede35544038c9a05e3642edce301 100644
--- a/ppdet/data/transform/operators.py
+++ b/ppdet/data/transform/operators.py
@@ -283,6 +283,11 @@ class Permute(BaseOperator):
im = sample['image']
im = im.transpose((2, 0, 1))
sample['image'] = im
+
+ if 'pre_image' in sample:
+ pre_im = sample['pre_image']
+ pre_im = pre_im.transpose((2, 0, 1))
+ sample['pre_image'] = pre_im
return sample
@@ -305,6 +310,9 @@ class Lighting(BaseOperator):
def apply(self, sample, context=None):
alpha = np.random.normal(scale=self.alphastd, size=(3, ))
sample['image'] += np.dot(self.eigvec, self.eigval * alpha)
+
+ if 'pre_image' in sample:
+ sample['pre_image'] += np.dot(self.eigvec, self.eigval * alpha)
return sample
@@ -403,6 +411,20 @@ class NormalizeImage(BaseOperator):
im -= mean
im /= std
sample['image'] = im
+
+ if 'pre_image' in sample:
+ pre_im = sample['pre_image']
+ pre_im = pre_im.astype(np.float32, copy=False)
+ if self.is_scale:
+ scale = 1.0 / 255.0
+ pre_im *= scale
+
+ if self.norm_type == 'mean_std':
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
+ pre_im -= mean
+ pre_im /= std
+ sample['pre_image'] = pre_im
return sample
@@ -2826,13 +2848,11 @@ class WarpAffine(BaseOperator):
input_h=512,
input_w=512,
scale=0.4,
- shift=0.1):
+ shift=0.1,
+ down_ratio=4):
"""WarpAffine
Warp affine the image
-
The code is based on https://github.com/xingyizhou/CenterNet/blob/master/src/lib/datasets/sample/ctdet.py
-
-
"""
super(WarpAffine, self).__init__()
self.keep_res = keep_res
@@ -2841,22 +2861,22 @@ class WarpAffine(BaseOperator):
self.input_w = input_w
self.scale = scale
self.shift = shift
+ self.down_ratio = down_ratio
def apply(self, sample, context=None):
img = sample['image']
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
- return sample
h, w = img.shape[:2]
if self.keep_res:
+ # True in detection eval/infer
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
-
else:
+ # False in centertrack eval_mot/eval_mot
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
@@ -2866,6 +2886,22 @@ class WarpAffine(BaseOperator):
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
sample['image'] = inp
+
+ if not self.keep_res:
+ out_h = input_h // self.down_ratio
+ out_w = input_w // self.down_ratio
+ trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
+
+ sample.update({
+ 'center': c,
+ 'scale': s,
+ 'out_height': out_h,
+ 'out_width': out_w,
+ 'inp_height': input_h,
+ 'inp_width': input_w,
+ 'trans_input': trans_input,
+ 'trans_output': trans_output,
+ })
return sample
@@ -2881,11 +2917,13 @@ class FlipWarpAffine(BaseOperator):
shift=0.1,
flip=0.5,
is_scale=True,
- use_random=True):
+ use_random=True,
+ add_pre_img=False):
"""FlipWarpAffine
1. Random Crop
2. Flip the image horizontal
- 3. Warp affine the image
+ 3. Warp affine the image
+ 4. (Optinal) Add previous image
"""
super(FlipWarpAffine, self).__init__()
self.keep_res = keep_res
@@ -2898,22 +2936,30 @@ class FlipWarpAffine(BaseOperator):
self.flip = flip
self.is_scale = is_scale
self.use_random = use_random
+ self.add_pre_img = add_pre_img
+
+ def __call__(self, samples, context=None):
+ if self.add_pre_img:
+ assert isinstance(samples, Sequence) and len(samples) == 2
+ sample, pre_sample = samples[0], samples[1]
+ else:
+ sample = samples
- def apply(self, sample, context=None):
img = sample['image']
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
return sample
h, w = img.shape[:2]
+ flipped = 0
if self.keep_res:
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
-
else:
+ # centernet training default
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
@@ -2921,6 +2967,7 @@ class FlipWarpAffine(BaseOperator):
if self.use_random:
gt_bbox = sample['gt_bbox']
if not self.not_rand_crop:
+ # centernet default
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
w_border = get_border(128, w)
h_border = get_border(128, h)
@@ -2940,18 +2987,50 @@ class FlipWarpAffine(BaseOperator):
oldx2 = gt_bbox[:, 2].copy()
gt_bbox[:, 0] = w - oldx2 - 1
gt_bbox[:, 2] = w - oldx1 - 1
+ flipped = 1
sample['gt_bbox'] = gt_bbox
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
- if not self.use_random:
- img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
if self.is_scale:
inp = (inp.astype(np.float32) / 255.)
+
sample['image'] = inp
sample['center'] = c
sample['scale'] = s
+
+ if self.add_pre_img:
+ sample['trans_input'] = trans_input
+
+ # previous image, use same aug trans_input as current image
+ pre_img = pre_sample['image']
+ pre_img = cv2.cvtColor(pre_img, cv2.COLOR_RGB2BGR)
+ if flipped:
+ pre_img = pre_img[:, ::-1, :].copy()
+ pre_inp = cv2.warpAffine(
+ pre_img,
+ trans_input, (input_w, input_h),
+ flags=cv2.INTER_LINEAR)
+ if self.is_scale:
+ pre_inp = (pre_inp.astype(np.float32) / 255.)
+ sample['pre_image'] = pre_inp
+
+ # if empty gt_bbox
+ if 'gt_bbox' in pre_sample and len(pre_sample['gt_bbox']) == 0:
+ return sample
+ pre_gt_bbox = pre_sample['gt_bbox']
+ if flipped:
+ pre_oldx1 = pre_gt_bbox[:, 0].copy()
+ pre_oldx2 = pre_gt_bbox[:, 2].copy()
+ pre_gt_bbox[:, 0] = w - pre_oldx1 - 1
+ pre_gt_bbox[:, 2] = w - pre_oldx2 - 1
+ sample['pre_gt_bbox'] = pre_gt_bbox
+
+ sample['pre_gt_class'] = pre_sample['gt_class']
+ sample['pre_gt_track_id'] = pre_sample['gt_track_id']
+ del pre_sample
+
return sample
@@ -2993,18 +3072,28 @@ class CenterRandColor(BaseOperator):
img_mean *= (1 - alpha)
img += img_mean
- def __call__(self, sample, context=None):
- img = sample['image']
- img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ def apply(self, sample, context=None):
functions = [
self.apply_brightness,
self.apply_contrast,
self.apply_saturation,
]
+
+ img = sample['image']
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
distortions = np.random.permutation(functions)
for func in distortions:
img = func(img, img_gray)
sample['image'] = img
+
+ if 'pre_image' in sample:
+ pre_img = sample['pre_image']
+ pre_img_gray = cv2.cvtColor(pre_img, cv2.COLOR_BGR2GRAY)
+ pre_distortions = np.random.permutation(functions)
+ for func in pre_distortions:
+ pre_img = func(pre_img, pre_img_gray)
+ sample['pre_image'] = pre_img
+
return sample
diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py
index 44ef8125fa6305f4eafe3822437a49e644425979..136bea3411dee26b8acbe3d0de5d549de5784655 100644
--- a/ppdet/engine/export_utils.py
+++ b/ppdet/engine/export_utils.py
@@ -42,6 +42,7 @@ TRT_MIN_SUBGRAPH = {
'HRNet': 3,
'DeepSORT': 3,
'ByteTrack': 10,
+ 'CenterTrack': 5,
'JDE': 10,
'FairMOT': 5,
'GFL': 16,
@@ -55,7 +56,7 @@ TRT_MIN_SUBGRAPH = {
}
KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
-MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
+MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
TO_STATIC_SPEC = {
'yolov3_darknet53_270e_coco': [{
@@ -179,6 +180,8 @@ def _dump_infer_config(config, path, image_shape, model):
if infer_arch in MOT_ARCH:
if infer_arch == 'DeepSORT':
tracker_cfg = config['DeepSORTTracker']
+ elif infer_arch == 'CenterTrack':
+ tracker_cfg = config['CenterTracker']
else:
tracker_cfg = config['JDETracker']
infer_cfg['tracker'] = _parse_tracker(tracker_cfg)
@@ -209,9 +212,15 @@ def _dump_infer_config(config, path, image_shape, model):
label_arch = 'keypoint_arch'
if infer_arch in MOT_ARCH:
- label_arch = 'mot_arch'
- reader_cfg = config['TestMOTReader']
- dataset_cfg = config['TestMOTDataset']
+ if config['metric'] in ['COCO', 'VOC']:
+ # MOT model run as Detector
+ reader_cfg = config['TestReader']
+ dataset_cfg = config['TestDataset']
+ else:
+ # 'metric' in ['MOT', 'MCMOT', 'KITTI']
+ label_arch = 'mot_arch'
+ reader_cfg = config['TestMOTReader']
+ dataset_cfg = config['TestMOTDataset']
else:
reader_cfg = config['TestReader']
dataset_cfg = config['TestDataset']
diff --git a/ppdet/engine/tracker.py b/ppdet/engine/tracker.py
index 52195356de4737efe32e744c951f77317d029aa0..f8f45cdd7fb43e9326a78bcc6ff655b032993666 100644
--- a/ppdet/engine/tracker.py
+++ b/ppdet/engine/tracker.py
@@ -29,7 +29,8 @@ from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results
-from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker, OCSORTTracker
+from ppdet.modeling.mot.tracker import JDETracker, CenterTracker
+from ppdet.modeling.mot.tracker import DeepSORTTracker, OCSORTTracker
from ppdet.modeling.architectures import YOLOX
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric, MCMOTMetric
from ppdet.data.source.category import get_categories
@@ -40,9 +41,9 @@ from .callbacks import Callback, ComposeCallback
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
-MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
-MOT_ARCH_JDE = ['JDE', 'FairMOT']
-MOT_ARCH_SDE = ['DeepSORT', 'ByteTrack']
+MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
+MOT_ARCH_JDE = MOT_ARCH[:2]
+MOT_ARCH_SDE = MOT_ARCH[2:4]
MOT_DATA_TYPE = ['mot', 'mcmot', 'kitti']
__all__ = ['Tracker']
@@ -138,6 +139,53 @@ class Tracker(object):
else:
load_weight(self.model.reid, reid_weights)
+ def _eval_seq_centertrack(self,
+ dataloader,
+ save_dir=None,
+ show_image=False,
+ frame_rate=30,
+ draw_threshold=0):
+ assert isinstance(self.model.tracker, CenterTracker)
+ if save_dir:
+ if not os.path.exists(save_dir): os.makedirs(save_dir)
+ tracker = self.model.tracker
+
+ timer = MOTTimer()
+ frame_id = 0
+ self.status['mode'] = 'track'
+ self.model.eval()
+ results = defaultdict(list) # only support single class now
+
+ for step_id, data in enumerate(tqdm(dataloader)):
+ self.status['step_id'] = step_id
+ if step_id == 0:
+ self.model.reset_tracking()
+
+ # forward
+ timer.tic()
+ pred_ret = self.model(data)
+
+ online_targets = tracker.update(pred_ret)
+ online_tlwhs, online_scores, online_ids = [], [], []
+ for t in online_targets:
+ bbox = t['bbox']
+ tlwh = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]]
+ tscore = float(t['score'])
+ tid = int(t['tracking_id'])
+ if tlwh[2] * tlwh[3] > 0:
+ online_tlwhs.append(tlwh)
+ online_ids.append(tid)
+ online_scores.append(tscore)
+ timer.toc()
+ # save results
+ results[0].append(
+ (frame_id + 1, online_tlwhs, online_scores, online_ids))
+ save_vis_results(data, frame_id, online_ids, online_tlwhs,
+ online_scores, timer.average_time, show_image,
+ save_dir, self.cfg.num_classes, self.ids2names)
+ frame_id += 1
+ return results, frame_id, timer.average_time, timer.calls
+
def _eval_seq_jde(self,
dataloader,
save_dir=None,
@@ -205,7 +253,11 @@ class Tracker(object):
if save_dir:
if not os.path.exists(save_dir): os.makedirs(save_dir)
use_detector = False if not self.model.detector else True
- use_reid = False if not self.model.reid else True
+ use_reid = hasattr(self.model, 'reid')
+ if use_reid and self.model.reid is not None:
+ use_reid = True
+ else:
+ use_reid = False
timer = MOTTimer()
results = defaultdict(list)
@@ -378,6 +430,7 @@ class Tracker(object):
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes, self.ids2names)
+
elif isinstance(tracker, OCSORTTracker):
# OC_SORT Tracker
online_targets = tracker.update(pred_dets_old, pred_embs)
@@ -399,6 +452,7 @@ class Tracker(object):
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes, self.ids2names)
+
else:
raise ValueError(tracker)
frame_id += 1
@@ -469,6 +523,12 @@ class Tracker(object):
scaled=scaled,
det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq)))
+ elif model_type == 'CenterTrack':
+ results, nf, ta, tc = self._eval_seq_centertrack(
+ dataloader,
+ save_dir=save_dir,
+ show_image=show_image,
+ frame_rate=frame_rate)
else:
raise ValueError(model_type)
@@ -595,6 +655,12 @@ class Tracker(object):
det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq)),
draw_threshold=draw_threshold)
+ elif model_type == 'CenterTrack':
+ results, nf, ta, tc = self._eval_seq_centertrack(
+ dataloader,
+ save_dir=save_dir,
+ show_image=show_image,
+ frame_rate=frame_rate)
else:
raise ValueError(model_type)
diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py
index 73012668658181dd54af3808271ff46a20fb444f..f72805f083a5c65bd99c68e0cea5dd1de2bc13a1 100644
--- a/ppdet/engine/trainer.py
+++ b/ppdet/engine/trainer.py
@@ -57,7 +57,7 @@ logger = setup_logger('ppdet.engine')
__all__ = ['Trainer']
-MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
+MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
class Trainer(object):
@@ -75,7 +75,9 @@ class Trainer(object):
# build data loader
capital_mode = self.mode.capitalize()
- if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
+ if cfg.architecture in MOT_ARCH and self.mode in [
+ 'eval', 'test'
+ ] and cfg.metric not in ['COCO', 'VOC']:
self.dataset = self.cfg['{}MOTDataset'.format(
capital_mode)] = create('{}MOTDataset'.format(capital_mode))()
else:
diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py
index b09e61c6c8221be445195877f1c477eeb545f083..83ab3e0ad842ab473325f9854934f219158a0450 100644
--- a/ppdet/modeling/architectures/__init__.py
+++ b/ppdet/modeling/architectures/__init__.py
@@ -38,6 +38,7 @@ from . import bytetrack
from . import yolox
from . import yolof
from . import pose3d_metro
+from . import centertrack
from .meta_arch import *
from .faster_rcnn import *
@@ -66,3 +67,4 @@ from .bytetrack import *
from .yolox import *
from .yolof import *
from .pose3d_metro import *
+from .centertrack import *
diff --git a/ppdet/modeling/architectures/centernet.py b/ppdet/modeling/architectures/centernet.py
index 2287d743baef289ec75491f44edc720c7a3ae37c..439e5f872668e5bcd5445adf5a6e41b320679c59 100755
--- a/ppdet/modeling/architectures/centernet.py
+++ b/ppdet/modeling/architectures/centernet.py
@@ -78,30 +78,25 @@ class CenterNet(BaseArch):
def get_pred(self):
head_out = self._forward()
+ bbox, bbox_num, bbox_inds, topk_clses, topk_ys, topk_xs = self.post_process(
+ head_out['heatmap'],
+ head_out['size'],
+ head_out['offset'],
+ im_shape=self.inputs['im_shape'],
+ scale_factor=self.inputs['scale_factor'])
+
if self.for_mot:
- bbox, bbox_inds, topk_clses = self.post_process(
- head_out['heatmap'],
- head_out['size'],
- head_out['offset'],
- im_shape=self.inputs['im_shape'],
- scale_factor=self.inputs['scale_factor'])
output = {
"bbox": bbox,
+ "bbox_num": bbox_num,
"bbox_inds": bbox_inds,
"topk_clses": topk_clses,
+ "topk_ys": topk_ys,
+ "topk_xs": topk_xs,
"neck_feat": head_out['neck_feat']
}
else:
- bbox, bbox_num, _ = self.post_process(
- head_out['heatmap'],
- head_out['size'],
- head_out['offset'],
- im_shape=self.inputs['im_shape'],
- scale_factor=self.inputs['scale_factor'])
- output = {
- "bbox": bbox,
- "bbox_num": bbox_num,
- }
+ output = {"bbox": bbox, "bbox_num": bbox_num}
return output
def get_loss(self):
diff --git a/ppdet/modeling/architectures/centertrack.py b/ppdet/modeling/architectures/centertrack.py
new file mode 100755
index 0000000000000000000000000000000000000000..b9880dbbb21435f1fc84c4c3203e5d818143e776
--- /dev/null
+++ b/ppdet/modeling/architectures/centertrack.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import math
+import numpy as np
+import paddle
+from ppdet.core.workspace import register, create
+from .meta_arch import BaseArch
+
+from ..keypoint_utils import affine_transform
+from ppdet.data.transform.op_helper import gaussian_radius, gaussian2D, draw_umich_gaussian
+
+__all__ = ['CenterTrack']
+
+
+@register
+class CenterTrack(BaseArch):
+ """
+ CenterTrack network, see http://arxiv.org/abs/2004.01177
+
+ Args:
+ detector (object): 'CenterNet' instance
+ plugin_head (object): 'CenterTrackHead' instance
+ tracker (object): 'CenterTracker' instance
+ """
+ __category__ = 'architecture'
+ __shared__ = ['mot_metric']
+
+ def __init__(self,
+ detector='CenterNet',
+ plugin_head='CenterTrackHead',
+ tracker='CenterTracker',
+ mot_metric=False):
+ super(CenterTrack, self).__init__()
+ self.detector = detector
+ self.plugin_head = plugin_head
+ self.tracker = tracker
+ self.mot_metric = mot_metric
+ self.pre_image = None
+ self.deploy = False
+
+ @classmethod
+ def from_config(cls, cfg, *args, **kwargs):
+ detector = create(cfg['detector'])
+ detector_out_shape = detector.neck and detector.neck.out_shape or detector.backbone.out_shape
+
+ kwargs = {'input_shape': detector_out_shape}
+ plugin_head = create(cfg['plugin_head'], **kwargs)
+ tracker = create(cfg['tracker'])
+
+ return {
+ 'detector': detector,
+ 'plugin_head': plugin_head,
+ 'tracker': tracker,
+ }
+
+ def _forward(self):
+ if self.training:
+ det_outs = self.detector(self.inputs)
+ neck_feat = det_outs['neck_feat']
+
+ losses = {}
+ for k, v in det_outs.items():
+ if 'loss' not in k: continue
+ losses.update({k: v})
+
+ plugin_outs = self.plugin_head(neck_feat, self.inputs)
+ for k, v in plugin_outs.items():
+ if 'loss' not in k: continue
+ losses.update({k: v})
+
+ losses['loss'] = det_outs['det_loss'] + plugin_outs['plugin_loss']
+ return losses
+
+ else:
+ if not self.mot_metric:
+ # detection, support bs>=1
+ det_outs = self.detector(self.inputs)
+ return {
+ 'bbox': det_outs['bbox'],
+ 'bbox_num': det_outs['bbox_num']
+ }
+
+ else:
+ # MOT, only support bs=1
+ if not self.deploy:
+ if self.pre_image is None:
+ self.pre_image = self.inputs['image']
+ # initializing tracker for the first frame
+ self.tracker.init_track([])
+ self.inputs['pre_image'] = self.pre_image
+ self.pre_image = self.inputs[
+ 'image'] # Note: update for next image
+
+ # render input heatmap from tracker status
+ pre_hm = self.get_additional_inputs(
+ self.tracker.tracks, self.inputs, with_hm=True)
+ self.inputs['pre_hm'] = paddle.to_tensor(pre_hm)
+
+ # model inference
+ det_outs = self.detector(self.inputs)
+ neck_feat = det_outs['neck_feat']
+ result = self.plugin_head(
+ neck_feat, self.inputs, det_outs['bbox'],
+ det_outs['bbox_inds'], det_outs['topk_clses'],
+ det_outs['topk_ys'], det_outs['topk_xs'])
+
+ if not self.deploy:
+ # convert the cropped and 4x downsampled output coordinate system
+ # back to the input image coordinate system
+ result = self.plugin_head.centertrack_post_process(
+ result, self.inputs, self.tracker.out_thresh)
+ return result
+
+ def get_pred(self):
+ return self._forward()
+
+ def get_loss(self):
+ return self._forward()
+
+ def reset_tracking(self):
+ self.tracker.reset()
+ self.pre_image = None
+
+ def get_additional_inputs(self, dets, meta, with_hm=True):
+ # Render input heatmap from previous trackings.
+ trans_input = meta['trans_input'][0].numpy()
+ inp_width, inp_height = int(meta['inp_width'][0]), int(meta[
+ 'inp_height'][0])
+ input_hm = np.zeros((1, inp_height, inp_width), dtype=np.float32)
+
+ for det in dets:
+ if det['score'] < self.tracker.pre_thresh:
+ continue
+ bbox = affine_transform_bbox(det['bbox'], trans_input, inp_width,
+ inp_height)
+ h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
+ if (h > 0 and w > 0):
+ radius = gaussian_radius(
+ (math.ceil(h), math.ceil(w)), min_overlap=0.7)
+ radius = max(0, int(radius))
+ ct = np.array(
+ [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
+ dtype=np.float32)
+ ct_int = ct.astype(np.int32)
+ if with_hm:
+ input_hm[0] = draw_umich_gaussian(input_hm[0], ct_int,
+ radius)
+ if with_hm:
+ input_hm = input_hm[np.newaxis]
+ return input_hm
+
+
+def affine_transform_bbox(bbox, trans, width, height):
+ bbox = np.array(copy.deepcopy(bbox), dtype=np.float32)
+ bbox[:2] = affine_transform(bbox[:2], trans)
+ bbox[2:] = affine_transform(bbox[2:], trans)
+ bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, width - 1)
+ bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, height - 1)
+ return bbox
diff --git a/ppdet/modeling/backbones/dla.py b/ppdet/modeling/backbones/dla.py
index 4ab06ab7f763a55232b7cc182e2e9df89e99bb88..51d1f0782f760839b5320e272a72ca765f47fd79 100755
--- a/ppdet/modeling/backbones/dla.py
+++ b/ppdet/modeling/backbones/dla.py
@@ -19,7 +19,7 @@ from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ..shape_spec import ShapeSpec
-DLA_cfg = {34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512])}
+DLA_cfg = {34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512]), }
class BasicBlock(nn.Layer):
@@ -157,17 +157,25 @@ class DLA(nn.Layer):
DLA, see https://arxiv.org/pdf/1707.06484.pdf
Args:
- depth (int): DLA depth, should be 34.
+ depth (int): DLA depth, only support 34 now.
residual_root (bool): whether use a reidual layer in the root block
-
+ pre_img (bool): add pre_img, only used in CenterTrack
+ pre_hm (bool): add pre_hm, only used in CenterTrack
"""
- def __init__(self, depth=34, residual_root=False):
+ def __init__(self,
+ depth=34,
+ residual_root=False,
+ pre_img=False,
+ pre_hm=False):
super(DLA, self).__init__()
- levels, channels = DLA_cfg[depth]
+ assert depth == 34, 'Only support DLA with depth of 34 now.'
if depth == 34:
block = BasicBlock
+ levels, channels = DLA_cfg[depth]
self.channels = channels
+ self.num_levels = len(levels)
+
self.base_layer = nn.Sequential(
ConvNormLayer(
3,
@@ -213,6 +221,29 @@ class DLA(nn.Layer):
level_root=True,
root_residual=residual_root)
+ if pre_img:
+ self.pre_img_layer = nn.Sequential(
+ ConvNormLayer(
+ 3,
+ channels[0],
+ filter_size=7,
+ stride=1,
+ bias_on=False,
+ norm_decay=None),
+ nn.ReLU())
+ if pre_hm:
+ self.pre_hm_layer = nn.Sequential(
+ ConvNormLayer(
+ 1,
+ channels[0],
+ filter_size=7,
+ stride=1,
+ bias_on=False,
+ norm_decay=None),
+ nn.ReLU())
+ self.pre_img = pre_img
+ self.pre_hm = pre_hm
+
def _make_conv_level(self, ch_in, ch_out, conv_num, stride=1):
modules = []
for i in range(conv_num):
@@ -230,13 +261,22 @@ class DLA(nn.Layer):
@property
def out_shape(self):
- return [ShapeSpec(channels=self.channels[i]) for i in range(6)]
+ return [
+ ShapeSpec(channels=self.channels[i]) for i in range(self.num_levels)
+ ]
def forward(self, inputs):
outs = []
- im = inputs['image']
- feats = self.base_layer(im)
- for i in range(6):
+ feats = self.base_layer(inputs['image'])
+
+ if self.pre_img and 'pre_image' in inputs and inputs[
+ 'pre_image'] is not None:
+ feats = feats + self.pre_img_layer(inputs['pre_image'])
+
+ if self.pre_hm and 'pre_hm' in inputs and inputs['pre_hm'] is not None:
+ feats = feats + self.pre_hm_layer(inputs['pre_hm'])
+
+ for i in range(self.num_levels):
feats = getattr(self, 'level{}'.format(i))(feats)
outs.append(feats)
diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py
index 8856d7c182f4efcd41f9366fb807808160ad7e36..bef87d2c17f9aecc70114100cdc3670092ca4a51 100644
--- a/ppdet/modeling/heads/__init__.py
+++ b/ppdet/modeling/heads/__init__.py
@@ -38,6 +38,7 @@ from . import ppyoloe_r_head
from . import ld_gfl_head
from . import yolof_head
from . import ppyoloe_contrast_head
+from . import centertrack_head
from .bbox_head import *
from .mask_head import *
@@ -64,4 +65,5 @@ from .fcosr_head import *
from .ld_gfl_head import *
from .ppyoloe_r_head import *
from .yolof_head import *
-from .ppyoloe_contrast_head import *
\ No newline at end of file
+from .ppyoloe_contrast_head import *
+from .centertrack_head import *
diff --git a/ppdet/modeling/heads/centernet_head.py b/ppdet/modeling/heads/centernet_head.py
index ce8b5c15ddd92c4da0aa217c98e7388cf9b6a3b5..76577749a8c45cf752cba6572ab81490ad4d1e7a 100755
--- a/ppdet/modeling/heads/centernet_head.py
+++ b/ppdet/modeling/heads/centernet_head.py
@@ -61,13 +61,12 @@ class CenterNetHead(nn.Layer):
in_channels (int): the channel number of input to CenterNetHead.
num_classes (int): the number of classes, 80 (COCO dataset) by default.
head_planes (int): the channel number in all head, 256 by default.
- heatmap_weight (float): the weight of heatmap loss, 1 by default.
+ prior_bias (float): prior bias in heatmap head, -2.19 by default, -4.6 in CenterTrack
regress_ltrb (bool): whether to regress left/top/right/bottom or
- width/height for a box, true by default
- size_weight (float): the weight of box size loss, 0.1 by default.
- size_loss (): the type of size regression loss, 'L1 loss' by default.
- offset_weight (float): the weight of center offset loss, 1 by default.
- iou_weight (float): the weight of iou head loss, 0 by default.
+ width/height for a box, True by default.
+ size_loss (str): the type of size regression loss, 'L1' by default, can be 'giou'.
+ loss_weight (dict): the weight of each loss.
+ add_iou (bool): whether to add iou branch, False by default.
"""
__shared__ = ['num_classes']
@@ -76,20 +75,20 @@ class CenterNetHead(nn.Layer):
in_channels,
num_classes=80,
head_planes=256,
- heatmap_weight=1,
+ prior_bias=-2.19,
regress_ltrb=True,
- size_weight=0.1,
size_loss='L1',
- offset_weight=1,
- iou_weight=0):
+ loss_weight={
+ 'heatmap': 1.0,
+ 'size': 0.1,
+ 'offset': 1.0,
+ 'iou': 0.0,
+ },
+ add_iou=False):
super(CenterNetHead, self).__init__()
self.regress_ltrb = regress_ltrb
- self.weights = {
- 'heatmap': heatmap_weight,
- 'size': size_weight,
- 'offset': offset_weight,
- 'iou': iou_weight
- }
+ self.loss_weight = loss_weight
+ self.add_iou = add_iou
# heatmap head
self.heatmap = nn.Sequential(
@@ -104,7 +103,7 @@ class CenterNetHead(nn.Layer):
padding=0,
bias=True))
with paddle.no_grad():
- self.heatmap[2].conv.bias[:] = -2.19
+ self.heatmap[2].conv.bias[:] = prior_bias
# size(ltrb or wh) head
self.size = nn.Sequential(
@@ -129,7 +128,7 @@ class CenterNetHead(nn.Layer):
head_planes, 2, kernel_size=1, stride=1, padding=0, bias=True))
# iou head (optinal)
- if iou_weight > 0:
+ if self.add_iou and 'iou' in self.loss_weight:
self.iou = nn.Sequential(
ConvLayer(
in_channels,
@@ -153,34 +152,34 @@ class CenterNetHead(nn.Layer):
return {'in_channels': input_shape.channels}
def forward(self, feat, inputs):
- heatmap = self.heatmap(feat)
+ heatmap = F.sigmoid(self.heatmap(feat))
size = self.size(feat)
offset = self.offset(feat)
- iou = self.iou(feat) if hasattr(self, 'iou_weight') else None
+ head_outs = {'heatmap': heatmap, 'size': size, 'offset': offset}
+ if self.add_iou and 'iou' in self.loss_weight:
+ iou = self.iou(feat)
+ head_outs.update({'iou': iou})
if self.training:
- loss = self.get_loss(
- inputs, self.weights, heatmap, size, offset, iou=iou)
- return loss
+ losses = self.get_loss(inputs, self.loss_weight, head_outs)
+ return losses
else:
- heatmap = F.sigmoid(heatmap)
- head_outs = {'heatmap': heatmap, 'size': size, 'offset': offset}
- if iou is not None:
- head_outs.update({'iou': iou})
return head_outs
- def get_loss(self, inputs, weights, heatmap, size, offset, iou=None):
- # heatmap head loss: CTFocalLoss
+ def get_loss(self, inputs, weights, head_outs):
+ # 1.heatmap(hm) head loss: CTFocalLoss
+ heatmap = head_outs['heatmap']
heatmap_target = inputs['heatmap']
- heatmap = paddle.clip(F.sigmoid(heatmap), 1e-4, 1 - 1e-4)
+ heatmap = paddle.clip(heatmap, 1e-4, 1 - 1e-4)
ctfocal_loss = CTFocalLoss()
heatmap_loss = ctfocal_loss(heatmap, heatmap_target)
- # size head loss: L1 loss or GIoU loss
+ # 2.size(wh) head loss: L1 loss or GIoU loss
+ size = head_outs['size']
index = inputs['index']
mask = inputs['index_mask']
size = paddle.transpose(size, perm=[0, 2, 3, 1])
- size_n, size_h, size_w, size_c = size.shape
+ size_n, _, _, size_c = size.shape
size = paddle.reshape(size, shape=[size_n, -1, size_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
@@ -208,7 +207,8 @@ class CenterNetHead(nn.Layer):
else:
# inputs['size'] is ltrb, but regress as wh
# shape: [bs, max_per_img, 4]
- size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :, 2:]
+ size_target = inputs['size'][:, :, 0:2] + inputs[
+ 'size'][:, :, 2:]
size_target.stop_gradient = True
size_loss = F.l1_loss(
@@ -232,10 +232,11 @@ class CenterNetHead(nn.Layer):
loc_reweight=None)
size_loss = size_loss / (pos_num + 1e-4)
- # offset head loss: L1 loss
+ # 3.offset(reg) head loss: L1 loss
+ offset = head_outs['offset']
offset_target = inputs['offset']
offset = paddle.transpose(offset, perm=[0, 2, 3, 1])
- offset_n, offset_h, offset_w, offset_c = offset.shape
+ offset_n, _, _, offset_c = offset.shape
offset = paddle.reshape(offset, shape=[offset_n, -1, offset_c])
pos_offset = paddle.gather_nd(offset, index=index)
offset_mask = paddle.expand_as(mask, pos_offset)
@@ -249,10 +250,11 @@ class CenterNetHead(nn.Layer):
reduction='sum')
offset_loss = offset_loss / (pos_num + 1e-4)
- # iou head loss: GIoU loss
- if iou is not None:
+ # 4.iou head loss: GIoU loss (optinal)
+ if self.add_iou and 'iou' in self.loss_weight:
+ iou = head_outs['iou']
iou = paddle.transpose(iou, perm=[0, 2, 3, 1])
- iou_n, iou_h, iou_w, iou_c = iou.shape
+ iou_n, _, _, iou_c = iou.shape
iou = paddle.reshape(iou, shape=[iou_n, -1, iou_c])
pos_iou = paddle.gather_nd(iou, index=index)
iou_mask = paddle.expand_as(mask, pos_iou)
@@ -284,8 +286,8 @@ class CenterNetHead(nn.Layer):
det_loss = weights['heatmap'] * heatmap_loss + weights[
'size'] * size_loss + weights['offset'] * offset_loss
- if iou is not None:
+ if self.add_iou and 'iou' in self.loss_weight:
losses.update({'iou_loss': iou_loss})
- det_loss = det_loss + weights['iou'] * iou_loss
+ det_loss += weights['iou'] * iou_loss
losses.update({'det_loss': det_loss})
return losses
diff --git a/ppdet/modeling/heads/centertrack_head.py b/ppdet/modeling/heads/centertrack_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc353362ad85bc6f61619e9210627d0e6f6c9862
--- /dev/null
+++ b/ppdet/modeling/heads/centertrack_head.py
@@ -0,0 +1,244 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from ppdet.core.workspace import register
+from .centernet_head import ConvLayer
+from ..keypoint_utils import get_affine_transform
+
+__all__ = ['CenterTrackHead']
+
+
+@register
+class CenterTrackHead(nn.Layer):
+ """
+ Args:
+ in_channels (int): the channel number of input to CenterNetHead.
+ num_classes (int): the number of classes, 1 (MOT17 dataset) by default.
+ head_planes (int): the channel number in all head, 256 by default.
+ task (str): the type of task for regression, 'tracking' by default.
+ loss_weight (dict): the weight of each loss.
+ add_ltrb_amodal (bool): whether to add ltrb_amodal branch, False by default.
+ """
+
+ __shared__ = ['num_classes']
+
+ def __init__(self,
+ in_channels,
+ num_classes=1,
+ head_planes=256,
+ task='tracking',
+ loss_weight={
+ 'tracking': 1.0,
+ 'ltrb_amodal': 0.1,
+ },
+ add_ltrb_amodal=True):
+ super(CenterTrackHead, self).__init__()
+ self.task = task
+ self.loss_weight = loss_weight
+ self.add_ltrb_amodal = add_ltrb_amodal
+
+ # tracking head
+ self.tracking = nn.Sequential(
+ ConvLayer(
+ in_channels, head_planes, kernel_size=3, padding=1, bias=True),
+ nn.ReLU(),
+ ConvLayer(
+ head_planes, 2, kernel_size=1, stride=1, padding=0, bias=True))
+
+ # ltrb_amodal head
+ if self.add_ltrb_amodal and 'ltrb_amodal' in self.loss_weight:
+ self.ltrb_amodal = nn.Sequential(
+ ConvLayer(
+ in_channels,
+ head_planes,
+ kernel_size=3,
+ padding=1,
+ bias=True),
+ nn.ReLU(),
+ ConvLayer(
+ head_planes,
+ 4,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True))
+
+ # TODO: add more tasks
+
+ @classmethod
+ def from_config(cls, cfg, input_shape):
+ if isinstance(input_shape, (list, tuple)):
+ input_shape = input_shape[0]
+ return {'in_channels': input_shape.channels}
+
+ def forward(self,
+ feat,
+ inputs,
+ bboxes=None,
+ bbox_inds=None,
+ topk_clses=None,
+ topk_ys=None,
+ topk_xs=None):
+ tracking = self.tracking(feat)
+ head_outs = {'tracking': tracking}
+ if self.add_ltrb_amodal and 'ltrb_amodal' in self.loss_weight:
+ ltrb_amodal = self.ltrb_amodal(feat)
+ head_outs.update({'ltrb_amodal': ltrb_amodal})
+
+ if self.training:
+ losses = self.get_loss(inputs, self.loss_weight, head_outs)
+ return losses
+ else:
+ ret = self.generic_decode(head_outs, bboxes, bbox_inds, topk_ys,
+ topk_xs)
+ return ret
+
+ def get_loss(self, inputs, weights, head_outs):
+ index = inputs['index'].unsqueeze(2)
+ mask = inputs['index_mask'].unsqueeze(2)
+ batch_inds = list()
+ for i in range(head_outs['tracking'].shape[0]):
+ batch_ind = paddle.full(
+ shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
+ batch_inds.append(batch_ind)
+ batch_inds = paddle.concat(batch_inds, axis=0)
+ index = paddle.concat(x=[batch_inds, index], axis=2)
+
+ # 1.tracking head loss: L1 loss
+ tracking = head_outs['tracking'].transpose([0, 2, 3, 1])
+ tracking_target = inputs['tracking']
+ bs, _, _, c = tracking.shape
+ tracking = tracking.reshape([bs, -1, c])
+ pos_tracking = paddle.gather_nd(tracking, index=index)
+ tracking_mask = paddle.cast(
+ paddle.expand_as(mask, pos_tracking), dtype=pos_tracking.dtype)
+ pos_num = tracking_mask.sum()
+ tracking_mask.stop_gradient = True
+ tracking_target.stop_gradient = True
+ tracking_loss = F.l1_loss(
+ pos_tracking * tracking_mask,
+ tracking_target * tracking_mask,
+ reduction='sum')
+ tracking_loss = tracking_loss / (pos_num + 1e-4)
+
+ # 2.ltrb_amodal head loss(optinal): L1 loss
+ if self.add_ltrb_amodal and 'ltrb_amodal' in self.loss_weight:
+ ltrb_amodal = head_outs['ltrb_amodal'].transpose([0, 2, 3, 1])
+ ltrb_amodal_target = inputs['ltrb_amodal']
+ bs, _, _, c = ltrb_amodal.shape
+ ltrb_amodal = ltrb_amodal.reshape([bs, -1, c])
+ pos_ltrb_amodal = paddle.gather_nd(ltrb_amodal, index=index)
+ ltrb_amodal_mask = paddle.cast(
+ paddle.expand_as(mask, pos_ltrb_amodal),
+ dtype=pos_ltrb_amodal.dtype)
+ pos_num = ltrb_amodal_mask.sum()
+ ltrb_amodal_mask.stop_gradient = True
+ ltrb_amodal_target.stop_gradient = True
+ ltrb_amodal_loss = F.l1_loss(
+ pos_ltrb_amodal * ltrb_amodal_mask,
+ ltrb_amodal_target * ltrb_amodal_mask,
+ reduction='sum')
+ ltrb_amodal_loss = ltrb_amodal_loss / (pos_num + 1e-4)
+
+ losses = {'tracking_loss': tracking_loss, }
+ plugin_loss = weights['tracking'] * tracking_loss
+
+ if self.add_ltrb_amodal and 'ltrb_amodal' in self.loss_weight:
+ losses.update({'ltrb_amodal_loss': ltrb_amodal_loss})
+ plugin_loss += weights['ltrb_amodal'] * ltrb_amodal_loss
+ losses.update({'plugin_loss': plugin_loss})
+ return losses
+
+ def generic_decode(self, head_outs, bboxes, bbox_inds, topk_ys, topk_xs):
+ topk_ys = paddle.floor(topk_ys) # note: More accurate
+ topk_xs = paddle.floor(topk_xs)
+ cts = paddle.concat([topk_xs, topk_ys], 1)
+ ret = {'bboxes': bboxes, 'cts': cts}
+
+ regression_heads = ['tracking'] # todo: add more tasks
+ for head in regression_heads:
+ if head in head_outs:
+ ret[head] = _tranpose_and_gather_feat(head_outs[head],
+ bbox_inds)
+
+ if 'ltrb_amodal' in head_outs:
+ ltrb_amodal = head_outs['ltrb_amodal']
+ ltrb_amodal = _tranpose_and_gather_feat(ltrb_amodal, bbox_inds)
+ bboxes_amodal = paddle.concat(
+ [
+ topk_xs * 1.0 + ltrb_amodal[..., 0:1],
+ topk_ys * 1.0 + ltrb_amodal[..., 1:2],
+ topk_xs * 1.0 + ltrb_amodal[..., 2:3],
+ topk_ys * 1.0 + ltrb_amodal[..., 3:4]
+ ],
+ axis=1)
+ ret['bboxes'] = paddle.concat([bboxes[:, 0:2], bboxes_amodal], 1)
+ # cls_id, score, x0, y0, x1, y1
+
+ return ret
+
+ def centertrack_post_process(self, dets, meta, out_thresh):
+ if not ('bboxes' in dets):
+ return [{}]
+
+ preds = []
+ c, s = meta['center'].numpy(), meta['scale'].numpy()
+ h, w = meta['out_height'].numpy(), meta['out_width'].numpy()
+ trans = get_affine_transform(
+ center=c[0],
+ input_size=s[0],
+ rot=0,
+ output_size=[w[0], h[0]],
+ shift=(0., 0.),
+ inv=True).astype(np.float32)
+ for i, dets_bbox in enumerate(dets['bboxes']):
+ if dets_bbox[1] < out_thresh:
+ break
+ item = {}
+ item['score'] = dets_bbox[1]
+ item['class'] = int(dets_bbox[0]) + 1
+ item['ct'] = transform_preds_with_trans(
+ dets['cts'][i].reshape([1, 2]), trans).reshape(2)
+
+ if 'tracking' in dets:
+ tracking = transform_preds_with_trans(
+ (dets['tracking'][i] + dets['cts'][i]).reshape([1, 2]),
+ trans).reshape(2)
+ item['tracking'] = tracking - item['ct']
+
+ if 'bboxes' in dets:
+ bbox = transform_preds_with_trans(
+ dets_bbox[2:6].reshape([2, 2]), trans).reshape(4)
+ item['bbox'] = bbox
+
+ preds.append(item)
+ return preds
+
+
+def transform_preds_with_trans(coords, trans):
+ target_coords = np.ones((coords.shape[0], 3), np.float32)
+ target_coords[:, :2] = coords
+ target_coords = np.dot(trans, target_coords.transpose()).transpose()
+ return target_coords[:, :2]
+
+
+def _tranpose_and_gather_feat(feat, bbox_inds):
+ feat = feat.transpose([0, 2, 3, 1])
+ feat = feat.reshape([-1, feat.shape[3]])
+ feat = paddle.gather(feat, bbox_inds)
+ return feat
diff --git a/ppdet/modeling/mot/tracker/__init__.py b/ppdet/modeling/mot/tracker/__init__.py
index 03a5dd0a169203b86edbc6c81a44a095ebe9b3cc..76ee2a6c99c5b5fd3da0f6749a13b12f935cb588 100644
--- a/ppdet/modeling/mot/tracker/__init__.py
+++ b/ppdet/modeling/mot/tracker/__init__.py
@@ -14,12 +14,16 @@
from . import base_jde_tracker
from . import base_sde_tracker
+
+from .base_jde_tracker import *
+from .base_sde_tracker import *
+
from . import jde_tracker
from . import deepsort_tracker
from . import ocsort_tracker
+from . import center_tracker
-from .base_jde_tracker import *
-from .base_sde_tracker import *
from .jde_tracker import *
from .deepsort_tracker import *
from .ocsort_tracker import *
+from .center_tracker import *
diff --git a/ppdet/modeling/mot/tracker/center_tracker.py b/ppdet/modeling/mot/tracker/center_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3e5734675c42b8f10d12f660b82229696810eba
--- /dev/null
+++ b/ppdet/modeling/mot/tracker/center_tracker.py
@@ -0,0 +1,149 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is based on https://github.com/xingyizhou/CenterTrack/blob/master/src/lib/utils/tracker.py
+"""
+
+import copy
+import numpy as np
+import sklearn
+
+from ppdet.core.workspace import register, serializable
+from ppdet.utils.logger import setup_logger
+logger = setup_logger(__name__)
+
+__all__ = ['CenterTracker']
+
+
+@register
+@serializable
+class CenterTracker(object):
+ __shared__ = ['num_classes']
+
+ def __init__(self,
+ num_classes=1,
+ min_box_area=0,
+ vertical_ratio=-1,
+ track_thresh=0.4,
+ pre_thresh=0.5,
+ new_thresh=0.4,
+ out_thresh=0.4,
+ hungarian=False):
+ self.num_classes = num_classes
+ self.min_box_area = min_box_area
+ self.vertical_ratio = vertical_ratio
+
+ self.track_thresh = track_thresh
+ self.pre_thresh = max(track_thresh, pre_thresh)
+ self.new_thresh = max(track_thresh, new_thresh)
+ self.out_thresh = max(track_thresh, out_thresh)
+ self.hungarian = hungarian
+
+ self.reset()
+
+ def init_track(self, results):
+ print('Initialize tracking!')
+ for item in results:
+ if item['score'] > self.new_thresh:
+ self.id_count += 1
+ item['tracking_id'] = self.id_count
+ if not ('ct' in item):
+ bbox = item['bbox']
+ item['ct'] = [(bbox[0] + bbox[2]) / 2,
+ (bbox[1] + bbox[3]) / 2]
+ self.tracks.append(item)
+
+ def reset(self):
+ self.id_count = 0
+ self.tracks = []
+
+ def update(self, results, public_det=None):
+ N = len(results)
+ M = len(self.tracks)
+
+ dets = np.array([det['ct'] + det['tracking'] for det in results],
+ np.float32) # N x 2
+ track_size = np.array([((track['bbox'][2] - track['bbox'][0]) * \
+ (track['bbox'][3] - track['bbox'][1])) \
+ for track in self.tracks], np.float32) # M
+ track_cat = np.array([track['class'] for track in self.tracks],
+ np.int32) # M
+ item_size = np.array([((item['bbox'][2] - item['bbox'][0]) * \
+ (item['bbox'][3] - item['bbox'][1])) \
+ for item in results], np.float32) # N
+ item_cat = np.array([item['class'] for item in results], np.int32) # N
+ tracks = np.array([pre_det['ct'] for pre_det in self.tracks],
+ np.float32) # M x 2
+ dist = (((tracks.reshape(1, -1, 2) - \
+ dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
+
+ invalid = ((dist > track_size.reshape(1, M)) + \
+ (dist > item_size.reshape(N, 1)) + \
+ (item_cat.reshape(N, 1) != track_cat.reshape(1, M))) > 0
+ dist = dist + invalid * 1e18
+
+ if self.hungarian:
+ item_score = np.array([item['score'] for item in results],
+ np.float32)
+ dist[dist > 1e18] = 1e18
+ from sklearn.utils.linear_assignment_ import linear_assignment
+ matched_indices = linear_assignment(dist)
+ else:
+ matched_indices = greedy_assignment(copy.deepcopy(dist))
+
+ unmatched_dets = [d for d in range(dets.shape[0]) \
+ if not (d in matched_indices[:, 0])]
+ unmatched_tracks = [d for d in range(tracks.shape[0]) \
+ if not (d in matched_indices[:, 1])]
+
+ if self.hungarian:
+ matches = []
+ for m in matched_indices:
+ if dist[m[0], m[1]] > 1e16:
+ unmatched_dets.append(m[0])
+ unmatched_tracks.append(m[1])
+ else:
+ matches.append(m)
+ matches = np.array(matches).reshape(-1, 2)
+ else:
+ matches = matched_indices
+
+ ret = []
+ for m in matches:
+ track = results[m[0]]
+ track['tracking_id'] = self.tracks[m[1]]['tracking_id']
+ ret.append(track)
+
+ # Private detection: create tracks for all un-matched detections
+ for i in unmatched_dets:
+ track = results[i]
+ if track['score'] > self.new_thresh:
+ self.id_count += 1
+ track['tracking_id'] = self.id_count
+ ret.append(track)
+
+ self.tracks = ret
+ return ret
+
+
+def greedy_assignment(dist):
+ matched_indices = []
+ if dist.shape[1] == 0:
+ return np.array(matched_indices, np.int32).reshape(-1, 2)
+ for i in range(dist.shape[0]):
+ j = dist[i].argmin()
+ if dist[i][j] < 1e16:
+ dist[:, j] = 1e18
+ matched_indices.append([i, j])
+ return np.array(matched_indices, np.int32).reshape(-1, 2)
diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py
index 002a80d953835952b7a1223da1d1780c6e6798a2..791af6ed2782913c5815170a44c64c73c2594333 100644
--- a/ppdet/modeling/post_process.py
+++ b/ppdet/modeling/post_process.py
@@ -18,7 +18,6 @@ import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling.bbox_utils import nonempty_bbox
-from ppdet.modeling.layers import TTFBox
from .transformers import bbox_cxcywh_to_xyxy
try:
from collections.abc import Sequence
@@ -358,55 +357,78 @@ class JDEBBoxPostProcess(nn.Layer):
@register
-class CenterNetPostProcess(TTFBox):
+class CenterNetPostProcess(object):
"""
Postprocess the model outputs to get final prediction:
1. Do NMS for heatmap to get top `max_per_img` bboxes.
2. Decode bboxes using center offset and box size.
3. Rescale decoded bboxes reference to the origin image shape.
-
Args:
max_per_img(int): the maximum number of predicted objects in a image,
500 by default.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
regress_ltrb (bool): whether to regress left/top/right/bottom or
width/height for a box, true by default.
- for_mot (bool): whether return other features used in tracking model.
"""
+ __shared__ = ['down_ratio']
- __shared__ = ['down_ratio', 'for_mot']
-
- def __init__(self,
- max_per_img=500,
- down_ratio=4,
- regress_ltrb=True,
- for_mot=False):
- super(TTFBox, self).__init__()
+ def __init__(self, max_per_img=500, down_ratio=4, regress_ltrb=True):
+ super(CenterNetPostProcess, self).__init__()
self.max_per_img = max_per_img
self.down_ratio = down_ratio
self.regress_ltrb = regress_ltrb
- self.for_mot = for_mot
+ # _simple_nms() _topk() are same as TTFBox in ppdet/modeling/layers.py
+
+ def _simple_nms(self, heat, kernel=3):
+ """ Use maxpool to filter the max score, get local peaks. """
+ pad = (kernel - 1) // 2
+ hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
+ keep = paddle.cast(hmax == heat, 'float32')
+ return heat * keep
+
+ def _topk(self, scores):
+ """ Select top k scores and decode to get xy coordinates. """
+ k = self.max_per_img
+ shape_fm = paddle.shape(scores)
+ shape_fm.stop_gradient = True
+ cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
+ # batch size is 1
+ scores_r = paddle.reshape(scores, [cat, -1])
+ topk_scores, topk_inds = paddle.topk(scores_r, k)
+ topk_ys = topk_inds // width
+ topk_xs = topk_inds % width
+
+ topk_score_r = paddle.reshape(topk_scores, [-1])
+ topk_score, topk_ind = paddle.topk(topk_score_r, k)
+ k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
+ topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
+
+ topk_inds = paddle.reshape(topk_inds, [-1])
+ topk_ys = paddle.reshape(topk_ys, [-1, 1])
+ topk_xs = paddle.reshape(topk_xs, [-1, 1])
+ topk_inds = paddle.gather(topk_inds, topk_ind)
+ topk_ys = paddle.gather(topk_ys, topk_ind)
+ topk_xs = paddle.gather(topk_xs, topk_ind)
+ return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
def __call__(self, hm, wh, reg, im_shape, scale_factor):
+ # 1.get clses and scores, note that hm had been done sigmoid
heat = self._simple_nms(hm)
scores, inds, topk_clses, ys, xs = self._topk(heat)
- scores = scores.unsqueeze(1)
clses = topk_clses.unsqueeze(1)
+ scores = scores.unsqueeze(1)
+ # 2.get bboxes, note only support batch_size=1 now
reg_t = paddle.transpose(reg, [0, 2, 3, 1])
- # Like TTFBox, batch size is 1.
- # TODO: support batch size > 1
reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
reg = paddle.gather(reg, inds)
xs = paddle.cast(xs, 'float32')
ys = paddle.cast(ys, 'float32')
xs = xs + reg[:, 0:1]
ys = ys + reg[:, 1:2]
-
wh_t = paddle.transpose(wh, [0, 2, 3, 1])
wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
wh = paddle.gather(wh, inds)
-
if self.regress_ltrb:
x1 = xs - wh[:, 0:1]
y1 = ys - wh[:, 1:2]
@@ -417,7 +439,6 @@ class CenterNetPostProcess(TTFBox):
y1 = ys - wh[:, 1:2] / 2
x2 = xs + wh[:, 0:1] / 2
y2 = ys + wh[:, 1:2] / 2
-
n, c, feat_h, feat_w = paddle.shape(hm)
padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
@@ -425,12 +446,10 @@ class CenterNetPostProcess(TTFBox):
y1 = y1 * self.down_ratio
x2 = x2 * self.down_ratio
y2 = y2 * self.down_ratio
-
x1 = x1 - padw
y1 = y1 - padh
x2 = x2 - padw
y2 = y2 - padh
-
bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
scale_y = scale_factor[:, 0:1]
scale_x = scale_factor[:, 1:2]
@@ -439,11 +458,9 @@ class CenterNetPostProcess(TTFBox):
boxes_shape = bboxes.shape[:]
scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
bboxes = paddle.divide(bboxes, scale_expand)
+
results = paddle.concat([clses, scores, bboxes], axis=1)
- if self.for_mot:
- return results, inds, topk_clses
- else:
- return results, paddle.shape(results)[0:1], topk_clses
+ return results, paddle.shape(results)[0:1], inds, topk_clses, ys, xs
@register
|