diff --git a/configs/mot/bytetrack/README_cn.md b/configs/mot/bytetrack/README_cn.md index 1313dfc7138a81b3c8037c425cc388399c9edf20..8e93134f83bac05abfa85ea8551f9571913df0f9 100644 --- a/configs/mot/bytetrack/README_cn.md +++ b/configs/mot/bytetrack/README_cn.md @@ -20,10 +20,9 @@ | MOT-17 half train | YOLOv3 | 608x608 | - | 42.7 | 49.5 | 54.8 | - |[配置文件](./bytetrack_yolov3.yml) | | MOT-17 half train | PP-YOLOE-l | 640x640 | - | 52.9 | 50.4 | 59.7 | - |[配置文件](./bytetrack_ppyoloe.yml) | | MOT-17 half train | PP-YOLOE-l | 640x640 |PPLCNet| 52.9 | 51.7 | 58.8 | - |[配置文件](./bytetrack_ppyoloe_pplcnet.yml) | -| **mot17_ch** | YOLOX-x | 800x1440| - | 61.9 | 77.3 | 71.6 | - |[配置文件](./bytetrack_yolox.yml) | +| **mix_mot_ch** | YOLOX-x | 800x1440| - | 61.9 | 77.3 | 71.6 | - |[配置文件](./bytetrack_yolox.yml) | | **mix_det** | YOLOX-x | 800x1440| - | 65.4 | 84.5 | 77.4 | - |[配置文件](./bytetrack_yolox.yml) | - **注意:** - 检测任务相关配置和文档请查看[detector](detector/) @@ -43,7 +42,7 @@ **注意:** - 模型权重下载链接在配置文件中的```det_weights```和```reid_weights```,运行```tools/eval_mot.py```评估的命令即可自动下载,```reid_weights```若为None则表示不需要使用,ByteTrack默认不使用ReID权重。 - **MOT17-half train**是MOT17的train序列(共7个)每个视频的前一半帧的图片和标注组成的数据集,而为了验证精度可以都用**MOT17-half val**数据集去评估,它是每个视频的后一半帧组成的,数据集可以从[此链接](https://dataset.bj.bcebos.com/mot/MOT17.zip)下载,并解压放在`dataset/mot/`文件夹下。 - - **mix_det**是MOT17、crowdhuman、Cityscapes、ETHZ组成的联合数据集,数据集整理的格式和目录可以参考[此链接](https://github.com/ifzhang/ByteTrack#data-preparation),最终放置于`dataset/mot/`目录下。为了验证精度可以都用**MOT17-half val**数据集去评估。 + - **mix_mot_ch**数据集,是MOT17、CrowdHuman组成的联合数据集,**mix_det**是MOT17、CrowdHuman、Cityscapes、ETHZ组成的联合数据集,数据集整理的格式和目录可以参考[此链接](https://github.com/ifzhang/ByteTrack#data-preparation),最终放置于`dataset/mot/`目录下。为了验证精度可以都用**MOT17-half val**数据集去评估。 - ByteTrack的训练是单独的检测器训练MOT数据集,推理是组装跟踪器去评估MOT指标,单独的检测模型也可以评估检测指标。 - ByteTrack的导出部署,是单独导出检测模型,再组装跟踪器运行的,参照[PP-Tracking](../../../deploy/pptracking/python/README.md)。 @@ -122,7 +121,7 @@ Step 2:导出ReID模型(可选步骤,默认不需要) CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/reid/deepsort_pplcnet.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet.pdparams ``` -### 4. 用导出的模型基于Python去预测 +### 5. 用导出的模型基于Python去预测 ```bash python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/ppyoloe_crn_l_36e_640x640_mot17half/ --tracker_config=deploy/pptracking/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --save_mot_txts diff --git a/configs/mot/ocsort/README.md b/configs/mot/ocsort/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6dfcf9fb76410b85f8fb9447961a6c881c838183 --- /dev/null +++ b/configs/mot/ocsort/README.md @@ -0,0 +1,101 @@ +简体中文 | [English](README.md) + +# OC_SORT (Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking) + +## 内容 +- [简介](#简介) +- [模型库](#模型库) +- [快速开始](#快速开始) +- [引用](#引用) + +## 简介 +[OC_SORT](https://arxiv.org/abs/2203.14360)(Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking)。此处提供了几个常用检测器的配置作为参考。由于训练数据集、输入尺度、训练epoch数、NMS阈值设置等的不同均会导致模型精度和性能的差异,请自行根据需求进行适配。 + +## 模型库 + +### OC_SORT在MOT-17 half Val Set上结果 + +| 检测训练数据集 | 检测器 | 输入尺度 | ReID | 检测mAP | MOTA | IDF1 | FPS | 配置文件 | +| :-------- | :----- | :----: | :----:|:------: | :----: |:-----: |:----:|:----: | +| MOT-17 half train | PP-YOLOE-l | 640x640 | - | 52.9 | 50.1 | 62.6 | - |[配置文件](./bytetrack_ppyoloe.yml) | +| **mot17_ch** | YOLOX-x | 800x1440| - | 61.9 | 75.5 | 77.0 | - |[配置文件](./ocsort_yolox.yml) | + +**注意:** + - 模型权重下载链接在配置文件中的```det_weights```和```reid_weights```,运行验证的命令即可自动下载,OC_SORT默认不需要```reid_weights```权重。 + - **MOT17-half train**是MOT17的train序列(共7个)每个视频的前一半帧的图片和标注组成的数据集,而为了验证精度可以都用**MOT17-half val**数据集去评估,它是每个视频的后一半帧组成的,数据集可以从[此链接](https://dataset.bj.bcebos.com/mot/MOT17.zip)下载,并解压放在`dataset/mot/`文件夹下。 + - **mix_mot_ch**数据集,是MOT17、CrowdHuman组成的联合数据集,**mix_det**是MOT17、CrowdHuman、Cityscapes、ETHZ组成的联合数据集,数据集整理的格式和目录可以参考[此链接](https://github.com/ifzhang/ByteTrack#data-preparation),最终放置于`dataset/mot/`目录下。为了验证精度可以都用**MOT17-half val**数据集去评估。 + - OC_SORT的训练是单独的检测器训练MOT数据集,推理是组装跟踪器去评估MOT指标,单独的检测模型也可以评估检测指标。 + - OC_SORT的导出部署,是单独导出检测模型,再组装跟踪器运行的,参照[PP-Tracking](../../../deploy/pptracking/python)。 + - OC_SORT是PP-Human和PP-Vehicle等Pipeline分析项目跟踪方向的主要方案,具体使用参照[Pipeline](../../../deploy/pipeline)和[MOT](../../../deploy/pipeline/docs/tutorials/mot.md)。 + + +## 快速开始 + +### 1. 训练 +通过如下命令一键式启动训练和评估 +```bash +python -m paddle.distributed.launch --log_dir=ppyoloe --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml --eval --amp +``` + +### 2. 评估 +#### 2.1 评估检测效果 +```bash +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml +``` + +**注意:** + - 评估检测使用的是```tools/eval.py```, 评估跟踪使用的是```tools/eval_mot.py```。 + +#### 2.2 评估跟踪效果 +```bash +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/ocsort/ocsort_ppyoloe.yml --scaled=True +# 或者 +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/ocsort/ocsort_yolox.yml --scaled=True +``` +**注意:** + - `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE YOLOv3则为False,如果使用通用检测模型则为True, 默认值是False。 + - 跟踪结果会存于`{output_dir}/mot_results/`中,里面每个视频序列对应一个txt,每个txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`, 此外`{output_dir}`可通过`--output_dir`设置。 + +### 3. 预测 + +使用单个GPU通过如下命令预测一个视频,并保存为视频 + +```bash +# 下载demo视频 +wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4 + +CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/ocsort/ocsort_yolox.yml --video_file=mot17_demo.mp4 --scaled=True --save_videos +``` + +**注意:** + - 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`。 + - `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。 + + +### 4. 导出预测模型 + +Step 1:导出检测模型 +```bash +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/bytetrack/detector/yolox_x_24e_800x1440_mix_det.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/yolox_x_24e_800x1440_mix_det.pdparams +``` + +### 5. 用导出的模型基于Python去预测 + +```bash +python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/yolox_x_24e_800x1440_mix_det/ --tracker_config=deploy/pptracking/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --save_mot_txts +``` +**注意:** + - 运行前需要手动修改`tracker_config.yml`的跟踪器类型为`type: OCSORTTracker`。 + - 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_mot_txt_per_img`(对每张图片保存一个txt)表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 + - 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`。 + + +## 引用 +``` +@article{cao2022observation, + title={Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking}, + author={Cao, Jinkun and Weng, Xinshuo and Khirodkar, Rawal and Pang, Jiangmiao and Kitani, Kris}, + journal={arXiv preprint arXiv:2203.14360}, + year={2022} +} +``` diff --git a/configs/mot/ocsort/ocsort_ppyoloe.yml b/configs/mot/ocsort/ocsort_ppyoloe.yml new file mode 100644 index 0000000000000000000000000000000000000000..0d81b2d1c0b0def8cd4458a96c6a352e04c16456 --- /dev/null +++ b/configs/mot/ocsort/ocsort_ppyoloe.yml @@ -0,0 +1,75 @@ +# This config is an assembled config for ByteTrack MOT, used as eval/infer mode for MOT. +_BASE_: [ + '../bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml', + '../bytetrack/_base_/mot17.yml', + '../bytetrack/_base_/ppyoloe_mot_reader_640x640.yml' +] +weights: output/ocsort_ppyoloe/model_final +log_iter: 20 +snapshot_epoch: 2 + +metric: MOT # eval/infer mode, set 'COCO' can be training mode +num_classes: 1 + +architecture: ByteTrack +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_300e_coco.pdparams +ByteTrack: + detector: YOLOv3 # PPYOLOe version + reid: None + tracker: OCSORTTracker +det_weights: https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams +reid_weights: None + +YOLOv3: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + +# Tracking requires higher quality boxes, so NMS score_threshold will be higher +PPYOLOEHead: + fpn_strides: [32, 16, 8] + grid_cell_scale: 5.0 + grid_cell_offset: 0.5 + static_assigner_epoch: -1 # 100 + use_varifocal_loss: True + loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5} + static_assigner: + name: ATSSAssigner + topk: 9 + assigner: + name: TaskAlignedAssigner + topk: 13 + alpha: 1.0 + beta: 6.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.1 # 0.01 in original detector + nms_threshold: 0.4 # 0.6 in original detector + + +OCSORTTracker: + det_thresh: 0.4 # 0.6 in yolox ocsort + max_age: 30 + min_hits: 3 + iou_threshold: 0.3 + delta_t: 3 + inertia: 0.2 + vertical_ratio: 0 + min_box_area: 0 + use_byte: False + + +# MOTDataset for MOT evaluation and inference +EvalMOTDataset: + !MOTImageFolder + dataset_dir: dataset/mot + data_root: MOT17/images/half + keep_ori_im: True # set as True in DeepSORT and ByteTrack + +TestMOTDataset: + !MOTImageFolder + dataset_dir: dataset/mot + keep_ori_im: True # set True if save visualization images or video diff --git a/configs/mot/ocsort/ocsort_yolox.yml b/configs/mot/ocsort/ocsort_yolox.yml new file mode 100644 index 0000000000000000000000000000000000000000..4f05e2d04ce1d83c98e54b35d21217915c5ee8f4 --- /dev/null +++ b/configs/mot/ocsort/ocsort_yolox.yml @@ -0,0 +1,83 @@ +# This config is an assembled config for ByteTrack MOT, used as eval/infer mode for MOT. +_BASE_: [ + '../bytetrack/detector/yolox_x_24e_800x1440_mix_det.yml', + '../bytetrack/_base_/mix_det.yml', + '../bytetrack/_base_/yolox_mot_reader_800x1440.yml' +] +weights: output/ocsort_yolox/model_final +log_iter: 20 +snapshot_epoch: 2 + +metric: MOT # eval/infer mode +num_classes: 1 + +architecture: ByteTrack +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/yolox_x_300e_coco.pdparams +ByteTrack: + detector: YOLOX + reid: None + tracker: OCSORTTracker +det_weights: https://bj.bcebos.com/v1/paddledet/models/mot/yolox_x_24e_800x1440_mix_mot_ch.pdparams +reid_weights: None + +depth_mult: 1.33 +width_mult: 1.25 + +YOLOX: + backbone: CSPDarkNet + neck: YOLOCSPPAN + head: YOLOXHead + input_size: [800, 1440] + size_stride: 32 + size_range: [18, 22] # multi-scale range [576*1024 ~ 800*1440], w/h ratio=1.8 + +CSPDarkNet: + arch: "X" + return_idx: [2, 3, 4] + depthwise: False + +YOLOCSPPAN: + depthwise: False + +# Tracking requires higher quality boxes, so NMS score_threshold will be higher +YOLOXHead: + l1_epoch: 20 + depthwise: False + loss_weight: {cls: 1.0, obj: 1.0, iou: 5.0, l1: 1.0} + assigner: + name: SimOTAAssigner + candidate_topk: 10 + use_vfl: False + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.1 + nms_threshold: 0.7 + # For speed while keep high mAP, you can modify 'nms_top_k' to 1000 and 'keep_top_k' to 100, the mAP will drop about 0.1%. + # For high speed demo, you can modify 'score_threshold' to 0.25 and 'nms_threshold' to 0.45, but the mAP will drop a lot. + + +OCSORTTracker: + det_thresh: 0.6 + max_age: 30 + min_hits: 3 + iou_threshold: 0.3 + delta_t: 3 + inertia: 0.2 + vertical_ratio: 1.6 + min_box_area: 100 + use_byte: False + + +# MOTDataset for MOT evaluation and inference +EvalMOTDataset: + !MOTImageFolder + dataset_dir: dataset/mot + data_root: MOT17/images/half + keep_ori_im: True # set as True in DeepSORT and ByteTrack + +TestMOTDataset: + !MOTImageFolder + dataset_dir: dataset/mot + keep_ori_im: True # set True if save visualization images or video diff --git a/deploy/pipeline/config/tracker_config.yml b/deploy/pipeline/config/tracker_config.yml index 50e92f7d7aecf94534c406667347cc0f45681966..c4f3c894655c3a8c58bdfb1b124d98427eeef4df 100644 --- a/deploy/pipeline/config/tracker_config.yml +++ b/deploy/pipeline/config/tracker_config.yml @@ -2,7 +2,8 @@ # The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model. # Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking. -type: JDETracker +type: OCSORTTracker # choose one tracker in ['JDETracker', 'OCSORTTracker'] + # BYTETracker JDETracker: @@ -13,3 +14,15 @@ JDETracker: match_thres: 0.9 min_box_area: 0 vertical_ratio: 0 # 1.6 for pedestrian + + +OCSORTTracker: + det_thresh: 0.4 + max_age: 30 + min_hits: 3 + iou_threshold: 0.3 + delta_t: 3 + inertia: 0.2 + vertical_ratio: 0 + min_box_area: 0 + use_byte: False diff --git a/deploy/pptracking/python/README.md b/deploy/pptracking/python/README.md index dc70b767bfa8f78c1caa94e4e67cbb9c75afb826..e56a69e3fe96aba11ed3a50d1173b39c30fc83cf 100644 --- a/deploy/pptracking/python/README.md +++ b/deploy/pptracking/python/README.md @@ -117,7 +117,7 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=mot_ppyoloe_l_36e_p -## 3. 对ByteTrack模型的导出和预测 +## 3. 对ByteTrack和OC_SORT模型的导出和预测 ### 3.1 导出预测模型 ```bash # 导出PPYOLOe行人检测模型 @@ -136,7 +136,8 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/pp python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/ppyoloe_crn_l_36e_640x640_mot17half/ --reid_model_dir=output_inference/deepsort_pplcnet/ --tracker_config=deploy/pptracking/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --threshold=0.5 --save_mot_txts --save_images ``` **注意:** - - 运行前需要确认`tracker_config.yml`的跟踪器类型为`type: JDETracker`。 + - 运行ByteTrack模型需要确认`tracker_config.yml`的跟踪器类型为`type: JDETracker`。 + - 可切换`tracker_config.yml`的跟踪器类型为`type: OCSORTTracker`运行OC_SORT模型。 - ByteTrack模型是加载导出的检测器和单独配置的`--tracker_config`文件运行的,为了实时跟踪所以不需要reid模型,`--reid_model_dir`表示reid导出模型的路径,默认为空,加不加具体视效果而定; - 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_images`表示保存跟踪结果可视化图片。 - 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`。 diff --git a/deploy/pptracking/python/mot/matching/__init__.py b/deploy/pptracking/python/mot/matching/__init__.py index 54c6680f79f16247c562a9da1024dd3e1de4c57f..f6a88c5673a50452415b1f86f7b18bac12297f49 100644 --- a/deploy/pptracking/python/mot/matching/__init__.py +++ b/deploy/pptracking/python/mot/matching/__init__.py @@ -14,6 +14,8 @@ from . import jde_matching from . import deepsort_matching +from . import ocsort_matching from .jde_matching import * from .deepsort_matching import * +from .ocsort_matching import * diff --git a/deploy/pptracking/python/mot/matching/ocsort_matching.py b/deploy/pptracking/python/mot/matching/ocsort_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..b2428d020731a6f91e28f9962168390cf1b3a12f --- /dev/null +++ b/deploy/pptracking/python/mot/matching/ocsort_matching.py @@ -0,0 +1,127 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/association.py +""" + +import os +import numpy as np + + +def iou_batch(bboxes1, bboxes2): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + o = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * + (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * + (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + return (o) + + +def speed_direction_batch(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0 + CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, ( + tracks[:, 1] + tracks[:, 3]) / 2.0 + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx**2 + dy**2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def linear_assignment(cost_matrix): + try: + import lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True) + return np.array([[y[i], i] for i in x if i >= 0]) # + except ImportError: + from scipy.optimize import linear_sum_assignment + x, y = linear_sum_assignment(cost_matrix) + return np.array(list(zip(x, y))) + + +def associate(detections, trackers, iou_threshold, velocities, previous_obs, + vdc_weight): + if (len(trackers) == 0): + return np.empty( + (0, 2), dtype=int), np.arange(len(detections)), np.empty( + (0, 5), dtype=int) + + Y, X = speed_direction_batch(detections, previous_obs) + inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1] + inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1) + inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1) + diff_angle_cos = inertia_X * X + inertia_Y * Y + diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1) + diff_angle = np.arccos(diff_angle_cos) + diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi + + valid_mask = np.ones(previous_obs.shape[0]) + valid_mask[np.where(previous_obs[:, 4] < 0)] = 0 + + iou_matrix = iou_batch(detections, trackers) + scores = np.repeat( + detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1) + # iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this + valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1) + + angle_diff_cost = (valid_mask * diff_angle) * vdc_weight + angle_diff_cost = angle_diff_cost.T + angle_diff_cost = angle_diff_cost * scores + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost)) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if (d not in matched_indices[:, 0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if (t not in matched_indices[:, 1]): + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if (iou_matrix[m[0], m[1]] < iou_threshold): + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if (len(matches) == 0): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) diff --git a/deploy/pptracking/python/mot/tracker/__init__.py b/deploy/pptracking/python/mot/tracker/__init__.py index b74593b4126d878cd655326e58369f5b6f76a2ae..03a5dd0a169203b86edbc6c81a44a095ebe9b3cc 100644 --- a/deploy/pptracking/python/mot/tracker/__init__.py +++ b/deploy/pptracking/python/mot/tracker/__init__.py @@ -16,8 +16,10 @@ from . import base_jde_tracker from . import base_sde_tracker from . import jde_tracker from . import deepsort_tracker +from . import ocsort_tracker from .base_jde_tracker import * from .base_sde_tracker import * from .jde_tracker import * from .deepsort_tracker import * +from .ocsort_tracker import * diff --git a/deploy/pptracking/python/mot/tracker/ocsort_tracker.py b/deploy/pptracking/python/mot/tracker/ocsort_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..ea97014cf1e62d85daf2aeaa91d3f9c0bf2f0260 --- /dev/null +++ b/deploy/pptracking/python/mot/tracker/ocsort_tracker.py @@ -0,0 +1,354 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/ocsort.py +""" + +import numpy as np +from filterpy.kalman import KalmanFilter + +from ..matching.ocsort_matching import associate, linear_assignment, iou_batch + + +def k_previous_obs(observations, cur_age, k): + if len(observations) == 0: + return [-1, -1, -1, -1, -1] + for i in range(k): + dt = k - i + if cur_age - dt in observations: + return observations[cur_age - dt] + max_age = max(observations.keys()) + return observations[max_age] + + +def convert_bbox_to_z(bbox): + """ + Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form + [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is + the aspect ratio + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2. + y = bbox[1] + h / 2. + s = w * h # scale is just area + r = w / float(h + 1e-6) + return np.array([x, y, s, r]).reshape((4, 1)) + + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if (score == None): + return np.array( + [x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., + x[1] + h / 2.]).reshape((1, 4)) + else: + score = np.array([score]) + return np.array([ + x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score + ]).reshape((1, 5)) + + +def speed_direction(bbox1, bbox2): + cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0 + cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0 + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1)**2 + (cx2 - cx1)**2) + 1e-6 + return speed / norm + + +class KalmanBoxTracker(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + + Args: + bbox (np.array): bbox in [x1,y1,x2,y2,score] format. + delta_t (int): delta_t of previous observation + """ + count = 0 + + def __init__(self, bbox, delta_t=3): + self.kf = KalmanFilter(dim_x=7, dim_z=4) + self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1]]) + self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]]) + self.kf.R[2:, 2:] *= 10. + self.kf.P[4:, 4:] *= 1000. + # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10. + self.kf.Q[-1, -1] *= 0.01 + self.kf.Q[4:, 4:] *= 0.01 + + self.score = bbox[4] + self.kf.x[:4] = convert_bbox_to_z(bbox) + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + KalmanBoxTracker.count += 1 + self.history = [] + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now. + """ + self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder + self.observations = dict() + self.history_observations = [] + self.velocity = None + self.delta_t = delta_t + + def update(self, bbox): + """ + Updates the state vector with observed bbox. + """ + if bbox is not None: + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for i in range(self.delta_t): + dt = self.delta_t - i + if self.age - dt in self.observations: + previous_box = self.observations[self.age - dt] + break + if previous_box is None: + previous_box = self.last_observation + """ + Estimate the track speed direction with observations \Delta t steps away + """ + self.velocity = speed_direction(previous_box, bbox) + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.history = [] + self.hits += 1 + self.hit_streak += 1 + self.kf.update(convert_bbox_to_z(bbox)) + else: + self.kf.update(bbox) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if ((self.kf.x[6] + self.kf.x[2]) <= 0): + self.kf.x[6] *= 0.0 + + self.kf.predict() + self.age += 1 + if (self.time_since_update > 0): + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(convert_x_to_bbox(self.kf.x, score=self.score)) + return self.history[-1] + + def get_state(self): + return convert_x_to_bbox(self.kf.x, score=self.score) + + +class OCSORTTracker(object): + """ + OCSORT tracker, support single class + + Args: + det_thresh (float): threshold of detection score + max_age (int): maximum number of missed misses before a track is deleted + min_hits (int): minimum hits for associate + iou_threshold (float): iou threshold for associate + delta_t (int): delta_t of previous observation + inertia (float): vdc_weight of angle_diff_cost for associate + vertical_ratio (float): w/h, the vertical ratio of the bbox to filter + bad results. If set <= 0 means no need to filter bboxes,usually set + 1.6 for pedestrian tracking. + min_box_area (int): min box area to filter out low quality boxes + use_byte (bool): Whether use ByteTracker, default False + """ + + def __init__(self, + det_thresh=0.6, + max_age=30, + min_hits=3, + iou_threshold=0.3, + delta_t=3, + inertia=0.2, + vertical_ratio=-1, + min_box_area=0, + use_byte=False): + self.det_thresh = det_thresh + self.max_age = max_age + self.min_hits = min_hits + self.iou_threshold = iou_threshold + self.delta_t = delta_t + self.inertia = inertia + self.vertical_ratio = vertical_ratio + self.min_box_area = min_box_area + self.use_byte = use_byte + + self.trackers = [] + self.frame_count = 0 + KalmanBoxTracker.count = 0 + + def update(self, pred_dets, pred_embs=None): + """ + Args: + pred_dets (np.array): Detection results of the image, the shape is + [N, 6], means 'cls_id, score, x0, y0, x1, y1'. + pred_embs (np.array): Embedding results of the image, the shape is + [N, 128] or [N, 512], default as None. + + Return: + tracking boxes (np.array): [M, 6], means 'x0, y0, x1, y1, score, id'. + """ + if pred_dets is None: + return np.empty((0, 6)) + + self.frame_count += 1 + + bboxes = pred_dets[:, 2:] + scores = pred_dets[:, 1:2] + dets = np.concatenate((bboxes, scores), axis=1) + scores = scores.squeeze(-1) + + inds_low = scores > 0.1 + inds_high = scores < self.det_thresh + inds_second = np.logical_and(inds_low, inds_high) + # self.det_thresh > score > 0.1, for second matching + dets_second = dets[inds_second] # detections for second matching + remain_inds = scores > self.det_thresh + dets = dets[remain_inds] + + # get predicted locations from existing trackers. + trks = np.zeros((len(self.trackers), 5)) + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos = self.trackers[t].predict()[0] + trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] + if np.any(np.isnan(pos)): + to_del.append(t) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + for t in reversed(to_del): + self.trackers.pop(t) + + velocities = np.array([ + trk.velocity if trk.velocity is not None else np.array((0, 0)) + for trk in self.trackers + ]) + last_boxes = np.array([trk.last_observation for trk in self.trackers]) + k_observations = np.array([ + k_previous_obs(trk.observations, trk.age, self.delta_t) + for trk in self.trackers + ]) + """ + First round of association + """ + matched, unmatched_dets, unmatched_trks = associate( + dets, trks, self.iou_threshold, velocities, k_observations, + self.inertia) + for m in matched: + self.trackers[m[1]].update(dets[m[0], :]) + """ + Second round of associaton by OCR + """ + # BYTE association + if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[ + 0] > 0: + u_trks = trks[unmatched_trks] + iou_left = iou_batch( + dets_second, + u_trks) # iou between low score detections and unmatched tracks + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + matched_indices = linear_assignment(-iou_left) + to_remove_trk_indices = [] + for m in matched_indices: + det_ind, trk_ind = m[0], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.trackers[trk_ind].update(dets_second[det_ind, :]) + to_remove_trk_indices.append(trk_ind) + unmatched_trks = np.setdiff1d(unmatched_trks, + np.array(to_remove_trk_indices)) + + if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0: + left_dets = dets[unmatched_dets] + left_trks = last_boxes[unmatched_trks] + iou_left = iou_batch(left_dets, left_trks) + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + rematched_indices = linear_assignment(-iou_left) + to_remove_det_indices = [] + to_remove_trk_indices = [] + for m in rematched_indices: + det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[ + 1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.trackers[trk_ind].update(dets[det_ind, :]) + to_remove_det_indices.append(det_ind) + to_remove_trk_indices.append(trk_ind) + unmatched_dets = np.setdiff1d(unmatched_dets, + np.array(to_remove_det_indices)) + unmatched_trks = np.setdiff1d(unmatched_trks, + np.array(to_remove_trk_indices)) + + for m in unmatched_trks: + self.trackers[m].update(None) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t) + self.trackers.append(trk) + i = len(self.trackers) + for trk in reversed(self.trackers): + if trk.last_observation.sum() < 0: + d = trk.get_state()[0] + else: + d = trk.last_observation # tlbr + score + if (trk.time_since_update < 1) and ( + trk.hit_streak >= self.min_hits or + self.frame_count <= self.min_hits): + # +1 as MOT benchmark requires positive + ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) + i -= 1 + # remove dead tracklet + if (trk.time_since_update > self.max_age): + self.trackers.pop(i) + if (len(ret) > 0): + return np.concatenate(ret) + return np.empty((0, 6)) diff --git a/deploy/pptracking/python/mot_sde_infer.py b/deploy/pptracking/python/mot_sde_infer.py index 7b5c8852defeea0fc51bd6d45c25490a794011a2..f7d62ec89d24e4d5981a24554ea10271cac01c6d 100644 --- a/deploy/pptracking/python/mot_sde_infer.py +++ b/deploy/pptracking/python/mot_sde_infer.py @@ -32,7 +32,7 @@ sys.path.insert(0, parent_path) from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor from mot_utils import argsparser, Timer, get_current_memory_mb, video2frames, _is_valid_video -from mot.tracker import JDETracker, DeepSORTTracker +from mot.tracker import JDETracker, DeepSORTTracker, OCSORTTracker from mot.utils import MOTTimer, write_mot_results, get_crops, clip_box, flow_statistic from mot.visualize import plot_tracking, plot_tracking_dict @@ -142,8 +142,10 @@ class SDE_Detector(Detector): # tracker config self.use_deepsort_tracker = True if tracker_cfg[ 'type'] == 'DeepSORTTracker' else False + self.use_ocsort_tracker = True if tracker_cfg[ + 'type'] == 'OCSORTTracker' else False + if self.use_deepsort_tracker: - # use DeepSORTTracker if self.reid_pred_config is not None and hasattr( self.reid_pred_config, 'tracker'): cfg = self.reid_pred_config.tracker @@ -161,6 +163,28 @@ class SDE_Detector(Detector): matching_threshold=matching_threshold, min_box_area=min_box_area, vertical_ratio=vertical_ratio, ) + + elif self.use_ocsort_tracker: + det_thresh = cfg.get('det_thresh', 0.4) + max_age = cfg.get('max_age', 30) + min_hits = cfg.get('min_hits', 3) + iou_threshold = cfg.get('iou_threshold', 0.3) + delta_t = cfg.get('delta_t', 3) + inertia = cfg.get('inertia', 0.2) + min_box_area = cfg.get('min_box_area', 0) + vertical_ratio = cfg.get('vertical_ratio', 0) + use_byte = cfg.get('use_byte', False) + + self.tracker = OCSORTTracker( + det_thresh=det_thresh, + max_age=max_age, + min_hits=min_hits, + iou_threshold=iou_threshold, + delta_t=delta_t, + inertia=inertia, + min_box_area=min_box_area, + vertical_ratio=vertical_ratio, + use_byte=use_byte) else: # use ByteTracker use_byte = cfg.get('use_byte', False) @@ -283,6 +307,32 @@ class SDE_Detector(Detector): feat_data['feat'] = _feat tracking_outs['feat_data'].update({_imgname: feat_data}) return tracking_outs + + elif self.use_ocsort_tracker: + # use OCSORTTracker, only support singe class + online_targets = self.tracker.update(pred_dets, pred_embs) + online_tlwhs = defaultdict(list) + online_scores = defaultdict(list) + online_ids = defaultdict(list) + for t in online_targets: + tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]] + tscore = float(t[4]) + tid = int(t[5]) + if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue + if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ + 3] > self.tracker.vertical_ratio: + continue + if tlwh[2] * tlwh[3] > 0: + online_tlwhs[0].append(tlwh) + online_ids[0].append(tid) + online_scores[0].append(tscore) + tracking_outs = { + 'online_tlwhs': online_tlwhs, + 'online_scores': online_scores, + 'online_ids': online_ids, + } + return tracking_outs + else: # use ByteTracker, support multiple class online_tlwhs = defaultdict(list) @@ -523,7 +573,7 @@ class SDE_Detector(Detector): online_tlwhs, online_scores, online_ids = mot_results[0] # flow statistic for one class, and only for bytetracker - if num_classes == 1 and not self.use_deepsort_tracker: + if num_classes == 1 and not self.use_deepsort_tracker and not self.use_ocsort_tracker: result = (frame_id + 1, online_tlwhs[0], online_scores[0], online_ids[0]) statistic = flow_statistic( @@ -533,8 +583,8 @@ class SDE_Detector(Detector): records = statistic['records'] fps = 1. / timer.duration - if self.use_deepsort_tracker: - # use DeepSORTTracker, only support singe class + if self.use_deepsort_tracker or self.use_ocsort_tracker: + # use DeepSORTTracker or OCSORTTracker, only support singe class results[0].append( (frame_id + 1, online_tlwhs, online_scores, online_ids)) im = plot_tracking( diff --git a/deploy/pptracking/python/tracker_config.yml b/deploy/pptracking/python/tracker_config.yml index ddd55e8653870ed9bdfe9734995e8af5b56f49e2..464337702a4a6b5d8d32fc87eeb8df9bcc01c57a 100644 --- a/deploy/pptracking/python/tracker_config.yml +++ b/deploy/pptracking/python/tracker_config.yml @@ -2,7 +2,7 @@ # The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model. # Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking. -type: JDETracker # 'JDETracker' or 'DeepSORTTracker' +type: OCSORTTracker # choose one tracker in ['JDETracker', 'OCSORTTracker', 'DeepSORTTracker'] # BYTETracker JDETracker: @@ -14,6 +14,19 @@ JDETracker: min_box_area: 0 vertical_ratio: 0 # 1.6 for pedestrian + +OCSORTTracker: + det_thresh: 0.4 + max_age: 30 + min_hits: 3 + iou_threshold: 0.3 + delta_t: 3 + inertia: 0.2 + min_box_area: 0 + vertical_ratio: 0 + use_byte: False + + DeepSORTTracker: input_size: [64, 192] min_box_area: 0 diff --git a/ppdet/engine/tracker.py b/ppdet/engine/tracker.py index 70c3552ac2ee29d7f9ca7de4086497e2d3da9cc4..090c72c7eb86562beee9257cf9c32c966af093bc 100644 --- a/ppdet/engine/tracker.py +++ b/ppdet/engine/tracker.py @@ -29,7 +29,7 @@ from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results -from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker +from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker, OCSORTTracker from ppdet.modeling.architectures import YOLOX from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric, MCMOTMetric import ppdet.utils.stats as stats @@ -370,7 +370,29 @@ class Tracker(object): save_vis_results(data, frame_id, online_ids, online_tlwhs, online_scores, timer.average_time, show_image, save_dir, self.cfg.num_classes) - + elif isinstance(tracker, OCSORTTracker): + # OC_SORT Tracker + online_targets = tracker.update(pred_dets_old, pred_embs) + online_tlwhs = [] + online_ids = [] + online_scores = [] + for t in online_targets: + tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]] + tscore = float(t[4]) + tid = int(t[5]) + if tlwh[2] * tlwh[3] > 0: + online_tlwhs.append(tlwh) + online_ids.append(tid) + online_scores.append(tscore) + timer.toc() + # save results + results[0].append( + (frame_id + 1, online_tlwhs, online_scores, online_ids)) + save_vis_results(data, frame_id, online_ids, online_tlwhs, + online_scores, timer.average_time, show_image, + save_dir, self.cfg.num_classes) + else: + raise ValueError(tracker) frame_id += 1 return results, frame_id, timer.average_time, timer.calls diff --git a/ppdet/modeling/mot/matching/__init__.py b/ppdet/modeling/mot/matching/__init__.py index 54c6680f79f16247c562a9da1024dd3e1de4c57f..f6a88c5673a50452415b1f86f7b18bac12297f49 100644 --- a/ppdet/modeling/mot/matching/__init__.py +++ b/ppdet/modeling/mot/matching/__init__.py @@ -14,6 +14,8 @@ from . import jde_matching from . import deepsort_matching +from . import ocsort_matching from .jde_matching import * from .deepsort_matching import * +from .ocsort_matching import * diff --git a/ppdet/modeling/mot/matching/ocsort_matching.py b/ppdet/modeling/mot/matching/ocsort_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..a32d76155b985f03b8ecbaedd88df70eaa9fd0fa --- /dev/null +++ b/ppdet/modeling/mot/matching/ocsort_matching.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/association.py +""" + +import os +import numpy as np + + +def iou_batch(bboxes1, bboxes2): + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + area = w * h + iou_matrix = area / ((bboxes1[..., 2] - bboxes1[..., 0]) * + (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * + (bboxes2[..., 3] - bboxes2[..., 1]) - area) + return iou_matrix + + +def speed_direction_batch(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0 + CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, ( + tracks[:, 1] + tracks[:, 3]) / 2.0 + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx**2 + dy**2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx + + +def linear_assignment(cost_matrix): + try: + import lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True) + return np.array([[y[i], i] for i in x if i >= 0]) + except ImportError: + from scipy.optimize import linear_sum_assignment + x, y = linear_sum_assignment(cost_matrix) + return np.array(list(zip(x, y))) + + +def associate(detections, trackers, iou_threshold, velocities, previous_obs, + vdc_weight): + if (len(trackers) == 0): + return np.empty( + (0, 2), dtype=int), np.arange(len(detections)), np.empty( + (0, 5), dtype=int) + + Y, X = speed_direction_batch(detections, previous_obs) + inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1] + inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1) + inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1) + diff_angle_cos = inertia_X * X + inertia_Y * Y + diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1) + diff_angle = np.arccos(diff_angle_cos) + diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi + + valid_mask = np.ones(previous_obs.shape[0]) + valid_mask[np.where(previous_obs[:, 4] < 0)] = 0 + + iou_matrix = iou_batch(detections, trackers) + scores = np.repeat( + detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1) + # iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this + valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1) + + angle_diff_cost = (valid_mask * diff_angle) * vdc_weight + angle_diff_cost = angle_diff_cost.T + angle_diff_cost = angle_diff_cost * scores + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost)) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if (d not in matched_indices[:, 0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if (t not in matched_indices[:, 1]): + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if (iou_matrix[m[0], m[1]] < iou_threshold): + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if (len(matches) == 0): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) diff --git a/ppdet/modeling/mot/tracker/__init__.py b/ppdet/modeling/mot/tracker/__init__.py index b74593b4126d878cd655326e58369f5b6f76a2ae..03a5dd0a169203b86edbc6c81a44a095ebe9b3cc 100644 --- a/ppdet/modeling/mot/tracker/__init__.py +++ b/ppdet/modeling/mot/tracker/__init__.py @@ -16,8 +16,10 @@ from . import base_jde_tracker from . import base_sde_tracker from . import jde_tracker from . import deepsort_tracker +from . import ocsort_tracker from .base_jde_tracker import * from .base_sde_tracker import * from .jde_tracker import * from .deepsort_tracker import * +from .ocsort_tracker import * diff --git a/ppdet/modeling/mot/tracker/ocsort_tracker.py b/ppdet/modeling/mot/tracker/ocsort_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..93ce082ac940f806174fd4c9d8d971d698702c9d --- /dev/null +++ b/ppdet/modeling/mot/tracker/ocsort_tracker.py @@ -0,0 +1,357 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/ocsort.py +""" + +import numpy as np +from filterpy.kalman import KalmanFilter + +from ..matching.ocsort_matching import associate, linear_assignment, iou_batch +from ppdet.core.workspace import register, serializable + + +def k_previous_obs(observations, cur_age, k): + if len(observations) == 0: + return [-1, -1, -1, -1, -1] + for i in range(k): + dt = k - i + if cur_age - dt in observations: + return observations[cur_age - dt] + max_age = max(observations.keys()) + return observations[max_age] + + +def convert_bbox_to_z(bbox): + """ + Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form + [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is + the aspect ratio + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2. + y = bbox[1] + h / 2. + s = w * h # scale is just area + r = w / float(h + 1e-6) + return np.array([x, y, s, r]).reshape((4, 1)) + + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if (score == None): + return np.array( + [x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., + x[1] + h / 2.]).reshape((1, 4)) + else: + score = np.array([score]) + return np.array([ + x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score + ]).reshape((1, 5)) + + +def speed_direction(bbox1, bbox2): + cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0 + cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0 + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1)**2 + (cx2 - cx1)**2) + 1e-6 + return speed / norm + + +class KalmanBoxTracker(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + + Args: + bbox (np.array): bbox in [x1,y1,x2,y2,score] format. + delta_t (int): delta_t of previous observation + """ + count = 0 + + def __init__(self, bbox, delta_t=3): + self.kf = KalmanFilter(dim_x=7, dim_z=4) + self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1]]) + self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]]) + self.kf.R[2:, 2:] *= 10. + self.kf.P[4:, 4:] *= 1000. + # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10. + self.kf.Q[-1, -1] *= 0.01 + self.kf.Q[4:, 4:] *= 0.01 + + self.score = bbox[4] + self.kf.x[:4] = convert_bbox_to_z(bbox) + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + KalmanBoxTracker.count += 1 + self.history = [] + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now. + """ + self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder + self.observations = dict() + self.history_observations = [] + self.velocity = None + self.delta_t = delta_t + + def update(self, bbox): + """ + Updates the state vector with observed bbox. + """ + if bbox is not None: + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for i in range(self.delta_t): + dt = self.delta_t - i + if self.age - dt in self.observations: + previous_box = self.observations[self.age - dt] + break + if previous_box is None: + previous_box = self.last_observation + """ + Estimate the track speed direction with observations \Delta t steps away + """ + self.velocity = speed_direction(previous_box, bbox) + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.history = [] + self.hits += 1 + self.hit_streak += 1 + self.kf.update(convert_bbox_to_z(bbox)) + else: + self.kf.update(bbox) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if ((self.kf.x[6] + self.kf.x[2]) <= 0): + self.kf.x[6] *= 0.0 + + self.kf.predict() + self.age += 1 + if (self.time_since_update > 0): + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(convert_x_to_bbox(self.kf.x, score=self.score)) + return self.history[-1] + + def get_state(self): + return convert_x_to_bbox(self.kf.x, score=self.score) + + +@register +@serializable +class OCSORTTracker(object): + """ + OCSORT tracker, support single class + + Args: + det_thresh (float): threshold of detection score + max_age (int): maximum number of missed misses before a track is deleted + min_hits (int): minimum hits for associate + iou_threshold (float): iou threshold for associate + delta_t (int): delta_t of previous observation + inertia (float): vdc_weight of angle_diff_cost for associate + vertical_ratio (float): w/h, the vertical ratio of the bbox to filter + bad results. If set <= 0 means no need to filter bboxes,usually set + 1.6 for pedestrian tracking. + min_box_area (int): min box area to filter out low quality boxes + use_byte (bool): Whether use ByteTracker, default False + """ + + def __init__(self, + det_thresh=0.6, + max_age=30, + min_hits=3, + iou_threshold=0.3, + delta_t=3, + inertia=0.2, + vertical_ratio=-1, + min_box_area=0, + use_byte=False): + self.det_thresh = det_thresh + self.max_age = max_age + self.min_hits = min_hits + self.iou_threshold = iou_threshold + self.delta_t = delta_t + self.inertia = inertia + self.vertical_ratio = vertical_ratio + self.min_box_area = min_box_area + self.use_byte = use_byte + + self.trackers = [] + self.frame_count = 0 + KalmanBoxTracker.count = 0 + + def update(self, pred_dets, pred_embs=None): + """ + Args: + pred_dets (np.array): Detection results of the image, the shape is + [N, 6], means 'cls_id, score, x0, y0, x1, y1'. + pred_embs (np.array): Embedding results of the image, the shape is + [N, 128] or [N, 512], default as None. + + Return: + tracking boxes (np.array): [M, 6], means 'x0, y0, x1, y1, score, id'. + """ + if pred_dets is None: + return np.empty((0, 6)) + + self.frame_count += 1 + + bboxes = pred_dets[:, 2:] + scores = pred_dets[:, 1:2] + dets = np.concatenate((bboxes, scores), axis=1) + scores = scores.squeeze(-1) + + inds_low = scores > 0.1 + inds_high = scores < self.det_thresh + inds_second = np.logical_and(inds_low, inds_high) + # self.det_thresh > score > 0.1, for second matching + dets_second = dets[inds_second] # detections for second matching + remain_inds = scores > self.det_thresh + dets = dets[remain_inds] + + # get predicted locations from existing trackers. + trks = np.zeros((len(self.trackers), 5)) + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos = self.trackers[t].predict()[0] + trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] + if np.any(np.isnan(pos)): + to_del.append(t) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + for t in reversed(to_del): + self.trackers.pop(t) + + velocities = np.array([ + trk.velocity if trk.velocity is not None else np.array((0, 0)) + for trk in self.trackers + ]) + last_boxes = np.array([trk.last_observation for trk in self.trackers]) + k_observations = np.array([ + k_previous_obs(trk.observations, trk.age, self.delta_t) + for trk in self.trackers + ]) + """ + First round of association + """ + matched, unmatched_dets, unmatched_trks = associate( + dets, trks, self.iou_threshold, velocities, k_observations, + self.inertia) + for m in matched: + self.trackers[m[1]].update(dets[m[0], :]) + """ + Second round of associaton by OCR + """ + # BYTE association + if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[ + 0] > 0: + u_trks = trks[unmatched_trks] + iou_left = iou_batch( + dets_second, + u_trks) # iou between low score detections and unmatched tracks + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + matched_indices = linear_assignment(-iou_left) + to_remove_trk_indices = [] + for m in matched_indices: + det_ind, trk_ind = m[0], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.trackers[trk_ind].update(dets_second[det_ind, :]) + to_remove_trk_indices.append(trk_ind) + unmatched_trks = np.setdiff1d(unmatched_trks, + np.array(to_remove_trk_indices)) + + if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0: + left_dets = dets[unmatched_dets] + left_trks = last_boxes[unmatched_trks] + iou_left = iou_batch(left_dets, left_trks) + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + rematched_indices = linear_assignment(-iou_left) + to_remove_det_indices = [] + to_remove_trk_indices = [] + for m in rematched_indices: + det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[ + 1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.trackers[trk_ind].update(dets[det_ind, :]) + to_remove_det_indices.append(det_ind) + to_remove_trk_indices.append(trk_ind) + unmatched_dets = np.setdiff1d(unmatched_dets, + np.array(to_remove_det_indices)) + unmatched_trks = np.setdiff1d(unmatched_trks, + np.array(to_remove_trk_indices)) + + for m in unmatched_trks: + self.trackers[m].update(None) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t) + self.trackers.append(trk) + i = len(self.trackers) + for trk in reversed(self.trackers): + if trk.last_observation.sum() < 0: + d = trk.get_state()[0] + else: + d = trk.last_observation # tlbr + score + if (trk.time_since_update < 1) and ( + trk.hit_streak >= self.min_hits or + self.frame_count <= self.min_hits): + # +1 as MOT benchmark requires positive + ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) + i -= 1 + # remove dead tracklet + if (trk.time_since_update > self.max_age): + self.trackers.pop(i) + if (len(ret) > 0): + return np.concatenate(ret) + return np.empty((0, 6)) diff --git a/requirements.txt b/requirements.txt index 6fa692bf0c5cb6a110e99d976a30c1709362e098..8563997322d3ccedfe2e08b9814ec0a03ef2387a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,8 +10,11 @@ Cython pycocotools #xtcocotools==1.6 #only for crowdpose setuptools>=42.0.0 +pyclipper + +# for mot lap sklearn motmetrics openpyxl -pyclipper +filterpy