未验证 提交 7fc9614a 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] add pptracking mtmct (#4615)

* add mtmct aic_vehicle baseline

* fix clip_bbox, add coco model filter, fix some bugs

* fix infer bugs, fix crops and mtmct vis

* fix engine tracker infer

* fix mtmct images deploy

* fix readme and cfg

* remove mtmct of cfg modeling and tracker

* fix doc cfg
上级 d93c9eb5
README_cn.md
\ No newline at end of file
English | [简体中文](README_cn.md)
# MTMCT (Multi-Target Multi-Camera Tracking)
## 内容
- [简介](#简介)
- [模型库](#模型库)
- [快速开始](#快速开始)
- [引用](#引用)
## 简介
MTMCT (Multi-Target Multi-Camera Tracking) 跨镜头多目标跟踪是某一场景下的不同摄像头拍摄的视频进行多目标跟踪,是跟踪领域一个非常重要的研究课题,在安防监控、自动驾驶、智慧城市等行业起着重要作用。MTMCT预测的是同一场景下的不同摄像头拍摄的视频,其方法的效果受场景先验知识和相机数量角度拓扑结构等信息的影响较大,PaddleDetection此处提供的是去除场景和相机相关优化方法后的一个基础版本的MTMCT算法实现,如果要继续提高效果,需要专门针对该场景和相机信息设计后处理算法。此处选用DeepSORT方案做MTMCT,为了达到实时性选用了PaddleDetection自研的PPYOLOv2和PP-PicoDet作为检测器,选用PaddleClas自研的轻量级网络PP-LCNet作为ReID模型。
MTMCT是[PP-Tracking](../../../deploy/pptracking)项目中一个非常重要的方向,[PP-Tracking](../../../deploy/pptracking/README.md)是基于PaddlePaddle深度学习框架的业界首个开源实时跟踪系统。针对实际业务的难点痛点,PP-Tracking内置行人车辆跟踪、跨镜头跟踪、多类别跟踪、小目标跟踪及流量计数等能力与产业应用,同时提供可视化开发界面。模型集成多目标跟踪,目标检测,ReID轻量级算法,进一步提升PP-Tracking在服务器端部署性能。同时支持python,C++部署,适配Linux,Nvidia Jetson多平台环境。具体可前往该目录使用。
## 模型库
### DeepSORT在 AIC21 MTMCT(CityFlow) 车辆跨境跟踪数据集Test集上的结果
| 检测器 | 输入尺度 | ReID | 场景 | Tricks | IDF1 | IDP | IDR | Precision | Recall | FPS | 检测器下载链接 | ReID下载链接 |
| :--------- | :--------- | :------- | :----- | :------ |:----- |:------- |:----- |:--------- |:-------- |:----- |:------ | :------ |
| PP-PicoDet | 640x640 | PP-LCNet | S06 | - | 0.3617 | 0.4417 | 0.3062 | 0.6266 | 0.4343 | - |[Detector](https://paddledet.bj.bcebos.com/models/mot/deepsort/picodet_l_640_aic21mtmct_vehicle.tar) |[ReID](https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet_vehicle.tar) |
| PPYOLOv2 | 640x640 | PP-LCNet | S06 | - | 0.4450 | 0.4611 | 0.4300 | 0.6385 | 0.5954 | - |[Detector](https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle.tar) |[ReID](https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet_vehicle.tar) |
**注意:**
S06是AIC21 MTMCT数据集Test集的场景名称,S06场景下有’c041,c042,c043,c044,c045,c046‘共6个摄像头的视频。
## 数据集准备
此处提供了车辆和行人的两种模型方案,对于车辆是选用的[AIC21 MTMCT](https://www.aicitychallenge.org) (CityFlow)车辆跨境跟踪数据集,对于行人是选用的[WILDTRACK](https://www.epfl.ch/labs/cvlab/data/data-wildtrack)行人跨境跟踪数据集。
AIC21 MTMCT原始数据集的目录如下所示:
```
|——————AIC21_Track3_MTMC_Tracking
|——————cam_framenum (Number of frames below each camera)
|——————cam_loc (Positional relationship between cameras)
|——————cam_timestamp (Time difference between cameras)
|——————eval (evaluation function and ground_truth.txt)
|——————test
|——————train
|——————validation
|——————DataLicenseAgreement_AICityChallenge_2021.pdf
|——————list_cam.txt (List of all camera paths)
|——————ReadMe.txt (Dataset description)
|——————gen_aicity_mtmct_data.py (Camera data extraction script)
```
需要处理成如下格式:
```
├── S01
│ ├── c001
│ ├── roi.jog (Area mask of the road)
│ ├── img1
│ ├── ...
│ ├── c002
│ ├── roi.jog
│ ├── img1
│ ├── ...
│ ├── c003
│ ├── roi.jog
│ ├── img1
│ ├── ...
├── gt
│ ├── ground_truth_train.txt
│ ├── ground_truth_validation.txt
├── zone (only for S06 when use camera track trick)
│ ├── ...
```
#### 生成S01场景的验证集数据
python gen_aicity_mtmct_data.py ./AIC21_Track3_MTMC_Tracking/train/S01
## 快速开始
### 1. 导出模型
Step 1:下载导出的检测模型
```bash
wget https://paddledet.bj.bcebos.com/models/mot/deepsort/picodet_l_640_aic21mtmct_vehicle.tar
tar -xvf picodet_l_640_aic21mtmct_vehicle.tar
```
Step 2:下载导出的ReID模型
```bash
wget https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet_vehicle.tar
tar -xvf deepsort_pplcnet_vehicle.tar
```
**注意:**
PP-PicoDet是轻量级检测模型,其训练请参考[configs/picodet](../../picodet/README.md),并注意修改种类数和数据集路径。
PP-LCNet是轻量级ReID模型,其训练请参考[PaddleClas](https://github.com/PaddlePaddle/PaddleClas),是在VERI-Wild车辆重识别数据集训练得到的权重,建议直接使用无需重训。
### 2. 用导出的模型基于Python去预测
```bash
# 用导出PicoDet车辆检测模型和PPLCNet车辆ReID模型
python deploy/pptracking/python/mot_sde_infer.py --model_dir=picodet_l_640_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir={your mtmct scene video folder} --device=GPU --scaled=True --save_mot_txts --save_images
```
**注意:**
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt),或`--save_images`表示保存跟踪结果可视化图片。
`--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
`--mtmct_dir`是MTMCT预测的某个场景的文件夹名字,里面包含该场景不同摄像头拍摄视频的图片文件夹,其数量至少为两个。
MTMCT跨镜头跟踪输出结果为视频和txt形式。每个图片文件夹各生成一个可视化的跨镜头跟踪结果,与单镜头跟踪的结果是不同的,单镜头跟踪的结果在几个视频文件夹间是独立无关的。MTMCT的结果txt只有一个,比单镜头跟踪结果txt多了第一列镜头id号。
MTMCT是[PP-Tracking](../../../deploy/pptracking)项目中的一个非常重要的方向,具体可前往该目录使用。
## 引用
```
@InProceedings{Tang19CityFlow,
author = {Zheng Tang and Milind Naphade and Ming-Yu Liu and Xiaodong Yang and Stan Birchfield and Shuo Wang and Ratnesh Kumar and David Anastasiu and Jenq-Neng Hwang},
title = {CityFlow: A City-Scale Benchmark for Multi-Target Multi-Camera Vehicle Tracking and Re-Identification},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2019},
pages = {8797–8806}
}
```
......@@ -68,6 +68,30 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/pp
- `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
## 3. 跨境跟踪模型的导出和预测
### 3.1 导出预测模型
Step 1:下载导出的检测模型
```bash
wget https://paddledet.bj.bcebos.com/models/mot/deepsort/picodet_l_640_aic21mtmct_vehicle.tar
tar -xvf picodet_l_640_aic21mtmct_vehicle.tar
```
Step 2:下载导出的ReID模型
```bash
wget https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet_vehicle.tar
tar -xvf deepsort_pplcnet_vehicle.tar
```
### 3.2 用导出的模型基于Python去预测
```bash
# 用导出PicoDet车辆检测模型和PPLCNet车辆ReID模型
python deploy/pptracking/python/mot_sde_infer.py --model_dir=picodet_l_640_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir={your mtmct scene video folder} --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --save_mot_txts --save_images
```
**注意:**
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt),或`--save_images`表示保存跟踪结果可视化图片。
`--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
`--mtmct_dir`是MTMCT预测的某个场景的文件夹名字,里面包含该场景不同摄像头拍摄视频的图片文件夹,其数量至少为两个。
## 参数说明:
| 参数 | 是否必须|含义 |
......@@ -88,6 +112,8 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/pp
| --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False |
| --do_entrance_counting | Option | 是否统计出入口流量,默认为False |
| --draw_center_traj | Option | 是否绘制跟踪轨迹,默认为False |
| --mtmct_dir | Option | 需要进行MTMCT跨境头跟踪预测的图片文件夹路径,默认为None |
| --mtmct_cfg | Option | 需要进行MTMCT跨境头跟踪预测的配置文件路径,默认为None |
说明:
......
......@@ -16,8 +16,10 @@ from . import matching
from . import tracker
from . import motion
from . import utils
from . import mtmct
from .matching import *
from .tracker import *
from .motion import *
from .utils import *
from .mtmct import *
......@@ -74,7 +74,7 @@ def cython_bbox_ious(atlbrs, btlbrs):
import cython_bbox
except Exception as e:
print('cython_bbox not found, please install cython_bbox.'
'for example: `pip install cython_bbox`.')
'for example: `pip install cython_bbox`.')
exit()
ious = cython_bbox.bbox_overlaps(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import utils
from . import postprocess
from .utils import *
from .postprocess import *
# The following codes are strongly related to zone and camera parameters
from . import camera_utils
from . import zone
from .camera_utils import *
from .zone import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
"""
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from .utils import get_dire, get_match, get_cid_tid, combin_feature, combin_cluster
from .utils import normalize, intracam_ignore, visual_rerank
__all__ = [
'st_filter',
'get_labels_with_camera',
]
CAM_DIST = [[0, 40, 55, 100, 120, 145], [40, 0, 15, 60, 80, 105],
[55, 15, 0, 40, 65, 90], [100, 60, 40, 0, 20, 45],
[120, 80, 65, 20, 0, 25], [145, 105, 90, 45, 25, 0]]
def st_filter(st_mask, cid_tids, cid_tid_dict):
count = len(cid_tids)
for i in range(count):
i_tracklet = cid_tid_dict[cid_tids[i]]
i_cid = i_tracklet['cam']
i_dire = get_dire(i_tracklet['zone_list'], i_cid)
i_iot = i_tracklet['io_time']
for j in range(count):
j_tracklet = cid_tid_dict[cid_tids[j]]
j_cid = j_tracklet['cam']
j_dire = get_dire(j_tracklet['zone_list'], j_cid)
j_iot = j_tracklet['io_time']
match_dire = True
cam_dist = CAM_DIST[i_cid - 41][j_cid - 41]
# if time overlopped
if i_iot[0] - cam_dist < j_iot[0] and j_iot[0] < i_iot[
1] + cam_dist:
match_dire = False
if i_iot[0] - cam_dist < j_iot[1] and j_iot[1] < i_iot[
1] + cam_dist:
match_dire = False
# not match after go out
if i_dire[1] in [1, 2]: # i out
if i_iot[0] < j_iot[1] + cam_dist:
match_dire = False
if i_dire[1] in [1, 2]:
if i_dire[0] in [3] and i_cid > j_cid:
match_dire = False
if i_dire[0] in [4] and i_cid < j_cid:
match_dire = False
if i_cid in [41] and i_dire[1] in [4]:
if i_iot[0] < j_iot[1] + cam_dist:
match_dire = False
if i_iot[1] > 199:
match_dire = False
if i_cid in [46] and i_dire[1] in [3]:
if i_iot[0] < j_iot[1] + cam_dist:
match_dire = False
# match after come into
if i_dire[0] in [1, 2]:
if i_iot[1] > j_iot[0] - cam_dist:
match_dire = False
if i_dire[0] in [1, 2]:
if i_dire[1] in [3] and i_cid > j_cid:
match_dire = False
if i_dire[1] in [4] and i_cid < j_cid:
match_dire = False
is_ignore = False
if ((i_dire[0] == i_dire[1] and i_dire[0] in [3, 4]) or
(j_dire[0] == j_dire[1] and j_dire[0] in [3, 4])):
is_ignore = True
if not is_ignore:
# direction conflict
if (i_dire[0] in [3] and j_dire[0] in [4]) or (
i_dire[1] in [3] and j_dire[1] in [4]):
match_dire = False
# filter before going next scene
if i_dire[1] in [3] and i_cid < j_cid:
if i_iot[1] > j_iot[1] - cam_dist:
match_dire = False
if i_dire[1] in [4] and i_cid > j_cid:
if i_iot[1] > j_iot[1] - cam_dist:
match_dire = False
if i_dire[0] in [3] and i_cid < j_cid:
if i_iot[0] < j_iot[0] + cam_dist:
match_dire = False
if i_dire[0] in [4] and i_cid > j_cid:
if i_iot[0] < j_iot[0] + cam_dist:
match_dire = False
## 3-30
## 4-1
if i_dire[0] in [3] and i_cid > j_cid:
if i_iot[1] > j_iot[0] - cam_dist:
match_dire = False
if i_dire[0] in [4] and i_cid < j_cid:
if i_iot[1] > j_iot[0] - cam_dist:
match_dire = False
# filter before going next scene
## 4-7
if i_dire[1] in [3] and i_cid > j_cid:
if i_iot[0] < j_iot[1] + cam_dist:
match_dire = False
if i_dire[1] in [4] and i_cid < j_cid:
if i_iot[0] < j_iot[1] + cam_dist:
match_dire = False
else:
if i_iot[1] > 199:
if i_dire[0] in [3] and i_cid < j_cid:
if i_iot[0] < j_iot[0] + cam_dist:
match_dire = False
if i_dire[0] in [4] and i_cid > j_cid:
if i_iot[0] < j_iot[0] + cam_dist:
match_dire = False
if i_dire[0] in [3] and i_cid > j_cid:
match_dire = False
if i_dire[0] in [4] and i_cid < j_cid:
match_dire = False
if i_iot[0] < 1:
if i_dire[1] in [3] and i_cid > j_cid:
match_dire = False
if i_dire[1] in [4] and i_cid < j_cid:
match_dire = False
if not match_dire:
st_mask[i, j] = 0.0
st_mask[j, i] = 0.0
return st_mask
def subcam_list(cid_tid_dict, cid_tids):
sub_3_4 = dict()
sub_4_3 = dict()
for cid_tid in cid_tids:
cid, tid = cid_tid
tracklet = cid_tid_dict[cid_tid]
zs, ze = get_dire(tracklet['zone_list'], cid)
if zs in [3] and cid not in [46]: # 4 to 3
if not cid + 1 in sub_4_3:
sub_4_3[cid + 1] = []
sub_4_3[cid + 1].append(cid_tid)
if ze in [4] and cid not in [41]: # 4 to 3
if not cid in sub_4_3:
sub_4_3[cid] = []
sub_4_3[cid].append(cid_tid)
if zs in [4] and cid not in [41]: # 3 to 4
if not cid - 1 in sub_3_4:
sub_3_4[cid - 1] = []
sub_3_4[cid - 1].append(cid_tid)
if ze in [3] and cid not in [46]: # 3 to 4
if not cid in sub_3_4:
sub_3_4[cid] = []
sub_3_4[cid].append(cid_tid)
sub_cid_tids = dict()
for i in sub_3_4:
sub_cid_tids[(i, i + 1)] = sub_3_4[i]
for i in sub_4_3:
sub_cid_tids[(i, i - 1)] = sub_4_3[i]
return sub_cid_tids
def subcam_list2(cid_tid_dict, cid_tids):
sub_dict = dict()
for cid_tid in cid_tids:
cid, tid = cid_tid
if cid not in [41]:
if not cid in sub_dict:
sub_dict[cid] = []
sub_dict[cid].append(cid_tid)
if cid not in [46]:
if not cid + 1 in sub_dict:
sub_dict[cid + 1] = []
sub_dict[cid + 1].append(cid_tid)
return sub_dict
def get_sim_matrix(cid_tid_dict,
cid_tids,
use_ff=True,
use_rerank=True,
use_st_filter=False):
# Note: carame releated get_sim_matrix function,
# which is different from the one in utils.py.
count = len(cid_tids)
q_arr = np.array(
[cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
g_arr = np.array(
[cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
q_arr = normalize(q_arr, axis=1)
g_arr = normalize(g_arr, axis=1)
st_mask = np.ones((count, count), dtype=np.float32)
st_mask = intracam_ignore(st_mask, cid_tids)
# different from utils.py
if use_st_filter:
st_mask = st_filter(st_mask, cid_tids, cid_tid_dict)
visual_sim_matrix = visual_rerank(
q_arr, g_arr, cid_tids, use_ff=use_ff, use_rerank=use_rerank)
visual_sim_matrix = visual_sim_matrix.astype('float32')
np.set_printoptions(precision=3)
sim_matrix = visual_sim_matrix * st_mask
np.fill_diagonal(sim_matrix, 0)
return sim_matrix
def get_labels_with_camera(cid_tid_dict,
cid_tids,
use_ff=True,
use_rerank=True,
use_st_filter=False):
# 1st cluster
sub_cid_tids = subcam_list(cid_tid_dict, cid_tids)
sub_labels = dict()
dis_thrs = [0.7, 0.5, 0.5, 0.5, 0.5, 0.7, 0.5, 0.5, 0.5, 0.5]
for i, sub_c_to_c in enumerate(sub_cid_tids):
sim_matrix = get_sim_matrix(
cid_tid_dict,
sub_cid_tids[sub_c_to_c],
use_ff=use_ff,
use_rerank=use_rerank,
use_st_filter=use_st_filter)
cluster_labels = AgglomerativeClustering(
n_clusters=None,
distance_threshold=1 - dis_thrs[i],
affinity='precomputed',
linkage='complete').fit_predict(1 - sim_matrix)
labels = get_match(cluster_labels)
cluster_cid_tids = get_cid_tid(labels, sub_cid_tids[sub_c_to_c])
sub_labels[sub_c_to_c] = cluster_cid_tids
labels, sub_cluster = combin_cluster(sub_labels, cid_tids)
# 2nd cluster
cid_tid_dict_new = combin_feature(cid_tid_dict, sub_cluster)
sub_cid_tids = subcam_list2(cid_tid_dict_new, cid_tids)
sub_labels = dict()
for i, sub_c_to_c in enumerate(sub_cid_tids):
sim_matrix = get_sim_matrix(
cid_tid_dict_new,
sub_cid_tids[sub_c_to_c],
use_ff=use_ff,
use_rerank=use_rerank,
use_st_filter=use_st_filter)
cluster_labels = AgglomerativeClustering(
n_clusters=None,
distance_threshold=1 - 0.1,
affinity='precomputed',
linkage='complete').fit_predict(1 - sim_matrix)
labels = get_match(cluster_labels)
cluster_cid_tids = get_cid_tid(labels, sub_cid_tids[sub_c_to_c])
sub_labels[sub_c_to_c] = cluster_cid_tids
labels, sub_cluster = combin_cluster(sub_labels, cid_tids)
return labels
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
"""
import re
import cv2
from tqdm import tqdm
import pickle
import os
import os.path as osp
from os.path import join as opj
import numpy as np
import motmetrics as mm
from functools import reduce
from .utils import parse_pt_gt, parse_pt, compare_dataframes_mtmc
from .utils import get_labels, getData, gen_new_mot
from .camera_utils import get_labels_with_camera
from .zone import Zone
from ..utils import plot_tracking
__all__ = [
'trajectory_fusion',
'sub_cluster',
'gen_res',
'print_mtmct_result',
'get_mtmct_matching_results',
'save_mtmct_crops',
'save_mtmct_vis_results',
]
def trajectory_fusion(mot_feature, cid, cid_bias, use_zone=False, zone_path=''):
cur_bias = cid_bias[cid]
mot_list_break = {}
if use_zone:
zones = Zone(zone_path=zone_path)
zones.set_cam(cid)
mot_list = parse_pt(mot_feature, zones)
else:
mot_list = parse_pt(mot_feature)
if use_zone:
mot_list = zones.break_mot(mot_list, cid)
mot_list = zones.filter_mot(mot_list, cid) # filter by zone
mot_list = zones.filter_bbox(mot_list, cid) # filter bbox
mot_list_break = gen_new_mot(mot_list) # save break feature for gen result
tid_data = dict()
for tid in mot_list:
tracklet = mot_list[tid]
if len(tracklet) <= 1:
continue
frame_list = list(tracklet.keys())
frame_list.sort()
# filter area too large
zone_list = [tracklet[f]['zone'] for f in frame_list]
feature_list = [
tracklet[f]['feat'] for f in frame_list
if (tracklet[f]['bbox'][3] - tracklet[f]['bbox'][1]
) * (tracklet[f]['bbox'][2] - tracklet[f]['bbox'][0]) > 2000
]
if len(feature_list) < 2:
feature_list = [tracklet[f]['feat'] for f in frame_list]
io_time = [
cur_bias + frame_list[0] / 10., cur_bias + frame_list[-1] / 10.
]
all_feat = np.array([feat for feat in feature_list])
mean_feat = np.mean(all_feat, axis=0)
tid_data[tid] = {
'cam': cid,
'tid': tid,
'mean_feat': mean_feat,
'zone_list': zone_list,
'frame_list': frame_list,
'tracklet': tracklet,
'io_time': io_time
}
return tid_data, mot_list_break
def sub_cluster(cid_tid_dict,
scene_cluster,
use_ff=True,
use_rerank=True,
use_camera=False,
use_st_filter=False):
'''
cid_tid_dict: all camera_id and track_id
scene_cluster: like [41, 42, 43, 44, 45, 46] in AIC21 MTMCT S06 test videos
'''
assert (len(scene_cluster) != 0), "Error: scene_cluster length equals 0"
cid_tids = sorted(
[key for key in cid_tid_dict.keys() if key[0] in scene_cluster])
if use_camera:
clu = get_labels_with_camera(
cid_tid_dict,
cid_tids,
use_ff=use_ff,
use_rerank=use_rerank,
use_st_filter=use_st_filter)
else:
clu = get_labels(
cid_tid_dict,
cid_tids,
use_ff=use_ff,
use_rerank=use_rerank,
use_st_filter=use_st_filter)
new_clu = list()
for c_list in clu:
if len(c_list) <= 1: continue
cam_list = [cid_tids[c][0] for c in c_list]
if len(cam_list) != len(set(cam_list)): continue
new_clu.append([cid_tids[c] for c in c_list])
all_clu = new_clu
cid_tid_label = dict()
for i, c_list in enumerate(all_clu):
for c in c_list:
cid_tid_label[c] = i + 1
return cid_tid_label
def gen_res(output_dir_filename,
scene_cluster,
map_tid,
mot_list_breaks,
use_roi=False,
roi_dir=''):
f_w = open(output_dir_filename, 'w')
for idx, mot_feature in enumerate(mot_list_breaks):
cid = scene_cluster[idx]
img_rects = parse_pt_gt(mot_feature)
if use_roi:
assert (roi_dir != ''), "Error: roi_dir is not empty!"
roi = cv2.imread(os.path.join(roi_dir, f'c{cid:03d}/roi.jpg'), 0)
height, width = roi.shape
for fid in img_rects:
tid_rects = img_rects[fid]
fid = int(fid) + 1
for tid_rect in tid_rects:
tid = tid_rect[0]
rect = tid_rect[1:]
cx = 0.5 * rect[0] + 0.5 * rect[2]
cy = 0.5 * rect[1] + 0.5 * rect[3]
w = rect[2] - rect[0]
w = min(w * 1.2, w + 40)
h = rect[3] - rect[1]
h = min(h * 1.2, h + 40)
rect[2] -= rect[0]
rect[3] -= rect[1]
rect[0] = max(0, rect[0])
rect[1] = max(0, rect[1])
x1, y1 = max(0, cx - 0.5 * w), max(0, cy - 0.5 * h)
if use_roi:
x2, y2 = min(width, cx + 0.5 * w), min(height, cy + 0.5 * h)
else:
x2, y2 = cx + 0.5 * w, cy + 0.5 * h
w, h = x2 - x1, y2 - y1
new_rect = list(map(int, [x1, y1, w, h]))
rect = list(map(int, rect))
if (cid, tid) in map_tid:
new_tid = map_tid[(cid, tid)]
f_w.write(
str(cid) + ' ' + str(new_tid) + ' ' + str(fid) + ' ' +
' '.join(map(str, new_rect)) + ' -1 -1'
'\n')
print('gen_res: write file in {}'.format(output_dir_filename))
f_w.close()
def print_mtmct_result(gt_file, pred_file):
names = [
'CameraId', 'Id', 'FrameId', 'X', 'Y', 'Width', 'Height', 'Xworld',
'Yworld'
]
gt = getData(gt_file, names=names)
pred = getData(pred_file, names=names)
summary = compare_dataframes_mtmc(gt, pred)
print('MTMCT summary: ', summary.columns.tolist())
formatters = {
'idf1': '{:2.2f}'.format,
'idp': '{:2.2f}'.format,
'idr': '{:2.2f}'.format,
'mota': '{:2.2f}'.format
}
summary = summary[['idf1', 'idp', 'idr', 'mota']]
summary.loc[:, 'idp'] *= 100
summary.loc[:, 'idr'] *= 100
summary.loc[:, 'idf1'] *= 100
summary.loc[:, 'mota'] *= 100
print(
mm.io.render_summary(
summary,
formatters=formatters,
namemap=mm.io.motchallenge_metric_names))
def get_mtmct_matching_results(pred_mtmct_file, secs_interval=0.5,
video_fps=20):
res = np.loadtxt(pred_mtmct_file) # 'cid, tid, fid, x1, y1, w, h, -1, -1'
carame_ids = list(map(int, np.unique(res[:, 0])))
num_track_ids = int(np.max(res[:, 1]))
num_frames = int(np.max(res[:, 2]))
res = res[:, :7]
# each line in res: 'cid, tid, fid, x1, y1, w, h'
carame_tids = []
carame_results = dict()
for c_id in carame_ids:
carame_results[c_id] = res[res[:, 0] == c_id]
tids = np.unique(carame_results[c_id][:, 1])
tids = list(map(int, tids))
carame_tids.append(tids)
# select common tids throughout each video
common_tids = reduce(np.intersect1d, carame_tids)
if len(common_tids) == 0:
print(
'No common tracked ids in these videos, please check your MOT result or select new videos.'
)
return None
# get mtmct matching results by cid_tid_fid_results[c_id][t_id][f_id]
cid_tid_fid_results = dict()
cid_tid_to_fids = dict()
interval = int(secs_interval * video_fps) # preferably less than 10
for c_id in carame_ids:
cid_tid_fid_results[c_id] = dict()
cid_tid_to_fids[c_id] = dict()
for t_id in common_tids:
tid_mask = carame_results[c_id][:, 1] == t_id
cid_tid_fid_results[c_id][t_id] = dict()
carame_trackid_results = carame_results[c_id][tid_mask]
fids = np.unique(carame_trackid_results[:, 2])
fids = fids[fids % interval == 0]
fids = list(map(int, fids))
cid_tid_to_fids[c_id][t_id] = fids
for f_id in fids:
st_frame = f_id
ed_frame = f_id + interval
st_mask = carame_trackid_results[:, 2] >= st_frame
ed_mask = carame_trackid_results[:, 2] < ed_frame
frame_mask = np.logical_and(st_mask, ed_mask)
cid_tid_fid_results[c_id][t_id][f_id] = carame_trackid_results[
frame_mask]
return carame_results, cid_tid_fid_results
def save_mtmct_crops(cid_tid_fid_res,
images_dir,
crops_dir,
width=300,
height=200):
carame_ids = cid_tid_fid_res.keys()
seqs_folder = os.listdir(images_dir)
seqs = []
for x in seqs_folder:
if os.path.isdir(os.path.join(images_dir, x)):
seqs.append(x)
assert len(seqs) == len(carame_ids)
seqs.sort()
if not os.path.exists(crops_dir):
os.makedirs(crops_dir)
common_tids = list(cid_tid_fid_res[list(carame_ids)[0]].keys())
# get crops by name 'tid_cid_fid.jpg
for t_id in common_tids:
for i, c_id in enumerate(carame_ids):
infer_dir = os.path.join(images_dir, seqs[i])
if os.path.exists(os.path.join(infer_dir, 'img1')):
infer_dir = os.path.join(infer_dir, 'img1')
all_images = os.listdir(infer_dir)
all_images.sort()
for f_id in cid_tid_fid_res[c_id][t_id].keys():
frame_idx = f_id - 1 if f_id > 0 else 0
im_path = os.path.join(infer_dir, all_images[frame_idx])
im = cv2.imread(im_path) # (H, W, 3)
track = cid_tid_fid_res[c_id][t_id][f_id][
0] # only select one track
cid, tid, fid, x1, y1, w, h = [int(v) for v in track]
clip = im[y1:(y1 + h), x1:(x1 + w)]
clip = cv2.resize(clip, (width, height))
cv2.imwrite(
os.path.join(crops_dir,
'tid{:06d}_cid{:06d}_fid{:06d}.jpg'.format(
tid, cid, fid)), clip)
print("Finish cropping image of tracked_id {} in camera: {}".format(
t_id, c_id))
def save_mtmct_vis_results(carame_results,
images_dir,
save_dir,
save_videos=False):
# carame_results: 'cid, tid, fid, x1, y1, w, h'
carame_ids = carame_results.keys()
seqs_folder = os.listdir(images_dir)
seqs = []
for x in seqs_folder:
if os.path.isdir(os.path.join(images_dir, x)):
seqs.append(x)
assert len(seqs) == len(carame_ids)
seqs.sort()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for i, c_id in enumerate(carame_ids):
print("Start visualization for camera {} of sequence {}.".format(
c_id, seqs[i]))
cid_save_dir = os.path.join(save_dir, '{}'.format(seqs[i]))
if not os.path.exists(cid_save_dir):
os.makedirs(cid_save_dir)
infer_dir = os.path.join(images_dir, seqs[i])
if os.path.exists(os.path.join(infer_dir, 'img1')):
infer_dir = os.path.join(infer_dir, 'img1')
all_images = os.listdir(infer_dir)
all_images.sort()
for f_id, im_path in enumerate(all_images):
img = cv2.imread(os.path.join(infer_dir, im_path))
tracks = carame_results[c_id][carame_results[c_id][:, 2] == f_id]
if tracks.shape[0] > 0:
tracked_ids = tracks[:, 1]
xywhs = tracks[:, 3:]
online_im = plot_tracking(
img, xywhs, tracked_ids, scores=None, frame_id=f_id)
else:
online_im = img
print('Frame {} of seq {} has no tracking results'.format(
f_id, seqs[i]))
cv2.imwrite(
os.path.join(cid_save_dir, '{:05d}.jpg'.format(f_id)),
online_im)
if f_id % 40 == 0:
print('Processing frame {}'.format(f_id))
if save_videos:
output_video_path = os.path.join(cid_save_dir, '..',
'{}_mtmct_vis.mp4'.format(seqs[i]))
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
cid_save_dir, output_video_path)
os.system(cmd_str)
print('Save camera {} video in {}.'.format(seqs[i],
output_video_path))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
"""
import os
import re
import cv2
import paddle
import numpy as np
from sklearn import preprocessing
from sklearn.cluster import AgglomerativeClustering
import gc
import motmetrics as mm
import pandas as pd
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
__all__ = [
'parse_pt', 'parse_bias', 'get_dire', 'parse_pt_gt',
'compare_dataframes_mtmc', 'get_sim_matrix', 'get_labels', 'getData',
'gen_new_mot'
]
def parse_pt(mot_feature, zones=None):
mot_list = dict()
for line in mot_feature:
fid = int(re.sub('[a-z,A-Z]', "", mot_feature[line]['frame']))
tid = mot_feature[line]['id']
bbox = list(map(lambda x: int(float(x)), mot_feature[line]['bbox']))
if tid not in mot_list:
mot_list[tid] = dict()
out_dict = mot_feature[line]
if zones is not None:
out_dict['zone'] = zones.get_zone(bbox)
else:
out_dict['zone'] = None
mot_list[tid][fid] = out_dict
return mot_list
def gen_new_mot(mot_list):
out_dict = dict()
for tracklet in mot_list:
tracklet = mot_list[tracklet]
for f in tracklet:
out_dict[tracklet[f]['imgname']] = tracklet[f]
return out_dict
def mergesetfeat1_notrk(P, neg_vector, in_feats, in_labels):
out_feats = []
for i in range(in_feats.shape[0]):
camera_id = in_labels[i, 1]
feat = in_feats[i] - neg_vector[camera_id]
feat = P[camera_id].dot(feat)
feat = feat / np.linalg.norm(feat, ord=2)
out_feats.append(feat)
out_feats = np.vstack(out_feats)
return out_feats
def compute_P2(prb_feats, gal_feats, gal_labels, la=3.0):
X = gal_feats
neg_vector = {}
u_labels = np.unique(gal_labels[:, 1])
P = {}
for label in u_labels:
curX = gal_feats[gal_labels[:, 1] == label, :]
neg_vector[label] = np.mean(curX, axis=0)
P[label] = np.linalg.inv(
curX.T.dot(curX) + curX.shape[0] * la * np.eye(X.shape[1]))
return P, neg_vector
def parse_bias(cameras_bias):
cid_bias = dict()
for cameras in cameras_bias.keys():
cameras_id = re.sub('[a-z,A-Z]', "", cameras)
cameras_id = int(cameras_id)
bias = cameras_bias[cameras]
cid_bias[cameras_id] = float(bias)
return cid_bias
def get_dire(zone_list, cid):
zs, ze = zone_list[0], zone_list[-1]
return (zs, ze)
def intracam_ignore(st_mask, cid_tids):
count = len(cid_tids)
for i in range(count):
for j in range(count):
if cid_tids[i][0] == cid_tids[j][0]:
st_mask[i, j] = 0.
return st_mask
def mergesetfeat(in_feats, in_labels, in_tracks):
trackset = list(set(list(in_tracks)))
out_feats = []
out_labels = []
for track in trackset:
feat = np.mean(in_feats[in_tracks == track], axis=0)
feat = feat / np.linalg.norm(feat, ord=2)
label = in_labels[in_tracks == track][0]
out_feats.append(feat)
out_labels.append(label)
out_feats = np.vstack(out_feats)
out_labels = np.vstack(out_labels)
return out_feats, out_labels
def mergesetfeat3(X, labels, gX, glabels, beta=0.08, knn=20, lr=0.5):
for i in range(0, X.shape[0]):
if i % 1000 == 0:
print('feat3:%d/%d' % (i, X.shape[0]))
knnX = gX[glabels[:, 1] != labels[i, 1], :]
sim = knnX.dot(X[i, :])
knnX = knnX[sim > 0, :]
sim = sim[sim > 0]
if len(sim) > 0:
idx = np.argsort(-sim)
if len(sim) > 2 * knn:
sim = sim[idx[:2 * knn]]
knnX = knnX[idx[:2 * knn], :]
else:
sim = sim[idx]
knnX = knnX[idx, :]
knn = min(knn, len(sim))
knn_pos_weight = np.exp((sim[:knn] - 1) / beta)
knn_neg_weight = np.ones(len(sim) - knn)
knn_pos_prob = knn_pos_weight / np.sum(knn_pos_weight)
knn_neg_prob = knn_neg_weight / np.sum(knn_neg_weight)
X[i, :] += lr * (knn_pos_prob.dot(knnX[:knn, :]) -
knn_neg_prob.dot(knnX[knn:, :]))
X[i, :] /= np.linalg.norm(X[i, :])
return X
def run_fic(prb_feats, gal_feats, prb_labels, gal_labels, la=3.0):
P, neg_vector = compute_P2(prb_feats, gal_feats, gal_labels, la)
prb_feats_new = mergesetfeat1_notrk(P, neg_vector, prb_feats, prb_labels)
gal_feats_new = mergesetfeat1_notrk(P, neg_vector, gal_feats, gal_labels)
return prb_feats_new, gal_feats_new
def run_fac(prb_feats,
gal_feats,
prb_labels,
gal_labels,
beta=0.08,
knn=20,
lr=0.5,
prb_epoch=2,
gal_epoch=3):
gal_feats_new = gal_feats.copy()
for i in range(prb_epoch):
gal_feats_new = mergesetfeat3(gal_feats_new, gal_labels, gal_feats,
gal_labels, beta, knn, lr)
prb_feats_new = prb_feats.copy()
for i in range(gal_epoch):
prb_feats_new = mergesetfeat3(prb_feats_new, prb_labels, gal_feats_new,
gal_labels, beta, knn, lr)
return prb_feats_new, gal_feats_new
def euclidean_distance(qf, gf):
m = qf.shape[0]
n = gf.shape[0]
dist_mat = 2 - 2 * paddle.matmul(qf, gf.t())
return dist_mat
def batch_paddle_topk(qf, gf, k1, N=6000):
m = qf.shape[0]
n = gf.shape[0]
dist_mat = []
initial_rank = []
for j in range(n // N + 1):
temp_gf = gf[j * N:j * N + N]
temp_qd = []
for i in range(m // N + 1):
temp_qf = qf[i * N:i * N + N]
temp_d = euclidean_distance(temp_qf, temp_gf)
temp_qd.append(temp_d)
temp_qd = paddle.concat(temp_qd, axis=0)
temp_qd = temp_qd / (paddle.max(temp_qd, axis=0)[0])
temp_qd = temp_qd.t()
initial_rank.append(
paddle.topk(
temp_qd, k=k1, axis=1, largest=False, sorted=True)[1])
del temp_qd
del temp_gf
del temp_qf
del temp_d
initial_rank = paddle.concat(initial_rank, axis=0).cpu().numpy()
return initial_rank
def batch_euclidean_distance(qf, gf, N=6000):
m = qf.shape[0]
n = gf.shape[0]
dist_mat = []
for j in range(n // N + 1):
temp_gf = gf[j * N:j * N + N]
temp_qd = []
for i in range(m // N + 1):
temp_qf = qf[i * N:i * N + N]
temp_d = euclidean_distance(temp_qf, temp_gf)
temp_qd.append(temp_d)
temp_qd = paddle.concat(temp_qd, axis=0)
temp_qd = temp_qd / (paddle.max(temp_qd, axis=0)[0])
dist_mat.append(temp_qd.t()) # transpose
del temp_qd
del temp_gf
del temp_qf
del temp_d
dist_mat = paddle.concat(dist_mat, axis=0)
return dist_mat
def batch_v(feat, R, all_num):
V = np.zeros((all_num, all_num), dtype=np.float32)
m = feat.shape[0]
for i in tqdm(range(m)):
temp_gf = feat[i].unsqueeze(0)
temp_qd = euclidean_distance(temp_gf, feat)
temp_qd = temp_qd / (paddle.max(temp_qd))
temp_qd = temp_qd.squeeze()
temp_qd = temp_qd.numpy()[R[i].tolist()]
temp_qd = paddle.to_tensor(temp_qd)
weight = paddle.exp(-temp_qd)
weight = (weight / paddle.sum(weight)).numpy()
V[i, R[i]] = weight.astype(np.float32)
return V
def k_reciprocal_neigh(initial_rank, i, k1):
forward_k_neigh_index = initial_rank[i, :k1 + 1]
backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
fi = np.where(backward_k_neigh_index == i)[0]
return forward_k_neigh_index[fi]
def ReRank2(probFea, galFea, k1=20, k2=6, lambda_value=0.3):
# The following naming, e.g. gallery_num, is different from outer scope.
# Don't care about it.
query_num = probFea.shape[0]
all_num = query_num + galFea.shape[0]
feat = paddle.concat([probFea, galFea], axis=0)
initial_rank = batch_paddle_topk(feat, feat, k1 + 1, N=6000)
# del feat
del probFea
del galFea
gc.collect() # empty memory
R = []
for i in tqdm(range(all_num)):
# k-reciprocal neighbors
k_reciprocal_index = k_reciprocal_neigh(initial_rank, i, k1)
k_reciprocal_expansion_index = k_reciprocal_index
for j in range(len(k_reciprocal_index)):
candidate = k_reciprocal_index[j]
candidate_k_reciprocal_index = k_reciprocal_neigh(
initial_rank, candidate, int(np.around(k1 / 2)))
if len(
np.intersect1d(candidate_k_reciprocal_index,
k_reciprocal_index)) > 2. / 3 * len(
candidate_k_reciprocal_index):
k_reciprocal_expansion_index = np.append(
k_reciprocal_expansion_index, candidate_k_reciprocal_index)
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
R.append(k_reciprocal_expansion_index)
gc.collect() # empty memory
V = batch_v(feat, R, all_num)
del R
gc.collect() # empty memory
initial_rank = initial_rank[:, :k2]
### Faster version
if k2 != 1:
V_qe = np.zeros_like(V, dtype=np.float16)
for i in range(all_num):
V_qe[i, :] = np.mean(V[initial_rank[i], :], axis=0)
V = V_qe
del V_qe
del initial_rank
gc.collect() # empty memory
invIndex = []
for i in range(all_num):
invIndex.append(np.where(V[:, i] != 0)[0])
jaccard_dist = np.zeros((query_num, all_num), dtype=np.float32)
for i in tqdm(range(query_num)):
temp_min = np.zeros(shape=[1, all_num], dtype=np.float32)
indNonZero = np.where(V[i, :] != 0)[0]
indImages = [invIndex[ind] for ind in indNonZero]
for j in range(len(indNonZero)):
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(
V[i, indNonZero[j]], V[indImages[j], indNonZero[j]])
jaccard_dist[i] = 1 - temp_min / (2. - temp_min)
del V
gc.collect() # empty memory
original_dist = batch_euclidean_distance(feat, feat[:query_num, :]).numpy()
final_dist = jaccard_dist * (1 - lambda_value
) + original_dist * lambda_value
del original_dist
del jaccard_dist
final_dist = final_dist[:query_num, query_num:]
return final_dist
def visual_rerank(prb_feats,
gal_feats,
cid_tids,
use_ff=False,
use_rerank=False):
"""Rerank by visual cures."""
gal_labels = np.array([[0, item[0]] for item in cid_tids])
prb_labels = gal_labels.copy()
if use_ff:
print('current use ff finetuned parameters....')
# Step1-1: fic. finetuned parameters: [la]
prb_feats, gal_feats = run_fic(prb_feats, gal_feats, prb_labels,
gal_labels, 3.0)
# Step1=2: fac. finetuned parameters: [beta,knn,lr,prb_epoch,gal_epoch]
prb_feats, gal_feats = run_fac(prb_feats, gal_feats, prb_labels,
gal_labels, 0.08, 20, 0.5, 1, 1)
if use_rerank:
print('current use rerank finetuned parameters....')
# Step2: k-reciprocal. finetuned parameters: [k1,k2,lambda_value]
sims = ReRank2(
paddle.to_tensor(prb_feats),
paddle.to_tensor(gal_feats), 20, 3, 0.3)
else:
# sims = ComputeEuclid(prb_feats, gal_feats, 1)
sims = 1.0 - np.dot(prb_feats, gal_feats.T)
# NOTE: sims here is actually dist, the smaller the more similar
return 1.0 - sims
# sub_cluster
def normalize(nparray, axis=0):
nparray = preprocessing.normalize(nparray, norm='l2', axis=axis)
return nparray
def get_match(cluster_labels):
cluster_dict = dict()
cluster = list()
for i, l in enumerate(cluster_labels):
if l in list(cluster_dict.keys()):
cluster_dict[l].append(i)
else:
cluster_dict[l] = [i]
for idx in cluster_dict:
cluster.append(cluster_dict[idx])
return cluster
def get_cid_tid(cluster_labels, cid_tids):
cluster = list()
for labels in cluster_labels:
cid_tid_list = list()
for label in labels:
cid_tid_list.append(cid_tids[label])
cluster.append(cid_tid_list)
return cluster
def combin_feature(cid_tid_dict, sub_cluster):
for sub_ct in sub_cluster:
if len(sub_ct) < 2: continue
mean_feat = np.array([cid_tid_dict[i]['mean_feat'] for i in sub_ct])
for i in sub_ct:
cid_tid_dict[i]['mean_feat'] = mean_feat.mean(axis=0)
return cid_tid_dict
def combin_cluster(sub_labels, cid_tids):
cluster = list()
for sub_c_to_c in sub_labels:
if len(cluster) < 1:
cluster = sub_labels[sub_c_to_c]
continue
for c_ts in sub_labels[sub_c_to_c]:
is_add = False
for i_c, c_set in enumerate(cluster):
if len(set(c_ts) & set(c_set)) > 0:
new_list = list(set(c_ts) | set(c_set))
cluster[i_c] = new_list
is_add = True
break
if not is_add:
cluster.append(c_ts)
labels = list()
num_tr = 0
for c_ts in cluster:
label_list = list()
for c_t in c_ts:
label_list.append(cid_tids.index(c_t))
num_tr += 1
label_list.sort()
labels.append(label_list)
return labels, cluster
def parse_pt_gt(mot_feature):
img_rects = dict()
for line in mot_feature:
fid = int(re.sub('[a-z,A-Z]', "", mot_feature[line]['frame']))
tid = mot_feature[line]['id']
rect = list(map(lambda x: int(float(x)), mot_feature[line]['bbox']))
if fid not in img_rects:
img_rects[fid] = list()
rect.insert(0, tid)
img_rects[fid].append(rect)
return img_rects
# eval result
def compare_dataframes_mtmc(gts, ts):
"""Compute ID-based evaluation metrics for MTMCT
Return:
df (pandas.DataFrame): Results of the evaluations in a df with only the 'idf1', 'idp', and 'idr' columns.
"""
gtds = []
tsds = []
gtcams = gts['CameraId'].drop_duplicates().tolist()
tscams = ts['CameraId'].drop_duplicates().tolist()
maxFrameId = 0
for k in sorted(gtcams):
gtd = gts.query('CameraId == %d' % k)
gtd = gtd[['FrameId', 'Id', 'X', 'Y', 'Width', 'Height']]
# max FrameId in gtd only
mfid = gtd['FrameId'].max()
gtd['FrameId'] += maxFrameId
gtd = gtd.set_index(['FrameId', 'Id'])
gtds.append(gtd)
if k in tscams:
tsd = ts.query('CameraId == %d' % k)
tsd = tsd[['FrameId', 'Id', 'X', 'Y', 'Width', 'Height']]
# max FrameId among both gtd and tsd
mfid = max(mfid, tsd['FrameId'].max())
tsd['FrameId'] += maxFrameId
tsd = tsd.set_index(['FrameId', 'Id'])
tsds.append(tsd)
maxFrameId += mfid
# compute multi-camera tracking evaluation stats
multiCamAcc = mm.utils.compare_to_groundtruth(
pd.concat(gtds), pd.concat(tsds), 'iou')
metrics = list(mm.metrics.motchallenge_metrics)
metrics.extend(['num_frames', 'idfp', 'idfn', 'idtp'])
mh = mm.metrics.create()
summary = mh.compute(multiCamAcc, metrics=metrics, name='MultiCam')
return summary
def get_sim_matrix(cid_tid_dict,
cid_tids,
use_ff=True,
use_rerank=True,
use_st_filter=False):
# Note: carame independent get_sim_matrix function,
# which is different from the one in camera_utils.py.
count = len(cid_tids)
q_arr = np.array(
[cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
g_arr = np.array(
[cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
q_arr = normalize(q_arr, axis=1)
g_arr = normalize(g_arr, axis=1)
st_mask = np.ones((count, count), dtype=np.float32)
st_mask = intracam_ignore(st_mask, cid_tids)
visual_sim_matrix = visual_rerank(
q_arr, g_arr, cid_tids, use_ff=use_ff, use_rerank=use_rerank)
visual_sim_matrix = visual_sim_matrix.astype('float32')
np.set_printoptions(precision=3)
sim_matrix = visual_sim_matrix * st_mask
np.fill_diagonal(sim_matrix, 0)
return sim_matrix
def get_labels(cid_tid_dict,
cid_tids,
use_ff=True,
use_rerank=True,
use_st_filter=False):
# 1st cluster
sub_cid_tids = list(cid_tid_dict.keys())
sub_labels = dict()
dis_thrs = [0.7, 0.5, 0.5, 0.5, 0.5, 0.7, 0.5, 0.5, 0.5, 0.5]
sim_matrix = get_sim_matrix(
cid_tid_dict,
cid_tids,
use_ff=use_ff,
use_rerank=use_rerank,
use_st_filter=use_st_filter)
cluster_labels = AgglomerativeClustering(
n_clusters=None,
distance_threshold=0.5,
affinity='precomputed',
linkage='complete').fit_predict(1 - sim_matrix)
labels = get_match(cluster_labels)
sub_cluster = get_cid_tid(labels, cid_tids)
# 2nd cluster
cid_tid_dict_new = combin_feature(cid_tid_dict, sub_cluster)
sub_labels = dict()
sim_matrix = get_sim_matrix(
cid_tid_dict_new,
cid_tids,
use_ff=use_ff,
use_rerank=use_rerank,
use_st_filter=use_st_filter)
cluster_labels = AgglomerativeClustering(
n_clusters=None,
distance_threshold=0.9,
affinity='precomputed',
linkage='complete').fit_predict(1 - sim_matrix)
labels = get_match(cluster_labels)
sub_cluster = get_cid_tid(labels, cid_tids)
return labels
def getData(fpath, names=None, sep='\s+|\t+|,'):
""" Get the necessary track data from a file handle.
Args:
fpath (str) : Original path of file reading from.
names (list[str]): List of column names for the data.
sep (str): Allowed separators regular expression string.
Return:
df (pandas.DataFrame): Data frame containing the data loaded from the
stream with optionally assigned column names. No index is set on the data.
"""
try:
df = pd.read_csv(
fpath,
sep=sep,
index_col=None,
skipinitialspace=True,
header=None,
names=names,
engine='python')
return df
except Exception as e:
raise ValueError("Could not read input from %s. Error: %s" %
(fpath, repr(e)))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
"""
import os
import cv2
import numpy as np
from sklearn.cluster import AgglomerativeClustering
BBOX_B = 10 / 15
class Zone(object):
def __init__(self, zone_path='datasets/zone'):
# 0: b 1: g 3: r 123:w
# w r not high speed
# b g high speed
assert zone_path != '', "Error: zone_path is not empty!"
zones = {}
for img_name in os.listdir(zone_path):
camnum = int(img_name.split('.')[0][-3:])
zone_img = cv2.imread(os.path.join(zone_path, img_name))
zones[camnum] = zone_img
self.zones = zones
self.current_cam = 0
def set_cam(self, cam):
self.current_cam = cam
def get_zone(self, bbox):
cx = int((bbox[0] + bbox[2]) / 2)
cy = int((bbox[1] + bbox[3]) / 2)
pix = self.zones[self.current_cam][max(cy - 1, 0), max(cx - 1, 0), :]
zone_num = 0
if pix[0] > 50 and pix[1] > 50 and pix[2] > 50: # w
zone_num = 1
if pix[0] < 50 and pix[1] < 50 and pix[2] > 50: # r
zone_num = 2
if pix[0] < 50 and pix[1] > 50 and pix[2] < 50: # g
zone_num = 3
if pix[0] > 50 and pix[1] < 50 and pix[2] < 50: # b
zone_num = 4
return zone_num
def is_ignore(self, zone_list, frame_list, cid):
# 0 not in any corssroad, 1 white 2 red 3 green 4 bule
zs, ze = zone_list[0], zone_list[-1]
fs, fe = frame_list[0], frame_list[-1]
if zs == ze:
# if always on one section, excluding
if ze in [1, 2]:
return 2
if zs != 0 and 0 in zone_list:
return 0
if fe - fs > 1500:
return 2
if fs < 2:
if cid in [45]:
if ze in [3, 4]:
return 1
else:
return 2
if fe > 1999:
if cid in [41]:
if ze not in [3]:
return 2
else:
return 0
if fs < 2 or fe > 1999:
if ze in [3, 4]:
return 0
if ze in [3, 4]:
return 1
return 2
else:
# if camera section change
if cid in [41, 42, 43, 44, 45, 46]:
# come from road extension, exclusing
if zs == 1 and ze == 2:
return 2
if zs == 2 and ze == 1:
return 2
if cid in [41]:
# On 41 camera, no vehicle come into 42 camera
if (zs in [1, 2]) and ze == 4:
return 2
if zs == 4 and (ze in [1, 2]):
return 2
if cid in [46]:
# On 46 camera,no vehicle come into 45
if (zs in [1, 2]) and ze == 3:
return 2
if zs == 3 and (ze in [1, 2]):
return 2
return 0
def filter_mot(self, mot_list, cid):
new_mot_list = dict()
sub_mot_list = dict()
for tracklet in mot_list:
tracklet_dict = mot_list[tracklet]
frame_list = list(tracklet_dict.keys())
frame_list.sort()
zone_list = []
for f in frame_list:
zone_list.append(tracklet_dict[f]['zone'])
if self.is_ignore(zone_list, frame_list, cid) == 0:
new_mot_list[tracklet] = tracklet_dict
if self.is_ignore(zone_list, frame_list, cid) == 1:
sub_mot_list[tracklet] = tracklet_dict
return new_mot_list
def filter_bbox(self, mot_list, cid):
new_mot_list = dict()
yh = self.zones[cid].shape[0]
for tracklet in mot_list:
tracklet_dict = mot_list[tracklet]
frame_list = list(tracklet_dict.keys())
frame_list.sort()
bbox_list = []
for f in frame_list:
bbox_list.append(tracklet_dict[f]['bbox'])
bbox_x = [b[0] for b in bbox_list]
bbox_y = [b[1] for b in bbox_list]
bbox_w = [b[2] - b[0] for b in bbox_list]
bbox_h = [b[3] - b[1] for b in bbox_list]
new_frame_list = list()
if 0 in bbox_x or 0 in bbox_y:
b0 = [
i for i, f in enumerate(frame_list)
if bbox_x[i] < 5 or bbox_y[i] + bbox_h[i] > yh - 5
]
if len(b0) == len(frame_list):
if cid in [41, 42, 44, 45, 46]:
continue
max_w = max(bbox_w)
max_h = max(bbox_h)
for i, f in enumerate(frame_list):
if bbox_w[i] > max_w * BBOX_B and bbox_h[
i] > max_h * BBOX_B:
new_frame_list.append(f)
else:
l_i, r_i = 0, len(frame_list) - 1
if len(b0) == 0:
continue
if b0[0] == 0:
for i in range(len(b0) - 1):
if b0[i] + 1 == b0[i + 1]:
l_i = b0[i + 1]
else:
break
if b0[-1] == len(frame_list) - 1:
for i in range(len(b0) - 1):
i = len(b0) - 1 - i
if b0[i] - 1 == b0[i - 1]:
r_i = b0[i - 1]
else:
break
max_lw, max_lh = bbox_w[l_i], bbox_h[l_i]
max_rw, max_rh = bbox_w[r_i], bbox_h[r_i]
for i, f in enumerate(frame_list):
if i < l_i:
if bbox_w[i] > max_lw * BBOX_B and bbox_h[
i] > max_lh * BBOX_B:
new_frame_list.append(f)
elif i > r_i:
if bbox_w[i] > max_rw * BBOX_B and bbox_h[
i] > max_rh * BBOX_B:
new_frame_list.append(f)
else:
new_frame_list.append(f)
new_tracklet_dict = dict()
for f in new_frame_list:
new_tracklet_dict[f] = tracklet_dict[f]
new_mot_list[tracklet] = new_tracklet_dict
else:
new_mot_list[tracklet] = tracklet_dict
return new_mot_list
def break_mot(self, mot_list, cid):
new_mot_list = dict()
new_num_tracklets = max(mot_list) + 1
for tracklet in mot_list:
tracklet_dict = mot_list[tracklet]
frame_list = list(tracklet_dict.keys())
frame_list.sort()
zone_list = []
back_tracklet = False
new_zone_f = 0
pre_frame = frame_list[0]
time_break = False
for f in frame_list:
if f - pre_frame > 100:
if cid in [44, 45]:
time_break = True
break
if not cid in [41, 44, 45, 46]:
break
pre_frame = f
new_zone = tracklet_dict[f]['zone']
if len(zone_list) > 0 and zone_list[-1] == new_zone:
continue
if new_zone_f > 1:
if len(zone_list) > 1 and new_zone in zone_list:
back_tracklet = True
zone_list.append(new_zone)
new_zone_f = 0
else:
new_zone_f += 1
if back_tracklet:
new_tracklet_dict = dict()
pre_bbox = -1
pre_arrow = 0
have_break = False
for f in frame_list:
now_bbox = tracklet_dict[f]['bbox']
if type(pre_bbox) == int:
if pre_bbox == -1:
pre_bbox = now_bbox
now_arrow = now_bbox[0] - pre_bbox[0]
if pre_arrow * now_arrow < 0 and len(
new_tracklet_dict) > 15 and not have_break:
new_mot_list[tracklet] = new_tracklet_dict
new_tracklet_dict = dict()
have_break = True
if have_break:
tracklet_dict[f]['id'] = new_num_tracklets
new_tracklet_dict[f] = tracklet_dict[f]
pre_bbox, pre_arrow = now_bbox, now_arrow
if have_break:
new_mot_list[new_num_tracklets] = new_tracklet_dict
new_num_tracklets += 1
else:
new_mot_list[tracklet] = new_tracklet_dict
elif time_break:
new_tracklet_dict = dict()
have_break = False
pre_frame = frame_list[0]
for f in frame_list:
if f - pre_frame > 100:
new_mot_list[tracklet] = new_tracklet_dict
new_tracklet_dict = dict()
have_break = True
new_tracklet_dict[f] = tracklet_dict[f]
pre_frame = f
if have_break:
new_mot_list[new_num_tracklets] = new_tracklet_dict
new_num_tracklets += 1
else:
new_mot_list[tracklet] = new_tracklet_dict
else:
new_mot_list[tracklet] = tracklet_dict
return new_mot_list
def intra_matching(self, mot_list, sub_mot_list):
sub_zone_dict = dict()
new_mot_list = dict()
new_mot_list, new_sub_mot_list = self.do_intra_matching2(mot_list,
sub_mot_list)
return new_mot_list
def do_intra_matching2(self, mot_list, sub_list):
new_zone_dict = dict()
def get_trac_info(tracklet1):
t1_f = list(tracklet1)
t1_f.sort()
t1_fs = t1_f[0]
t1_fe = t1_f[-1]
t1_zs = tracklet1[t1_fs]['zone']
t1_ze = tracklet1[t1_fe]['zone']
t1_boxs = tracklet1[t1_fs]['bbox']
t1_boxe = tracklet1[t1_fe]['bbox']
t1_boxs = [(t1_boxs[2] + t1_boxs[0]) / 2,
(t1_boxs[3] + t1_boxs[1]) / 2]
t1_boxe = [(t1_boxe[2] + t1_boxe[0]) / 2,
(t1_boxe[3] + t1_boxe[1]) / 2]
return t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe
for t1id in sub_list:
tracklet1 = sub_list[t1id]
if tracklet1 == -1:
continue
t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe = get_trac_info(
tracklet1)
sim_dict = dict()
for t2id in mot_list:
tracklet2 = mot_list[t2id]
t2_fs, t2_fe, t2_zs, t2_ze, t2_boxs, t2_boxe = get_trac_info(
tracklet2)
if t1_ze == t2_zs:
if abs(t2_fs - t1_fe) < 5 and abs(t2_boxe[0] - t1_boxs[
0]) < 50 and abs(t2_boxe[1] - t1_boxs[1]) < 50:
t1_feat = tracklet1[t1_fe]['feat']
t2_feat = tracklet2[t2_fs]['feat']
sim_dict[t2id] = np.matmul(t1_feat, t2_feat)
if t1_zs == t2_ze:
if abs(t2_fe - t1_fs) < 5 and abs(t2_boxs[0] - t1_boxe[
0]) < 50 and abs(t2_boxs[1] - t1_boxe[1]) < 50:
t1_feat = tracklet1[t1_fs]['feat']
t2_feat = tracklet2[t2_fe]['feat']
sim_dict[t2id] = np.matmul(t1_feat, t2_feat)
if len(sim_dict) > 0:
max_sim = 0
max_id = 0
for t2id in sim_dict:
if sim_dict[t2id] > max_sim:
sim_dict[t2id] = max_sim
max_id = t2id
if max_sim > 0.5:
t2 = mot_list[max_id]
for t1f in tracklet1:
if t1f not in t2:
tracklet1[t1f]['id'] = max_id
t2[t1f] = tracklet1[t1f]
mot_list[max_id] = t2
sub_list[t1id] = -1
return mot_list, sub_list
def do_intra_matching(self, sub_zone_dict, sub_zone):
new_zone_dict = dict()
id_list = list(sub_zone_dict)
id2index = dict()
for index, id in enumerate(id_list):
id2index[id] = index
def get_trac_info(tracklet1):
t1_f = list(tracklet1)
t1_f.sort()
t1_fs = t1_f[0]
t1_fe = t1_f[-1]
t1_zs = tracklet1[t1_fs]['zone']
t1_ze = tracklet1[t1_fe]['zone']
t1_boxs = tracklet1[t1_fs]['bbox']
t1_boxe = tracklet1[t1_fe]['bbox']
t1_boxs = [(t1_boxs[2] + t1_boxs[0]) / 2,
(t1_boxs[3] + t1_boxs[1]) / 2]
t1_boxe = [(t1_boxe[2] + t1_boxe[0]) / 2,
(t1_boxe[3] + t1_boxe[1]) / 2]
return t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe
sim_matrix = np.zeros([len(id_list), len(id_list)])
for t1id in sub_zone_dict:
tracklet1 = sub_zone_dict[t1id]
t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe = get_trac_info(
tracklet1)
t1_feat = tracklet1[t1_fe]['feat']
for t2id in sub_zone_dict:
if t1id == t2id:
continue
tracklet2 = sub_zone_dict[t2id]
t2_fs, t2_fe, t2_zs, t2_ze, t2_boxs, t2_boxe = get_trac_info(
tracklet2)
if t1_zs != t1_ze and t2_ze != t2_zs or t1_fe > t2_fs:
continue
if abs(t1_boxe[0] - t2_boxs[0]) > 50 or abs(t1_boxe[1] -
t2_boxs[1]) > 50:
continue
if t2_fs - t1_fe > 5:
continue
t2_feat = tracklet2[t2_fs]['feat']
sim_matrix[id2index[t1id], id2index[t2id]] = np.matmul(t1_feat,
t2_feat)
sim_matrix[id2index[t2id], id2index[t1id]] = np.matmul(t1_feat,
t2_feat)
sim_matrix = 1 - sim_matrix
cluster_labels = AgglomerativeClustering(
n_clusters=None,
distance_threshold=0.7,
affinity='precomputed',
linkage='complete').fit_predict(sim_matrix)
new_zone_dict = dict()
label2id = dict()
for index, label in enumerate(cluster_labels):
tracklet = sub_zone_dict[id_list[index]]
if label not in label2id:
new_id = tracklet[list(tracklet)[0]]
new_tracklet = dict()
else:
new_id = label2id[label]
new_tracklet = new_zone_dict[label2id[label]]
for tf in tracklet:
tracklet[tf]['id'] = new_id
new_tracklet[tf] = tracklet[tf]
new_zone_dict[label] = new_tracklet
return new_zone_dict
......@@ -84,6 +84,7 @@ class Track(object):
self.state = TrackState.Tentative
self.features = []
self.feat = feature
if feature is not None:
self.features.append(feature)
......@@ -122,6 +123,7 @@ class Track(object):
self.covariance,
detection.to_xyah())
self.features.append(detection.feature)
self.feat = detection.feature
self.cls_id = detection.cls_id
self.score = detection.score
......
......@@ -15,13 +15,13 @@
import os
import cv2
import time
import paddle
import numpy as np
import collections
__all__ = [
'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results',
'preprocess_reid', 'get_crops', 'clip_box', 'scale_coords', 'flow_statistic'
'preprocess_reid', 'get_crops', 'clip_box', 'scale_coords', 'flow_statistic',
'plot_tracking'
]
......@@ -107,21 +107,22 @@ def write_mot_results(filename, results, data_type='mot', num_classes=1):
f = open(filename, 'w')
for cls_id in range(num_classes):
for frame_id, tlwhs, tscores, track_ids in results[cls_id]:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0: continue
if data_type == 'kitti':
frame_id -= 1
elif data_type == 'mot':
if data_type == 'mot':
cls_id = -1
elif data_type == 'mcmot':
cls_id = cls_id
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
x2=x2,
y2=y2,
w=w,
h=h,
score=score,
......@@ -144,45 +145,45 @@ def load_det_results(det_file, num_frames):
# [frame_id],[x0],[y0],[w],[h],[score],[class_id]
for l in lables_with_frame:
results['bbox'].append(l[1:5])
results['score'].append(l[5])
results['cls_id'].append(l[6])
results['score'].append(l[5:6])
results['cls_id'].append(l[6:7])
results_list.append(results)
return results_list
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor[0][0]
# Note: ratio has only one value, scale_factor[0] == scale_factor[1]
#
# This function only used for JDE YOLOv3 or other detectors with
# LetterBoxResize and JDEBBoxPostProcess, coords output from detector had
# not scaled back to the origin image.
ratio = scale_factor[0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords = paddle.cast(coords, 'float32')
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = paddle.clip(coords[:, :4], min=0, max=coords[:, :4].max())
coords[:, :4] = np.clip(coords[:, :4], a_min=0, a_max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor.numpy()[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = paddle.clip(xyxy[:, 0::2], min=0, max=img0_shape[1])
xyxy[:, 1::2] = paddle.clip(xyxy[:, 1::2], min=0, max=img0_shape[0])
def clip_box(xyxy, ori_image_shape):
H, W = ori_image_shape
xyxy[:, 0::2] = np.clip(xyxy[:, 0::2], a_min=0, a_max=W)
xyxy[:, 1::2] = np.clip(xyxy[:, 1::2], a_min=0, a_max=H)
w = xyxy[:, 2:3] - xyxy[:, 0:1]
h = xyxy[:, 3:4] - xyxy[:, 1:2]
mask = paddle.logical_and(h > 0, w > 0)
keep_idx = paddle.nonzero(mask)
xyxy = paddle.gather_nd(xyxy, keep_idx[:, :1])
return xyxy, keep_idx
mask = np.logical_and(h > 0, w > 0)
keep_idx = np.nonzero(mask)
return xyxy[keep_idx[0]], keep_idx
def get_crops(xyxy, ori_img, w, h):
crops = []
xyxy = xyxy.numpy().astype(np.int64)
xyxy = xyxy.astype(np.int64)
ori_img = ori_img.numpy()
ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2)
ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2) # [h,w,3]->[w,h,3]
for i, bbox in enumerate(xyxy):
crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
crops.append(crop)
......@@ -285,3 +286,77 @@ def flow_statistic(result,
"prev_center": prev_center,
"records": records
}
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def plot_tracking(image,
tlwhs,
obj_ids,
scores=None,
frame_id=0,
fps=0.,
ids2names=[],
do_entrance_counting=False,
entrance=None):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
if fps > 0:
_line = 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs))
else:
_line = 'frame: %d num: %d' % (frame_id, len(tlwhs))
cv2.putText(
im,
_line,
(0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id))
if ids2names != []:
assert len(
ids2names) == 1, "plot_tracking only supports single classes."
id_text = '{}_'.format(ids2names[0]) + id_text
_line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id))
cv2.rectangle(
im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText(
im,
id_text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=text_thickness)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
if do_entrance_counting:
entrance_line = tuple(map(int, entrance))
cv2.rectangle(
im,
entrance_line[0:2],
entrance_line[2:4],
color=(0, 255, 255),
thickness=line_thickness)
return im
......@@ -16,6 +16,7 @@ import os
import time
import yaml
import cv2
import re
import numpy as np
from collections import defaultdict
......@@ -24,7 +25,7 @@ from paddle.inference import Config
from paddle.inference import create_predictor
from picodet_postprocess import PicoDetPostProcess
from utils import argsparser, Timer, get_current_memory_mb
from utils import argsparser, Timer, get_current_memory_mb, _is_valid_video, video2frames
from det_infer import Detector, DetectorPicoDet, get_test_images, print_arguments, PredictConfig
from det_infer import load_predictor
from benchmark_utils import PaddleInferBenchmark
......@@ -33,6 +34,10 @@ from visualize import plot_tracking
from mot.tracker import DeepSORTTracker
from mot.utils import MOTTimer, write_mot_results, flow_statistic
from mot.mtmct.utils import parse_bias
from mot.mtmct.postprocess import trajectory_fusion, sub_cluster, gen_res, print_mtmct_result
from mot.mtmct.postprocess import get_mtmct_matching_results, save_mtmct_crops, save_mtmct_vis_results
# Global dictionary
MOT_SUPPORT_MODELS = {'DeepSORT'}
......@@ -444,9 +449,62 @@ class SDE_ReID(object):
online_scores.append(tscore)
online_ids.append(tid)
return online_tlwhs, online_scores, online_ids
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
def postprocess_mtmct(self, pred_dets, pred_embs, frame_id, seq_name):
tracker = self.tracker
tracker.predict()
online_targets = tracker.update(pred_dets, pred_embs)
def predict(self, crops, pred_dets, warmup=0, repeats=1):
online_tlwhs, online_scores, online_ids = [], [], []
online_tlbrs, online_feats = [], []
for t in online_targets:
if not t.is_confirmed() or t.time_since_update > 1:
continue
tlwh = t.to_tlwh()
tscore = t.score
tid = t.track_id
if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_scores.append(tscore)
online_ids.append(tid)
online_tlbrs.append(t.to_tlbr())
online_feats.append(t.feat)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
'feat_data': {},
}
for _tlbr, _id, _feat in zip(online_tlbrs, online_ids, online_feats):
feat_data = {}
feat_data['bbox'] = _tlbr
feat_data['frame'] = f"{frame_id:06d}"
feat_data['id'] = _id
_imgname = f'{seq_name}_{_id}_{frame_id}.jpg'
feat_data['imgname'] = _imgname
feat_data['feat'] = _feat
tracking_outs['feat_data'].update({_imgname: feat_data})
return tracking_outs
def predict(self,
crops,
pred_dets,
warmup=0,
repeats=1,
MTMCT=False,
frame_id=0,
seq_name=''):
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(crops)
self.det_times.preprocess_time_s.end()
......@@ -471,12 +529,15 @@ class SDE_ReID(object):
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets,
pred_embs)
if MTMCT == False:
tracking_outs = self.postprocess(pred_dets, pred_embs)
else:
tracking_outs = self.postprocess_mtmct(pred_dets, pred_embs,
frame_id, seq_name)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids
return tracking_outs
def predict_image(detector, reid_model, image_list):
......@@ -504,11 +565,15 @@ def predict_image(detector, reid_model, image_list):
crops = reid_model.get_crops(pred_xyxys, frame)
if FLAGS.run_benchmark:
online_tlwhs, online_scores, online_ids = reid_model.predict(
tracking_outs = reid_model.predict(
crops, pred_dets, warmup=10, repeats=10)
else:
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets)
tracking_outs = reid_model.predict(crops, pred_dets)
online_tlwhs = tracking_outs['online_tlwhs']
online_scores = tracking_outs['online_scores']
online_ids = tracking_outs['online_ids']
online_im = plot_tracking(
frame, online_tlwhs, online_ids, online_scores, frame_id=i)
......@@ -570,8 +635,12 @@ def predict_video(detector, reid_model, camera_id):
else:
# reid process
crops = reid_model.get_crops(pred_xyxys, frame)
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets)
tracking_outs = reid_model.predict(crops, pred_dets)
online_tlwhs = tracking_outs['online_tlwhs']
online_scores = tracking_outs['online_scores']
online_ids = tracking_outs['online_ids']
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
# NOTE: just implement flow statistic for one class
......@@ -640,6 +709,170 @@ def predict_video(detector, reid_model, camera_id):
writer.release()
def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
fpath = os.path.join(FLAGS.mtmct_dir, seq_name)
if os.path.exists(os.path.join(fpath, 'img1')):
fpath = os.path.join(fpath, 'img1')
assert os.path.isdir(fpath), '{} should be a directory'.format(fpath)
image_list = os.listdir(fpath)
image_list.sort()
assert len(image_list) > 0, '{} has no images.'.format(fpath)
results = defaultdict(list)
mot_features_dict = {} # cid_tid_fid feats
print('Totally {} frames found in seq {}.'.format(len(image_list), seq_name))
for frame_id, img_file in enumerate(image_list):
if frame_id % 40 == 0:
print('Processing frame {} of seq {}.'.format(frame_id, seq_name))
frame = cv2.imread(os.path.join(fpath, img_file))
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled,
FLAGS.threshold)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'.
format(frame_id))
online_im = frame
else:
# reid process
crops = reid_model.get_crops(pred_xyxys, frame)
tracking_outs = reid_model.predict(
crops,
pred_dets,
MTMCT=True,
frame_id=frame_id,
seq_name=seq_name)
feat_data_dict = tracking_outs['feat_data']
mot_features_dict = dict(mot_features_dict, **feat_data_dict)
online_tlwhs = tracking_outs['online_tlwhs']
online_scores = tracking_outs['online_scores']
online_ids = tracking_outs['online_ids']
online_im = plot_tracking(frame, online_tlwhs, online_ids,
online_scores, frame_id)
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
if FLAGS.save_images:
save_dir = os.path.join(output_dir, seq_name)
if not os.path.exists(save_dir): os.makedirs(save_dir)
img_name = os.path.split(img_file)[-1]
out_path = os.path.join(save_dir, img_name)
cv2.imwrite(out_path, online_im)
if FLAGS.save_mot_txts:
result_filename = os.path.join(output_dir, seq_name + '.txt')
write_mot_results(result_filename, results)
return mot_features_dict
def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
MTMCT = mtmct_cfg['MTMCT']
assert MTMCT == True, 'predict_mtmct should be used for MTMCT.'
cameras_bias = mtmct_cfg['cameras_bias']
cid_bias = parse_bias(cameras_bias)
scene_cluster = list(cid_bias.keys())
# 1.zone releated parameters
use_zone = mtmct_cfg['use_zone']
zone_path = mtmct_cfg['zone_path']
# 2.tricks parameters, can be used for other mtmct dataset
use_ff = mtmct_cfg['use_ff']
use_rerank = mtmct_cfg['use_rerank']
# 3.camera releated parameters
use_camera = mtmct_cfg['use_camera']
use_st_filter = mtmct_cfg['use_st_filter']
# 4.zone releated parameters
use_roi = mtmct_cfg['use_roi']
roi_dir = mtmct_cfg['roi_dir']
mot_list_breaks = []
cid_tid_dict = dict()
output_dir = FLAGS.output_dir
if not os.path.exists(output_dir): os.makedirs(output_dir)
seqs = os.listdir(mtmct_dir)
seqs.sort()
for seq in seqs:
fpath = os.path.join(mtmct_dir, seq)
if os.path.isfile(fpath) and _is_valid_video(fpath):
ext = seq.split('.')[-1]
seq = seq.split('.')[-2]
print('ffmpeg processing of video {}'.format(fpath))
frames_path = video2frames(video_path=fpath, outpath=mtmct_dir, frame_rate=25)
fpath = os.path.join(mtmct_dir, seq)
if os.path.isdir(fpath) == False:
print('{} is not a image folder.'.format(fpath))
continue
mot_features_dict = predict_mtmct_seq(detector, reid_model,
seq, output_dir)
cid = int(re.sub('[a-z,A-Z]', "", seq))
tid_data, mot_list_break = trajectory_fusion(
mot_features_dict,
cid,
cid_bias,
use_zone=use_zone,
zone_path=zone_path)
mot_list_breaks.append(mot_list_break)
# single seq process
for line in tid_data:
tracklet = tid_data[line]
tid = tracklet['tid']
if (cid, tid) not in cid_tid_dict:
cid_tid_dict[(cid, tid)] = tracklet
map_tid = sub_cluster(
cid_tid_dict,
scene_cluster,
use_ff=use_ff,
use_rerank=use_rerank,
use_camera=use_camera,
use_st_filter=use_st_filter)
pred_mtmct_file = os.path.join(output_dir, 'mtmct_result.txt')
if use_camera:
gen_res(pred_mtmct_file, scene_cluster, map_tid, mot_list_breaks)
else:
gen_res(
pred_mtmct_file,
scene_cluster,
map_tid,
mot_list_breaks,
use_roi=use_roi,
roi_dir=roi_dir)
pred_mtmct_file = os.path.join(output_dir, 'mtmct_result.txt')
if FLAGS.save_images:
carame_results, cid_tid_fid_res = get_mtmct_matching_results(
pred_mtmct_file)
crops_dir = os.path.join(output_dir, 'mtmct_crops')
save_mtmct_crops(
cid_tid_fid_res, images_dir=mtmct_dir, crops_dir=crops_dir)
save_dir = os.path.join(output_dir, 'mtmct_vis')
save_mtmct_vis_results(
carame_results,
images_dir=mtmct_dir,
save_dir=save_dir,
save_videos=FLAGS.save_images)
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector_func = 'SDE_Detector'
......@@ -675,6 +908,13 @@ def main():
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, reid_model, FLAGS.camera_id)
elif FLAGS.mtmct_dir is not None:
mtmct_cfg_file = FLAGS.mtmct_cfg
with open(mtmct_cfg_file) as f:
mtmct_cfg = yaml.safe_load(f)
predict_mtmct(detector, reid_model, FLAGS.mtmct_dir, mtmct_cfg)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
......
# config for MTMCT
MTMCT: True
cameras_bias:
c041: 0
c042: 0
# 1.zone releated parameters
use_zone: True
zone_path: dataset/mot/aic21mtmct_vehicle/S06/zone
# 2.tricks parameters, can be used for other mtmct dataset
use_ff: True
use_rerank: True
# 3.camera releated parameters
use_camera: True
use_st_filter: False
# 4.zone releated parameters
use_roi: True
roi_dir: dataset/mot/aic21mtmct_vehicle/S06
......@@ -14,6 +14,7 @@
import time
import os
import sys
import ast
import argparse
......@@ -135,6 +136,13 @@ def argsparser():
"--draw_center_traj",
action='store_true',
help="Whether drawing the trajectory of center")
parser.add_argument(
"--mtmct_dir",
type=str,
default=None,
help="The MTMCT scene video folder.")
parser.add_argument(
"--mtmct_cfg", type=str, default=None, help="The MTMCT config.")
return parser
......@@ -243,3 +251,38 @@ def get_current_memory_mb():
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem = meminfo.used / 1024. / 1024.
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
def video2frames(video_path, outpath, frame_rate=25, **kargs):
def _dict2str(kargs):
cmd_str = ''
for k, v in kargs.items():
cmd_str += (' ' + str(k) + ' ' + str(v))
return cmd_str
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = os.path.basename(video_path).split('.')[0]
out_full_path = os.path.join(outpath, vid_name)
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = os.path.join(out_full_path, '%05d.jpg')
cmd = ffmpeg
cmd = ffmpeg + [
' -i ', video_path, ' -r ', str(frame_rate), ' -f image2 ', outformat
]
cmd = ''.join(cmd) + _dict2str(kargs)
if os.system(cmd) != 0:
raise RuntimeError('ffmpeg process video: {} error'.format(video_path))
sys.exit(-1)
sys.stdout.flush()
return out_full_path
def _is_valid_video(f, extensions=('.mp4', '.avi', '.mov', '.rmvb', '.flv')):
return f.lower().endswith(extensions)
# coding: utf-8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,7 +18,6 @@ import os
import cv2
import numpy as np
from PIL import Image, ImageDraw
import math
from collections import deque
......@@ -135,13 +133,10 @@ def plot_tracking(image,
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
radius = max(5, int(im_w / 140.))
cv2.putText(
im,
'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
......@@ -205,14 +200,10 @@ def plot_tracking_dict(image,
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
radius = max(5, int(im_w / 140.))
if num_classes == 1:
start = records[-1].find('Total')
end = records[-1].find('In')
......
......@@ -19,8 +19,10 @@ from __future__ import print_function
import os
import cv2
import glob
import re
import paddle
import numpy as np
import os.path as osp
from collections import defaultdict
from ppdet.core.workspace import create
......@@ -176,6 +178,7 @@ class Tracker(object):
save_dir=None,
show_image=False,
frame_rate=30,
seq_name='',
scaled=False,
det_file='',
draw_threshold=0):
......@@ -200,23 +203,31 @@ class Tracker(object):
logger.info('Processing frame {} ({:.2f} fps)'.format(
frame_id, 1. / max(1e-5, timer.average_time)))
ori_image = data['ori_image']
ori_image = data['ori_image'] # [bs, H, W, 3]
ori_image_shape = data['ori_image'].shape[1:3]
# ori_image_shape: [H, W]
input_shape = data['image'].shape[2:]
im_shape = data['im_shape']
scale_factor = data['scale_factor']
# input_shape: [h, w], before data transforms, set in model config
im_shape = data['im_shape'][0].numpy()
# im_shape: [new_h, new_w], after data transforms
scale_factor = data['scale_factor'][0].numpy()
empty_detections = False
# when it has no detected bboxes, will not inference reid model
# and if visualize, use original image instead
# forward
timer.tic()
if not use_detector:
dets = dets_list[frame_id]
bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32')
bbox_tlwh = np.array(dets['bbox'], dtype='float32')
if bbox_tlwh.shape[0] > 0:
# detector outputs: pred_cls_ids, pred_scores, pred_bboxes
pred_cls_ids = paddle.to_tensor(
dets['cls_id'], dtype='float32').unsqueeze(1)
pred_scores = paddle.to_tensor(
dets['score'], dtype='float32').unsqueeze(1)
pred_bboxes = paddle.concat(
pred_cls_ids = np.array(dets['cls_id'], dtype='float32')
pred_scores = np.array(dets['score'], dtype='float32')
pred_bboxes = np.concatenate(
(bbox_tlwh[:, 0:2],
bbox_tlwh[:, 2:4] + bbox_tlwh[:, 0:2]),
axis=1)
......@@ -224,16 +235,21 @@ class Tracker(object):
logger.warning(
'Frame {} has not object, try to modify score threshold.'.
format(frame_id))
frame_id += 1
continue
empty_detections = True
else:
outs = self.model.detector(data)
if outs['bbox_num'] > 0:
outs['bbox'] = outs['bbox'].numpy()
outs['bbox_num'] = outs['bbox_num'].numpy()
if outs['bbox_num'] > 0 and empty_detections == False:
# detector outputs: pred_cls_ids, pred_scores, pred_bboxes
pred_cls_ids = outs['bbox'][:, 0:1]
pred_scores = outs['bbox'][:, 1:2]
if not scaled:
# scaled means whether the coords after detector outputs
# Note: scaled=False only in JDE YOLOv3 or other detectors
# with LetterBoxResize and JDEBBoxPostProcess.
#
# 'scaled' means whether the coords after detector outputs
# have been scaled back to the original image, set True
# in general detector, set False in JDE YOLOv3.
pred_bboxes = scale_coords(outs['bbox'][:, 2:],
......@@ -243,20 +259,36 @@ class Tracker(object):
pred_bboxes = outs['bbox'][:, 2:]
else:
logger.warning(
'Frame {} has not object, try to modify score threshold.'.
'Frame {} has not detected object, try to modify score threshold.'.
format(frame_id))
frame_id += 1
continue
empty_detections = True
if not empty_detections:
pred_xyxys, keep_idx = clip_box(pred_bboxes, ori_image_shape)
if len(keep_idx[0]) == 0:
logger.warning(
'Frame {} has not detected object left after clip_box.'.
format(frame_id))
empty_detections = True
if empty_detections:
timer.toc()
# if visualize, use original image instead
online_ids, online_tlwhs, online_scores = None, None, None
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
frame_id += 1
# thus will not inference reid model
continue
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor)
pred_scores = paddle.gather_nd(pred_scores, keep_idx).unsqueeze(1)
pred_cls_ids = paddle.gather_nd(pred_cls_ids, keep_idx).unsqueeze(1)
pred_tlwhs = paddle.concat(
pred_scores = pred_scores[keep_idx[0]]
pred_cls_ids = pred_cls_ids[keep_idx[0]]
pred_tlwhs = np.concatenate(
(pred_xyxys[:, 0:2],
pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1),
axis=1)
pred_dets = paddle.concat(
pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1)
tracker = self.model.tracker
......@@ -268,8 +300,7 @@ class Tracker(object):
crops = paddle.to_tensor(crops)
data.update({'crops': crops})
pred_embs = self.model(data)
pred_dets, pred_embs = pred_dets.numpy(), pred_embs.numpy()
pred_embs = self.model(data).numpy()
tracker.predict()
online_targets = tracker.update(pred_dets, pred_embs)
......@@ -361,6 +392,7 @@ class Tracker(object):
save_dir=save_dir,
show_image=show_image,
frame_rate=frame_rate,
seq_name=seq,
scaled=scaled,
det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq)))
......@@ -417,19 +449,19 @@ class Tracker(object):
logger.info("Found {} inference images in total.".format(len(images)))
return images
def mot_predict(self,
video_file,
frame_rate,
image_dir,
output_dir,
data_type='mot',
model_type='JDE',
save_images=False,
save_videos=True,
show_image=False,
scaled=False,
det_results_dir='',
draw_threshold=0.5):
def mot_predict_seq(self,
video_file,
frame_rate,
image_dir,
output_dir,
data_type='mot',
model_type='JDE',
save_images=False,
save_videos=True,
show_image=False,
scaled=False,
det_results_dir='',
draw_threshold=0.5):
assert video_file is not None or image_dir is not None, \
"--video_file or --image_dir should be set."
assert video_file is None or os.path.isfile(video_file), \
......@@ -452,6 +484,8 @@ class Tracker(object):
logger.info('Starting tracking video {}'.format(video_file))
elif image_dir:
seq = image_dir.split('/')[-1].split('.')[0]
if os.path.exists(os.path.join(image_dir, 'img1')):
image_dir = os.path.join(image_dir, 'img1')
images = [
'{}/{}'.format(image_dir, x) for x in os.listdir(image_dir)
]
......@@ -484,6 +518,7 @@ class Tracker(object):
save_dir=save_dir,
show_image=show_image,
frame_rate=frame_rate,
seq_name=seq,
scaled=scaled,
det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq)),
......@@ -491,9 +526,6 @@ class Tracker(object):
else:
raise ValueError(model_type)
write_mot_results(result_filename, results, data_type,
self.cfg.num_classes)
if save_videos:
output_video_path = os.path.join(save_dir, '..',
'{}_vis.mp4'.format(seq))
......@@ -501,3 +533,6 @@ class Tracker(object):
save_dir, output_video_path)
os.system(cmd_str)
logger.info('Save video in {}'.format(output_video_path))
write_mot_results(result_filename, results, data_type,
self.cfg.num_classes)
......@@ -87,6 +87,7 @@ class Track(object):
self.state = TrackState.Tentative
self.features = []
self.feat = feature
if feature is not None:
self.features.append(feature)
......@@ -125,6 +126,7 @@ class Track(object):
self.covariance,
detection.to_xyah())
self.features.append(detection.feature)
self.feat = detection.feature
self.cls_id = detection.cls_id
self.score = detection.score
......
......@@ -15,9 +15,8 @@
import os
import cv2
import time
import paddle
import numpy as np
from .visualization import plot_tracking_dict
from .visualization import plot_tracking_dict, plot_tracking
__all__ = [
'MOTTimer',
......@@ -157,14 +156,26 @@ def save_vis_results(data,
if show_image or save_dir is not None:
assert 'ori_image' in data
img0 = data['ori_image'].numpy()[0]
online_im = plot_tracking_dict(
img0,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / average_time)
if online_ids is None:
online_im = img0
else:
if isinstance(online_tlwhs, dict):
online_im = plot_tracking_dict(
img0,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / average_time)
else:
online_im = plot_tracking(
img0,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / average_time)
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
......@@ -186,45 +197,45 @@ def load_det_results(det_file, num_frames):
# [frame_id],[x0],[y0],[w],[h],[score],[class_id]
for l in lables_with_frame:
results['bbox'].append(l[1:5])
results['score'].append(l[5])
results['cls_id'].append(l[6])
results['score'].append(l[5:6])
results['cls_id'].append(l[6:7])
results_list.append(results)
return results_list
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor[0][0]
# Note: ratio has only one value, scale_factor[0] == scale_factor[1]
#
# This function only used for JDE YOLOv3 or other detectors with
# LetterBoxResize and JDEBBoxPostProcess, coords output from detector had
# not scaled back to the origin image.
ratio = scale_factor[0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords = paddle.cast(coords, 'float32')
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = paddle.clip(coords[:, :4], min=0, max=coords[:, :4].max())
coords[:, :4] = np.clip(coords[:, :4], a_min=0, a_max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor.numpy()[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = paddle.clip(xyxy[:, 0::2], min=0, max=img0_shape[1])
xyxy[:, 1::2] = paddle.clip(xyxy[:, 1::2], min=0, max=img0_shape[0])
def clip_box(xyxy, ori_image_shape):
H, W = ori_image_shape
xyxy[:, 0::2] = np.clip(xyxy[:, 0::2], a_min=0, a_max=W)
xyxy[:, 1::2] = np.clip(xyxy[:, 1::2], a_min=0, a_max=H)
w = xyxy[:, 2:3] - xyxy[:, 0:1]
h = xyxy[:, 3:4] - xyxy[:, 1:2]
mask = paddle.logical_and(h > 0, w > 0)
keep_idx = paddle.nonzero(mask)
xyxy = paddle.gather_nd(xyxy, keep_idx[:, :1])
return xyxy, keep_idx
mask = np.logical_and(h > 0, w > 0)
keep_idx = np.nonzero(mask)
return xyxy[keep_idx[0]], keep_idx
def get_crops(xyxy, ori_img, w, h):
crops = []
xyxy = xyxy.numpy().astype(np.int64)
xyxy = xyxy.astype(np.int64)
ori_img = ori_img.numpy()
ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2)
ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2) # [h,w,3]->[w,h,3]
for i, bbox in enumerate(xyxy):
crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
crops.append(crop)
......
......@@ -53,7 +53,8 @@ def plot_tracking(image,
obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id))
if ids2names != []:
assert len(ids2names) == 1, "plot_tracking only supports single classes."
assert len(
ids2names) == 1, "plot_tracking only supports single classes."
id_text = '{}_'.format(ids2names[0]) + id_text
_line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id))
......
......@@ -103,7 +103,7 @@ def run(FLAGS, cfg):
tracker.load_weights_jde(cfg.weights)
# inference
tracker.mot_predict(
tracker.mot_predict_seq(
video_file=FLAGS.video_file,
frame_rate=FLAGS.frame_rate,
image_dir=FLAGS.image_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册