未验证 提交 8a87d99e 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] unify tracker for single class and multi-calss MOT (#4403)

* add mcfairmot train eval export code

* fix eval and export

* update modelzoo

* add hrnet18_dlafpn

* unify jdetracker

* fix deploy infer for single class

* fix multi class fairmot deploy

* fix mcmot data source and deploy

* fix num_identities

* fix num_identities_dict
上级 56705e1e
metric: MCMOT
num_classes: 10
# using VisDrone2019 MOT dataset with 10 classes as default, you can modify it for your needs.
# for MCMOT training
TrainDataset:
!MCMOTDataSet
dataset_dir: dataset/mot
image_lists: ['visdrone_mcmot.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
label_list: label_list.txt
# for MCMOT evaluation
# If you want to change the MCMOT evaluation dataset, please modify 'data_root'
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: visdrone_mcmot/images/val
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
# for MCMOT video inference
TestMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video
......@@ -33,7 +33,6 @@ CenterNetHead:
FairMOTEmbeddingHead:
ch_head: 256
ch_emb: 128
num_identifiers: 14455 # for mix dataset (Caltech, CityPersons, CUHK-SYSU, PRW, ETHZ and MOT16)
CenterNetPostProcess:
max_per_img: 500
......
English | [简体中文](README_cn.md)
# MCFairMOT (Multi-class FairMOT)
## Table of Contents
- [Introduction](#Introduction)
- [Model Zoo](#Model_Zoo)
- [Getting Start](#Getting_Start)
- [Citations](#Citations)
## Introduction
MCFairMOT is the Multi-class extended version of [FairMOT](https://arxiv.org/abs/2004.01888).
## Model Zoo
### MCFairMOT DLA-34 Results on VisDrone2019 Val Set
| backbone | input shape | MOTA | IDF1 | IDS | FPS | download | config |
| :--------------| :------- | :----: | :----: | :---: | :------: | :----: |:----: |
| DLA-34 | 1088x608 | 24.3 | 41.6 | 2314 | - |[model](https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams) | [config](./mcfairmot_dla34_30e_1088x608_visdrone.yml) |
| HRNetV2-W18 | 1088x608 | 20.4 | 39.9 | 2603 | - |[model](https://paddledet.bj.bcebos.com/models/mot/mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone.pdparams) | [config](./mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone.yml) |
**Notes:**
MOTA is the average MOTA of 10 catecories in the VisDrone2019 MOT dataset, and its value is also equal to the average MOTA of all the evaluated video sequences.
## Getting Start
### 1. Training
Training MCFairMOT on 8 GPUs with following command
```bash
python -m paddle.distributed.launch --log_dir=./mcfairmot_dla34_30e_1088x608_visdrone/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml
```
### 2. Evaluation
Evaluating the track performance of MCFairMOT on val dataset in single GPU with following commands:
```bash
# use weights released in PaddleDetection model zoo
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams
# use saved checkpoint in training
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=output/mcfairmot_dla34_30e_1088x608_visdrone/model_final.pdparams
```
**Notes:**
The default evaluation dataset is VisDrone2019 MOT val-set. If you want to change the evaluation dataset, please refer to the following code and modify `configs/datasets/mcmot.yml`
```
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: your_dataset/images/val
keep_ori_im: False # set True if save visualization images or video
```
### 3. Inference
Inference a vidoe on single GPU with following command:
```bash
# inference on video and save a video
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams --video_file={your video name}.mp4 --save_videos
```
**Notes:**
Please make sure that [ffmpeg](https://ffmpeg.org/ffmpeg.html) is installed first, on Linux(Ubuntu) platform you can directly install it by the following command:`apt-get update && apt-get install -y ffmpeg`.
### 4. Export model
```bash
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams
```
### 5. Using exported model for python inference
```bash
python deploy/python/mot_jde_infer.py --model_dir=output_inference/mcfairmot_dla34_30e_1088x608_visdrone --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**Notes:**
The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts` to save the txt result file, or `--save_images` to save the visualization images.
## Citations
```
@article{zhang2020fair,
title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking},
author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu},
journal={arXiv preprint arXiv:2004.01888},
year={2020}
}
@ARTICLE{9573394,
author={Zhu, Pengfei and Wen, Longyin and Du, Dawei and Bian, Xiao and Fan, Heng and Hu, Qinghua and Ling, Haibin},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={Detection and Tracking Meet Drones Challenge},
year={2021},
volume={},
number={},
pages={1-1},
doi={10.1109/TPAMI.2021.3119563}
}
```
简体中文 | [English](README.md)
# MCFairMOT (Multi-class FairMOT)
## 内容
- [简介](#简介)
- [模型库](#模型库)
- [快速开始](#快速开始)
- [引用](#引用)
## 内容
MCFairMOT是[FairMOT](https://arxiv.org/abs/2004.01888)的多类别扩展版本。
## 模型库
### MCFairMOT DLA-34 在VisDrone2019 MOT val-set上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :----: | :----: | :---: | :------: | :----: |:----: |
| DLA-34 | 1088x608 | 24.3 | 41.6 | 2314 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams) | [配置文件](./mcfairmot_dla34_30e_1088x608_visdrone.yml) |
| HRNetV2-W18 | 1088x608 | 20.4 | 39.9 | 2603 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone.pdparams) | [配置文件](./mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone.yml) |
**注意:**
MOTA是VisDrone2019 MOT数据集10类目标的平均MOTA, 其值也等于所有评估的视频序列的平均MOTA。
## 快速开始
### 1. 训练
使用8个GPU通过如下命令一键式启动训练
```bash
python -m paddle.distributed.launch --log_dir=./mcfairmot_dla34_30e_1088x608_visdrone/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml
```
### 2. 评估
使用单张GPU通过如下命令一键式启动评估
```bash
# 使用PaddleDetection发布的权重
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams
# 使用训练保存的checkpoint
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=output/mcfairmot_dla34_30e_1088x608_visdrone/model_final.pdparams
```
**注意:**
默认评估的是VisDrone2019 MOT val-set数据集, 如需换评估数据集可参照以下代码修改`configs/datasets/mcmot.yml`
```
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: your_dataset/images/val
keep_ori_im: False # set True if save visualization images or video
```
### 3. 预测
使用单个GPU通过如下命令预测一个视频,并保存为视频
```bash
# 预测一个视频
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams --video_file={your video name}.mp4 --save_videos
```
**注意:**
请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`
### 4. 导出预测模型
```bash
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams
```
### 5. 用导出的模型基于Python去预测
```bash
python deploy/python/mot_jde_infer.py --model_dir=output_inference/mcfairmot_dla34_30e_1088x608_visdrone --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**注意:**
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
## 引用
```
@article{zhang2020fair,
title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking},
author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu},
journal={arXiv preprint arXiv:2004.01888},
year={2020}
}
@ARTICLE{9573394,
author={Zhu, Pengfei and Wen, Longyin and Du, Dawei and Bian, Xiao and Fan, Heng and Hu, Qinghua and Ling, Haibin},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={Detection and Tracking Meet Drones Challenge},
year={2021},
volume={},
number={},
pages={1-1},
doi={10.1109/TPAMI.2021.3119563}
}
```
_BASE_: [
'../fairmot/fairmot_dla34_30e_1088x608.yml',
'../../datasets/mcmot.yml'
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/fairmot_dla34_crowdhuman_pretrained.pdparams
FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker # multi-class tracker
CenterNetHead:
regress_ltrb: False
CenterNetPostProcess:
for_mot: True
regress_ltrb: False
max_per_img: 200
JDETracker:
min_box_area: 0
vertical_ratio: 0 # no need to filter bboxes according to w/h
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
weights: output/mcfairmot_dla34_30e_1088x608_visdrone/model_final
epoch: 30
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [10, 20]
use_warmup: False
OptimizerBuilder:
optimizer:
type: Adam
regularizer: NULL
_BASE_: [
'../fairmot/fairmot_hrnetv2_w18_dlafpn_30e_1088x608.yml',
'../../datasets/mcmot.yml'
]
architecture: FairMOT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/HRNet_W18_C_pretrained.pdparams
for_mot: True
FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker # multi-class tracker
CenterNetHead:
regress_ltrb: False
CenterNetPostProcess:
regress_ltrb: False
max_per_img: 200
JDETracker:
min_box_area: 0
vertical_ratio: 0 # no need to filter bboxes according to w/h
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
weights: output/mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone/model_final
epoch: 30
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [10, 20]
use_warmup: False
OptimizerBuilder:
optimizer:
type: Adam
regularizer: NULL
TrainReader:
batch_size: 8
......@@ -17,18 +17,20 @@ import time
import yaml
import cv2
import numpy as np
import paddle
from benchmark_utils import PaddleInferBenchmark
from preprocess import preprocess
from tracker import JDETracker
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as MOTTimer
from collections import defaultdict
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from preprocess import preprocess
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, PredictConfig
from benchmark_utils import PaddleInferBenchmark
from ppdet.modeling.mot.tracker import JDETracker
from ppdet.modeling.mot.visualization import plot_tracking_dict
from ppdet.modeling.mot.utils import MOTTimer, write_mot_results
# Global dictionary
MOT_SUPPORT_MODELS = {
......@@ -80,13 +82,17 @@ class JDE_Detector(Detector):
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
assert pred_config.tracker, "Tracking model should have tracker"
self.num_classes = len(pred_config.labels)
tp = pred_config.tracker
min_box_area = tp['min_box_area'] if 'min_box_area' in tp else 200
vertical_ratio = tp['vertical_ratio'] if 'vertical_ratio' in tp else 1.6
conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0.
tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7
metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean'
self.tracker = JDETracker(
num_classes=self.num_classes,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
conf_thres=conf_thres,
......@@ -94,25 +100,25 @@ class JDE_Detector(Detector):
metric_type=metric_type)
def postprocess(self, pred_dets, pred_embs, threshold):
online_targets = self.tracker.update(pred_dets, pred_embs)
if online_targets == []:
# First few frames, the model may have no tracking results but have
# detection results,use the detection results instead, and set id -1.
return [pred_dets[0][:4]], [pred_dets[0][4]], [-1]
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < threshold: continue
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(tscore)
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id]
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < threshold: continue
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs[cls_id].append(tlwh)
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
return online_tlwhs, online_scores, online_ids
def predict(self, image_list, threshold=0.5, warmup=0, repeats=1):
......@@ -121,7 +127,7 @@ class JDE_Detector(Detector):
image_list (list): list of image
threshold (float): threshold of predicted box' score
Returns:
online_tlwhs, online_scores, online_ids (np.ndarray)
online_tlwhs, online_scores, online_ids (dict[np.array])
'''
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
......@@ -157,38 +163,12 @@ class JDE_Detector(Detector):
return online_tlwhs, online_scores, online_ids
def write_mot_results(filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
x2=x2,
y2=y2,
w=w,
h=h,
score=score)
f.write(line)
def predict_image(detector, image_list):
results = []
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
image_list.sort()
for i, img_file in enumerate(image_list):
for frame_id, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
detector.predict([frame], FLAGS.threshold, warmup=10, repeats=10)
......@@ -196,12 +176,12 @@ def predict_image(detector, image_list):
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file))
print('Test iter {}, file name:{}'.format(frame_id, img_file))
else:
online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold)
online_im = mot_vis.plot_tracking(
frame, online_tlwhs, online_ids, online_scores, frame_id=i)
online_im = plot_tracking_dict(frame, num_classes, online_tlwhs,
online_ids, online_scores, frame_id)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
......@@ -233,7 +213,9 @@ def predict_video(detector, camera_id):
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
results = []
results = defaultdict(list) # support single class and multi classes
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
while (1):
ret, frame = capture.read()
if not ret:
......@@ -243,10 +225,14 @@ def predict_video(detector, camera_id):
[frame], FLAGS.threshold)
timer.toc()
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
for cls_id in range(num_classes):
results[cls_id].append((frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
fps = 1. / timer.average_time
im = mot_vis.plot_tracking(
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
......@@ -261,14 +247,6 @@ def predict_video(detector, camera_id):
else:
writer.write(im)
if FLAGS.save_mot_txt_per_img:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
result_filename = os.path.join(save_dir,
'{:05d}.txt'.format(frame_id))
write_mot_results(result_filename, [results[-1]])
frame_id += 1
print('detect frame: %d' % (frame_id))
if camera_id != -1:
......@@ -278,7 +256,8 @@ def predict_video(detector, camera_id):
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
write_mot_results(result_filename, results, data_type, num_classes)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
......
......@@ -15,23 +15,23 @@
import os
import cv2
import math
import copy
import numpy as np
from collections import defaultdict
import paddle
import copy
from mot_keypoint_unite_utils import argsparser
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
from utils import get_current_memory_mb
from infer import Detector, PredictConfig, print_arguments, get_test_images
from visualize import draw_pose
from benchmark_utils import PaddleInferBenchmark
from utils import Timer
from tracker import JDETracker
from mot_jde_infer import JDE_Detector, write_mot_results
from infer import Detector, PredictConfig, print_arguments, get_test_images
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as FPSTimer
from utils import get_current_memory_mb
from mot_keypoint_unite_utils import argsparser
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
from det_keypoint_unite_infer import predict_with_given_det, bench_log
from mot_jde_infer import JDE_Detector
from ppdet.modeling.mot.visualization import plot_tracking_dict
from ppdet.modeling.mot.utils import MOTTimer as FPSTimer
from ppdet.modeling.mot.utils import write_mot_results
# Global dictionary
KEYPOINT_SUPPORT_MODELS = {
......@@ -56,6 +56,9 @@ def mot_keypoint_unite_predict_image(mot_model,
keypoint_model,
image_list,
keypoint_batch_size=1):
num_classes = mot_model.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
data_type = 'mot'
image_list.sort()
for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
......@@ -104,9 +107,13 @@ def mot_keypoint_unite_predict_image(mot_model,
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown'
else None)
online_im = mot_vis.plot_tracking(
im, online_tlwhs, online_ids, online_scores, frame_id=i)
online_im = plot_tracking_dict(
im,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=i)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
......@@ -143,7 +150,13 @@ def mot_keypoint_unite_predict_video(mot_model,
timer_mot = FPSTimer()
timer_kp = FPSTimer()
timer_mot_kp = FPSTimer()
mot_results = []
# support single class and multi classes, but should be single class here
mot_results = defaultdict(list)
num_classes = mot_model.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
data_type = 'mot'
while (1):
ret, frame = capture.read()
if not ret:
......@@ -153,15 +166,15 @@ def mot_keypoint_unite_predict_video(mot_model,
online_tlwhs, online_scores, online_ids = mot_model.predict(
[frame], FLAGS.mot_threshold)
timer_mot.toc()
mot_results.append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
mot_results[0].append(
(frame_id + 1, online_tlwhs[0], online_scores[0], online_ids[0]))
mot_fps = 1. / timer_mot.average_time
timer_kp.tic()
keypoint_arch = keypoint_model.pred_config.arch
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown':
results = convert_mot_to_det(online_tlwhs, online_scores)
results = convert_mot_to_det(online_tlwhs[0], online_scores[0])
keypoint_results = predict_with_given_det(
frame, results, keypoint_model, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold,
......@@ -184,8 +197,9 @@ def mot_keypoint_unite_predict_video(mot_model,
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' else
None)
online_im = mot_vis.plot_tracking(
online_im = plot_tracking_dict(
im,
num_classes,
online_tlwhs,
online_ids,
online_scores,
......@@ -212,7 +226,7 @@ def mot_keypoint_unite_predict_video(mot_model,
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, mot_results)
write_mot_results(result_filename, mot_results, data_type, num_classes)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
......
......@@ -22,7 +22,7 @@ from benchmark_utils import PaddleInferBenchmark
from preprocess import preprocess
from tracker import DeepSORTTracker
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as MOTTimer
from ppdet.modeling.mot.utils import MOTTimer
from paddle.inference import Config
from paddle.inference import create_predictor
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import jde_tracker
from . import deepsort_tracker
from .jde_tracker import *
from .deepsort_tracker import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np
from ppdet.modeling.mot.matching import jde_matching as matching
from ppdet.modeling.mot.motion import KalmanFilter
from ppdet.modeling.mot.tracker.base_jde_tracker import TrackState, BaseTrack, STrack
from ppdet.modeling.mot.tracker.base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
__all__ = ['JDETracker']
class JDETracker(object):
"""
JDE tracker
Args:
det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set -1
means no need to filter bboxes.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
r_tracked_thresh (float): linear assignment threshold of
tracked stracks and unmatched detections
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
motion (object): KalmanFilter instance
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def __init__(self,
det_thresh=0.3,
track_buffer=30,
min_box_area=200,
vertical_ratio=1.6,
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
motion='KalmanFilter',
conf_thres=0,
metric_type='euclidean'):
self.det_thresh = det_thresh
self.track_buffer = track_buffer
self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh
self.motion = KalmanFilter()
self.conf_thres = conf_thres
self.metric_type = metric_type
self.frame_id = 0
self.tracked_stracks = []
self.lost_stracks = []
self.removed_stracks = []
self.max_time_lost = 0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def update(self, pred_dets, pred_embs):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles
lost, removed, refound and active tracklets.
Args:
pred_dets (Tensor): Detection results of the image, shape is [N, 5].
pred_embs (Tensor): Embedding results of the image, shape is [N, 512].
Return:
output_stracks (list): The list contains information regarding the
online_tracklets for the recieved image tensor.
"""
self.frame_id += 1
activated_starcks = []
# for storing active tracks, for the current frame
refind_stracks = []
# Lost Tracks whose detections are obtained in the current frame
lost_stracks = []
# The tracks which are not obtained in the current frame but are not
# removed. (Lost for some time lesser than the threshold for removing)
removed_stracks = []
remain_inds = np.nonzero(pred_dets[:, 4] > self.conf_thres)
if len(remain_inds) == 0:
pred_dets = np.zeros([0, 1])
pred_embs = np.zeros([0, 1])
else:
pred_dets = pred_dets[remain_inds]
pred_embs = pred_embs[remain_inds]
# Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]]
empty_pred = True if len(pred_dets) == 1 and np.sum(
pred_dets) == 0.0 else False
""" Step 1: Network forward, get detections & embeddings"""
if len(pred_dets) > 0 and not empty_pred:
detections = [
STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30)
for (tlbrs, f) in zip(pred_dets, pred_embs)
]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:
if not track.is_activated:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed.append(track)
else:
# Active tracks are added to the local list 'tracked_stracks'
tracked_stracks.append(track)
""" Step 2: First association, with embedding"""
# Combining currently tracked_stracks and lost_stracks
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool, self.motion)
dists = matching.embedding_distance(
strack_pool, detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists, strack_pool,
detections)
# The dists is the list of distances of the detection with the tracks in strack_pool
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.tracked_thresh)
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
for itracked, idet in matches:
# itracked is the id of the track and idet is the detection
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
# If the track is active, add the detection to the track
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
# We have obtained a detection from a track which is not active,
# hence put the track in refind_stracks list
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU"""
detections = [detections[i] for i in u_detection]
# detections is now a list of the unmatched detections
r_tracked_stracks = []
# This is container for stracks which were tracked till the previous
# frame but no detection was found for it in the current frame.
for i in u_track:
if strack_pool[i].state == TrackState.Tracked:
r_tracked_stracks.append(strack_pool[i])
dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.r_tracked_thresh)
# matches is the list of detections which matched with corresponding
# tracks by IOU distance method.
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# Same process done for some unmatched detections, but now considering IOU_distance as measure
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(
dists, thresh=self.unconfirmed_thresh)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
# The tracks which are yet not matched
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
removed_stracks.append(track)
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.det_thresh:
continue
track.activate(self.motion, self.frame_id)
activated_starcks.append(track)
""" Step 5: Update state"""
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_stracks.append(track)
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
self.tracked_stracks = [
t for t in self.tracked_stracks if t.state == TrackState.Tracked
]
self.tracked_stracks = joint_stracks(self.tracked_stracks,
activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks,
refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
self.tracked_stracks, self.lost_stracks)
# get scores of lost tracks
output_stracks = [
track for track in self.tracked_stracks if track.is_activated
]
return output_stracks
......@@ -90,9 +90,12 @@ def get_categories(metric_type, anno_file=None, arch=None):
elif metric_type.lower() in ['mot', 'motdet', 'reid']:
return _mot_category()
elif metric_type.lower() in ['kitti', 'bdd100k']:
elif metric_type.lower() in ['kitti', 'bdd100kmot']:
return _mot_category(category='car')
elif metric_type.lower() in ['mcmot']:
return _visdrone_category()
else:
raise ValueError("unknown metric type {}".format(metric_type))
......@@ -825,3 +828,21 @@ def _oid19_category():
}
return clsid2catid, catid2name
def _visdrone_category():
clsid2catid = {i: i for i in range(10)}
catid2name = {
0: 'pedestrian',
1: 'people',
2: 'bicycle',
3: 'car',
4: 'van',
5: 'truck',
6: 'tricycle',
7: 'awning-tricycle',
8: 'bus',
9: 'motor'
}
return clsid2catid, catid2name
......@@ -17,7 +17,7 @@ import sys
import cv2
import glob
import numpy as np
from collections import OrderedDict
from collections import OrderedDict, defaultdict
try:
from collections.abc import Sequence
except Exception:
......@@ -32,7 +32,8 @@ logger = setup_logger(__name__)
@serializable
class MOTDataSet(DetDataset):
"""
Load dataset with MOT format.
Load dataset with MOT format, only support single class MOT.
Args:
dataset_dir (str): root directory for dataset.
image_lists (str|list): mot data image lists, muiti-source mot dataset.
......@@ -152,18 +153,17 @@ class MOTDataSet(DetDataset):
self.tid_start_index[k] = last_index
last_index += v
self.total_identities = int(last_index + 1)
self.num_identities_dict = defaultdict(int)
self.num_identities_dict[0] = int(last_index + 1) # single class
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data)
logger.info('=' * 80)
logger.info('MOT dataset summary: ')
logger.info(self.tid_num)
logger.info('total images: {}'.format(self.total_imgs))
logger.info('image start index: {}'.format(self.img_start_index))
logger.info('total identities: {}'.format(self.total_identities))
logger.info('identity start index: {}'.format(self.tid_start_index))
logger.info('=' * 80)
logger.info('Total images: {}'.format(self.total_imgs))
logger.info('Image start index: {}'.format(self.img_start_index))
logger.info('Total identities: {}'.format(self.num_identities_dict[0]))
logger.info('Identity start index: {}'.format(self.tid_start_index))
records = []
cname2cid = mot_label()
......@@ -222,9 +222,223 @@ class MOTDataSet(DetDataset):
self.roidbs, self.cname2cid = records, cname2cid
def mot_label():
labels_map = {'person': 0}
return labels_map
@register
@serializable
class MCMOTDataSet(DetDataset):
"""
Load dataset with MOT format, support multi-class MOT.
Args:
dataset_dir (str): root directory for dataset.
image_lists (list(str)): mcmot data image lists, muiti-source mcmot dataset.
data_fields (list): key name of data dictionary, at least have 'image'.
label_list (str): if use_default_label is False, will load
mapping between category and class index.
sample_num (int): number of samples to load, -1 means all.
Notes:
MCMOT datasets root directory following this:
dataset/mot
|——————image_lists
| |——————visdrone_mcmot.train
| |——————visdrone_mcmot.val
visdrone_mcmot
|——————images
| └——————train
| └——————val
└——————labels_with_ids
└——————train
"""
def __init__(self,
dataset_dir=None,
image_lists=[],
data_fields=['image'],
label_list=None,
sample_num=-1):
super(MCMOTDataSet, self).__init__(
dataset_dir=dataset_dir,
data_fields=data_fields,
sample_num=sample_num)
self.dataset_dir = dataset_dir
self.image_lists = image_lists
if isinstance(self.image_lists, str):
self.image_lists = [self.image_lists]
self.label_list = label_list
self.roidbs = None
self.cname2cid = None
def get_anno(self):
if self.image_lists == []:
return
# only used to get categories and metric
return os.path.join(self.dataset_dir, 'image_lists',
self.image_lists[0])
def parse_dataset(self):
self.img_files = OrderedDict()
self.img_start_index = OrderedDict()
self.label_files = OrderedDict()
self.tid_num = OrderedDict()
self.tid_start_idx_of_cls_ids = defaultdict(dict) # for MCMOT
img_index = 0
for data_name in self.image_lists:
# check every data image list
image_lists_dir = os.path.join(self.dataset_dir, 'image_lists')
assert os.path.isdir(image_lists_dir), \
"The {} is not a directory.".format(image_lists_dir)
list_path = os.path.join(image_lists_dir, data_name)
assert os.path.exists(list_path), \
"The list path {} does not exist.".format(list_path)
# record img_files, filter out empty ones
with open(list_path, 'r') as file:
self.img_files[data_name] = file.readlines()
self.img_files[data_name] = [
os.path.join(self.dataset_dir, x.strip())
for x in self.img_files[data_name]
]
self.img_files[data_name] = list(
filter(lambda x: len(x) > 0, self.img_files[data_name]))
self.img_start_index[data_name] = img_index
img_index += len(self.img_files[data_name])
# record label_files
self.label_files[data_name] = [
x.replace('images', 'labels_with_ids').replace(
'.png', '.txt').replace('.jpg', '.txt')
for x in self.img_files[data_name]
]
for data_name, label_paths in self.label_files.items():
# using max_ids_dict rather than max_index
max_ids_dict = defaultdict(int)
for lp in label_paths:
lb = np.loadtxt(lp)
if len(lb) < 1:
continue
lb = lb.reshape(-1, 6)
for item in lb:
if item[1] > max_ids_dict[int(item[0])]:
# item[0]: cls_id
# item[1]: track id
max_ids_dict[int(item[0])] = int(item[1])
# track id number
self.tid_num[data_name] = max_ids_dict
last_idx_dict = defaultdict(int)
for i, (k, v) in enumerate(self.tid_num.items()): # each sub dataset
for cls_id, id_num in v.items(): # v is a max_ids_dict
self.tid_start_idx_of_cls_ids[k][cls_id] = last_idx_dict[cls_id]
last_idx_dict[cls_id] += id_num
self.num_identities_dict = defaultdict(int)
for k, v in last_idx_dict.items():
self.num_identities_dict[k] = int(v) # total ids of each category
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data)
# cname2cid and cid2cname
cname2cid = {}
if self.label_list:
# if use label_list for multi source mix dataset,
# please make sure label_list in the first sub_dataset at least.
sub_dataset = self.image_lists[0].split('.')[0]
label_path = os.path.join(self.dataset_dir, sub_dataset,
self.label_list)
if not os.path.exists(label_path):
raise ValueError("label_list {} does not exists".format(
label_path))
with open(label_path, 'r') as fr:
label_id = 0
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else:
cname2cid = visdrone_mcmot_label()
cid2cname = dict([(v, k) for (k, v) in cname2cid.items()])
logger.info('MCMOT dataset summary: ')
logger.info(self.tid_num)
logger.info('Total images: {}'.format(self.total_imgs))
logger.info('Image start index: {}'.format(self.img_start_index))
logger.info('Total identities of each category: ')
self.num_identities_dict = sorted(
self.num_identities_dict.items(), key=lambda x: x[0])
total_IDs_all_cats = 0
for (k, v) in self.num_identities_dict:
logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k],
v))
total_IDs_all_cats += v
logger.info('Total identities of all categories: {}'.format(
total_IDs_all_cats))
logger.info('Identity start index of each category: ')
for k, v in self.tid_start_idx_of_cls_ids.items():
sorted_v = sorted(v.items(), key=lambda x: x[0])
for (cls_id, start_idx) in sorted_v:
logger.info('Start index of dataset {} category {:d} is {:d}'
.format(k, cls_id, start_idx))
records = []
for img_index in range(self.total_imgs):
for i, (k, v) in enumerate(self.img_start_index.items()):
if img_index >= v:
data_name = list(self.label_files.keys())[i]
start_index = v
img_file = self.img_files[data_name][img_index - start_index]
lbl_file = self.label_files[data_name][img_index - start_index]
if not os.path.exists(img_file):
logger.warning('Illegal image file: {}, and it will be ignored'.
format(img_file))
continue
if not os.path.isfile(lbl_file):
logger.warning('Illegal label file: {}, and it will be ignored'.
format(lbl_file))
continue
labels = np.loadtxt(lbl_file, dtype=np.float32).reshape(-1, 6)
# each row in labels (N, 6) is [gt_class, gt_identity, cx, cy, w, h]
cx, cy = labels[:, 2], labels[:, 3]
w, h = labels[:, 4], labels[:, 5]
gt_bbox = np.stack((cx, cy, w, h)).T.astype('float32')
gt_class = labels[:, 0:1].astype('int32')
gt_score = np.ones((len(labels), 1)).astype('float32')
gt_ide = labels[:, 1:2].astype('int32')
for i, _ in enumerate(gt_ide):
if gt_ide[i] > -1:
cls_id = int(gt_class[i])
start_idx = self.tid_start_idx_of_cls_ids[data_name][cls_id]
gt_ide[i] += start_idx
mot_rec = {
'im_file': img_file,
'im_id': img_index,
} if 'image' in self.data_fields else {}
gt_rec = {
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_ide': gt_ide,
}
for k, v in gt_rec.items():
if k in self.data_fields:
mot_rec[k] = v
records.append(mot_rec)
if self.sample_num > 0 and img_index >= self.sample_num:
break
assert len(records) > 0, 'not found any mot record in %s' % (
self.image_lists)
self.roidbs, self.cname2cid = records, cname2cid
@register
......@@ -382,3 +596,24 @@ def video2frames(video_path, outpath, frame_rate, **kargs):
sys.stdout.flush()
return out_full_path
def mot_label():
labels_map = {'person': 0}
return labels_map
def visdrone_mcmot_label():
labels_map = {
'pedestrian': 0,
'people': 1,
'bicycle': 2,
'car': 3,
'van': 4,
'truck': 5,
'tricycle': 6,
'awning-tricycle': 7,
'bus': 8,
'motor': 9,
}
return labels_map
......@@ -556,6 +556,11 @@ class Gt2FairMOTTarget(Gt2TTFTarget):
index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
reid = np.zeros((self.max_objs, ), dtype=np.int64)
bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
if self.num_classes > 1:
# each category corresponds to a set of track ids
cls_tr_ids = np.zeros(
(self.num_classes, output_h, output_w), dtype=np.int64)
cls_id_map = np.full((output_h, output_w), -1, dtype=np.int64)
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
......@@ -598,6 +603,10 @@ class Gt2FairMOTTarget(Gt2TTFTarget):
index_mask[k] = 1
reid[k] = ide
bbox_xys[k] = bbox_xy
if self.num_classes > 1:
cls_id_map[ct_int[1], ct_int[0]] = cls_id
cls_tr_ids[cls_id][ct_int[1]][ct_int[0]] = ide - 1
# track id start from 0
sample['heatmap'] = heatmap
sample['index'] = index
......@@ -605,6 +614,9 @@ class Gt2FairMOTTarget(Gt2TTFTarget):
sample['size'] = bbox_size
sample['index_mask'] = index_mask
sample['reid'] = reid
if self.num_classes > 1:
sample['cls_id_map'] = cls_id_map
sample['cls_tr_ids'] = cls_tr_ids
sample['bbox_xys'] = bbox_xys
sample.pop('is_crowd', None)
sample.pop('difficult', None)
......
......@@ -21,14 +21,15 @@ import cv2
import glob
import paddle
import numpy as np
from collections import defaultdict
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import Timer, load_det_results
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric
from ppdet.metrics import MCMOTMetric
import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback
......@@ -74,6 +75,8 @@ class Tracker(object):
if self.cfg.metric == 'MOT':
self._metrics = [MOTMetric(), ]
elif self.cfg.metric == 'MCMOT':
self._metrics = [MCMOTMetric(self.cfg.num_classes), ]
elif self.cfg.metric == 'KITTI':
self._metrics = [KITTIMOTMetric(), ]
else:
......@@ -121,43 +124,49 @@ class Tracker(object):
tracker = self.model.tracker
tracker.max_time_lost = int(frame_rate / 30.0 * tracker.track_buffer)
timer = Timer()
results = []
timer = MOTTimer()
frame_id = 0
self.status['mode'] = 'track'
self.model.eval()
results = defaultdict(list) # support single class and multi classes
for step_id, data in enumerate(dataloader):
self.status['step_id'] = step_id
if frame_id % 40 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(
frame_id, 1. / max(1e-5, timer.average_time)))
# forward
timer.tic()
pred_dets, pred_embs = self.model(data)
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = [], [], []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < draw_threshold: continue
if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(tscore)
timer.toc()
pred_dets, pred_embs = pred_dets.numpy(), pred_embs.numpy()
online_targets_dict = self.model.tracker.update(pred_dets,
pred_embs)
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
for cls_id in range(self.cfg.num_classes):
online_targets = online_targets_dict[cls_id]
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio:
continue
online_tlwhs[cls_id].append(tlwh)
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
# save results
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
# save results
results.append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir)
timer.toc()
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
frame_id += 1
return results, frame_id, timer.average_time, timer.calls
......@@ -174,7 +183,7 @@ class Tracker(object):
if not os.path.exists(save_dir): os.makedirs(save_dir)
use_detector = False if not self.model.detector else True
timer = Timer()
timer = MOTTimer()
results = []
frame_id = 0
self.status['mode'] = 'track'
......@@ -284,9 +293,9 @@ class Tracker(object):
# save results
results.append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir)
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
frame_id += 1
return results, frame_id, timer.average_time, timer.calls
......@@ -305,37 +314,39 @@ class Tracker(object):
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root)
assert data_type in ['mot', 'kitti'], \
"data_type should be 'mot' or 'kitti'"
assert data_type in ['mot', 'mcmot', 'kitti'], \
"data_type should be 'mot', 'mcmot' or 'kitti'"
assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
"model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
# run tracking
n_frame = 0
timer_avgs, timer_calls = [], []
for seq in seqs:
if not os.path.isdir(os.path.join(data_root, seq)):
infer_dir = os.path.join(data_root, seq)
if not os.path.exists(infer_dir) or not os.path.isdir(infer_dir):
logger.warning("Seq {} error, {} has no images.".format(
seq, infer_dir))
continue
infer_dir = os.path.join(data_root, seq, 'img1')
if os.path.exists(os.path.join(infer_dir, 'img1')):
infer_dir = os.path.join(infer_dir, 'img1')
frame_rate = 30
seqinfo = os.path.join(data_root, seq, 'seqinfo.ini')
if not os.path.exists(seqinfo) or not os.path.exists(
infer_dir) or not os.path.isdir(infer_dir):
continue
if os.path.exists(seqinfo):
meta_info = open(seqinfo).read()
frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
meta_info.find('\nseqLength')])
save_dir = os.path.join(output_dir, 'mot_outputs',
seq) if save_images or save_videos else None
logger.info('start seq: {}'.format(seq))
images = self.get_infer_images(infer_dir)
self.dataset.set_images(images)
self.dataset.set_images(self.get_infer_images(infer_dir))
dataloader = create('EvalMOTReader')(self.dataset, 0)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
meta_info = open(seqinfo).read()
frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
meta_info.find('\nseqLength')])
with paddle.no_grad():
if model_type in ['JDE', 'FairMOT']:
results, nf, ta, tc = self._eval_seq_jde(
......@@ -355,7 +366,8 @@ class Tracker(object):
else:
raise ValueError(model_type)
self.write_mot_results(result_filename, results, data_type)
write_mot_results(result_filename, results, data_type,
self.cfg.num_classes)
n_frame += nf
timer_avgs.append(ta)
timer_calls.append(tc)
......@@ -427,8 +439,8 @@ class Tracker(object):
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root)
assert data_type in ['mot', 'kitti'], \
"data_type should be 'mot' or 'kitti'"
assert data_type in ['mot', 'mcmot', 'kitti'], \
"data_type should be 'mot', 'mcmot' or 'kitti'"
assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
"model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
......@@ -478,7 +490,8 @@ class Tracker(object):
else:
raise ValueError(model_type)
self.write_mot_results(result_filename, results, data_type)
write_mot_results(result_filename, results, data_type,
self.cfg.num_classes)
if save_videos:
output_video_path = os.path.join(save_dir, '..',
......@@ -487,52 +500,3 @@ class Tracker(object):
save_dir, output_video_path)
os.system(cmd_str)
logger.info('Save video in {}'.format(output_video_path))
def write_mot_results(self, filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
x2=x2,
y2=y2,
w=w,
h=h,
score=score)
f.write(line)
logger.info('MOT results save in {}'.format(filename))
def save_results(self, data, frame_id, online_ids, online_tlwhs,
online_scores, average_time, show_image, save_dir):
if show_image or save_dir is not None:
assert 'ori_image' in data
img0 = data['ori_image'].numpy()[0]
online_im = mot_vis.plot_tracking(
img0,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / average_time)
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im)
......@@ -41,7 +41,7 @@ from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats
from ppdet.utils import profiler
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter,SniperProposalsGenerator
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
from .export_utils import _dump_infer_config, _prune_input_spec
from ppdet.utils.logger import setup_logger
......@@ -77,11 +77,12 @@ class Trainer(object):
if cfg.architecture == 'JDE' and self.mode == 'train':
cfg['JDEEmbeddingHead'][
'num_identifiers'] = self.dataset.total_identities
'num_identities'] = self.dataset.num_identities_dict[0]
# JDE only support single class MOT now.
if cfg.architecture == 'FairMOT' and self.mode == 'train':
cfg['FairMOTEmbeddingHead'][
'num_identifiers'] = self.dataset.total_identities
cfg['FairMOTEmbeddingHead']['num_identities_dict'] = self.dataset.num_identities_dict
# FairMOT support single class and multi-class MOT now.
# build model
if 'model' not in self.cfg:
......@@ -192,7 +193,7 @@ class Trainer(object):
IouType=IouType,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == "SNIPERCOCO": # sniper
elif self.cfg.metric == "SNIPERCOCO": # sniper
self._metrics = [
SNIPERCOCOMetric(
anno_file=anno_file,
......@@ -202,8 +203,7 @@ class Trainer(object):
output_eval=output_eval,
bias=bias,
IouType=IouType,
save_prediction_only=save_prediction_only
)
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'RBOX':
# TODO: bias should be unified
......@@ -516,7 +516,8 @@ class Trainer(object):
results.append(outs)
# sniper
if type(self.dataset) == SniperCOCODataSet:
results = self.dataset.anno_cropper.aggregate_chips_detections(results)
results = self.dataset.anno_cropper.aggregate_chips_detections(
results)
for outs in results:
batch_res = get_infer_results(outs, clsid2catid)
......
......@@ -22,4 +22,8 @@ __all__ = metrics.__all__ + keypoint_metrics.__all__
from . import mot_metrics
from .mot_metrics import *
__all__ = __all__ + mot_metrics.__all__
__all__ = metrics.__all__ + mot_metrics.__all__
from . import mcmot_metrics
from .mcmot_metrics import *
__all__ = metrics.__all__ + mcmot_metrics.__all__
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import sys
import math
from collections import defaultdict
from motmetrics.math_util import quiet_divide
import numpy as np
import pandas as pd
import paddle
import paddle.nn.functional as F
from .metrics import Metric
import motmetrics as mm
import openpyxl
metrics = mm.metrics.motchallenge_metrics
mh = mm.metrics.create()
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['MCMOTEvaluator', 'MCMOTMetric']
METRICS_LIST = [
'num_frames', 'num_matches', 'num_switches', 'num_transfer', 'num_ascend',
'num_migrate', 'num_false_positives', 'num_misses', 'num_detections',
'num_objects', 'num_predictions', 'num_unique_objects', 'mostly_tracked',
'partially_tracked', 'mostly_lost', 'num_fragmentations', 'motp', 'mota',
'precision', 'recall', 'idfp', 'idfn', 'idtp', 'idp', 'idr', 'idf1'
]
NAME_MAP = {
'num_frames': 'num_frames',
'num_matches': 'num_matches',
'num_switches': 'IDs',
'num_transfer': 'IDt',
'num_ascend': 'IDa',
'num_migrate': 'IDm',
'num_false_positives': 'FP',
'num_misses': 'FN',
'num_detections': 'num_detections',
'num_objects': 'num_objects',
'num_predictions': 'num_predictions',
'num_unique_objects': 'GT',
'mostly_tracked': 'MT',
'partially_tracked': 'partially_tracked',
'mostly_lost': 'ML',
'num_fragmentations': 'FM',
'motp': 'MOTP',
'mota': 'MOTA',
'precision': 'Prcn',
'recall': 'Rcll',
'idfp': 'idfp',
'idfn': 'idfn',
'idtp': 'idtp',
'idp': 'IDP',
'idr': 'IDR',
'idf1': 'IDF1'
}
def parse_accs_metrics(seq_acc, index_name, verbose=False):
"""
Parse the evaluation indicators of multiple MOTAccumulator
"""
mh = mm.metrics.create()
summary = MCMOTEvaluator.get_summary(seq_acc, index_name, METRICS_LIST)
summary.loc['OVERALL', 'motp'] = (summary['motp'] * summary['num_detections']).sum() / \
summary.loc['OVERALL', 'num_detections']
if verbose:
strsummary = mm.io.render_summary(
summary, formatters=mh.formatters, namemap=NAME_MAP)
print(strsummary)
return summary
def seqs_overall_metrics(summary_df, verbose=False):
"""
Calculate overall metrics for multiple sequences
"""
add_col = [
'num_frames', 'num_matches', 'num_switches', 'num_transfer',
'num_ascend', 'num_migrate', 'num_false_positives', 'num_misses',
'num_detections', 'num_objects', 'num_predictions',
'num_unique_objects', 'mostly_tracked', 'partially_tracked',
'mostly_lost', 'num_fragmentations', 'idfp', 'idfn', 'idtp'
]
calc_col = ['motp', 'mota', 'precision', 'recall', 'idp', 'idr', 'idf1']
calc_df = summary_df.copy()
overall_dic = {}
for col in add_col:
overall_dic[col] = calc_df[col].sum()
for col in calc_col:
overall_dic[col] = getattr(MCMOTMetricOverall, col + '_overall')(
calc_df, overall_dic)
overall_df = pd.DataFrame(overall_dic, index=['overall_calc'])
calc_df = pd.concat([calc_df, overall_df])
if verbose:
mh = mm.metrics.create()
str_calc_df = mm.io.render_summary(
calc_df, formatters=mh.formatters, namemap=NAME_MAP)
print(str_calc_df)
return calc_df
class MCMOTMetricOverall(object):
def motp_overall(summary_df, overall_dic):
motp = quiet_divide((summary_df['motp'] *
summary_df['num_detections']).sum(),
overall_dic['num_detections'])
return motp
def mota_overall(summary_df, overall_dic):
del summary_df
mota = 1. - quiet_divide(
(overall_dic['num_misses'] + overall_dic['num_switches'] +
overall_dic['num_false_positives']), overall_dic['num_objects'])
return mota
def precision_overall(summary_df, overall_dic):
del summary_df
precision = quiet_divide(overall_dic['num_detections'], (
overall_dic['num_false_positives'] + overall_dic['num_detections']))
return precision
def recall_overall(summary_df, overall_dic):
del summary_df
recall = quiet_divide(overall_dic['num_detections'],
overall_dic['num_objects'])
return recall
def idp_overall(summary_df, overall_dic):
del summary_df
idp = quiet_divide(overall_dic['idtp'],
(overall_dic['idtp'] + overall_dic['idfp']))
return idp
def idr_overall(summary_df, overall_dic):
del summary_df
idr = quiet_divide(overall_dic['idtp'],
(overall_dic['idtp'] + overall_dic['idfn']))
return idr
def idf1_overall(summary_df, overall_dic):
del summary_df
idf1 = quiet_divide(2. * overall_dic['idtp'], (
overall_dic['num_objects'] + overall_dic['num_predictions']))
return idf1
def read_mcmot_results_union(filename, is_gt, is_ignore):
results_dict = dict()
if os.path.isfile(filename):
all_result = np.loadtxt(filename, delimiter=',')
if all_result.shape[0] == 0 or all_result.shape[1] < 7:
return results_dict
if is_ignore:
return results_dict
if is_gt:
# only for test use
all_result = all_result[all_result[:, 7] != 0]
all_result[:, 7] = all_result[:, 7] - 1
if all_result.shape[0] == 0:
return results_dict
class_unique = np.unique(all_result[:, 7])
last_max_id = 0
result_cls_list = []
for cls in class_unique:
result_cls_split = all_result[all_result[:, 7] == cls]
result_cls_split[:, 1] = result_cls_split[:, 1] + last_max_id
# make sure track id different between every category
last_max_id = max(np.unique(result_cls_split[:, 1])) + 1
result_cls_list.append(result_cls_split)
results_con = np.concatenate(result_cls_list)
for line in range(len(results_con)):
linelist = results_con[line]
fid = int(linelist[0])
if fid < 1:
continue
results_dict.setdefault(fid, list())
if is_gt:
score = 1
else:
score = float(linelist[6])
tlwh = tuple(map(float, linelist[2:6]))
target_id = int(linelist[1])
cls = int(linelist[7])
results_dict[fid].append((tlwh, target_id, cls, score))
return results_dict
def read_mcmot_results(filename, is_gt, is_ignore):
results_dict = dict()
if os.path.isfile(filename):
with open(filename, 'r') as f:
for line in f.readlines():
linelist = line.strip().split(',')
if len(linelist) < 7:
continue
fid = int(linelist[0])
if fid < 1:
continue
cid = int(linelist[7])
if is_gt:
score = 1
# only for test use
cid -= 1
else:
score = float(linelist[6])
cls_result_dict = results_dict.setdefault(cid, dict())
cls_result_dict.setdefault(fid, list())
tlwh = tuple(map(float, linelist[2:6]))
target_id = int(linelist[1])
cls_result_dict[fid].append((tlwh, target_id, score))
return results_dict
def read_results(filename,
data_type,
is_gt=False,
is_ignore=False,
multi_class=False,
union=False):
if data_type in ['mcmot', 'lab']:
if multi_class:
if union:
# The results are evaluated by union all the categories.
# Track IDs between different categories cannot be duplicate.
read_fun = read_mcmot_results_union
else:
# The results are evaluated separately by category.
read_fun = read_mcmot_results
else:
raise ValueError('multi_class: {}, MCMOT should have cls_id.'.
format(multi_class))
else:
raise ValueError('Unknown data type: {}'.format(data_type))
return read_fun(filename, is_gt, is_ignore)
def unzip_objs(objs):
if len(objs) > 0:
tlwhs, ids, scores = zip(*objs)
else:
tlwhs, ids, scores = [], [], []
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
return tlwhs, ids, scores
def unzip_objs_cls(objs):
if len(objs) > 0:
tlwhs, ids, cls, scores = zip(*objs)
else:
tlwhs, ids, cls, scores = [], [], [], []
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
ids = np.array(ids)
cls = np.array(cls)
scores = np.array(scores)
return tlwhs, ids, cls, scores
class MCMOTEvaluator(object):
def __init__(self, data_root, seq_name, data_type, num_classes):
self.data_root = data_root
self.seq_name = seq_name
self.data_type = data_type
self.num_classes = num_classes
self.load_annotations()
self.reset_accumulator()
self.class_accs = []
def load_annotations(self):
assert self.data_type == 'mcmot'
self.gt_filename = os.path.join(self.data_root, '../', '../',
'sequences',
'{}.txt'.format(self.seq_name))
def reset_accumulator(self):
import motmetrics as mm
mm.lap.default_solver = 'lap'
self.acc = mm.MOTAccumulator(auto_id=True)
def eval_frame_dict(self, trk_objs, gt_objs, rtn_events=False, union=False):
import motmetrics as mm
mm.lap.default_solver = 'lap'
if union:
trk_tlwhs, trk_ids, trk_cls = unzip_objs_cls(trk_objs)[:3]
gt_tlwhs, gt_ids, gt_cls = unzip_objs_cls(gt_objs)[:3]
# get distance matrix
iou_distance = mm.distances.iou_matrix(
gt_tlwhs, trk_tlwhs, max_iou=0.5)
# Set the distance between objects of different categories to nan
gt_cls_len = len(gt_cls)
trk_cls_len = len(trk_cls)
# When the number of GT or Trk is 0, iou_distance dimension is (0,0)
if gt_cls_len != 0 and trk_cls_len != 0:
gt_cls = gt_cls.reshape(gt_cls_len, 1)
gt_cls = np.repeat(gt_cls, trk_cls_len, axis=1)
trk_cls = trk_cls.reshape(1, trk_cls_len)
trk_cls = np.repeat(trk_cls, gt_cls_len, axis=0)
iou_distance = np.where(gt_cls == trk_cls, iou_distance, np.nan)
else:
trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
# get distance matrix
iou_distance = mm.distances.iou_matrix(
gt_tlwhs, trk_tlwhs, max_iou=0.5)
self.acc.update(gt_ids, trk_ids, iou_distance)
if rtn_events and iou_distance.size > 0 and hasattr(self.acc,
'mot_events'):
events = self.acc.mot_events # only supported by https://github.com/longcw/py-motmetrics
else:
events = None
return events
def eval_file(self, result_filename):
# evaluation of each category
gt_frame_dict = read_results(
self.gt_filename,
self.data_type,
is_gt=True,
multi_class=True,
union=False)
result_frame_dict = read_results(
result_filename,
self.data_type,
is_gt=False,
multi_class=True,
union=False)
for cid in range(self.num_classes):
self.reset_accumulator()
cls_result_frame_dict = result_frame_dict.setdefault(cid, dict())
cls_gt_frame_dict = gt_frame_dict.setdefault(cid, dict())
# only labeled frames will be evaluated
frames = sorted(list(set(cls_gt_frame_dict.keys())))
for frame_id in frames:
trk_objs = cls_result_frame_dict.get(frame_id, [])
gt_objs = cls_gt_frame_dict.get(frame_id, [])
self.eval_frame_dict(trk_objs, gt_objs, rtn_events=False)
self.class_accs.append(self.acc)
return self.class_accs
@staticmethod
def get_summary(accs,
names,
metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1',
'precision', 'recall')):
import motmetrics as mm
mm.lap.default_solver = 'lap'
names = copy.deepcopy(names)
if metrics is None:
metrics = mm.metrics.motchallenge_metrics
metrics = copy.deepcopy(metrics)
mh = mm.metrics.create()
summary = mh.compute_many(
accs, metrics=metrics, names=names, generate_overall=True)
return summary
@staticmethod
def save_summary(summary, filename):
import pandas as pd
writer = pd.ExcelWriter(filename)
summary.to_excel(writer)
writer.save()
class MCMOTMetric(Metric):
def __init__(self, num_classes, save_summary=False):
self.num_classes = num_classes
self.save_summary = save_summary
self.MCMOTEvaluator = MCMOTEvaluator
self.result_root = None
self.reset()
self.seqs_overall = defaultdict(list)
def reset(self):
self.accs = []
self.seqs = []
def update(self, data_root, seq, data_type, result_root, result_filename):
evaluator = self.MCMOTEvaluator(data_root, seq, data_type,
self.num_classes)
seq_acc = evaluator.eval_file(result_filename)
self.accs.append(seq_acc)
self.seqs.append(seq)
self.result_root = result_root
cls_index_name = [
'{}_{}'.format(seq, i) for i in range(self.num_classes)
]
summary = parse_accs_metrics(seq_acc, cls_index_name)
summary.rename(
index={'OVERALL': '{}_OVERALL'.format(seq)}, inplace=True)
for row in range(len(summary)):
self.seqs_overall[row].append(summary.iloc[row:row + 1])
def accumulate(self):
self.cls_summary_list = []
for row in range(self.num_classes):
seqs_cls_df = pd.concat(self.seqs_overall[row])
seqs_cls_summary = seqs_overall_metrics(seqs_cls_df)
cls_summary_overall = seqs_cls_summary.iloc[-1:].copy()
cls_summary_overall.rename(
index={'overall_calc': 'overall_calc_{}'.format(row)},
inplace=True)
self.cls_summary_list.append(cls_summary_overall)
def log(self):
seqs_summary = seqs_overall_metrics(
pd.concat(self.seqs_overall[self.num_classes]), verbose=True)
class_summary = seqs_overall_metrics(
pd.concat(self.cls_summary_list), verbose=True)
def get_results(self):
return 1
......@@ -79,7 +79,7 @@ class CenterNet(BaseArch):
def get_pred(self):
head_out = self._forward()
if self.for_mot:
bbox, bbox_inds = self.post_process(
bbox, bbox_inds, topk_clses = self.post_process(
head_out['heatmap'],
head_out['size'],
head_out['offset'],
......@@ -88,10 +88,11 @@ class CenterNet(BaseArch):
output = {
"bbox": bbox,
"bbox_inds": bbox_inds,
"topk_clses": topk_clses,
"neck_feat": head_out['neck_feat']
}
else:
bbox, bbox_num = self.post_process(
bbox, bbox_num, _ = self.post_process(
head_out['heatmap'],
head_out['size'],
head_out['offset'],
......
......@@ -86,13 +86,9 @@ class FairMOT(BaseArch):
loss.update({'reid_loss': reid_loss})
return loss
else:
embedding = self.reid(neck_feat, self.inputs)
bbox_inds = det_outs['bbox_inds']
embedding = paddle.transpose(embedding, [0, 2, 3, 1])
embedding = paddle.reshape(embedding,
[-1, paddle.shape(embedding)[-1]])
pred_embs = paddle.gather(embedding, bbox_inds)
pred_dets = det_outs['bbox']
pred_dets, pred_embs = self.reid(
neck_feat, self.inputs, det_outs['bbox'], det_outs['bbox_inds'],
det_outs['topk_clses'])
return pred_dets, pred_embs
def get_pred(self):
......
......@@ -59,7 +59,7 @@ class CenterNetHead(nn.Layer):
"""
Args:
in_channels (int): the channel number of input to CenterNetHead.
num_classes (int): the number of classes, 80 by default.
num_classes (int): the number of classes, 80 (COCO dataset) by default.
head_planes (int): the channel number in all head, 256 by default.
heatmap_weight (float): the weight of heatmap loss, 1 by default.
regress_ltrb (bool): whether to regress left/top/right/bottom or
......@@ -83,6 +83,7 @@ class CenterNetHead(nn.Layer):
offset_weight=1,
iou_weight=0):
super(CenterNetHead, self).__init__()
self.regress_ltrb = regress_ltrb
self.weights = {
'heatmap': heatmap_weight,
'size': size_weight,
......@@ -196,7 +197,14 @@ class CenterNetHead(nn.Layer):
pos_num = size_mask.sum()
size_mask.stop_gradient = True
if self.size_loss == 'L1':
size_target = inputs['size']
if self.regress_ltrb:
size_target = inputs['size']
# shape: [bs, max_per_img, 4]
else:
size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :,
2:]
# shape: [bs, max_per_img, 2]
size_target.stop_gradient = True
size_loss = F.l1_loss(
pos_size * size_mask, size_target * size_mask, reduction='sum')
......
......@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np
from collections import defaultdict
from collections import deque, OrderedDict
from ..matching import jde_matching as matching
from ppdet.core.workspace import register, serializable
......@@ -40,7 +38,7 @@ class TrackState(object):
@register
@serializable
class BaseTrack(object):
_count = 0
_count_dict = defaultdict(int) # support single class and multi classes
track_id = 0
is_activated = False
......@@ -62,9 +60,23 @@ class BaseTrack(object):
return self.frame_id
@staticmethod
def next_id():
BaseTrack._count += 1
return BaseTrack._count
def next_id(cls_id):
BaseTrack._count_dict[cls_id] += 1
return BaseTrack._count_dict[cls_id]
# @even: reset track id
@staticmethod
def init_count(num_classes):
"""
Initiate _count for all object classes
:param num_classes:
"""
for cls_id in range(num_classes):
BaseTrack._count_dict[cls_id] = 0
@staticmethod
def reset_track_count(cls_id):
BaseTrack._count_dict[cls_id] = 0
def activate(self, *args):
raise NotImplementedError
......@@ -85,7 +97,15 @@ class BaseTrack(object):
@register
@serializable
class STrack(BaseTrack):
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
def __init__(self,
tlwh,
score,
temp_feat,
num_classes,
cls_id,
buff_size=30):
# object class id
self.cls_id = cls_id
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float)
self.kalman_filter = None
......@@ -93,20 +113,21 @@ class STrack(BaseTrack):
self.is_activated = False
self.score = score
self.tracklet_len = 0
self.track_len = 0
self.smooth_feat = None
self.update_features(temp_feat)
self.features = deque([], maxlen=buffer_size)
self.features = deque([], maxlen=buff_size)
self.alpha = 0.9
def update_features(self, feat):
# L2 normalizing
feat /= np.linalg.norm(feat)
self.curr_feat = feat
if self.smooth_feat is None:
self.smooth_feat = feat
else:
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha
self.smooth_feat = self.alpha * self.smooth_feat + (1.0 - self.alpha
) * feat
self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
......@@ -119,54 +140,60 @@ class STrack(BaseTrack):
self.covariance)
@staticmethod
def multi_predict(stracks, kalman_filter):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
for i, st in enumerate(stracks):
def multi_predict(tracks, kalman_filter):
if len(tracks) > 0:
multi_mean = np.asarray([track.mean.copy() for track in tracks])
multi_covariance = np.asarray(
[track.covariance for track in tracks])
for i, st in enumerate(tracks):
if st.state != TrackState.Tracked:
multi_mean[i][7] = 0
multi_mean, multi_covariance = kalman_filter.multi_predict(
multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean
stracks[i].covariance = cov
tracks[i].mean = mean
tracks[i].covariance = cov
def reset_track_id(self):
self.reset_track_count(self.cls_id)
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
"""Start a new track"""
self.kalman_filter = kalman_filter
self.track_id = self.next_id()
# update track id for the object class
self.track_id = self.next_id(self.cls_id)
self.mean, self.covariance = self.kalman_filter.initiate(
self.tlwh_to_xyah(self._tlwh))
self.tracklet_len = 0
self.state = TrackState.Tracked
if frame_id == 1:
self.track_len = 0
self.state = TrackState.Tracked # set flag 'tracked'
if frame_id == 1: # to record the first frame's detection result
self.is_activated = True
self.frame_id = frame_id
self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
self.update_features(new_track.curr_feat)
self.tracklet_len = 0
self.track_len = 0
self.state = TrackState.Tracked
self.is_activated = True
self.frame_id = frame_id
if new_id:
self.track_id = self.next_id()
if new_id: # update track id for the object class
self.track_id = self.next_id(self.cls_id)
def update(self, new_track, frame_id, update_feature=True):
self.frame_id = frame_id
self.tracklet_len += 1
self.track_len += 1
new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
self.state = TrackState.Tracked
self.is_activated = True
self.state = TrackState.Tracked # set flag 'tracked'
self.is_activated = True # set flag 'activated'
self.score = new_track.score
if update_feature:
......@@ -174,12 +201,12 @@ class STrack(BaseTrack):
@property
def tlwh(self):
"""
Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
......@@ -187,8 +214,7 @@ class STrack(BaseTrack):
@property
def tlbr(self):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
......@@ -197,8 +223,7 @@ class STrack(BaseTrack):
@staticmethod
def tlwh_to_xyah(tlwh):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
"""Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = np.asarray(tlwh).copy()
......@@ -222,8 +247,8 @@ class STrack(BaseTrack):
return ret
def __repr__(self):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame,
self.end_frame)
return 'OT_({}-{})_({}-{})'.format(self.cls_id, self.track_id,
self.start_frame, self.end_frame)
def joint_stracks(tlista, tlistb):
......
......@@ -17,10 +17,13 @@ import cv2
import time
import paddle
import numpy as np
from .visualization import plot_tracking_dict
__all__ = [
'Timer',
'MOTTimer',
'Detection',
'write_mot_results',
'save_vis_results',
'load_det_results',
'preprocess_reid',
'get_crops',
......@@ -29,7 +32,7 @@ __all__ = [
]
class Timer(object):
class MOTTimer(object):
"""
This class used to compute and print the current FPS while evaling.
"""
......@@ -106,6 +109,68 @@ class Detection(object):
return ret
def write_mot_results(filename, results, data_type='mot', num_classes=1):
# support single and multi classes
if data_type in ['mot', 'mcmot']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},{cls_id},-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
f = open(filename, 'w')
for cls_id in range(num_classes):
for frame_id, tlwhs, tscores, track_ids in results[cls_id]:
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0: continue
if data_type == 'kitti':
frame_id -= 1
elif data_type == 'mot':
cls_id = -1
elif data_type == 'mcmot':
cls_id = cls_id
x1, y1, w, h = tlwh
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
w=w,
h=h,
score=score,
cls_id=cls_id)
f.write(line)
print('MOT results save in {}'.format(filename))
def save_vis_results(data,
frame_id,
online_ids,
online_tlwhs,
online_scores,
average_time,
show_image,
save_dir,
num_classes=1):
if show_image or save_dir is not None:
assert 'ori_image' in data
img0 = data['ori_image'].numpy()[0]
online_im = plot_tracking_dict(
img0,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / average_time)
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
def load_det_results(det_file, num_frames):
assert os.path.exists(det_file) and os.path.isfile(det_file), \
'{} is not exist or not a file.'.format(det_file)
......
......@@ -16,28 +16,12 @@ import cv2
import numpy as np
def tlwhs_to_tlbrs(tlwhs):
tlbrs = np.copy(tlwhs)
if len(tlbrs) == 0:
return tlbrs
tlbrs[:, 2] += tlwhs[:, 0]
tlbrs[:, 3] += tlwhs[:, 1]
return tlbrs
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def resize_image(image, max_size=800):
if max(image.shape[:2]) > max_size:
scale = float(max_size) / max(image.shape[:2])
image = cv2.resize(image, None, fx=scale, fy=scale)
return image
def plot_tracking(image,
tlwhs,
obj_ids,
......@@ -92,44 +76,67 @@ def plot_tracking(image,
return im
def plot_trajectory(image, tlwhs, track_ids):
image = image.copy()
for one_tlwhs, track_id in zip(tlwhs, track_ids):
color = get_color(int(track_id))
for tlwh in one_tlwhs:
x1, y1, w, h = tuple(map(int, tlwh))
cv2.circle(
image, (int(x1 + 0.5 * w), int(y1 + h)), 2, color, thickness=2)
return image
def plot_detections(image, tlbrs, scores=None, color=(255, 0, 0), ids=None):
im = np.copy(image)
text_scale = max(1, image.shape[1] / 800.)
thickness = 2 if text_scale > 1.3 else 1
for i, det in enumerate(tlbrs):
x1, y1, x2, y2 = np.asarray(det[:4], dtype=np.int)
if len(det) >= 7:
label = 'det' if det[5] > 0 else 'trk'
if ids is not None:
text = '{}# {:.2f}: {:d}'.format(label, det[6], ids[i])
cv2.putText(
im,
text, (x1, y1 + 30),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=thickness)
def plot_tracking_dict(image,
num_classes,
tlwhs_dict,
obj_ids_dict,
scores_dict,
frame_id=0,
fps=0.,
ids2=None):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
radius = max(5, int(im_w / 140.))
for cls_id in range(num_classes):
tlwhs = tlwhs_dict[cls_id]
obj_ids = obj_ids_dict[cls_id]
scores = scores_dict[cls_id]
cv2.putText(
im,
'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
(0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
if num_classes == 1:
id_text = '{}'.format(int(obj_id))
else:
text = '{}# {:.2f}'.format(label, det[6])
id_text = 'class{}_id{}'.format(cls_id, int(obj_id))
if scores is not None:
text = '{:.2f}'.format(scores[i])
_line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id))
cv2.rectangle(
im,
intbox[0:2],
intbox[2:4],
color=color,
thickness=line_thickness)
cv2.putText(
im,
text, (x1, y1 + 30),
id_text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=thickness)
text_scale, (0, 0, 255),
thickness=text_thickness)
cv2.rectangle(im, (x1, y1), (x2, y2), color, 2)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im
......@@ -415,7 +415,6 @@ class CenterNetPostProcess(TTFBox):
regress_ltrb (bool): whether to regress left/top/right/bottom or
width/height for a box, true by default.
for_mot (bool): whether return other features used in tracking model.
"""
__shared__ = ['down_ratio', 'for_mot']
......@@ -433,9 +432,9 @@ class CenterNetPostProcess(TTFBox):
def __call__(self, hm, wh, reg, im_shape, scale_factor):
heat = self._simple_nms(hm)
scores, inds, clses, ys, xs = self._topk(heat)
scores, inds, topk_clses, ys, xs = self._topk(heat)
scores = paddle.tensor.unsqueeze(scores, [1])
clses = paddle.tensor.unsqueeze(clses, [1])
clses = paddle.tensor.unsqueeze(topk_clses, [1])
reg_t = paddle.transpose(reg, [0, 2, 3, 1])
# Like TTFBox, batch size is 1.
......@@ -486,10 +485,10 @@ class CenterNetPostProcess(TTFBox):
bboxes = paddle.divide(bboxes, scale_expand)
if self.for_mot:
results = paddle.concat([bboxes, scores, clses], axis=1)
return results, inds
return results, inds, topk_clses
else:
results = paddle.concat([clses, scores, bboxes], axis=1)
return results, paddle.shape(results)[0:1]
return results, paddle.shape(results)[0:1], topk_clses
@register
......
......@@ -26,21 +26,27 @@ __all__ = ['FairMOTEmbeddingHead']
@register
class FairMOTEmbeddingHead(nn.Layer):
__shared__ = ['num_classes']
"""
Args:
in_channels (int): the channel number of input to FairMOTEmbeddingHead.
ch_head (int): the channel of features before fed into embedding, 256 by default.
ch_emb (int): the channel of the embedding feature, 128 by default.
num_identifiers (int): the number of identifiers, 14455 by default.
num_identities_dict (dict): the number of identities of each category,
support single class and multi-calss, {0: 14455} as default.
"""
def __init__(self,
in_channels,
ch_head=256,
ch_emb=128,
num_identifiers=14455):
num_classes=1,
num_identities_dict={0: 14455}):
super(FairMOTEmbeddingHead, self).__init__()
assert num_classes >= 1
self.num_classes = num_classes
self.ch_emb = ch_emb
self.num_identities_dict = num_identities_dict
self.reid = nn.Sequential(
ConvLayer(
in_channels, ch_head, kernel_size=3, padding=1, bias=True),
......@@ -50,15 +56,27 @@ class FairMOTEmbeddingHead(nn.Layer):
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
bound = 1 / math.sqrt(ch_emb)
bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
self.classifier = nn.Linear(
ch_emb,
num_identifiers,
weight_attr=param_attr,
bias_attr=bias_attr)
self.reid_loss = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum')
# When num_identifiers is 1, emb_scale is set as 1
self.emb_scale = math.sqrt(2) * math.log(
num_identifiers - 1) if num_identifiers > 1 else 1
if num_classes == 1:
nID = self.num_identities_dict[0] # single class
self.classifier = nn.Linear(
ch_emb,
nID,
weight_attr=param_attr,
bias_attr=bias_attr)
# When num_identities(nID) is 1, emb_scale is set as 1
self.emb_scale = math.sqrt(2) * math.log(
nID - 1) if nID > 1 else 1
else:
self.classifiers = dict()
self.emb_scale_dict = dict()
for cls_id, nID in self.num_identities_dict.items():
self.classifiers[str(cls_id)] = nn.Linear(
ch_emb, nID, weight_attr=param_attr, bias_attr=bias_attr)
# When num_identities(nID) is 1, emb_scale is set as 1
self.emb_scale_dict[str(cls_id)] = math.sqrt(2) * math.log(
nID - 1) if nID > 1 else 1
@classmethod
def from_config(cls, cfg, input_shape):
......@@ -66,14 +84,56 @@ class FairMOTEmbeddingHead(nn.Layer):
input_shape = input_shape[0]
return {'in_channels': input_shape.channels}
def forward(self, feat, inputs):
def process_by_class(self, det_outs, embedding, bbox_inds, topk_clses):
pred_dets, pred_embs = [], []
for cls_id in range(self.num_classes):
inds_masks = topk_clses == cls_id
inds_masks = paddle.cast(inds_masks, 'float32')
pos_num = inds_masks.sum().numpy()
if pos_num == 0:
continue
cls_inds_mask = inds_masks > 0
bbox_mask = paddle.nonzero(cls_inds_mask)
cls_det_outs = paddle.gather_nd(det_outs, bbox_mask)
pred_dets.append(cls_det_outs)
cls_inds = paddle.masked_select(bbox_inds, cls_inds_mask)
cls_inds = cls_inds.unsqueeze(-1)
cls_embedding = paddle.gather_nd(embedding, cls_inds)
pred_embs.append(cls_embedding)
return paddle.concat(pred_dets), paddle.concat(pred_embs)
def forward(self,
feat,
inputs,
det_outs=None,
bbox_inds=None,
topk_clses=None):
reid_feat = self.reid(feat)
if self.training:
loss = self.get_loss(reid_feat, inputs)
if self.num_classes == 1:
loss = self.get_loss(reid_feat, inputs)
else:
loss = self.get_mc_loss(reid_feat, inputs)
return loss
else:
assert det_outs is not None and bbox_inds is not None
reid_feat = F.normalize(reid_feat)
return reid_feat
embedding = paddle.transpose(reid_feat, [0, 2, 3, 1])
embedding = paddle.reshape(embedding, [-1, self.ch_emb])
# embedding shape: [bs * h * w, ch_emb]
if self.num_classes == 1:
pred_dets = det_outs
pred_embs = paddle.gather(embedding, bbox_inds)
else:
pred_dets, pred_embs = self.process_by_class(
det_outs, embedding, bbox_inds, topk_clses)
return pred_dets, pred_embs
def get_loss(self, feat, inputs):
index = inputs['index']
......@@ -113,3 +173,56 @@ class FairMOTEmbeddingHead(nn.Layer):
loss = loss / count
return loss
def get_mc_loss(self, feat, inputs):
# feat.shape = [bs, ch_emb, h, w]
assert 'cls_id_map' in inputs and 'cls_tr_ids' in inputs
index = inputs['index']
mask = inputs['index_mask']
cls_id_map = inputs['cls_id_map'] # [bs, h, w]
cls_tr_ids = inputs['cls_tr_ids'] # [bs, num_classes, h, w]
feat = paddle.transpose(feat, perm=[0, 2, 3, 1])
feat_n, feat_h, feat_w, feat_c = feat.shape
feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
for i in range(feat_n):
batch_ind = paddle.full(
shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
batch_inds.append(batch_ind)
batch_inds = paddle.concat(batch_inds, axis=0)
index = paddle.concat(x=[batch_inds, index], axis=2)
feat = paddle.gather_nd(feat, index=index)
mask = paddle.unsqueeze(mask, axis=2)
mask = paddle.expand_as(mask, feat)
mask.stop_gradient = True
feat = paddle.masked_select(feat, mask > 0)
feat = paddle.reshape(feat, shape=[-1, feat_c])
reid_losses = 0
for cls_id, id_num in self.num_identities_dict.items():
# target
cur_cls_tr_ids = paddle.reshape(
cls_tr_ids[:, cls_id, :, :], shape=[feat_n, -1]) # [bs, h*w]
cls_id_target = paddle.gather_nd(cur_cls_tr_ids, index=index)
mask = inputs['index_mask']
cls_id_target = paddle.masked_select(cls_id_target, mask > 0)
cls_id_target.stop_gradient = True
# feat
cls_id_feat = self.emb_scale_dict[str(cls_id)] * F.normalize(feat)
cls_id_pred = self.classifiers[str(cls_id)](cls_id_feat)
loss = self.reid_loss(cls_id_pred, cls_id_target)
valid = (cls_id_target != self.reid_loss.ignore_index)
valid.stop_gradient = True
count = paddle.sum((paddle.cast(valid, dtype=np.int32)))
count.stop_gradient = True
if count > 0:
loss = loss / count
reid_losses += loss
return reid_losses
......@@ -49,7 +49,7 @@ class JDEEmbeddingHead(nn.Layer):
JDEEmbeddingHead
Args:
num_classes(int): Number of classes. Only support one class tracking.
num_identifiers(int): Number of identifiers.
num_identities(int): Number of identities.
anchor_levels(int): Number of anchor levels, same as FPN levels.
anchor_scales(int): Number of anchor scales on each FPN level.
embedding_dim(int): Embedding dimension. Default: 512.
......@@ -60,7 +60,7 @@ class JDEEmbeddingHead(nn.Layer):
def __init__(
self,
num_classes=1,
num_identifiers=14455, # defined by dataset.total_identities when training
num_identities=14455, # dataset.num_identities_dict[0]
anchor_levels=3,
anchor_scales=4,
embedding_dim=512,
......@@ -68,7 +68,7 @@ class JDEEmbeddingHead(nn.Layer):
jde_loss='JDELoss'):
super(JDEEmbeddingHead, self).__init__()
self.num_classes = num_classes
self.num_identifiers = num_identifiers
self.num_identities = num_identities
self.anchor_levels = anchor_levels
self.anchor_scales = anchor_scales
self.embedding_dim = embedding_dim
......@@ -76,7 +76,7 @@ class JDEEmbeddingHead(nn.Layer):
self.jde_loss = jde_loss
self.emb_scale = math.sqrt(2) * math.log(
self.num_identifiers - 1) if self.num_identifiers > 1 else 1
self.num_identities - 1) if self.num_identities > 1 else 1
self.identify_outputs = []
self.loss_params_cls = []
......@@ -106,7 +106,7 @@ class JDEEmbeddingHead(nn.Layer):
'classifier',
nn.Linear(
self.embedding_dim,
self.num_identifiers,
self.num_identities,
weight_attr=ParamAttr(
learning_rate=1., initializer=Normal(
mean=0.0, std=0.01)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册